1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
19 #include "tensorflow/lite/kernels/kernel_util.h"
20 #include "tensorflow/lite/micro/kernels/kernel_util.h"
21 #include "tensorflow/lite/micro/micro_utils.h"
22 
23 namespace tflite {
24 namespace {
25 
26 constexpr int kInputTensor = 0;
27 constexpr int kInputPositions = 1;
28 constexpr int kOutputTensor = 0;
29 
30 template <typename InputT, typename CoordsT = int32_t>
Gather(const TfLiteGatherParams * params,const TfLiteEvalTensor * input,const TfLiteEvalTensor * coords,TfLiteEvalTensor * output)31 TfLiteStatus Gather(const TfLiteGatherParams* params,
32                     const TfLiteEvalTensor* input,
33                     const TfLiteEvalTensor* coords, TfLiteEvalTensor* output) {
34   const InputT* input_data = tflite::micro::GetTensorData<InputT>(input);
35   const CoordsT* coords_data = tflite::micro::GetTensorData<CoordsT>(coords);
36   InputT* output_data = tflite::micro::GetTensorData<InputT>(output);
37   const TfLiteIntArray* input_dims = input->dims;
38   const int input_dims_size = input_dims->size;
39   int axis = params->axis;
40   if (axis < 0) {
41     axis += input_dims_size;
42   }
43   TFLITE_DCHECK_GE(axis, 0);
44   TFLITE_DCHECK_LT(axis, input_dims_size);
45 
46   int batch_dims = params->batch_dims;
47   // batch_dims should be in range: [-rank(coords), rank(coords)].
48   // Negative batch_dims is added with rank of coords.
49   const TfLiteIntArray* coords_dims = coords->dims;
50   const int coords_dims_size = coords_dims->size;
51   if (batch_dims < 0) {
52     batch_dims += coords_dims_size;
53   }
54   TFLITE_DCHECK_GE(batch_dims, 0);
55   TFLITE_DCHECK_LT(batch_dims, input_dims_size);
56   TFLITE_DCHECK_LE(batch_dims, coords_dims_size);
57   TFLITE_DCHECK_GE(axis, batch_dims);
58   for (int i = 0; i < batch_dims; ++i) {
59     TFLITE_DCHECK_EQ(input_dims->data[i], coords_dims->data[i]);
60   }
61 
62   const int axis_size = input_dims->data[axis];
63 
64   int batch_size = 1;
65   for (int i = 0; i < batch_dims; ++i) {
66     batch_size *= input_dims->data[i];
67   }
68   int outer_size = 1;
69   for (int i = batch_dims; i < axis; ++i) {
70     outer_size *= input_dims->data[i];
71   }
72   int inner_size = 1;
73   for (int i = axis + 1; i < input_dims_size; ++i) {
74     inner_size *= input_dims->data[i];
75   }
76   int coord_size = 1;
77   for (int i = batch_dims; i < coords_dims_size; ++i) {
78     coord_size *= coords_dims->data[i];
79   }
80 
81   for (int batch = 0; batch < batch_size; ++batch) {
82     for (int outer = 0; outer < outer_size; ++outer) {
83       for (int coord = 0; coord < coord_size; ++coord) {
84         TFLITE_DCHECK_GE(coords_data[coord], 0);
85         TFLITE_DCHECK_LT(coords_data[coord], axis_size);
86         std::memcpy(output_data +
87                         (((batch * outer_size) + outer) * coord_size + coord) *
88                             inner_size,
89                     input_data + (((batch * outer_size) + outer) * axis_size +
90                                   coords_data[batch * coord_size + coord]) *
91                                      inner_size,
92                     sizeof(InputT) * inner_size);
93       }
94     }
95   }
96   return kTfLiteOk;
97 }
98 
Prepare(TfLiteContext * context,TfLiteNode * node)99 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
100   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
101   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
102 
103   const auto* params =
104       reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
105   const TfLiteTensor* input;
106   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
107   const TfLiteTensor* coords;
108   TF_LITE_ENSURE_OK(context,
109                     GetInputSafe(context, node, kInputPositions, &coords));
110   TfLiteTensor* output;
111   TF_LITE_ENSURE_OK(context,
112                     GetOutputSafe(context, node, kOutputTensor, &output));
113   switch (coords->type) {
114     case kTfLiteInt32:
115       break;
116     default:
117       TF_LITE_KERNEL_LOG(context,
118                          "Positions of type '%s' are not supported by gather.",
119                          TfLiteTypeGetName(coords->type));
120       return kTfLiteError;
121       break;
122   }
123 
124   // Assign to output the input type.
125   output->type = input->type;
126 
127   // Check conditions for different types.
128   switch (input->type) {
129     case kTfLiteFloat32:
130     case kTfLiteInt8:
131       break;
132     default:
133       TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
134                          TfLiteTypeGetName(input->type));
135       return kTfLiteError;
136       break;
137   }
138 
139   int axis = params->axis;
140   if (axis < 0) {
141     axis += NumDimensions(input);
142   }
143   TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input));
144 
145   int batch_dims = params->batch_dims;
146   // batch_dims should be in range: [-rank(coords), rank(coords)].
147   // Negative batch_dims is added with rank of coords.
148   if (batch_dims < 0) {
149     batch_dims += NumDimensions(coords);
150   }
151   TF_LITE_ENSURE(context, batch_dims <= axis);
152   TF_LITE_ENSURE(context, 0 <= batch_dims && batch_dims < NumDimensions(input));
153   TF_LITE_ENSURE(context, batch_dims <= NumDimensions(coords));
154   for (int i = 0; i < batch_dims; ++i) {
155     TF_LITE_ENSURE_EQ(context, input->dims->data[i], coords->dims->data[i]);
156   }
157 
158   // GATHER updates the output tensor dimensions, but TfLiteTensor in the
159   // MicroInterpreter is a temporary allocation. We must therefore relocate the
160   // dims from the FlatBuffer to the persistant storage arena.
161   TfLiteEvalTensor* output_eval =
162       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
163   TF_LITE_ENSURE_OK(context, tflite::micro::CreateWritableTensorDimsWithCopy(
164                                  context, output, output_eval));
165 
166   TfLiteIntArray* output_shape = output->dims;
167   output_shape->size =
168       NumDimensions(input) + NumDimensions(coords) - 1 - batch_dims;
169   int output_index = 0;
170   for (int i = 0; i < axis; ++i) {
171     output_shape->data[output_index++] = input->dims->data[i];
172   }
173   for (int i = batch_dims; i < coords->dims->size; ++i) {
174     output_shape->data[output_index++] = coords->dims->data[i];
175   }
176   for (int i = axis + 1; i < input->dims->size; ++i) {
177     output_shape->data[output_index++] = input->dims->data[i];
178   }
179   return kTfLiteOk;
180 }
181 
Eval(TfLiteContext * context,TfLiteNode * node)182 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
183   const auto* params =
184       reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
185   const TfLiteEvalTensor* input =
186       tflite::micro::GetEvalInput(context, node, kInputTensor);
187   const TfLiteEvalTensor* coords =
188       tflite::micro::GetEvalInput(context, node, kInputPositions);
189   TfLiteEvalTensor* output =
190       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
191 
192   if (coords->type == kTfLiteInt32) {
193     switch (input->type) {
194       case kTfLiteFloat32:
195         return Gather<float, int32_t>(params, input, coords, output);
196         break;
197       case kTfLiteInt8:
198         return Gather<int8_t, int32_t>(params, input, coords, output);
199         break;
200       default:
201         TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
202                            TfLiteTypeGetName(input->type));
203         return kTfLiteError;
204         break;
205     }
206   }
207   return kTfLiteOk;
208 }
209 }  // namespace
210 
Register_GATHER()211 TfLiteRegistration Register_GATHER() {
212   return {/*init=*/nullptr,
213           /*free=*/nullptr,
214           /*prepare=*/Prepare,
215           /*invoke=*/Eval,
216           /*profiling_string=*/nullptr,
217           /*builtin_code=*/0,
218           /*custom_name=*/nullptr,
219           /*version=*/0};
220 }
221 
222 }  // namespace tflite
223