1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/c/common.h"
17 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
18 #include "tensorflow/lite/kernels/kernel_util.h"
19 #include "tensorflow/lite/micro/kernels/kernel_util.h"
20 #include "tensorflow/lite/micro/micro_utils.h"
21
22 namespace tflite {
23 namespace {
24
25 constexpr int kParams = 0;
26 constexpr int kIndices = 1;
27 constexpr int kOutputTensor = 0;
28 constexpr int MAX_INDICES_ND = 5;
29
Prepare(TfLiteContext * context,TfLiteNode * node)30 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
31 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
32 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
33
34 const TfLiteTensor* params;
35 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kParams, ¶ms));
36 const TfLiteTensor* indices;
37 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
38 TfLiteTensor* output;
39 TF_LITE_ENSURE_OK(context,
40 GetOutputSafe(context, node, kOutputTensor, &output));
41
42 switch (params->type) {
43 case kTfLiteFloat32:
44 case kTfLiteInt8:
45 break;
46 default:
47 TF_LITE_KERNEL_LOG(context,
48 "Params of type '%s' are not supported by gather_nd.",
49 TfLiteTypeGetName(params->type));
50 return kTfLiteError;
51 break;
52 }
53 switch (indices->type) {
54 case kTfLiteInt32:
55 break;
56 default:
57 TF_LITE_KERNEL_LOG(context,
58 "Indices of type '%s' are not supported by gather_nd.",
59 TfLiteTypeGetName(indices->type));
60 return kTfLiteError;
61 }
62
63 const int params_rank = NumDimensions(params);
64 const int indices_rank = NumDimensions(indices);
65 const int indices_nd = SizeOfDimension(indices, indices_rank - 1);
66 if (params_rank < 1) {
67 TF_LITE_KERNEL_LOG(context, "Params must be at least a vector.");
68 return kTfLiteError;
69 }
70 if (indices_rank < 1) {
71 TF_LITE_KERNEL_LOG(context, "Indices must be at least a vector.");
72 return kTfLiteError;
73 }
74 if (indices_nd > params_rank) {
75 TF_LITE_KERNEL_LOG(
76 context, "Index innermost dimension length must be <= params rank.");
77 return kTfLiteError;
78 }
79 if (indices_nd > MAX_INDICES_ND) {
80 TF_LITE_KERNEL_LOG(context,
81 "Index innermost dimension length must not exceed %d.",
82 MAX_INDICES_ND);
83 return kTfLiteError;
84 }
85
86 // Assign to output the input type.
87 output->type = params->type;
88
89 // TFLM gather_nd does not create the output tensor, but it needs to ensure
90 // that the output shape is correct. The result shape is
91 // indices.shape[:-1] + params.shape[indices.shape[-1]:]
92 TfLiteIntArray* output_shape = output->dims;
93 int output_index = 0;
94 for (int i = 0; i < indices_rank - 1; ++i) {
95 output_shape->data[output_index++] = indices->dims->data[i];
96 }
97 for (int i = indices_nd; i < params_rank; ++i) {
98 output_shape->data[output_index++] = params->dims->data[i];
99 }
100 output_shape->size = output_index;
101 return kTfLiteOk;
102 }
103
104 template <typename ParamsT, typename IndicesT>
GatherNd(const TfLiteEvalTensor * params,const TfLiteEvalTensor * indices,TfLiteEvalTensor * output)105 TfLiteStatus GatherNd(const TfLiteEvalTensor* params,
106 const TfLiteEvalTensor* indices,
107 TfLiteEvalTensor* output) {
108 const int indices_dims = indices->dims->size;
109 const int indices_nd = indices->dims->data[indices_dims - 1];
110 const int params_dims = params->dims->size;
111 const IndicesT* index_data = tflite::micro::GetTensorData<IndicesT>(indices);
112 const ParamsT* param_data = tflite::micro::GetTensorData<ParamsT>(params);
113 ParamsT* output_data = tflite::micro::GetTensorData<ParamsT>(output);
114
115 int n_slices = 1;
116 for (int i = 0; i < indices_dims - 1; ++i) {
117 n_slices *= indices->dims->data[i];
118 }
119
120 // If indices[-1] == params.rank, fetch single elements.
121 // If indices[-1] < params.rank, fetch slices.
122 int slice_size = 1;
123 for (int i = indices_nd; i < params_dims; ++i) {
124 slice_size *= params->dims->data[i];
125 }
126
127 int remain_flat_size = ElementCount(*params->dims);
128
129 // Number of elements per dimension
130 int dims_to_count[MAX_INDICES_ND];
131 for (int i = 0; i < indices_nd; ++i) {
132 dims_to_count[i] = remain_flat_size / params->dims->data[i];
133 remain_flat_size = dims_to_count[i];
134 }
135
136 for (int i = 0; i < n_slices; ++i) {
137 int from_pos = 0;
138 for (int j = 0; j < indices_nd; ++j) {
139 int offset = i * indices_nd + j;
140 IndicesT index = index_data[offset];
141 from_pos += index * dims_to_count[j];
142 }
143 std::memcpy(output_data + i * slice_size, param_data + from_pos,
144 sizeof(ParamsT) * slice_size);
145 }
146 return kTfLiteOk;
147 }
148
149 template <typename IndicesT>
EvalGatherNd(TfLiteContext * context,const TfLiteEvalTensor * params,const TfLiteEvalTensor * indices,TfLiteEvalTensor * output)150 TfLiteStatus EvalGatherNd(TfLiteContext* context,
151 const TfLiteEvalTensor* params,
152 const TfLiteEvalTensor* indices,
153 TfLiteEvalTensor* output) {
154 switch (params->type) {
155 case kTfLiteFloat32:
156 return GatherNd<float, IndicesT>(params, indices, output);
157 break;
158 case kTfLiteInt8:
159 return GatherNd<int8_t, IndicesT>(params, indices, output);
160 break;
161 default:
162 TF_LITE_KERNEL_LOG(context,
163 "Params type '%s' are not supported by gather_nd.",
164 TfLiteTypeGetName(params->type));
165 return kTfLiteError;
166 }
167 }
168
Eval(TfLiteContext * context,TfLiteNode * node)169 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
170 const TfLiteEvalTensor* params =
171 tflite::micro::GetEvalInput(context, node, kParams);
172 const TfLiteEvalTensor* indices =
173 tflite::micro::GetEvalInput(context, node, kIndices);
174 TfLiteEvalTensor* output =
175 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
176
177 switch (indices->type) {
178 case kTfLiteInt32:
179 return EvalGatherNd<int32_t>(context, params, indices, output);
180 break;
181 default:
182 TF_LITE_KERNEL_LOG(context,
183 "Indices of type '%s' are not supported by gather_nd.",
184 TfLiteTypeGetName(indices->type));
185 return kTfLiteError;
186 }
187 }
188 } // namespace
189
Register_GATHER_ND()190 TfLiteRegistration Register_GATHER_ND() {
191 return {/*init=*/nullptr,
192 /*free=*/nullptr,
193 /*prepare=*/Prepare,
194 /*invoke=*/Eval,
195 /*profiling_string=*/nullptr,
196 /*builtin_code=*/0,
197 /*custom_name=*/nullptr,
198 /*version=*/0};
199 }
200
201 } // namespace tflite
202