1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/c/common.h"
17 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
18 #include "tensorflow/lite/kernels/kernel_util.h"
19 #include "tensorflow/lite/micro/kernels/kernel_util.h"
20 #include "tensorflow/lite/micro/micro_utils.h"
21 
22 namespace tflite {
23 namespace {
24 
25 constexpr int kParams = 0;
26 constexpr int kIndices = 1;
27 constexpr int kOutputTensor = 0;
28 constexpr int MAX_INDICES_ND = 5;
29 
Prepare(TfLiteContext * context,TfLiteNode * node)30 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
31   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
32   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
33 
34   const TfLiteTensor* params;
35   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kParams, &params));
36   const TfLiteTensor* indices;
37   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
38   TfLiteTensor* output;
39   TF_LITE_ENSURE_OK(context,
40                     GetOutputSafe(context, node, kOutputTensor, &output));
41 
42   switch (params->type) {
43     case kTfLiteFloat32:
44     case kTfLiteInt8:
45       break;
46     default:
47       TF_LITE_KERNEL_LOG(context,
48                          "Params of type '%s' are not supported by gather_nd.",
49                          TfLiteTypeGetName(params->type));
50       return kTfLiteError;
51       break;
52   }
53   switch (indices->type) {
54     case kTfLiteInt32:
55       break;
56     default:
57       TF_LITE_KERNEL_LOG(context,
58                          "Indices of type '%s' are not supported by gather_nd.",
59                          TfLiteTypeGetName(indices->type));
60       return kTfLiteError;
61   }
62 
63   const int params_rank = NumDimensions(params);
64   const int indices_rank = NumDimensions(indices);
65   const int indices_nd = SizeOfDimension(indices, indices_rank - 1);
66   if (params_rank < 1) {
67     TF_LITE_KERNEL_LOG(context, "Params must be at least a vector.");
68     return kTfLiteError;
69   }
70   if (indices_rank < 1) {
71     TF_LITE_KERNEL_LOG(context, "Indices must be at least a vector.");
72     return kTfLiteError;
73   }
74   if (indices_nd > params_rank) {
75     TF_LITE_KERNEL_LOG(
76         context, "Index innermost dimension length must be <= params rank.");
77     return kTfLiteError;
78   }
79   if (indices_nd > MAX_INDICES_ND) {
80     TF_LITE_KERNEL_LOG(context,
81                        "Index innermost dimension length must not exceed %d.",
82                        MAX_INDICES_ND);
83     return kTfLiteError;
84   }
85 
86   // Assign to output the input type.
87   output->type = params->type;
88 
89   // TFLM gather_nd does not create the output tensor, but it needs to ensure
90   // that the output shape is correct. The result shape is
91   // indices.shape[:-1] + params.shape[indices.shape[-1]:]
92   TfLiteIntArray* output_shape = output->dims;
93   int output_index = 0;
94   for (int i = 0; i < indices_rank - 1; ++i) {
95     output_shape->data[output_index++] = indices->dims->data[i];
96   }
97   for (int i = indices_nd; i < params_rank; ++i) {
98     output_shape->data[output_index++] = params->dims->data[i];
99   }
100   output_shape->size = output_index;
101   return kTfLiteOk;
102 }
103 
104 template <typename ParamsT, typename IndicesT>
GatherNd(const TfLiteEvalTensor * params,const TfLiteEvalTensor * indices,TfLiteEvalTensor * output)105 TfLiteStatus GatherNd(const TfLiteEvalTensor* params,
106                       const TfLiteEvalTensor* indices,
107                       TfLiteEvalTensor* output) {
108   const int indices_dims = indices->dims->size;
109   const int indices_nd = indices->dims->data[indices_dims - 1];
110   const int params_dims = params->dims->size;
111   const IndicesT* index_data = tflite::micro::GetTensorData<IndicesT>(indices);
112   const ParamsT* param_data = tflite::micro::GetTensorData<ParamsT>(params);
113   ParamsT* output_data = tflite::micro::GetTensorData<ParamsT>(output);
114 
115   int n_slices = 1;
116   for (int i = 0; i < indices_dims - 1; ++i) {
117     n_slices *= indices->dims->data[i];
118   }
119 
120   // If indices[-1] == params.rank, fetch single elements.
121   // If indices[-1] < params.rank, fetch slices.
122   int slice_size = 1;
123   for (int i = indices_nd; i < params_dims; ++i) {
124     slice_size *= params->dims->data[i];
125   }
126 
127   int remain_flat_size = ElementCount(*params->dims);
128 
129   // Number of elements per dimension
130   int dims_to_count[MAX_INDICES_ND];
131   for (int i = 0; i < indices_nd; ++i) {
132     dims_to_count[i] = remain_flat_size / params->dims->data[i];
133     remain_flat_size = dims_to_count[i];
134   }
135 
136   for (int i = 0; i < n_slices; ++i) {
137     int from_pos = 0;
138     for (int j = 0; j < indices_nd; ++j) {
139       int offset = i * indices_nd + j;
140       IndicesT index = index_data[offset];
141       from_pos += index * dims_to_count[j];
142     }
143     std::memcpy(output_data + i * slice_size, param_data + from_pos,
144                 sizeof(ParamsT) * slice_size);
145   }
146   return kTfLiteOk;
147 }
148 
149 template <typename IndicesT>
EvalGatherNd(TfLiteContext * context,const TfLiteEvalTensor * params,const TfLiteEvalTensor * indices,TfLiteEvalTensor * output)150 TfLiteStatus EvalGatherNd(TfLiteContext* context,
151                           const TfLiteEvalTensor* params,
152                           const TfLiteEvalTensor* indices,
153                           TfLiteEvalTensor* output) {
154   switch (params->type) {
155     case kTfLiteFloat32:
156       return GatherNd<float, IndicesT>(params, indices, output);
157       break;
158     case kTfLiteInt8:
159       return GatherNd<int8_t, IndicesT>(params, indices, output);
160       break;
161     default:
162       TF_LITE_KERNEL_LOG(context,
163                          "Params type '%s' are not supported by gather_nd.",
164                          TfLiteTypeGetName(params->type));
165       return kTfLiteError;
166   }
167 }
168 
Eval(TfLiteContext * context,TfLiteNode * node)169 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
170   const TfLiteEvalTensor* params =
171       tflite::micro::GetEvalInput(context, node, kParams);
172   const TfLiteEvalTensor* indices =
173       tflite::micro::GetEvalInput(context, node, kIndices);
174   TfLiteEvalTensor* output =
175       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
176 
177   switch (indices->type) {
178     case kTfLiteInt32:
179       return EvalGatherNd<int32_t>(context, params, indices, output);
180       break;
181     default:
182       TF_LITE_KERNEL_LOG(context,
183                          "Indices of type '%s' are not supported by gather_nd.",
184                          TfLiteTypeGetName(indices->type));
185       return kTfLiteError;
186   }
187 }
188 }  // namespace
189 
Register_GATHER_ND()190 TfLiteRegistration Register_GATHER_ND() {
191   return {/*init=*/nullptr,
192           /*free=*/nullptr,
193           /*prepare=*/Prepare,
194           /*invoke=*/Eval,
195           /*profiling_string=*/nullptr,
196           /*builtin_code=*/0,
197           /*custom_name=*/nullptr,
198           /*version=*/0};
199 }
200 
201 }  // namespace tflite
202