1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/micro/all_ops_resolver.h"
19 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
20 #include "tensorflow/lite/micro/test_helpers.h"
21 #include "tensorflow/lite/micro/testing/micro_test.h"
22 
23 namespace tflite {
24 namespace testing {
25 namespace {
26 
27 template <typename T>
TestExpandDims(int * input_dims,const T * input_data,int * axis_dims,const int32_t * axis_data,int * expected_output_dims,int * output_dims,const T * expected_output_data,T * output_data)28 void TestExpandDims(int* input_dims, const T* input_data, int* axis_dims,
29                     const int32_t* axis_data, int* expected_output_dims,
30                     int* output_dims, const T* expected_output_data,
31                     T* output_data) {
32   TfLiteIntArray* in_dims = IntArrayFromInts(input_dims);
33   TfLiteIntArray* ax_dims = IntArrayFromInts(axis_dims);
34   TfLiteIntArray* out_dims = IntArrayFromInts(output_dims);
35   const int in_dims_size = in_dims->size;
36 
37   constexpr int inputs_size = 2;
38   constexpr int outputs_size = 1;
39   constexpr int tensors_size = inputs_size + outputs_size;
40   TfLiteTensor tensors[tensors_size] = {
41       CreateTensor(input_data, in_dims),
42       CreateTensor(axis_data, ax_dims),
43       CreateTensor(output_data, out_dims, true),
44   };
45   int inputs_array_data[] = {2, 0, 1};
46   TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
47   int outputs_array_data[] = {1, 2};
48   TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
49 
50   const TfLiteRegistration registration = Register_EXPAND_DIMS();
51   micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
52                              outputs_array,
53                              /*builtin_data=*/nullptr);
54   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
55   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
56 
57   // The output tensor's data and shape have been updated by the kernel.
58   TfLiteTensor* actual_out_tensor = &tensors[2];
59   TfLiteIntArray* actual_out_dims = actual_out_tensor->dims;
60   const int actual_out_dims_size = actual_out_dims->size;
61   const int output_size = ElementCount(*actual_out_dims);
62   TF_LITE_MICRO_EXPECT_EQ(actual_out_dims_size, (in_dims_size + 1));
63   for (int i = 0; i < actual_out_dims_size; ++i) {
64     TF_LITE_MICRO_EXPECT_EQ(expected_output_dims[i], actual_out_dims->data[i]);
65   }
66   for (int i = 0; i < output_size; ++i) {
67     TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]);
68   }
69 }
70 
71 }  // namespace
72 }  // namespace testing
73 }  // namespace tflite
74 
75 TF_LITE_MICRO_TESTS_BEGIN
76 
TF_LITE_MICRO_TEST(ExpandDimsPositiveAxisTest0)77 TF_LITE_MICRO_TEST(ExpandDimsPositiveAxisTest0) {
78   int8_t output_data[4];
79   int input_dims[] = {2, 2, 2};
80   const int8_t input_data[] = {-1, 1, -2, 2};
81   const int8_t golden_data[] = {-1, 1, -2, 2};
82   int axis_dims[] = {1, 1};
83   const int32_t axis_data[] = {0};
84   int golden_dims[] = {1, 2, 2};
85   int output_dims[] = {3, 0, 0, 0};
86   tflite::testing::TestExpandDims<int8_t>(input_dims, input_data, axis_dims,
87                                           axis_data, golden_dims, output_dims,
88                                           golden_data, output_data);
89 }
90 
TF_LITE_MICRO_TEST(ExpandDimsPositiveAxisTest1)91 TF_LITE_MICRO_TEST(ExpandDimsPositiveAxisTest1) {
92   float output_data[4];
93   int input_dims[] = {2, 2, 2};
94   const float input_data[] = {-1.1, 1.2, -2.1, 2.2};
95   const float golden_data[] = {-1.1, 1.2, -2.1, 2.2};
96   int axis_dims[] = {1, 1};
97   const int32_t axis_data[] = {1};
98   int golden_dims[] = {2, 1, 2};
99   int output_dims[] = {3, 0, 0, 0};
100   tflite::testing::TestExpandDims<float>(input_dims, input_data, axis_dims,
101                                          axis_data, golden_dims, output_dims,
102                                          golden_data, output_data);
103 }
104 
TF_LITE_MICRO_TEST(ExpandDimsPositiveAxisTest2)105 TF_LITE_MICRO_TEST(ExpandDimsPositiveAxisTest2) {
106   int8_t output_data[4];
107   int input_dims[] = {2, 2, 2};
108   const int8_t input_data[] = {-1, 1, -2, 2};
109   const int8_t golden_data[] = {-1, 1, -2, 2};
110   int axis_dims[] = {1, 1};
111   const int32_t axis_data[] = {2};
112   int golden_dims[] = {2, 2, 1};
113   int output_dims[] = {3, 0, 0, 0};
114   tflite::testing::TestExpandDims<int8_t>(input_dims, input_data, axis_dims,
115                                           axis_data, golden_dims, output_dims,
116                                           golden_data, output_data);
117 }
118 
TF_LITE_MICRO_TEST(ExpandDimsNegativeAxisTest4)119 TF_LITE_MICRO_TEST(ExpandDimsNegativeAxisTest4) {
120   int8_t output_data[6];
121   int input_dims[] = {3, 3, 1, 2};
122   const int8_t input_data[] = {-1, 1, 2, -2, 0, 3};
123   const int8_t golden_data[] = {-1, 1, 2, -2, 0, 3};
124   int axis_dims[] = {1, 1};
125   const int32_t axis_data[] = {-4};
126   int golden_dims[] = {1, 3, 1, 2};
127   int output_dims[] = {4, 0, 0, 0, 0};
128   tflite::testing::TestExpandDims<int8_t>(input_dims, input_data, axis_dims,
129                                           axis_data, golden_dims, output_dims,
130                                           golden_data, output_data);
131 }
132 
TF_LITE_MICRO_TEST(ExpandDimsNegativeAxisTest3)133 TF_LITE_MICRO_TEST(ExpandDimsNegativeAxisTest3) {
134   float output_data[6];
135   int input_dims[] = {3, 3, 1, 2};
136   const float input_data[] = {0.1, -0.8, -1.2, -0.5, 0.9, 1.3};
137   const float golden_data[] = {0.1, -0.8, -1.2, -0.5, 0.9, 1.3};
138   int axis_dims[] = {1, 1};
139   const int32_t axis_data[] = {-3};
140   int golden_dims[] = {3, 1, 1, 2};
141   int output_dims[] = {4, 0, 0, 0, 0};
142   tflite::testing::TestExpandDims<float>(input_dims, input_data, axis_dims,
143                                          axis_data, golden_dims, output_dims,
144                                          golden_data, output_data);
145 }
146 
TF_LITE_MICRO_TEST(ExpandDimsNegativeAxisTest2)147 TF_LITE_MICRO_TEST(ExpandDimsNegativeAxisTest2) {
148   int8_t output_data[6];
149   int input_dims[] = {3, 1, 2, 3};
150   const int8_t input_data[] = {-1, 1, 2, -2, 0, 3};
151   const int8_t golden_data[] = {-1, 1, 2, -2, 0, 3};
152   int axis_dims[] = {1, 1};
153   const int32_t axis_data[] = {-2};
154   int golden_dims[] = {1, 2, 1, 3};
155   int output_dims[] = {4, 0, 0, 0, 0};
156   tflite::testing::TestExpandDims<int8_t>(input_dims, input_data, axis_dims,
157                                           axis_data, golden_dims, output_dims,
158                                           golden_data, output_data);
159 }
160 
TF_LITE_MICRO_TEST(ExpandDimsNegativeAxisTest1)161 TF_LITE_MICRO_TEST(ExpandDimsNegativeAxisTest1) {
162   float output_data[6];
163   int input_dims[] = {3, 1, 3, 2};
164   const float input_data[] = {0.1, -0.8, -1.2, -0.5, 0.9, 1.3};
165   const float golden_data[] = {0.1, -0.8, -1.2, -0.5, 0.9, 1.3};
166   int axis_dims[] = {1, 1};
167   const int32_t axis_data[] = {-1};
168   int golden_dims[] = {1, 3, 2, 1};
169   int output_dims[] = {4, 0, 0, 0, 0};
170   tflite::testing::TestExpandDims<float>(input_dims, input_data, axis_dims,
171                                          axis_data, golden_dims, output_dims,
172                                          golden_data, output_data);
173 }
174 
175 TF_LITE_MICRO_TESTS_END
176