1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/kernels/internal/reference/sub.h"
17
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/common.h"
21 #include "tensorflow/lite/kernels/internal/quantization_util.h"
22 #include "tensorflow/lite/kernels/internal/reference/add.h"
23 #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
24 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25 #include "tensorflow/lite/kernels/internal/types.h"
26 #include "tensorflow/lite/kernels/kernel_util.h"
27 #include "tensorflow/lite/kernels/op_macros.h"
28 #include "tensorflow/lite/micro/kernels/kernel_util.h"
29
30 namespace tflite {
31 namespace ops {
32 namespace micro {
33 namespace sub {
34
35 constexpr int kInputTensor1 = 0;
36 constexpr int kInputTensor2 = 1;
37 constexpr int kOutputTensor = 0;
38
39 struct OpData {
40 bool requires_broadcast;
41
42 // These fields are used in both the general 8-bit -> 8bit quantized path,
43 // and the special 16-bit -> 16bit quantized path
44 int input1_shift;
45 int input2_shift;
46 int32_t output_activation_min;
47 int32_t output_activation_max;
48
49 // These fields are used only in the general 8-bit -> 8bit quantized path
50 int32_t input1_multiplier;
51 int32_t input2_multiplier;
52 int32_t output_multiplier;
53 int output_shift;
54 int left_shift;
55 int32_t input1_offset;
56 int32_t input2_offset;
57 int32_t output_offset;
58 };
59
CalculateOpData(TfLiteContext * context,TfLiteSubParams * params,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output,OpData * data)60 TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteSubParams* params,
61 const TfLiteTensor* input1,
62 const TfLiteTensor* input2, TfLiteTensor* output,
63 OpData* data) {
64 data->requires_broadcast = !HaveSameShapes(input1, input2);
65
66 if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ||
67 output->type == kTfLiteInt16) {
68 // 8bit -> 8bit general quantized path, with general rescalings
69 data->input1_offset = -input1->params.zero_point;
70 data->input2_offset = -input2->params.zero_point;
71 data->output_offset = output->params.zero_point;
72
73 // The shift is set to 15 in case of 16-bit and 20 in case of 8-bit,
74 // accordingly. In case of 16-bit we have 65535 << 15 which is less than 1
75 // << 31, therefore the addition will still fit in a 32 bit accumulator.
76 data->left_shift = output->type == kTfLiteInt16 ? 15 : 20;
77 const float twice_max_input_scale =
78 2 * std::max(input1->params.scale, input2->params.scale);
79 const double real_input1_multiplier =
80 static_cast<double>(input1->params.scale / twice_max_input_scale);
81 const double real_input2_multiplier =
82 static_cast<double>(input2->params.scale / twice_max_input_scale);
83 const double real_output_multiplier =
84 static_cast<double>(twice_max_input_scale /
85 ((1 << data->left_shift) * output->params.scale));
86
87 QuantizeMultiplierSmallerThanOneExp(
88 real_input1_multiplier, &data->input1_multiplier, &data->input1_shift);
89
90 QuantizeMultiplierSmallerThanOneExp(
91 real_input2_multiplier, &data->input2_multiplier, &data->input2_shift);
92
93 // Use add kernel for 16-bit sub, since it supports output requantization.
94 // This matches behavior in TFLite.
95 data->input2_multiplier *= (output->type == kTfLiteInt16) ? -1 : 1;
96 QuantizeMultiplierSmallerThanOneExp(
97 real_output_multiplier, &data->output_multiplier, &data->output_shift);
98
99 TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
100 context, params->activation, output, &data->output_activation_min,
101 &data->output_activation_max));
102 }
103
104 return kTfLiteOk;
105 }
106
Init(TfLiteContext * context,const char * buffer,size_t length)107 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
108 TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
109 return context->AllocatePersistentBuffer(context, sizeof(OpData));
110 }
111
Prepare(TfLiteContext * context,TfLiteNode * node)112 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
113 TFLITE_DCHECK(node->user_data != nullptr);
114 TFLITE_DCHECK(node->builtin_data != nullptr);
115
116 OpData* data = static_cast<OpData*>(node->user_data);
117 auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
118
119 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
120 TF_LITE_ENSURE(context, input1 != nullptr);
121 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
122 TF_LITE_ENSURE(context, input2 != nullptr);
123 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
124 TF_LITE_ENSURE(context, output != nullptr);
125
126 TF_LITE_ENSURE_STATUS(
127 CalculateOpData(context, params, input1, input2, output, data));
128 return kTfLiteOk;
129 }
130
EvalSub(TfLiteContext * context,TfLiteNode * node,TfLiteSubParams * params,const OpData * data,const TfLiteEvalTensor * input1,const TfLiteEvalTensor * input2,TfLiteEvalTensor * output)131 void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
132 const OpData* data, const TfLiteEvalTensor* input1,
133 const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) {
134 float output_activation_min, output_activation_max;
135 CalculateActivationRange(params->activation, &output_activation_min,
136 &output_activation_max);
137 tflite::ArithmeticParams op_params;
138 SetActivationParams(output_activation_min, output_activation_max, &op_params);
139 if (data->requires_broadcast) {
140 tflite::reference_ops::BroadcastSubSlow(
141 op_params, tflite::micro::GetTensorShape(input1),
142 tflite::micro::GetTensorData<float>(input1),
143 tflite::micro::GetTensorShape(input2),
144 tflite::micro::GetTensorData<float>(input2),
145 tflite::micro::GetTensorShape(output),
146 tflite::micro::GetTensorData<float>(output));
147 } else {
148 tflite::reference_ops::SubWithActivation(
149 op_params, tflite::micro::GetTensorShape(input1),
150 tflite::micro::GetTensorData<float>(input1),
151 tflite::micro::GetTensorShape(input2),
152 tflite::micro::GetTensorData<float>(input2),
153 tflite::micro::GetTensorShape(output),
154 tflite::micro::GetTensorData<float>(output));
155 }
156 }
157
EvalSubQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteSubParams * params,const OpData * data,const TfLiteEvalTensor * input1,const TfLiteEvalTensor * input2,TfLiteEvalTensor * output)158 TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
159 TfLiteSubParams* params, const OpData* data,
160 const TfLiteEvalTensor* input1,
161 const TfLiteEvalTensor* input2,
162 TfLiteEvalTensor* output) {
163 tflite::ArithmeticParams op_params;
164 op_params.left_shift = data->left_shift;
165 op_params.input1_offset = data->input1_offset;
166 op_params.input1_multiplier = data->input1_multiplier;
167 op_params.input1_shift = data->input1_shift;
168 op_params.input2_offset = data->input2_offset;
169 op_params.input2_multiplier = data->input2_multiplier;
170 op_params.input2_shift = data->input2_shift;
171 op_params.output_offset = data->output_offset;
172 op_params.output_multiplier = data->output_multiplier;
173 op_params.output_shift = data->output_shift;
174 SetActivationParams(data->output_activation_min, data->output_activation_max,
175 &op_params);
176 bool need_broadcast = reference_ops::ProcessBroadcastShapes(
177 tflite::micro::GetTensorShape(input1),
178 tflite::micro::GetTensorShape(input2), &op_params);
179
180 switch (output->type) {
181 case kTfLiteInt8: {
182 if (need_broadcast) {
183 tflite::reference_ops::BroadcastSubSlow(
184 op_params, tflite::micro::GetTensorShape(input1),
185 tflite::micro::GetTensorData<int8_t>(input1),
186 tflite::micro::GetTensorShape(input2),
187 tflite::micro::GetTensorData<int8_t>(input2),
188 tflite::micro::GetTensorShape(output),
189 tflite::micro::GetTensorData<int8_t>(output));
190 } else {
191 tflite::reference_ops::Sub(
192 op_params, tflite::micro::GetTensorShape(input1),
193 tflite::micro::GetTensorData<int8_t>(input1),
194 tflite::micro::GetTensorShape(input2),
195 tflite::micro::GetTensorData<int8_t>(input2),
196 tflite::micro::GetTensorShape(output),
197 tflite::micro::GetTensorData<int8_t>(output));
198 }
199 break;
200 }
201 case kTfLiteInt16: {
202 if (need_broadcast) {
203 tflite::reference_ops::BroadcastAdd4DSlow(
204 op_params, tflite::micro::GetTensorShape(input1),
205 tflite::micro::GetTensorData<int16_t>(input1),
206 tflite::micro::GetTensorShape(input2),
207 tflite::micro::GetTensorData<int16_t>(input2),
208 tflite::micro::GetTensorShape(output),
209 tflite::micro::GetTensorData<int16_t>(output));
210 } else {
211 tflite::reference_ops::Add(
212 op_params, tflite::micro::GetTensorShape(input1),
213 tflite::micro::GetTensorData<int16_t>(input1),
214 tflite::micro::GetTensorShape(input2),
215 tflite::micro::GetTensorData<int16_t>(input2),
216 tflite::micro::GetTensorShape(output),
217 tflite::micro::GetTensorData<int16_t>(output), false);
218 }
219 break;
220 }
221 case kTfLiteUInt8: {
222 if (need_broadcast) {
223 tflite::reference_ops::BroadcastSubSlow(
224 op_params, tflite::micro::GetTensorShape(input1),
225 tflite::micro::GetTensorData<uint8_t>(input1),
226 tflite::micro::GetTensorShape(input2),
227 tflite::micro::GetTensorData<uint8_t>(input2),
228 tflite::micro::GetTensorShape(output),
229 tflite::micro::GetTensorData<uint8_t>(output));
230 } else {
231 tflite::reference_ops::Sub(
232 op_params, tflite::micro::GetTensorShape(input1),
233 tflite::micro::GetTensorData<uint8_t>(input1),
234 tflite::micro::GetTensorShape(input2),
235 tflite::micro::GetTensorData<uint8_t>(input2),
236 tflite::micro::GetTensorShape(output),
237 tflite::micro::GetTensorData<uint8_t>(output));
238 }
239 break;
240 }
241 default:
242 TF_LITE_KERNEL_LOG(context, "Quantized type %s not currently supported.",
243 TfLiteTypeGetName(output->type));
244 return kTfLiteError;
245 }
246 return kTfLiteOk;
247 }
248
Eval(TfLiteContext * context,TfLiteNode * node)249 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
250 auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
251
252 const TfLiteEvalTensor* input1 =
253 tflite::micro::GetEvalInput(context, node, kInputTensor1);
254 const TfLiteEvalTensor* input2 =
255 tflite::micro::GetEvalInput(context, node, kInputTensor2);
256 TfLiteEvalTensor* output =
257 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
258
259 TFLITE_DCHECK(node->user_data != nullptr);
260 const OpData& data = *(static_cast<const OpData*>(node->user_data));
261
262 if (output->type == kTfLiteFloat32) {
263 EvalSub(context, node, params, &data, input1, input2, output);
264 } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ||
265 output->type == kTfLiteInt16) {
266 TF_LITE_ENSURE_OK(context, EvalSubQuantized(context, node, params, &data,
267 input1, input2, output));
268 } else {
269 TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
270 TfLiteTypeGetName(output->type), output->type);
271 return kTfLiteError;
272 }
273
274 return kTfLiteOk;
275 }
276
277 } // namespace sub
278
Register_SUB()279 TfLiteRegistration Register_SUB() {
280 return {/*init=*/sub::Init,
281 /*free=*/nullptr,
282 /*prepare=*/sub::Prepare,
283 /*invoke=*/sub::Eval,
284 /*profiling_string=*/nullptr,
285 /*builtin_code=*/0,
286 /*custom_name=*/nullptr,
287 /*version=*/0};
288 }
289
290 } // namespace micro
291 } // namespace ops
292 } // namespace tflite
293