Skip to content

Commit e08c43d

Browse files
Merge pull request #78 from MCW-Dev/tflite_pad_missing_datatype
TfLite Pad missing datatype support (#50)
2 parents 5e1ed15 + 658828b commit e08c43d

File tree

3 files changed

+173
-0
lines changed

3 files changed

+173
-0
lines changed

tensorflow/lite/kernels/pad.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,44 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
330330
}
331331
}
332332
} break;
333+
case kTfLiteFloat16: {
334+
Eigen::half pad_value =
335+
op_context.constant_values == nullptr
336+
? static_cast<Eigen::half>(0.f)
337+
: *GetTensorData<Eigen::half>(op_context.constant_values);
338+
if (kernel_type == kReference) {
339+
if (op_context.resizing_category == ResizingCategory::kImageStyle) {
340+
TF_LITE_PAD(reference_ops, PadImageStyle, Eigen::half, pad_value);
341+
} else {
342+
TF_LITE_PAD(reference_ops, Pad, Eigen::half, pad_value);
343+
}
344+
} else if (kernel_type == kGenericOptimized) {
345+
if (op_context.resizing_category == ResizingCategory::kImageStyle) {
346+
TF_LITE_PAD(optimized_ops, PadImageStyle, Eigen::half, pad_value);
347+
} else {
348+
TF_LITE_PAD(optimized_ops, Pad, Eigen::half, pad_value);
349+
}
350+
}
351+
} break;
352+
case kTfLiteBFloat16: {
353+
Eigen::bfloat16 pad_value =
354+
op_context.constant_values == nullptr
355+
? static_cast<Eigen::bfloat16>(0.f)
356+
: *GetTensorData<Eigen::bfloat16>(op_context.constant_values);
357+
if (kernel_type == kReference) {
358+
if (op_context.resizing_category == ResizingCategory::kImageStyle) {
359+
TF_LITE_PAD(reference_ops, PadImageStyle, Eigen::bfloat16, pad_value);
360+
} else {
361+
TF_LITE_PAD(reference_ops, Pad, Eigen::bfloat16, pad_value);
362+
}
363+
} else if (kernel_type == kGenericOptimized) {
364+
if (op_context.resizing_category == ResizingCategory::kImageStyle) {
365+
TF_LITE_PAD(optimized_ops, PadImageStyle, Eigen::bfloat16, pad_value);
366+
} else {
367+
TF_LITE_PAD(optimized_ops, Pad, Eigen::bfloat16, pad_value);
368+
}
369+
}
370+
} break;
333371
case kTfLiteUInt8: {
334372
EvalInt<uint8_t>(context, op_context, op_params);
335373
} break;

tensorflow/lite/kernels/pad_test.cc

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,53 @@ TEST_F(PadV2OpTest, Int64PaddingSimpleConstFloat32ValuedTestInt8) {
771771
SimpleConstFloat32ValuedTestInt8<int64_t>();
772772
}
773773

774+
template <typename padding_integer_type>
775+
void SimpleConstFloat16ValuedTest() {
776+
PadV2OpConstModel<Eigen::half, padding_integer_type> m(
777+
{TensorType_FLOAT16, {1, 2, 2, 1}}, {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0},
778+
Eigen::half{4.0f}, {TensorType_FLOAT16});
779+
m.SetInput({Eigen::half{1.5f}, Eigen::half{2.5f}, Eigen::half{3.5f},
780+
Eigen::half{4.5}});
781+
ASSERT_EQ(m.Invoke(), kTfLiteOk);
782+
EXPECT_THAT(
783+
m.GetOutput(),
784+
ElementsAreArray(ArrayFloatNear(
785+
{Eigen::half{4}, Eigen::half{4}, Eigen::half{4}, Eigen::half{4},
786+
Eigen::half{4}, Eigen::half{1.5}, Eigen::half{2.5}, Eigen::half{4},
787+
Eigen::half{4}, Eigen::half{3.5}, Eigen::half{4.5}, Eigen::half{4},
788+
Eigen::half{4}, Eigen::half{4}, Eigen::half{4}, Eigen::half{4}})));
789+
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
790+
}
791+
792+
TEST_F(PadV2OpTest, Int32PaddingSimpleConstFloat16) {
793+
SimpleConstFloat16ValuedTest<int32_t>();
794+
}
795+
796+
TEST_F(PadV2OpTest, Int64PaddingSimpleConstFloat16) {
797+
SimpleConstFloat16ValuedTest<int64_t>();
798+
}
799+
800+
template <typename padding_integer_type>
801+
void SimpleConstBFloat16ValuedTest() {
802+
PadV2OpConstModel<Eigen::bfloat16, padding_integer_type> m(
803+
{TensorType_BFLOAT16, {1, 2, 2, 1}}, {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0},
804+
Eigen::bfloat16{6.0f}, {TensorType_BFLOAT16});
805+
m.SetInput({Eigen::bfloat16{1.0f}, Eigen::bfloat16{2.0f},
806+
Eigen::bfloat16{3.0f}, Eigen::bfloat16{4.0}});
807+
ASSERT_EQ(m.Invoke(), kTfLiteOk);
808+
EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 6, 6, 6, 6, 1, 2, 6, 6, 3, 4,
809+
6, 6, 6, 6, 6}));
810+
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
811+
}
812+
813+
TEST_F(PadV2OpTest, Int32PaddingSimpleConstBFloat16) {
814+
SimpleConstBFloat16ValuedTest<int32_t>();
815+
}
816+
817+
TEST_F(PadV2OpTest, Int64PaddingSimpleConstBFloat16) {
818+
SimpleConstBFloat16ValuedTest<int64_t>();
819+
}
820+
774821
template <typename padding_integer_type>
775822
void Simple4DConstFloat32ValuedTest() {
776823
// Padding is represented as four 2-D lists representing above padding and
@@ -792,6 +839,49 @@ TEST_F(PadV2OpTest, Int64PaddingSimple4DConstFloat32ValuedTest) {
792839
Simple4DConstFloat32ValuedTest<int64_t>();
793840
}
794841

842+
template <typename padding_integer_type>
843+
void Simple4DConstFloat16ValuedTest() {
844+
PadV2OpConstModel<Eigen::half, padding_integer_type> m(
845+
{TensorType_FLOAT16, {1, 1, 2, 1}}, {4, 2}, {0, 1, 0, 0, 0, 0, 0, 1},
846+
Eigen::half{7.0}, {TensorType_FLOAT16});
847+
m.SetInput({Eigen::half{3.0f}, Eigen::half{6.0f}});
848+
ASSERT_EQ(m.Invoke(), kTfLiteOk);
849+
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 7, 6, 7, 7, 7, 7, 7}));
850+
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2, 2}));
851+
}
852+
853+
TEST_F(PadV2OpTest, Int32PaddingSimple4DConstFloat16ValuedTest) {
854+
Simple4DConstFloat16ValuedTest<int32_t>();
855+
}
856+
857+
TEST_F(PadV2OpTest, Int64PaddingSimple4DConstFloat16ValuedTest) {
858+
Simple4DConstFloat16ValuedTest<int64_t>();
859+
}
860+
861+
template <typename padding_integer_type>
862+
void Simple4DConstBFloat16ValuedTest() {
863+
PadV2OpConstModel<Eigen::bfloat16, padding_integer_type> m(
864+
{TensorType_BFLOAT16, {1, 1, 2, 1}}, {4, 2}, {0, 1, 0, 0, 0, 0, 0, 1},
865+
Eigen::bfloat16{5.0}, {TensorType_BFLOAT16});
866+
m.SetInput({Eigen::bfloat16{3.2f}, Eigen::bfloat16{6.4f}});
867+
ASSERT_EQ(m.Invoke(), kTfLiteOk);
868+
EXPECT_THAT(
869+
m.GetOutput(),
870+
ElementsAreArray(ArrayFloatNear(
871+
{Eigen::bfloat16{3.2f}, Eigen::bfloat16{5.0f}, Eigen::bfloat16{6.4f},
872+
Eigen::bfloat16{5.0f}, Eigen::bfloat16{5.0f}, Eigen::bfloat16{5.0f},
873+
Eigen::bfloat16{5.0f}, Eigen::bfloat16{5.0f}})));
874+
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2, 2}));
875+
}
876+
877+
TEST_F(PadV2OpTest, Int32PaddingSimple4DConstBFloat16ValuedTest) {
878+
Simple4DConstBFloat16ValuedTest<int32_t>();
879+
}
880+
881+
TEST_F(PadV2OpTest, Int64PaddingSimple4DConstBFloat16ValuedTest) {
882+
Simple4DConstBFloat16ValuedTest<int64_t>();
883+
}
884+
795885
template <typename padding_integer_type>
796886
void SimpleConstInt32ValuedTest() {
797887
// Padding is represented as four 2-D lists representing above padding and
@@ -834,6 +924,50 @@ TEST_F(PadV2OpTest, Int64PaddingSimpleDynamicTest) {
834924
SimpleDynamicTestV2<int64_t>();
835925
}
836926

927+
template <typename padding_integer_type>
928+
void SimpleDynamicTestV2Float16() {
929+
PadV2OpDynamicModel<Eigen::half, padding_integer_type> m(
930+
{TensorType_FLOAT16, {1, 2, 2, 1}}, {4, 2}, Eigen::half{0.0},
931+
{TensorType_FLOAT16});
932+
m.SetInput({Eigen::half{1.0f}, Eigen::half{2.0f}, Eigen::half{3.0f},
933+
Eigen::half{4.0f}});
934+
m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
935+
ASSERT_EQ(m.Invoke(), kTfLiteOk);
936+
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4,
937+
0, 0, 0, 0, 0}));
938+
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
939+
}
940+
941+
TEST_F(PadV2OpTest, Int32PaddingSimpleDynamicTestFloat16) {
942+
SimpleDynamicTestV2Float16<int32_t>();
943+
}
944+
945+
TEST_F(PadV2OpTest, Int64PaddingSimpleDynamicTestFloat16) {
946+
SimpleDynamicTestV2Float16<int64_t>();
947+
}
948+
949+
template <typename padding_integer_type>
950+
void SimpleDynamicTestV2BFloat16() {
951+
PadV2OpDynamicModel<Eigen::bfloat16, padding_integer_type> m(
952+
{TensorType_BFLOAT16, {1, 2, 2, 1}}, {4, 2}, Eigen::bfloat16{2.0},
953+
{TensorType_BFLOAT16});
954+
m.SetInput({Eigen::bfloat16{5.0f}, Eigen::bfloat16{6.0f},
955+
Eigen::bfloat16{7.0f}, Eigen::bfloat16{8.0f}});
956+
m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
957+
ASSERT_EQ(m.Invoke(), kTfLiteOk);
958+
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 2, 2, 2, 2, 5, 6, 2, 2, 7, 8,
959+
2, 2, 2, 2, 2}));
960+
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
961+
}
962+
963+
TEST_F(PadV2OpTest, Int32PaddingSimpleDynamicTestBFloat16) {
964+
SimpleDynamicTestV2BFloat16<int32_t>();
965+
}
966+
967+
TEST_F(PadV2OpTest, Int64PaddingSimpleDynamicTestBFloat16) {
968+
SimpleDynamicTestV2BFloat16<int64_t>();
969+
}
970+
837971
template <typename padding_integer_type>
838972
void PadV2OpDynamicUnequalDimensions() {
839973
if (SingleOpModel::GetForceUseNnapi()) {

tensorflow/lite/kernels/test_util.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,6 +1168,7 @@ TFLITE_TENSOR_TYPE_ASSOC(uint32_t, TensorType_UINT32);
11681168
TFLITE_TENSOR_TYPE_ASSOC(uint64_t, TensorType_UINT64);
11691169
TFLITE_TENSOR_TYPE_ASSOC(TfLiteFloat16, TensorType_FLOAT16);
11701170
TFLITE_TENSOR_TYPE_ASSOC(Eigen::half, TensorType_FLOAT16);
1171+
TFLITE_TENSOR_TYPE_ASSOC(Eigen::bfloat16, TensorType_BFLOAT16);
11711172
TFLITE_TENSOR_TYPE_ASSOC(float, TensorType_FLOAT32);
11721173
TFLITE_TENSOR_TYPE_ASSOC(double, TensorType_FLOAT64);
11731174
TFLITE_TENSOR_TYPE_ASSOC(std::string, TensorType_STRING);

0 commit comments

Comments
 (0)