Skip to content

Commit 0b4114a

Browse files
Reverts d4c44d1
PiperOrigin-RevId: 682079618
1 parent 4a29233 commit 0b4114a

File tree

7 files changed

+128
-12
lines changed

7 files changed

+128
-12
lines changed

tensorflow/lite/acceleration/configuration/configuration.proto

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,9 @@ enum XNNPackFlags {
330330

331331
message XNNPackSettings {
332332
optional int32 num_threads = 1;
333-
optional XNNPackFlags flags = 2 [default = TFLITE_XNNPACK_DELEGATE_NO_FLAGS];
333+
// If flags is unset or zero, it means use the default XNNPack delegate flags.
334+
// Any other value means use exactly (and only) the flags specified.
335+
optional XNNPackFlags flags = 2;
334336
// Path to the XNNPack cache file. XNNPack packed buffers are saved to and
335337
// reloaded from this cache which can reduce initialization time and the
336338
// packing memory footprint.
@@ -1129,4 +1131,4 @@ message BenchmarkEventStorage {
11291131
optional BenchmarkEvent benchmark_event = 2;
11301132
}
11311133

1132-
// LINT.ThenChange(//tensorflow/lite/acceleration/configuration/testdata/configuration.proto_prev)
1134+
// LINT.ThenChange(//tensorflow/lite/acceleration/configuration/testdata/configuration.proto_prev:all)

tensorflow/lite/acceleration/configuration/testdata/configuration.proto_prev

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,20 @@ enum XNNPackFlags {
310310
TFLITE_XNNPACK_DELEGATE_FLAG_QS8_QU8 = 3;
311311
// Force 16-bit floating point inference.
312312
TFLITE_XNNPACK_DELEGATE_FLAG_FORCE_FP16 = 4;
313+
// Enable XNNPACK acceleration for FULLY_CONNECTED operator with dynamic
314+
// weights.
315+
TFLITE_XNNPACK_DELEGATE_FLAG_DYNAMIC_FULLY_CONNECTED = 8;
316+
// Enable XNNPACK acceleration for VAR_HANDLE, READ_VARIABLE, and
317+
// ASSIGN_VARIABLE operators.
318+
TFLITE_XNNPACK_DELEGATE_FLAG_VARIABLE_OPERATORS = 16;
319+
// Enable transient indirection buffer to reduce memory usage in selected
320+
// operators.
321+
TFLITE_XNNPACK_DELEGATE_FLAG_TRANSIENT_INDIRECTION_BUFFER = 32;
322+
// Enable the latest XNNPACK operators and features in the delegate which have
323+
// not yet been enabled by default.
324+
TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS = 64;
325+
// Enable XNNPack subgraph reshaping.
326+
TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_SUBGRAPH_RESHAPING = 128;
313327
}
314328

315329
message XNNPackSettings {

tensorflow/lite/core/acceleration/configuration/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ cc_library(
108108
"//tensorflow/lite/acceleration/configuration:configuration_fbs",
109109
"//tensorflow/lite/c:c_api_types",
110110
"//tensorflow/lite/core/acceleration/configuration:delegate_registry",
111+
"//tensorflow/lite/core/c:common",
111112
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
112113
"@com_google_absl//absl/base:log_severity",
113114
"@com_google_absl//absl/memory",

tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ limitations under the License.
1717

1818
#include "tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.h"
1919

20-
#include <memory>
21-
2220
#include "tensorflow/lite/acceleration/configuration/configuration_generated.h"
2321
#include "tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h"
2422
#include "tensorflow/lite/core/c/common.h"
@@ -30,6 +28,8 @@ static TfLiteDelegate* CreateDelegate(const void* settings) {
3028
const ::tflite::TFLiteSettings* tflite_settings =
3129
static_cast<const ::tflite::TFLiteSettings*>(settings);
3230
auto options(TfLiteXNNPackDelegateOptionsDefault());
31+
// The following code block is duplicated in the C++ XNNPack delegate plugin.
32+
// LINT.IfChange(tflite_settings_to_xnnpack_delegate_options)
3333
const auto* xnnpack_settings = tflite_settings->xnnpack_settings();
3434
if (xnnpack_settings) {
3535
options.num_threads = xnnpack_settings->num_threads();
@@ -45,6 +45,7 @@ static TfLiteDelegate* CreateDelegate(const void* settings) {
4545
xnnpack_settings->weight_cache_file_path()->c_str();
4646
}
4747
}
48+
// LINT.ThenChange(../xnnpack_plugin.cc:tflite_settings_to_xnnpack_delegate_options)
4849
return TfLiteXNNPackDelegateCreate(&options);
4950
}
5051

tensorflow/lite/core/acceleration/configuration/xnnpack_plugin.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
#include "tensorflow/lite/acceleration/configuration/configuration_generated.h"
1818
#include "tensorflow/lite/c/c_api_types.h"
1919
#include "tensorflow/lite/core/acceleration/configuration/delegate_registry.h"
20+
#include "tensorflow/lite/core/c/common.h"
2021
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
2122

2223
namespace tflite {
@@ -34,11 +35,23 @@ class XNNPackPlugin : public DelegatePluginInterface {
3435
}
3536
explicit XNNPackPlugin(const TFLiteSettings& tflite_settings)
3637
: options_(TfLiteXNNPackDelegateOptionsDefault()) {
38+
// LINT.IfChange(tflite_settings_to_xnnpack_delegate_options)
3739
const auto* xnnpack_settings = tflite_settings.xnnpack_settings();
3840
if (xnnpack_settings) {
3941
options_.num_threads = xnnpack_settings->num_threads();
40-
options_.flags = xnnpack_settings->flags();
42+
// If xnnpack_settings->flags is zero, then leave options.flags
43+
// unmodified, i.e. use the default flags (not zero).
44+
// If xnnpack_settings->flags is nonzero, then use exactly
45+
// those flags (i.e. discard the default flags).
46+
if (xnnpack_settings->flags()) {
47+
options_.flags = xnnpack_settings->flags();
48+
}
49+
if (xnnpack_settings->weight_cache_file_path()) {
50+
options_.weight_cache_file_path =
51+
xnnpack_settings->weight_cache_file_path()->c_str();
52+
}
4153
}
54+
// LINT.ThenChange(c/xnnpack_plugin.cc:tflite_settings_to_xnnpack_delegate_options)
4255
}
4356

4457
private:

tensorflow/lite/core/acceleration/configuration/xnnpack_plugin_test.cc

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ limitations under the License.
1515

1616
// Some very simple unit tests of the (C++) XNNPack Delegate Plugin.
1717

18+
#include <memory>
19+
1820
#include <gtest/gtest.h>
1921
#include "flatbuffers/buffer.h" // from @flatbuffers
2022
#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers
@@ -28,20 +30,16 @@ namespace tflite {
2830
class XnnpackPluginTest : public testing::Test {
2931
public:
3032
static constexpr int kNumThreadsForTest = 7;
31-
static constexpr tflite::XNNPackFlags kFlagsForTest =
32-
tflite::XNNPackFlags::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_QS8_QU8;
3333
void SetUp() override {
3434
// Construct a FlatBuffer that contains
3535
// TFLiteSettings {
3636
// delegate: Delegate.XNNPACK,
37-
// XNNPackSettings { num_threads: kNumThreadsForTest
38-
// flags: TFLITE_XNNPACK_DELEGATE_FLAG_QS8 |
39-
// TFLITE_XNNPACK_DELEGATE_FLAG_QU8
37+
// XNNPackSettings {
38+
// num_threads: kNumThreadsForTest
4039
// }
4140
// }.
4241
XNNPackSettingsBuilder xnnpack_settings_builder(flatbuffer_builder_);
4342
xnnpack_settings_builder.add_num_threads(kNumThreadsForTest);
44-
xnnpack_settings_builder.add_flags(kFlagsForTest);
4543
flatbuffers::Offset<XNNPackSettings> xnnpack_settings =
4644
xnnpack_settings_builder.Finish();
4745
TFLiteSettingsBuilder tflite_settings_builder(flatbuffer_builder_);
@@ -58,7 +56,7 @@ class XnnpackPluginTest : public testing::Test {
5856
ASSERT_NE(delegate_plugin_, nullptr);
5957
}
6058
void TearDown() override { delegate_plugin_.reset(); }
61-
~XnnpackPluginTest() override {}
59+
~XnnpackPluginTest() override = default;
6260

6361
protected:
6462
// settings_ points into storage owned by flatbuffer_builder_.
@@ -88,4 +86,88 @@ TEST_F(XnnpackPluginTest, SetsCorrectThreadCount) {
8886
EXPECT_EQ(thread_count, kNumThreadsForTest);
8987
}
9088

89+
TEST_F(XnnpackPluginTest, UsesDefaultFlagsByDefault) {
90+
delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create();
91+
int flags = TfLiteXNNPackDelegateGetFlags(delegate.get());
92+
EXPECT_EQ(flags, TfLiteXNNPackDelegateOptionsDefault().flags);
93+
}
94+
95+
TEST_F(XnnpackPluginTest, UsesSpecifiedFlagsWhenNonzero) {
96+
XNNPackSettingsBuilder xnnpack_settings_builder(flatbuffer_builder_);
97+
xnnpack_settings_builder.add_flags(
98+
tflite::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_QS8);
99+
flatbuffers::Offset<XNNPackSettings> xnnpack_settings =
100+
xnnpack_settings_builder.Finish();
101+
TFLiteSettingsBuilder tflite_settings_builder(flatbuffer_builder_);
102+
tflite_settings_builder.add_xnnpack_settings(xnnpack_settings);
103+
flatbuffers::Offset<TFLiteSettings> tflite_settings =
104+
tflite_settings_builder.Finish();
105+
flatbuffer_builder_.Finish(tflite_settings);
106+
tflite_settings_ = flatbuffers::GetRoot<TFLiteSettings>(
107+
flatbuffer_builder_.GetBufferPointer());
108+
delegate_plugin_ = delegates::DelegatePluginRegistry::CreateByName(
109+
"XNNPackPlugin", *tflite_settings_);
110+
111+
delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create();
112+
int flags = TfLiteXNNPackDelegateGetFlags(delegate.get());
113+
EXPECT_EQ(flags, tflite::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_QS8);
114+
}
115+
116+
// Settings flags to XNNPackFlags_TFLITE_XNNPACK_DELEGATE_NO_FLAGS (zero)
117+
// causes flags to be set to their default values, not zero.
118+
// This is potentially confusing behaviour, but we can't distinguish
119+
// the case when flags isn't set from the case when flags is set to zero.
120+
TEST_F(XnnpackPluginTest, UsesDefaultFlagsWhenZero) {
121+
XNNPackSettingsBuilder xnnpack_settings_builder(flatbuffer_builder_);
122+
xnnpack_settings_builder.add_flags(
123+
tflite::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_NO_FLAGS);
124+
flatbuffers::Offset<XNNPackSettings> xnnpack_settings =
125+
xnnpack_settings_builder.Finish();
126+
TFLiteSettingsBuilder tflite_settings_builder(flatbuffer_builder_);
127+
tflite_settings_builder.add_xnnpack_settings(xnnpack_settings);
128+
flatbuffers::Offset<TFLiteSettings> tflite_settings =
129+
tflite_settings_builder.Finish();
130+
flatbuffer_builder_.Finish(tflite_settings);
131+
tflite_settings_ = flatbuffers::GetRoot<TFLiteSettings>(
132+
flatbuffer_builder_.GetBufferPointer());
133+
delegate_plugin_ = delegates::DelegatePluginRegistry::CreateByName(
134+
"XNNPackPlugin", *tflite_settings_);
135+
136+
delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create();
137+
int flags = TfLiteXNNPackDelegateGetFlags(delegate.get());
138+
EXPECT_EQ(flags, TfLiteXNNPackDelegateOptionsDefault().flags);
139+
}
140+
141+
TEST_F(XnnpackPluginTest, DoesNotSetWeightCacheFilePathByDefault) {
142+
delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create();
143+
const TfLiteXNNPackDelegateOptions *options =
144+
TfLiteXNNPackDelegateGetOptions(delegate.get());
145+
EXPECT_EQ(options->weight_cache_file_path, nullptr);
146+
}
147+
148+
TEST_F(XnnpackPluginTest, HonoursWeightCacheFilePathSetting) {
149+
const char *const kWeightCachePath = "/tmp/wcfp";
150+
const auto weight_cache_file_path_string =
151+
flatbuffer_builder_.CreateString(kWeightCachePath);
152+
XNNPackSettingsBuilder xnnpack_settings_builder(flatbuffer_builder_);
153+
xnnpack_settings_builder.add_weight_cache_file_path(
154+
weight_cache_file_path_string);
155+
flatbuffers::Offset<XNNPackSettings> xnnpack_settings =
156+
xnnpack_settings_builder.Finish();
157+
TFLiteSettingsBuilder tflite_settings_builder(flatbuffer_builder_);
158+
tflite_settings_builder.add_xnnpack_settings(xnnpack_settings);
159+
flatbuffers::Offset<TFLiteSettings> tflite_settings =
160+
tflite_settings_builder.Finish();
161+
flatbuffer_builder_.Finish(tflite_settings);
162+
tflite_settings_ = flatbuffers::GetRoot<TFLiteSettings>(
163+
flatbuffer_builder_.GetBufferPointer());
164+
delegate_plugin_ = delegates::DelegatePluginRegistry::CreateByName(
165+
"XNNPackPlugin", *tflite_settings_);
166+
167+
delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create();
168+
const TfLiteXNNPackDelegateOptions *options =
169+
TfLiteXNNPackDelegateGetOptions(delegate.get());
170+
EXPECT_STREQ(options->weight_cache_file_path, kWeightCachePath);
171+
}
172+
91173
} // namespace tflite

tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ limitations under the License.
1616
#ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_XNNPACK_DELEGATE_H_
1717
#define TENSORFLOW_LITE_DELEGATES_XNNPACK_XNNPACK_DELEGATE_H_
1818

19+
#include <stddef.h>
20+
#include <stdint.h>
21+
1922
#include "tensorflow/lite/core/c/common.h"
2023

2124
#ifdef __cplusplus

0 commit comments

Comments
 (0)