1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/micro/all_ops_resolver.h"
19 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
20 #include "tensorflow/lite/micro/test_helpers.h"
21 #include "tensorflow/lite/micro/testing/micro_test.h"
22
23 namespace tflite {
24 namespace testing {
25 namespace {
26
27 template <typename T>
TestExpandDims(int * input_dims,const T * input_data,int * axis_dims,const int32_t * axis_data,int * expected_output_dims,int * output_dims,const T * expected_output_data,T * output_data)28 void TestExpandDims(int* input_dims, const T* input_data, int* axis_dims,
29 const int32_t* axis_data, int* expected_output_dims,
30 int* output_dims, const T* expected_output_data,
31 T* output_data) {
32 TfLiteIntArray* in_dims = IntArrayFromInts(input_dims);
33 TfLiteIntArray* ax_dims = IntArrayFromInts(axis_dims);
34 TfLiteIntArray* out_dims = IntArrayFromInts(output_dims);
35 const int in_dims_size = in_dims->size;
36
37 constexpr int inputs_size = 2;
38 constexpr int outputs_size = 1;
39 constexpr int tensors_size = inputs_size + outputs_size;
40 TfLiteTensor tensors[tensors_size] = {
41 CreateTensor(input_data, in_dims),
42 CreateTensor(axis_data, ax_dims),
43 CreateTensor(output_data, out_dims, true),
44 };
45 int inputs_array_data[] = {2, 0, 1};
46 TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
47 int outputs_array_data[] = {1, 2};
48 TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
49
50 const TfLiteRegistration registration = Register_EXPAND_DIMS();
51 micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
52 outputs_array,
53 /*builtin_data=*/nullptr);
54 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
55 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
56
57 // The output tensor's data and shape have been updated by the kernel.
58 TfLiteTensor* actual_out_tensor = &tensors[2];
59 TfLiteIntArray* actual_out_dims = actual_out_tensor->dims;
60 const int actual_out_dims_size = actual_out_dims->size;
61 const int output_size = ElementCount(*actual_out_dims);
62 TF_LITE_MICRO_EXPECT_EQ(actual_out_dims_size, (in_dims_size + 1));
63 for (int i = 0; i < actual_out_dims_size; ++i) {
64 TF_LITE_MICRO_EXPECT_EQ(expected_output_dims[i], actual_out_dims->data[i]);
65 }
66 for (int i = 0; i < output_size; ++i) {
67 TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]);
68 }
69 }
70
71 } // namespace
72 } // namespace testing
73 } // namespace tflite
74
75 TF_LITE_MICRO_TESTS_BEGIN
76
TF_LITE_MICRO_TEST(ExpandDimsPositiveAxisTest0)77 TF_LITE_MICRO_TEST(ExpandDimsPositiveAxisTest0) {
78 int8_t output_data[4];
79 int input_dims[] = {2, 2, 2};
80 const int8_t input_data[] = {-1, 1, -2, 2};
81 const int8_t golden_data[] = {-1, 1, -2, 2};
82 int axis_dims[] = {1, 1};
83 const int32_t axis_data[] = {0};
84 int golden_dims[] = {1, 2, 2};
85 int output_dims[] = {3, 0, 0, 0};
86 tflite::testing::TestExpandDims<int8_t>(input_dims, input_data, axis_dims,
87 axis_data, golden_dims, output_dims,
88 golden_data, output_data);
89 }
90
TF_LITE_MICRO_TEST(ExpandDimsPositiveAxisTest1)91 TF_LITE_MICRO_TEST(ExpandDimsPositiveAxisTest1) {
92 float output_data[4];
93 int input_dims[] = {2, 2, 2};
94 const float input_data[] = {-1.1, 1.2, -2.1, 2.2};
95 const float golden_data[] = {-1.1, 1.2, -2.1, 2.2};
96 int axis_dims[] = {1, 1};
97 const int32_t axis_data[] = {1};
98 int golden_dims[] = {2, 1, 2};
99 int output_dims[] = {3, 0, 0, 0};
100 tflite::testing::TestExpandDims<float>(input_dims, input_data, axis_dims,
101 axis_data, golden_dims, output_dims,
102 golden_data, output_data);
103 }
104
TF_LITE_MICRO_TEST(ExpandDimsPositiveAxisTest2)105 TF_LITE_MICRO_TEST(ExpandDimsPositiveAxisTest2) {
106 int8_t output_data[4];
107 int input_dims[] = {2, 2, 2};
108 const int8_t input_data[] = {-1, 1, -2, 2};
109 const int8_t golden_data[] = {-1, 1, -2, 2};
110 int axis_dims[] = {1, 1};
111 const int32_t axis_data[] = {2};
112 int golden_dims[] = {2, 2, 1};
113 int output_dims[] = {3, 0, 0, 0};
114 tflite::testing::TestExpandDims<int8_t>(input_dims, input_data, axis_dims,
115 axis_data, golden_dims, output_dims,
116 golden_data, output_data);
117 }
118
TF_LITE_MICRO_TEST(ExpandDimsNegativeAxisTest4)119 TF_LITE_MICRO_TEST(ExpandDimsNegativeAxisTest4) {
120 int8_t output_data[6];
121 int input_dims[] = {3, 3, 1, 2};
122 const int8_t input_data[] = {-1, 1, 2, -2, 0, 3};
123 const int8_t golden_data[] = {-1, 1, 2, -2, 0, 3};
124 int axis_dims[] = {1, 1};
125 const int32_t axis_data[] = {-4};
126 int golden_dims[] = {1, 3, 1, 2};
127 int output_dims[] = {4, 0, 0, 0, 0};
128 tflite::testing::TestExpandDims<int8_t>(input_dims, input_data, axis_dims,
129 axis_data, golden_dims, output_dims,
130 golden_data, output_data);
131 }
132
TF_LITE_MICRO_TEST(ExpandDimsNegativeAxisTest3)133 TF_LITE_MICRO_TEST(ExpandDimsNegativeAxisTest3) {
134 float output_data[6];
135 int input_dims[] = {3, 3, 1, 2};
136 const float input_data[] = {0.1, -0.8, -1.2, -0.5, 0.9, 1.3};
137 const float golden_data[] = {0.1, -0.8, -1.2, -0.5, 0.9, 1.3};
138 int axis_dims[] = {1, 1};
139 const int32_t axis_data[] = {-3};
140 int golden_dims[] = {3, 1, 1, 2};
141 int output_dims[] = {4, 0, 0, 0, 0};
142 tflite::testing::TestExpandDims<float>(input_dims, input_data, axis_dims,
143 axis_data, golden_dims, output_dims,
144 golden_data, output_data);
145 }
146
TF_LITE_MICRO_TEST(ExpandDimsNegativeAxisTest2)147 TF_LITE_MICRO_TEST(ExpandDimsNegativeAxisTest2) {
148 int8_t output_data[6];
149 int input_dims[] = {3, 1, 2, 3};
150 const int8_t input_data[] = {-1, 1, 2, -2, 0, 3};
151 const int8_t golden_data[] = {-1, 1, 2, -2, 0, 3};
152 int axis_dims[] = {1, 1};
153 const int32_t axis_data[] = {-2};
154 int golden_dims[] = {1, 2, 1, 3};
155 int output_dims[] = {4, 0, 0, 0, 0};
156 tflite::testing::TestExpandDims<int8_t>(input_dims, input_data, axis_dims,
157 axis_data, golden_dims, output_dims,
158 golden_data, output_data);
159 }
160
TF_LITE_MICRO_TEST(ExpandDimsNegativeAxisTest1)161 TF_LITE_MICRO_TEST(ExpandDimsNegativeAxisTest1) {
162 float output_data[6];
163 int input_dims[] = {3, 1, 3, 2};
164 const float input_data[] = {0.1, -0.8, -1.2, -0.5, 0.9, 1.3};
165 const float golden_data[] = {0.1, -0.8, -1.2, -0.5, 0.9, 1.3};
166 int axis_dims[] = {1, 1};
167 const int32_t axis_data[] = {-1};
168 int golden_dims[] = {1, 3, 2, 1};
169 int output_dims[] = {4, 0, 0, 0, 0};
170 tflite::testing::TestExpandDims<float>(input_dims, input_data, axis_dims,
171 axis_data, golden_dims, output_dims,
172 golden_data, output_data);
173 }
174
175 TF_LITE_MICRO_TESTS_END
176