1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
21 #include "tensorflow/lite/micro/test_helpers.h"
22 #include "tensorflow/lite/micro/testing/micro_test.h"
23
24 namespace tflite {
25 namespace testing {
26 namespace {
27
28 constexpr int kBasicInputOutputSize = 16;
29 int basic_input_dims[] = {4, 4, 2, 2, 1};
30 const float basic_input[kBasicInputOutputSize] = {
31 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
32 int basic_block_shape_dims[] = {1, 2};
33 const int32_t basic_block_shape[] = {2, 2};
34 int basic_crops_dims[] = {1, 4};
35 const int32_t basic_crops[] = {0, 0, 0, 0};
36 int basic_output_dims[] = {4, 1, 4, 4, 1};
37 const float basic_golden[kBasicInputOutputSize] = {1, 5, 2, 6, 9, 13, 10, 14,
38 3, 7, 4, 8, 11, 15, 12, 16};
39
40 template <typename T>
ValidateBatchToSpaceNdGoldens(TfLiteTensor * tensors,int tensors_size,const T * golden,T * output,int output_size)41 TfLiteStatus ValidateBatchToSpaceNdGoldens(TfLiteTensor* tensors,
42 int tensors_size, const T* golden,
43 T* output, int output_size) {
44 int inputs_array_data[] = {3, 0, 1, 2};
45 TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
46 int outputs_array_data[] = {1, 3};
47 TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
48
49 const TfLiteRegistration registration = Register_BATCH_TO_SPACE_ND();
50 micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
51 outputs_array, nullptr);
52
53 TF_LITE_ENSURE_STATUS(runner.InitAndPrepare());
54 TF_LITE_ENSURE_STATUS(runner.Invoke());
55
56 for (int i = 0; i < output_size; ++i) {
57 // TODO(b/158102673): workaround for not having fatal test assertions.
58 TF_LITE_MICRO_EXPECT_EQ(golden[i], output[i]);
59 if (golden[i] != output[i]) {
60 return kTfLiteError;
61 }
62 }
63 return kTfLiteOk;
64 }
65
TestBatchToSpaceNdFloat(int * input_dims_data,const float * input_data,int * block_shape_dims_data,const int32_t * block_shape_data,int * crops_dims_data,const int32_t * crops_data,int * output_dims_data,const float * golden,float * output_data)66 TfLiteStatus TestBatchToSpaceNdFloat(
67 int* input_dims_data, const float* input_data, int* block_shape_dims_data,
68 const int32_t* block_shape_data, int* crops_dims_data,
69 const int32_t* crops_data, int* output_dims_data, const float* golden,
70 float* output_data) {
71 TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
72 TfLiteIntArray* block_shape_dims = IntArrayFromInts(block_shape_dims_data);
73 TfLiteIntArray* crops_dims = IntArrayFromInts(crops_dims_data);
74 TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
75
76 constexpr int inputs_size = 3;
77 constexpr int outputs_size = 1;
78 constexpr int tensors_size = inputs_size + outputs_size;
79 TfLiteTensor tensors[tensors_size] = {
80 CreateTensor(input_data, input_dims),
81 CreateTensor(block_shape_data, block_shape_dims),
82 CreateTensor(crops_data, crops_dims),
83 CreateTensor(output_data, output_dims),
84 };
85
86 return ValidateBatchToSpaceNdGoldens(tensors, tensors_size, golden,
87 output_data, ElementCount(*output_dims));
88 }
89
90 template <typename T>
TestBatchToSpaceNdQuantized(int * input_dims_data,const float * input_data,T * input_quantized,float input_scale,int input_zero_point,int * block_shape_dims_data,const int32_t * block_shape_data,int * crops_dims_data,const int32_t * crops_data,int * output_dims_data,const float * golden,T * golden_quantized,float output_scale,int output_zero_point,T * output_data)91 TfLiteStatus TestBatchToSpaceNdQuantized(
92 int* input_dims_data, const float* input_data, T* input_quantized,
93 float input_scale, int input_zero_point, int* block_shape_dims_data,
94 const int32_t* block_shape_data, int* crops_dims_data,
95 const int32_t* crops_data, int* output_dims_data, const float* golden,
96 T* golden_quantized, float output_scale, int output_zero_point,
97 T* output_data) {
98 TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
99 TfLiteIntArray* block_shape_dims = IntArrayFromInts(block_shape_dims_data);
100 TfLiteIntArray* crops_dims = IntArrayFromInts(crops_dims_data);
101 TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
102
103 constexpr int inputs_size = 3;
104 constexpr int outputs_size = 1;
105 constexpr int tensors_size = inputs_size + outputs_size;
106 TfLiteTensor tensors[tensors_size] = {
107 tflite::testing::CreateQuantizedTensor(input_data, input_quantized,
108 input_dims, input_scale,
109 input_zero_point),
110 tflite::testing::CreateTensor(block_shape_data, block_shape_dims),
111 tflite::testing::CreateTensor(crops_data, crops_dims),
112 tflite::testing::CreateQuantizedTensor(output_data, output_dims,
113 output_scale, output_zero_point),
114 };
115 tflite::Quantize(golden, golden_quantized, ElementCount(*output_dims),
116 output_scale, output_zero_point);
117
118 return ValidateBatchToSpaceNdGoldens(tensors, tensors_size, golden_quantized,
119 output_data, ElementCount(*output_dims));
120 }
121
122 } // namespace
123 } // namespace testing
124 } // namespace tflite
125
126 TF_LITE_MICRO_TESTS_BEGIN
127
TF_LITE_MICRO_TEST(BatchToSpaceBasicFloat)128 TF_LITE_MICRO_TEST(BatchToSpaceBasicFloat) {
129 float output[tflite::testing::kBasicInputOutputSize];
130 TF_LITE_MICRO_EXPECT_EQ(
131 kTfLiteOk,
132 tflite::testing::TestBatchToSpaceNdFloat(
133 tflite::testing::basic_input_dims, tflite::testing::basic_input,
134 tflite::testing::basic_block_shape_dims,
135 tflite::testing::basic_block_shape, tflite::testing::basic_crops_dims,
136 tflite::testing::basic_crops, tflite::testing::basic_output_dims,
137 tflite::testing::basic_golden, output));
138 }
139
TF_LITE_MICRO_TEST(BatchToSpaceBasicInt8)140 TF_LITE_MICRO_TEST(BatchToSpaceBasicInt8) {
141 int8_t output[tflite::testing::kBasicInputOutputSize];
142 int8_t input_quantized[tflite::testing::kBasicInputOutputSize];
143 int8_t golden_quantized[tflite::testing::kBasicInputOutputSize];
144 TF_LITE_MICRO_EXPECT_EQ(
145 kTfLiteOk,
146 tflite::testing::TestBatchToSpaceNdQuantized(
147 tflite::testing::basic_input_dims, tflite::testing::basic_input,
148 input_quantized, 1.0f, 0, tflite::testing::basic_block_shape_dims,
149 tflite::testing::basic_block_shape, tflite::testing::basic_crops_dims,
150 tflite::testing::basic_crops, tflite::testing::basic_output_dims,
151 tflite::testing::basic_golden, golden_quantized, 1.0f, 0, output));
152 }
153
154 TF_LITE_MICRO_TESTS_END
155