1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
16
17 #include <cmath>
18 #include <cstring>
19
20 #include "tensorflow/lite/c/builtin_op_data.h"
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 #include "tensorflow/lite/kernels/op_macros.h"
25 #include "tensorflow/lite/micro/kernels/kernel_util.h"
26
27 namespace tflite {
28 namespace ops {
29 namespace micro {
30 namespace strided_slice {
31
32 constexpr int kInputTensor = 0;
33 constexpr int kBeginTensor = 1;
34 constexpr int kEndTensor = 2;
35 constexpr int kStridesTensor = 3;
36 constexpr int kOutputTensor = 0;
37
38 struct StridedSliceContext {
StridedSliceContexttflite::ops::micro::strided_slice::StridedSliceContext39 StridedSliceContext(TfLiteContext* context, TfLiteNode* node) {
40 params = reinterpret_cast<TfLiteStridedSliceParams*>(node->builtin_data);
41 input = GetInput(context, node, kInputTensor);
42 begin = GetInput(context, node, kBeginTensor);
43 end = GetInput(context, node, kEndTensor);
44 strides = GetInput(context, node, kStridesTensor);
45 output = GetOutput(context, node, kOutputTensor);
46 dims = NumDimensions(input);
47 }
48 const TfLiteStridedSliceParams* params;
49 const TfLiteTensor* input;
50 const TfLiteTensor* begin;
51 const TfLiteTensor* end;
52 const TfLiteTensor* strides;
53 TfLiteTensor* output;
54 int dims;
55 };
56
57 // This Op only supports 1-4D cases and since we use the reference 4D
58 // implementation, the 1-3D tensors are mapped to 4D.
59 const int kMaxDim = 4;
60
BuildStridedSliceParams(StridedSliceContext * op_context)61 tflite::StridedSliceParams BuildStridedSliceParams(
62 StridedSliceContext* op_context) {
63 tflite::StridedSliceParams op_params;
64 op_params.start_indices_count = op_context->dims;
65 op_params.stop_indices_count = op_context->dims;
66 op_params.strides_count = op_context->dims;
67
68 for (int i = 0; i < op_context->dims; ++i) {
69 op_params.start_indices[i] = GetTensorData<int32_t>(op_context->begin)[i];
70 op_params.stop_indices[i] = GetTensorData<int32_t>(op_context->end)[i];
71 op_params.strides[i] = GetTensorData<int32_t>(op_context->strides)[i];
72 }
73
74 op_params.begin_mask = op_context->params->begin_mask;
75 op_params.ellipsis_mask = 0;
76 op_params.end_mask = op_context->params->end_mask;
77 op_params.new_axis_mask = 0;
78 op_params.shrink_axis_mask = op_context->params->shrink_axis_mask;
79 return op_params;
80 }
81
82 // Processes the indexing tensors (begin, end and strides) to resize the
83 // output tensor. This function is callable from both Prepare() and Eval() as
84 // long as the caller ensures the indexing tensors are present.
CheckOutputSize(TfLiteContext * context,StridedSliceContext * op_context)85 TfLiteStatus CheckOutputSize(TfLiteContext* context,
86 StridedSliceContext* op_context) {
87 using ::tflite::strided_slice::StartForAxis;
88 using ::tflite::strided_slice::StopForAxis;
89 TfLiteIntArray* output_shape = op_context->output->dims;
90 int shape_size = 0;
91 auto op_params = BuildStridedSliceParams(op_context);
92 auto input_shape = GetTensorShape(op_context->input);
93 for (int idx = 0; idx < op_context->dims; ++idx) {
94 int32_t stride = GetTensorData<int32_t>(op_context->strides)[idx];
95 TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero");
96 int32_t begin = StartForAxis(op_params, input_shape, idx);
97 int32_t end = StopForAxis(op_params, input_shape, idx, begin);
98
99 // When shrinking an axis, the end position does not matter (and can be
100 // incorrect when negative indexing is used, see Issue #19260). Always use
101 // begin + 1 to generate a length 1 slice, since begin has
102 // already been adjusted for negative indices by StartForAxis.
103 const bool shrink_axis = op_context->params->shrink_axis_mask & (1 << idx);
104 if (shrink_axis) {
105 end = begin + 1;
106 }
107
108 // This is valid for both positive and negative strides
109 int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
110 dim_shape = dim_shape < 0 ? 0 : dim_shape;
111 if (!shrink_axis) {
112 TF_LITE_ENSURE_EQ(context, output_shape->data[shape_size], dim_shape);
113 shape_size++;
114 }
115 }
116 TF_LITE_ENSURE_EQ(context, output_shape->size, shape_size);
117 return kTfLiteOk;
118 }
119
Init(TfLiteContext * context,const char * buffer,size_t length)120 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
121 TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
122 return context->AllocatePersistentBuffer(context, sizeof(StridedSliceParams));
123 }
124
Prepare(TfLiteContext * context,TfLiteNode * node)125 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
126 TFLITE_DCHECK(node->user_data != nullptr);
127 StridedSliceParams* op_params =
128 static_cast<StridedSliceParams*>(node->user_data);
129 TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
130 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
131 StridedSliceContext op_context(context, node);
132 TF_LITE_ENSURE_MSG(context, op_context.dims <= kMaxDim,
133 "input dim should not exceed 4");
134 auto params = BuildStridedSliceParams(&op_context);
135 memcpy(op_params, ¶ms, sizeof(StridedSliceParams));
136 return CheckOutputSize(context, &op_context);
137 }
138
Eval(TfLiteContext * context,TfLiteNode * node)139 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
140 TFLITE_DCHECK(node->user_data != nullptr);
141 const StridedSliceParams& op_params =
142 *(static_cast<const StridedSliceParams*>(node->user_data));
143
144 const TfLiteEvalTensor* input =
145 tflite::micro::GetEvalInput(context, node, kInputTensor);
146 TfLiteEvalTensor* output =
147 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
148 switch (output->type) {
149 case kTfLiteFloat32:
150 reference_ops::StridedSlice(op_params,
151 tflite::micro::GetTensorShape(input),
152 tflite::micro::GetTensorData<float>(input),
153 tflite::micro::GetTensorShape(output),
154 tflite::micro::GetTensorData<float>(output));
155 break;
156 case kTfLiteUInt8:
157 reference_ops::StridedSlice(
158 op_params, tflite::micro::GetTensorShape(input),
159 tflite::micro::GetTensorData<uint8_t>(input),
160 tflite::micro::GetTensorShape(output),
161 tflite::micro::GetTensorData<uint8_t>(output));
162 break;
163 case kTfLiteInt8:
164 reference_ops::StridedSlice(op_params,
165 tflite::micro::GetTensorShape(input),
166 tflite::micro::GetTensorData<int8_t>(input),
167 tflite::micro::GetTensorShape(output),
168 tflite::micro::GetTensorData<int8_t>(output));
169 break;
170 case kTfLiteInt16:
171 reference_ops::StridedSlice(
172 op_params, tflite::micro::GetTensorShape(input),
173 tflite::micro::GetTensorData<int16_t>(input),
174 tflite::micro::GetTensorShape(output),
175 tflite::micro::GetTensorData<int16_t>(output));
176 break;
177 default:
178 TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
179 TfLiteTypeGetName(input->type), input->type);
180 return kTfLiteError;
181 }
182 return kTfLiteOk;
183 }
184 } // namespace strided_slice
185
Register_STRIDED_SLICE()186 TfLiteRegistration Register_STRIDED_SLICE() {
187 return {/*init=*/strided_slice::Init,
188 /*free=*/nullptr,
189 /*prepare=*/strided_slice::Prepare,
190 /*invoke=*/strided_slice::Eval,
191 /*profiling_string=*/nullptr,
192 /*builtin_code=*/0,
193 /*custom_name=*/nullptr,
194 /*version=*/0};
195 }
196
197 } // namespace micro
198 } // namespace ops
199 } // namespace tflite
200