1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
19 #include "tensorflow/lite/kernels/kernel_util.h"
20 #include "tensorflow/lite/micro/kernels/kernel_util.h"
21 #include "tensorflow/lite/micro/micro_utils.h"
22
23 namespace tflite {
24 namespace {
25
26 constexpr int kInputTensor = 0;
27 constexpr int kInputPositions = 1;
28 constexpr int kOutputTensor = 0;
29
30 template <typename InputT, typename CoordsT = int32_t>
Gather(const TfLiteGatherParams * params,const TfLiteEvalTensor * input,const TfLiteEvalTensor * coords,TfLiteEvalTensor * output)31 TfLiteStatus Gather(const TfLiteGatherParams* params,
32 const TfLiteEvalTensor* input,
33 const TfLiteEvalTensor* coords, TfLiteEvalTensor* output) {
34 const InputT* input_data = tflite::micro::GetTensorData<InputT>(input);
35 const CoordsT* coords_data = tflite::micro::GetTensorData<CoordsT>(coords);
36 InputT* output_data = tflite::micro::GetTensorData<InputT>(output);
37 const TfLiteIntArray* input_dims = input->dims;
38 const int input_dims_size = input_dims->size;
39 int axis = params->axis;
40 if (axis < 0) {
41 axis += input_dims_size;
42 }
43 TFLITE_DCHECK_GE(axis, 0);
44 TFLITE_DCHECK_LT(axis, input_dims_size);
45
46 int batch_dims = params->batch_dims;
47 // batch_dims should be in range: [-rank(coords), rank(coords)].
48 // Negative batch_dims is added with rank of coords.
49 const TfLiteIntArray* coords_dims = coords->dims;
50 const int coords_dims_size = coords_dims->size;
51 if (batch_dims < 0) {
52 batch_dims += coords_dims_size;
53 }
54 TFLITE_DCHECK_GE(batch_dims, 0);
55 TFLITE_DCHECK_LT(batch_dims, input_dims_size);
56 TFLITE_DCHECK_LE(batch_dims, coords_dims_size);
57 TFLITE_DCHECK_GE(axis, batch_dims);
58 for (int i = 0; i < batch_dims; ++i) {
59 TFLITE_DCHECK_EQ(input_dims->data[i], coords_dims->data[i]);
60 }
61
62 const int axis_size = input_dims->data[axis];
63
64 int batch_size = 1;
65 for (int i = 0; i < batch_dims; ++i) {
66 batch_size *= input_dims->data[i];
67 }
68 int outer_size = 1;
69 for (int i = batch_dims; i < axis; ++i) {
70 outer_size *= input_dims->data[i];
71 }
72 int inner_size = 1;
73 for (int i = axis + 1; i < input_dims_size; ++i) {
74 inner_size *= input_dims->data[i];
75 }
76 int coord_size = 1;
77 for (int i = batch_dims; i < coords_dims_size; ++i) {
78 coord_size *= coords_dims->data[i];
79 }
80
81 for (int batch = 0; batch < batch_size; ++batch) {
82 for (int outer = 0; outer < outer_size; ++outer) {
83 for (int coord = 0; coord < coord_size; ++coord) {
84 TFLITE_DCHECK_GE(coords_data[coord], 0);
85 TFLITE_DCHECK_LT(coords_data[coord], axis_size);
86 std::memcpy(output_data +
87 (((batch * outer_size) + outer) * coord_size + coord) *
88 inner_size,
89 input_data + (((batch * outer_size) + outer) * axis_size +
90 coords_data[batch * coord_size + coord]) *
91 inner_size,
92 sizeof(InputT) * inner_size);
93 }
94 }
95 }
96 return kTfLiteOk;
97 }
98
Prepare(TfLiteContext * context,TfLiteNode * node)99 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
100 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
101 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
102
103 const auto* params =
104 reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
105 const TfLiteTensor* input;
106 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
107 const TfLiteTensor* coords;
108 TF_LITE_ENSURE_OK(context,
109 GetInputSafe(context, node, kInputPositions, &coords));
110 TfLiteTensor* output;
111 TF_LITE_ENSURE_OK(context,
112 GetOutputSafe(context, node, kOutputTensor, &output));
113 switch (coords->type) {
114 case kTfLiteInt32:
115 break;
116 default:
117 TF_LITE_KERNEL_LOG(context,
118 "Positions of type '%s' are not supported by gather.",
119 TfLiteTypeGetName(coords->type));
120 return kTfLiteError;
121 break;
122 }
123
124 // Assign to output the input type.
125 output->type = input->type;
126
127 // Check conditions for different types.
128 switch (input->type) {
129 case kTfLiteFloat32:
130 case kTfLiteInt8:
131 break;
132 default:
133 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
134 TfLiteTypeGetName(input->type));
135 return kTfLiteError;
136 break;
137 }
138
139 int axis = params->axis;
140 if (axis < 0) {
141 axis += NumDimensions(input);
142 }
143 TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input));
144
145 int batch_dims = params->batch_dims;
146 // batch_dims should be in range: [-rank(coords), rank(coords)].
147 // Negative batch_dims is added with rank of coords.
148 if (batch_dims < 0) {
149 batch_dims += NumDimensions(coords);
150 }
151 TF_LITE_ENSURE(context, batch_dims <= axis);
152 TF_LITE_ENSURE(context, 0 <= batch_dims && batch_dims < NumDimensions(input));
153 TF_LITE_ENSURE(context, batch_dims <= NumDimensions(coords));
154 for (int i = 0; i < batch_dims; ++i) {
155 TF_LITE_ENSURE_EQ(context, input->dims->data[i], coords->dims->data[i]);
156 }
157
158 // GATHER updates the output tensor dimensions, but TfLiteTensor in the
159 // MicroInterpreter is a temporary allocation. We must therefore relocate the
160 // dims from the FlatBuffer to the persistant storage arena.
161 TfLiteEvalTensor* output_eval =
162 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
163 TF_LITE_ENSURE_OK(context, tflite::micro::CreateWritableTensorDimsWithCopy(
164 context, output, output_eval));
165
166 TfLiteIntArray* output_shape = output->dims;
167 output_shape->size =
168 NumDimensions(input) + NumDimensions(coords) - 1 - batch_dims;
169 int output_index = 0;
170 for (int i = 0; i < axis; ++i) {
171 output_shape->data[output_index++] = input->dims->data[i];
172 }
173 for (int i = batch_dims; i < coords->dims->size; ++i) {
174 output_shape->data[output_index++] = coords->dims->data[i];
175 }
176 for (int i = axis + 1; i < input->dims->size; ++i) {
177 output_shape->data[output_index++] = input->dims->data[i];
178 }
179 return kTfLiteOk;
180 }
181
Eval(TfLiteContext * context,TfLiteNode * node)182 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
183 const auto* params =
184 reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
185 const TfLiteEvalTensor* input =
186 tflite::micro::GetEvalInput(context, node, kInputTensor);
187 const TfLiteEvalTensor* coords =
188 tflite::micro::GetEvalInput(context, node, kInputPositions);
189 TfLiteEvalTensor* output =
190 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
191
192 if (coords->type == kTfLiteInt32) {
193 switch (input->type) {
194 case kTfLiteFloat32:
195 return Gather<float, int32_t>(params, input, coords, output);
196 break;
197 case kTfLiteInt8:
198 return Gather<int8_t, int32_t>(params, input, coords, output);
199 break;
200 default:
201 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
202 TfLiteTypeGetName(input->type));
203 return kTfLiteError;
204 break;
205 }
206 }
207 return kTfLiteOk;
208 }
209 } // namespace
210
Register_GATHER()211 TfLiteRegistration Register_GATHER() {
212 return {/*init=*/nullptr,
213 /*free=*/nullptr,
214 /*prepare=*/Prepare,
215 /*invoke=*/Eval,
216 /*profiling_string=*/nullptr,
217 /*builtin_code=*/0,
218 /*custom_name=*/nullptr,
219 /*version=*/0};
220 }
221
222 } // namespace tflite
223