1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
16 
17 #include <cmath>
18 #include <cstring>
19 
20 #include "tensorflow/lite/c/builtin_op_data.h"
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 #include "tensorflow/lite/kernels/op_macros.h"
25 #include "tensorflow/lite/micro/kernels/kernel_util.h"
26 
27 namespace tflite {
28 namespace ops {
29 namespace micro {
30 namespace strided_slice {
31 
32 constexpr int kInputTensor = 0;
33 constexpr int kBeginTensor = 1;
34 constexpr int kEndTensor = 2;
35 constexpr int kStridesTensor = 3;
36 constexpr int kOutputTensor = 0;
37 
38 struct StridedSliceContext {
StridedSliceContexttflite::ops::micro::strided_slice::StridedSliceContext39   StridedSliceContext(TfLiteContext* context, TfLiteNode* node) {
40     params = reinterpret_cast<TfLiteStridedSliceParams*>(node->builtin_data);
41     input = GetInput(context, node, kInputTensor);
42     begin = GetInput(context, node, kBeginTensor);
43     end = GetInput(context, node, kEndTensor);
44     strides = GetInput(context, node, kStridesTensor);
45     output = GetOutput(context, node, kOutputTensor);
46     dims = NumDimensions(input);
47   }
48   const TfLiteStridedSliceParams* params;
49   const TfLiteTensor* input;
50   const TfLiteTensor* begin;
51   const TfLiteTensor* end;
52   const TfLiteTensor* strides;
53   TfLiteTensor* output;
54   int dims;
55 };
56 
57 // This Op only supports 1-4D cases and since we use the reference 4D
58 // implementation, the 1-3D tensors are mapped to 4D.
59 const int kMaxDim = 4;
60 
BuildStridedSliceParams(StridedSliceContext * op_context)61 tflite::StridedSliceParams BuildStridedSliceParams(
62     StridedSliceContext* op_context) {
63   tflite::StridedSliceParams op_params;
64   op_params.start_indices_count = op_context->dims;
65   op_params.stop_indices_count = op_context->dims;
66   op_params.strides_count = op_context->dims;
67 
68   for (int i = 0; i < op_context->dims; ++i) {
69     op_params.start_indices[i] = GetTensorData<int32_t>(op_context->begin)[i];
70     op_params.stop_indices[i] = GetTensorData<int32_t>(op_context->end)[i];
71     op_params.strides[i] = GetTensorData<int32_t>(op_context->strides)[i];
72   }
73 
74   op_params.begin_mask = op_context->params->begin_mask;
75   op_params.ellipsis_mask = 0;
76   op_params.end_mask = op_context->params->end_mask;
77   op_params.new_axis_mask = 0;
78   op_params.shrink_axis_mask = op_context->params->shrink_axis_mask;
79   return op_params;
80 }
81 
82 // Processes the indexing tensors (begin, end and strides) to resize the
83 // output tensor. This function is callable from both Prepare() and Eval() as
84 // long as the caller ensures the indexing tensors are present.
CheckOutputSize(TfLiteContext * context,StridedSliceContext * op_context)85 TfLiteStatus CheckOutputSize(TfLiteContext* context,
86                              StridedSliceContext* op_context) {
87   using ::tflite::strided_slice::StartForAxis;
88   using ::tflite::strided_slice::StopForAxis;
89   TfLiteIntArray* output_shape = op_context->output->dims;
90   int shape_size = 0;
91   auto op_params = BuildStridedSliceParams(op_context);
92   auto input_shape = GetTensorShape(op_context->input);
93   for (int idx = 0; idx < op_context->dims; ++idx) {
94     int32_t stride = GetTensorData<int32_t>(op_context->strides)[idx];
95     TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero");
96     int32_t begin = StartForAxis(op_params, input_shape, idx);
97     int32_t end = StopForAxis(op_params, input_shape, idx, begin);
98 
99     // When shrinking an axis, the end position does not matter (and can be
100     // incorrect when negative indexing is used, see Issue #19260). Always use
101     // begin + 1 to generate a length 1 slice, since begin has
102     // already been adjusted for negative indices by StartForAxis.
103     const bool shrink_axis = op_context->params->shrink_axis_mask & (1 << idx);
104     if (shrink_axis) {
105       end = begin + 1;
106     }
107 
108     // This is valid for both positive and negative strides
109     int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
110     dim_shape = dim_shape < 0 ? 0 : dim_shape;
111     if (!shrink_axis) {
112       TF_LITE_ENSURE_EQ(context, output_shape->data[shape_size], dim_shape);
113       shape_size++;
114     }
115   }
116   TF_LITE_ENSURE_EQ(context, output_shape->size, shape_size);
117   return kTfLiteOk;
118 }
119 
Init(TfLiteContext * context,const char * buffer,size_t length)120 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
121   TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
122   return context->AllocatePersistentBuffer(context, sizeof(StridedSliceParams));
123 }
124 
Prepare(TfLiteContext * context,TfLiteNode * node)125 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
126   TFLITE_DCHECK(node->user_data != nullptr);
127   StridedSliceParams* op_params =
128       static_cast<StridedSliceParams*>(node->user_data);
129   TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
130   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
131   StridedSliceContext op_context(context, node);
132   TF_LITE_ENSURE_MSG(context, op_context.dims <= kMaxDim,
133                      "input dim should not exceed 4");
134   auto params = BuildStridedSliceParams(&op_context);
135   memcpy(op_params, &params, sizeof(StridedSliceParams));
136   return CheckOutputSize(context, &op_context);
137 }
138 
Eval(TfLiteContext * context,TfLiteNode * node)139 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
140   TFLITE_DCHECK(node->user_data != nullptr);
141   const StridedSliceParams& op_params =
142       *(static_cast<const StridedSliceParams*>(node->user_data));
143 
144   const TfLiteEvalTensor* input =
145       tflite::micro::GetEvalInput(context, node, kInputTensor);
146   TfLiteEvalTensor* output =
147       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
148   switch (output->type) {
149     case kTfLiteFloat32:
150       reference_ops::StridedSlice(op_params,
151                                   tflite::micro::GetTensorShape(input),
152                                   tflite::micro::GetTensorData<float>(input),
153                                   tflite::micro::GetTensorShape(output),
154                                   tflite::micro::GetTensorData<float>(output));
155       break;
156     case kTfLiteUInt8:
157       reference_ops::StridedSlice(
158           op_params, tflite::micro::GetTensorShape(input),
159           tflite::micro::GetTensorData<uint8_t>(input),
160           tflite::micro::GetTensorShape(output),
161           tflite::micro::GetTensorData<uint8_t>(output));
162       break;
163     case kTfLiteInt8:
164       reference_ops::StridedSlice(op_params,
165                                   tflite::micro::GetTensorShape(input),
166                                   tflite::micro::GetTensorData<int8_t>(input),
167                                   tflite::micro::GetTensorShape(output),
168                                   tflite::micro::GetTensorData<int8_t>(output));
169       break;
170     case kTfLiteInt16:
171       reference_ops::StridedSlice(
172           op_params, tflite::micro::GetTensorShape(input),
173           tflite::micro::GetTensorData<int16_t>(input),
174           tflite::micro::GetTensorShape(output),
175           tflite::micro::GetTensorData<int16_t>(output));
176       break;
177     default:
178       TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
179                          TfLiteTypeGetName(input->type), input->type);
180       return kTfLiteError;
181   }
182   return kTfLiteOk;
183 }
184 }  // namespace strided_slice
185 
Register_STRIDED_SLICE()186 TfLiteRegistration Register_STRIDED_SLICE() {
187   return {/*init=*/strided_slice::Init,
188           /*free=*/nullptr,
189           /*prepare=*/strided_slice::Prepare,
190           /*invoke=*/strided_slice::Eval,
191           /*profiling_string=*/nullptr,
192           /*builtin_code=*/0,
193           /*custom_name=*/nullptr,
194           /*version=*/0};
195 }
196 
197 }  // namespace micro
198 }  // namespace ops
199 }  // namespace tflite
200