1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/kernels/internal/reference/sub.h"
17 
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/common.h"
21 #include "tensorflow/lite/kernels/internal/quantization_util.h"
22 #include "tensorflow/lite/kernels/internal/reference/add.h"
23 #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
24 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25 #include "tensorflow/lite/kernels/internal/types.h"
26 #include "tensorflow/lite/kernels/kernel_util.h"
27 #include "tensorflow/lite/kernels/op_macros.h"
28 #include "tensorflow/lite/micro/kernels/kernel_util.h"
29 
30 namespace tflite {
31 namespace ops {
32 namespace micro {
33 namespace sub {
34 
35 constexpr int kInputTensor1 = 0;
36 constexpr int kInputTensor2 = 1;
37 constexpr int kOutputTensor = 0;
38 
39 struct OpData {
40   bool requires_broadcast;
41 
42   // These fields are used in both the general 8-bit -> 8bit quantized path,
43   // and the special 16-bit -> 16bit quantized path
44   int input1_shift;
45   int input2_shift;
46   int32_t output_activation_min;
47   int32_t output_activation_max;
48 
49   // These fields are used only in the general 8-bit -> 8bit quantized path
50   int32_t input1_multiplier;
51   int32_t input2_multiplier;
52   int32_t output_multiplier;
53   int output_shift;
54   int left_shift;
55   int32_t input1_offset;
56   int32_t input2_offset;
57   int32_t output_offset;
58 };
59 
CalculateOpData(TfLiteContext * context,TfLiteSubParams * params,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output,OpData * data)60 TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteSubParams* params,
61                              const TfLiteTensor* input1,
62                              const TfLiteTensor* input2, TfLiteTensor* output,
63                              OpData* data) {
64   data->requires_broadcast = !HaveSameShapes(input1, input2);
65 
66   if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ||
67       output->type == kTfLiteInt16) {
68     // 8bit -> 8bit general quantized path, with general rescalings
69     data->input1_offset = -input1->params.zero_point;
70     data->input2_offset = -input2->params.zero_point;
71     data->output_offset = output->params.zero_point;
72 
73     // The shift is set to 15 in case of 16-bit and 20 in case of 8-bit,
74     // accordingly. In case of 16-bit we have 65535 << 15 which is less than 1
75     // << 31, therefore the addition will still fit in a 32 bit accumulator.
76     data->left_shift = output->type == kTfLiteInt16 ? 15 : 20;
77     const float twice_max_input_scale =
78         2 * std::max(input1->params.scale, input2->params.scale);
79     const double real_input1_multiplier =
80         static_cast<double>(input1->params.scale / twice_max_input_scale);
81     const double real_input2_multiplier =
82         static_cast<double>(input2->params.scale / twice_max_input_scale);
83     const double real_output_multiplier =
84         static_cast<double>(twice_max_input_scale /
85                             ((1 << data->left_shift) * output->params.scale));
86 
87     QuantizeMultiplierSmallerThanOneExp(
88         real_input1_multiplier, &data->input1_multiplier, &data->input1_shift);
89 
90     QuantizeMultiplierSmallerThanOneExp(
91         real_input2_multiplier, &data->input2_multiplier, &data->input2_shift);
92 
93     // Use add kernel for 16-bit sub, since it supports output requantization.
94     // This matches behavior in TFLite.
95     data->input2_multiplier *= (output->type == kTfLiteInt16) ? -1 : 1;
96     QuantizeMultiplierSmallerThanOneExp(
97         real_output_multiplier, &data->output_multiplier, &data->output_shift);
98 
99     TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
100         context, params->activation, output, &data->output_activation_min,
101         &data->output_activation_max));
102   }
103 
104   return kTfLiteOk;
105 }
106 
Init(TfLiteContext * context,const char * buffer,size_t length)107 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
108   TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
109   return context->AllocatePersistentBuffer(context, sizeof(OpData));
110 }
111 
Prepare(TfLiteContext * context,TfLiteNode * node)112 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
113   TFLITE_DCHECK(node->user_data != nullptr);
114   TFLITE_DCHECK(node->builtin_data != nullptr);
115 
116   OpData* data = static_cast<OpData*>(node->user_data);
117   auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
118 
119   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
120   TF_LITE_ENSURE(context, input1 != nullptr);
121   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
122   TF_LITE_ENSURE(context, input2 != nullptr);
123   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
124   TF_LITE_ENSURE(context, output != nullptr);
125 
126   TF_LITE_ENSURE_STATUS(
127       CalculateOpData(context, params, input1, input2, output, data));
128   return kTfLiteOk;
129 }
130 
EvalSub(TfLiteContext * context,TfLiteNode * node,TfLiteSubParams * params,const OpData * data,const TfLiteEvalTensor * input1,const TfLiteEvalTensor * input2,TfLiteEvalTensor * output)131 void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
132              const OpData* data, const TfLiteEvalTensor* input1,
133              const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) {
134   float output_activation_min, output_activation_max;
135   CalculateActivationRange(params->activation, &output_activation_min,
136                            &output_activation_max);
137   tflite::ArithmeticParams op_params;
138   SetActivationParams(output_activation_min, output_activation_max, &op_params);
139   if (data->requires_broadcast) {
140     tflite::reference_ops::BroadcastSubSlow(
141         op_params, tflite::micro::GetTensorShape(input1),
142         tflite::micro::GetTensorData<float>(input1),
143         tflite::micro::GetTensorShape(input2),
144         tflite::micro::GetTensorData<float>(input2),
145         tflite::micro::GetTensorShape(output),
146         tflite::micro::GetTensorData<float>(output));
147   } else {
148     tflite::reference_ops::SubWithActivation(
149         op_params, tflite::micro::GetTensorShape(input1),
150         tflite::micro::GetTensorData<float>(input1),
151         tflite::micro::GetTensorShape(input2),
152         tflite::micro::GetTensorData<float>(input2),
153         tflite::micro::GetTensorShape(output),
154         tflite::micro::GetTensorData<float>(output));
155   }
156 }
157 
EvalSubQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteSubParams * params,const OpData * data,const TfLiteEvalTensor * input1,const TfLiteEvalTensor * input2,TfLiteEvalTensor * output)158 TfLiteStatus EvalSubQuantized(TfLiteContext* context, TfLiteNode* node,
159                               TfLiteSubParams* params, const OpData* data,
160                               const TfLiteEvalTensor* input1,
161                               const TfLiteEvalTensor* input2,
162                               TfLiteEvalTensor* output) {
163   tflite::ArithmeticParams op_params;
164   op_params.left_shift = data->left_shift;
165   op_params.input1_offset = data->input1_offset;
166   op_params.input1_multiplier = data->input1_multiplier;
167   op_params.input1_shift = data->input1_shift;
168   op_params.input2_offset = data->input2_offset;
169   op_params.input2_multiplier = data->input2_multiplier;
170   op_params.input2_shift = data->input2_shift;
171   op_params.output_offset = data->output_offset;
172   op_params.output_multiplier = data->output_multiplier;
173   op_params.output_shift = data->output_shift;
174   SetActivationParams(data->output_activation_min, data->output_activation_max,
175                       &op_params);
176   bool need_broadcast = reference_ops::ProcessBroadcastShapes(
177       tflite::micro::GetTensorShape(input1),
178       tflite::micro::GetTensorShape(input2), &op_params);
179 
180   switch (output->type) {
181     case kTfLiteInt8: {
182       if (need_broadcast) {
183         tflite::reference_ops::BroadcastSubSlow(
184             op_params, tflite::micro::GetTensorShape(input1),
185             tflite::micro::GetTensorData<int8_t>(input1),
186             tflite::micro::GetTensorShape(input2),
187             tflite::micro::GetTensorData<int8_t>(input2),
188             tflite::micro::GetTensorShape(output),
189             tflite::micro::GetTensorData<int8_t>(output));
190       } else {
191         tflite::reference_ops::Sub(
192             op_params, tflite::micro::GetTensorShape(input1),
193             tflite::micro::GetTensorData<int8_t>(input1),
194             tflite::micro::GetTensorShape(input2),
195             tflite::micro::GetTensorData<int8_t>(input2),
196             tflite::micro::GetTensorShape(output),
197             tflite::micro::GetTensorData<int8_t>(output));
198       }
199       break;
200     }
201     case kTfLiteInt16: {
202       if (need_broadcast) {
203         tflite::reference_ops::BroadcastAdd4DSlow(
204             op_params, tflite::micro::GetTensorShape(input1),
205             tflite::micro::GetTensorData<int16_t>(input1),
206             tflite::micro::GetTensorShape(input2),
207             tflite::micro::GetTensorData<int16_t>(input2),
208             tflite::micro::GetTensorShape(output),
209             tflite::micro::GetTensorData<int16_t>(output));
210       } else {
211         tflite::reference_ops::Add(
212             op_params, tflite::micro::GetTensorShape(input1),
213             tflite::micro::GetTensorData<int16_t>(input1),
214             tflite::micro::GetTensorShape(input2),
215             tflite::micro::GetTensorData<int16_t>(input2),
216             tflite::micro::GetTensorShape(output),
217             tflite::micro::GetTensorData<int16_t>(output), false);
218       }
219       break;
220     }
221     case kTfLiteUInt8: {
222       if (need_broadcast) {
223         tflite::reference_ops::BroadcastSubSlow(
224             op_params, tflite::micro::GetTensorShape(input1),
225             tflite::micro::GetTensorData<uint8_t>(input1),
226             tflite::micro::GetTensorShape(input2),
227             tflite::micro::GetTensorData<uint8_t>(input2),
228             tflite::micro::GetTensorShape(output),
229             tflite::micro::GetTensorData<uint8_t>(output));
230       } else {
231         tflite::reference_ops::Sub(
232             op_params, tflite::micro::GetTensorShape(input1),
233             tflite::micro::GetTensorData<uint8_t>(input1),
234             tflite::micro::GetTensorShape(input2),
235             tflite::micro::GetTensorData<uint8_t>(input2),
236             tflite::micro::GetTensorShape(output),
237             tflite::micro::GetTensorData<uint8_t>(output));
238       }
239       break;
240     }
241     default:
242       TF_LITE_KERNEL_LOG(context, "Quantized type %s not currently supported.",
243                          TfLiteTypeGetName(output->type));
244       return kTfLiteError;
245   }
246   return kTfLiteOk;
247 }
248 
Eval(TfLiteContext * context,TfLiteNode * node)249 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
250   auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
251 
252   const TfLiteEvalTensor* input1 =
253       tflite::micro::GetEvalInput(context, node, kInputTensor1);
254   const TfLiteEvalTensor* input2 =
255       tflite::micro::GetEvalInput(context, node, kInputTensor2);
256   TfLiteEvalTensor* output =
257       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
258 
259   TFLITE_DCHECK(node->user_data != nullptr);
260   const OpData& data = *(static_cast<const OpData*>(node->user_data));
261 
262   if (output->type == kTfLiteFloat32) {
263     EvalSub(context, node, params, &data, input1, input2, output);
264   } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ||
265              output->type == kTfLiteInt16) {
266     TF_LITE_ENSURE_OK(context, EvalSubQuantized(context, node, params, &data,
267                                                 input1, input2, output));
268   } else {
269     TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
270                        TfLiteTypeGetName(output->type), output->type);
271     return kTfLiteError;
272   }
273 
274   return kTfLiteOk;
275 }
276 
277 }  // namespace sub
278 
Register_SUB()279 TfLiteRegistration Register_SUB() {
280   return {/*init=*/sub::Init,
281           /*free=*/nullptr,
282           /*prepare=*/sub::Prepare,
283           /*invoke=*/sub::Eval,
284           /*profiling_string=*/nullptr,
285           /*builtin_code=*/0,
286           /*custom_name=*/nullptr,
287           /*version=*/0};
288 }
289 
290 }  // namespace micro
291 }  // namespace ops
292 }  // namespace tflite
293