@@ -15,6 +15,8 @@ limitations under the License.
15
15
16
16
// Some very simple unit tests of the (C++) XNNPack Delegate Plugin.
17
17
18
+ #include < memory>
19
+
18
20
#include < gtest/gtest.h>
19
21
#include " flatbuffers/buffer.h" // from @flatbuffers
20
22
#include " flatbuffers/flatbuffer_builder.h" // from @flatbuffers
@@ -28,20 +30,16 @@ namespace tflite {
28
30
class XnnpackPluginTest : public testing ::Test {
29
31
public:
30
32
static constexpr int kNumThreadsForTest = 7 ;
31
- static constexpr tflite::XNNPackFlags kFlagsForTest =
32
- tflite::XNNPackFlags::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_QS8_QU8;
33
33
void SetUp () override {
34
34
// Construct a FlatBuffer that contains
35
35
// TFLiteSettings {
36
36
// delegate: Delegate.XNNPACK,
37
- // XNNPackSettings { num_threads: kNumThreadsForTest
38
- // flags: TFLITE_XNNPACK_DELEGATE_FLAG_QS8 |
39
- // TFLITE_XNNPACK_DELEGATE_FLAG_QU8
37
+ // XNNPackSettings {
38
+ // num_threads: kNumThreadsForTest
40
39
// }
41
40
// }.
42
41
XNNPackSettingsBuilder xnnpack_settings_builder (flatbuffer_builder_);
43
42
xnnpack_settings_builder.add_num_threads (kNumThreadsForTest );
44
- xnnpack_settings_builder.add_flags (kFlagsForTest );
45
43
flatbuffers::Offset<XNNPackSettings> xnnpack_settings =
46
44
xnnpack_settings_builder.Finish ();
47
45
TFLiteSettingsBuilder tflite_settings_builder (flatbuffer_builder_);
@@ -58,7 +56,7 @@ class XnnpackPluginTest : public testing::Test {
58
56
ASSERT_NE (delegate_plugin_, nullptr );
59
57
}
60
58
void TearDown () override { delegate_plugin_.reset (); }
61
- ~XnnpackPluginTest () override {}
59
+ ~XnnpackPluginTest () override = default ;
62
60
63
61
protected:
64
62
// settings_ points into storage owned by flatbuffer_builder_.
@@ -88,4 +86,88 @@ TEST_F(XnnpackPluginTest, SetsCorrectThreadCount) {
88
86
EXPECT_EQ (thread_count, kNumThreadsForTest );
89
87
}
90
88
89
+ TEST_F (XnnpackPluginTest, UsesDefaultFlagsByDefault) {
90
+ delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create ();
91
+ int flags = TfLiteXNNPackDelegateGetFlags (delegate.get ());
92
+ EXPECT_EQ (flags, TfLiteXNNPackDelegateOptionsDefault ().flags );
93
+ }
94
+
95
+ TEST_F (XnnpackPluginTest, UsesSpecifiedFlagsWhenNonzero) {
96
+ XNNPackSettingsBuilder xnnpack_settings_builder (flatbuffer_builder_);
97
+ xnnpack_settings_builder.add_flags (
98
+ tflite::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_QS8);
99
+ flatbuffers::Offset<XNNPackSettings> xnnpack_settings =
100
+ xnnpack_settings_builder.Finish ();
101
+ TFLiteSettingsBuilder tflite_settings_builder (flatbuffer_builder_);
102
+ tflite_settings_builder.add_xnnpack_settings (xnnpack_settings);
103
+ flatbuffers::Offset<TFLiteSettings> tflite_settings =
104
+ tflite_settings_builder.Finish ();
105
+ flatbuffer_builder_.Finish (tflite_settings);
106
+ tflite_settings_ = flatbuffers::GetRoot<TFLiteSettings>(
107
+ flatbuffer_builder_.GetBufferPointer ());
108
+ delegate_plugin_ = delegates::DelegatePluginRegistry::CreateByName (
109
+ " XNNPackPlugin" , *tflite_settings_);
110
+
111
+ delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create ();
112
+ int flags = TfLiteXNNPackDelegateGetFlags (delegate.get ());
113
+ EXPECT_EQ (flags, tflite::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_QS8);
114
+ }
115
+
116
+ // Settings flags to XNNPackFlags_TFLITE_XNNPACK_DELEGATE_NO_FLAGS (zero)
117
+ // causes flags to be set to their default values, not zero.
118
+ // This is potentially confusing behaviour, but we can't distinguish
119
+ // the case when flags isn't set from the case when flags is set to zero.
120
+ TEST_F (XnnpackPluginTest, UsesDefaultFlagsWhenZero) {
121
+ XNNPackSettingsBuilder xnnpack_settings_builder (flatbuffer_builder_);
122
+ xnnpack_settings_builder.add_flags (
123
+ tflite::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_NO_FLAGS);
124
+ flatbuffers::Offset<XNNPackSettings> xnnpack_settings =
125
+ xnnpack_settings_builder.Finish ();
126
+ TFLiteSettingsBuilder tflite_settings_builder (flatbuffer_builder_);
127
+ tflite_settings_builder.add_xnnpack_settings (xnnpack_settings);
128
+ flatbuffers::Offset<TFLiteSettings> tflite_settings =
129
+ tflite_settings_builder.Finish ();
130
+ flatbuffer_builder_.Finish (tflite_settings);
131
+ tflite_settings_ = flatbuffers::GetRoot<TFLiteSettings>(
132
+ flatbuffer_builder_.GetBufferPointer ());
133
+ delegate_plugin_ = delegates::DelegatePluginRegistry::CreateByName (
134
+ " XNNPackPlugin" , *tflite_settings_);
135
+
136
+ delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create ();
137
+ int flags = TfLiteXNNPackDelegateGetFlags (delegate.get ());
138
+ EXPECT_EQ (flags, TfLiteXNNPackDelegateOptionsDefault ().flags );
139
+ }
140
+
141
+ TEST_F (XnnpackPluginTest, DoesNotSetWeightCacheFilePathByDefault) {
142
+ delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create ();
143
+ const TfLiteXNNPackDelegateOptions *options =
144
+ TfLiteXNNPackDelegateGetOptions (delegate.get ());
145
+ EXPECT_EQ (options->weight_cache_file_path , nullptr );
146
+ }
147
+
148
+ TEST_F (XnnpackPluginTest, HonoursWeightCacheFilePathSetting) {
149
+ const char *const kWeightCachePath = " /tmp/wcfp" ;
150
+ const auto weight_cache_file_path_string =
151
+ flatbuffer_builder_.CreateString (kWeightCachePath );
152
+ XNNPackSettingsBuilder xnnpack_settings_builder (flatbuffer_builder_);
153
+ xnnpack_settings_builder.add_weight_cache_file_path (
154
+ weight_cache_file_path_string);
155
+ flatbuffers::Offset<XNNPackSettings> xnnpack_settings =
156
+ xnnpack_settings_builder.Finish ();
157
+ TFLiteSettingsBuilder tflite_settings_builder (flatbuffer_builder_);
158
+ tflite_settings_builder.add_xnnpack_settings (xnnpack_settings);
159
+ flatbuffers::Offset<TFLiteSettings> tflite_settings =
160
+ tflite_settings_builder.Finish ();
161
+ flatbuffer_builder_.Finish (tflite_settings);
162
+ tflite_settings_ = flatbuffers::GetRoot<TFLiteSettings>(
163
+ flatbuffer_builder_.GetBufferPointer ());
164
+ delegate_plugin_ = delegates::DelegatePluginRegistry::CreateByName (
165
+ " XNNPackPlugin" , *tflite_settings_);
166
+
167
+ delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create ();
168
+ const TfLiteXNNPackDelegateOptions *options =
169
+ TfLiteXNNPackDelegateGetOptions (delegate.get ());
170
+ EXPECT_STREQ (options->weight_cache_file_path , kWeightCachePath );
171
+ }
172
+
91
173
} // namespace tflite
0 commit comments