1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/kernels/internal/reference/add_n.h"
17 
18 #include <cstdint>
19 
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/kernels/internal/quantization_util.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 #include "tensorflow/lite/micro/kernels/kernel_util.h"
25 
26 namespace tflite {
27 namespace {
28 
29 constexpr int kInputTensor0 = 0;
30 constexpr int kOutputTensor = 0;
31 
32 constexpr int kAddNIntegerShift = 20;
33 
34 // only used with INT8 tensors
35 struct OpData {
36   int32_t output_activation_min;
37   int32_t output_activation_max;
38   int32_t input_offset;
39   int32_t output_offset;
40   int32_t input_multiplier;
41   int32_t output_multiplier;
42   int input_shift;
43   int output_shift;
44   int left_shift;
45   int scratch_index;
46 };
47 
CalculateOpData(TfLiteContext * context,TfLiteNode * node)48 TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
49   int num_inputs = NumInputs(node);
50   TF_LITE_ENSURE(context, num_inputs >= 2);
51   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
52 
53   const TfLiteTensor* input_tensor_first;
54   TF_LITE_ENSURE_OK(
55       context, GetInputSafe(context, node, kInputTensor0, &input_tensor_first));
56   TfLiteTensor* output;
57   TF_LITE_ENSURE_OK(context,
58                     GetOutputSafe(context, node, kOutputTensor, &output));
59 
60   // Check that all tensors have the same shape and type.
61   TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_tensor_first->type);
62   for (int i = kInputTensor0 + 1; i < num_inputs; ++i) {
63     const TfLiteTensor* input;
64     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
65     TF_LITE_ENSURE(context, HaveSameShapes(input_tensor_first, input));
66     TF_LITE_ENSURE_TYPES_EQ(context, input_tensor_first->type, input->type);
67 
68     // Check that all INT8 input tensors have the same zero-point and scale.
69     if (input_tensor_first->type == kTfLiteInt8) {
70       TF_LITE_ENSURE(context, input_tensor_first->params.zero_point ==
71                                   input->params.zero_point);
72       TF_LITE_ENSURE(context,
73                      input_tensor_first->params.scale == input->params.scale);
74     }
75   }
76 
77   if (output->type == kTfLiteFloat32) {
78     // Allocate scratch buffer space for pointer to each tensor's data
79     // and store the scratch buffer index in the node's user_data
80     int scratch_index;
81     size_t scratch_size = sizeof(float*) * num_inputs;
82     TF_LITE_ENSURE_OK(context, context->RequestScratchBufferInArena(
83                                    context, scratch_size, &scratch_index));
84     node->user_data =
85         reinterpret_cast<decltype(node->user_data)>(scratch_index);
86   } else if (output->type == kTfLiteInt8) {
87     node->user_data =
88         context->AllocatePersistentBuffer(context, sizeof(OpData));
89     OpData* data = static_cast<OpData*>(node->user_data);
90 
91     // Allocate scratch buffer space for pointer to each tensor's data
92     // and store the scratch buffer index in OpData
93     size_t scratch_size = sizeof(int8_t*) * num_inputs;
94     TF_LITE_ENSURE_OK(
95         context, context->RequestScratchBufferInArena(context, scratch_size,
96                                                       &data->scratch_index));
97 
98     // 8bit -> 8bit general quantized path, with general rescalings
99     data->input_offset = -input_tensor_first->params.zero_point;
100     data->output_offset = output->params.zero_point;
101     data->left_shift = kAddNIntegerShift;
102     const double twice_max_input_scale =
103         2 * static_cast<double>(input_tensor_first->params.scale);
104     const double real_input_multiplier =
105         static_cast<double>(input_tensor_first->params.scale) /
106         twice_max_input_scale;
107     const double real_output_multiplier =
108         twice_max_input_scale /
109         ((1 << data->left_shift) * static_cast<double>(output->params.scale));
110 
111     QuantizeMultiplierSmallerThanOneExp(
112         real_input_multiplier, &data->input_multiplier, &data->input_shift);
113 
114     QuantizeMultiplierSmallerThanOneExp(
115         real_output_multiplier, &data->output_multiplier, &data->output_shift);
116 
117     TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
118         context, kTfLiteActNone, output, &data->output_activation_min,
119         &data->output_activation_max));
120   } else {
121     TF_LITE_KERNEL_LOG(context, "ADD_N only supports FLOAT32 and INT8, got %s.",
122                        TfLiteTypeGetName(output->type));
123     return kTfLiteError;
124   }
125 
126   return kTfLiteOk;
127 }
128 
Prepare(TfLiteContext * context,TfLiteNode * node)129 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
130   return CalculateOpData(context, node);
131 }
132 
133 template <typename T>
CopyInputsToScratchBuffer(TfLiteContext * context,TfLiteNode * node,const int scratch_index)134 inline const T** CopyInputsToScratchBuffer(TfLiteContext* context,
135                                            TfLiteNode* node,
136                                            const int scratch_index) {
137   int num_inputs = NumInputs(node);
138   void* scratch_buffer = context->GetScratchBuffer(context, scratch_index);
139   const T** all_inputs = static_cast<decltype(all_inputs)>(scratch_buffer);
140   for (int i = 0; i < num_inputs; i++) {
141     const TfLiteEvalTensor* next_input =
142         tflite::micro::GetEvalInput(context, node, kInputTensor0 + i);
143     all_inputs[i] = tflite::micro::GetTensorData<T>(next_input);
144   }
145 
146   return all_inputs;
147 }
148 
149 template <typename T>
EvalAddN(TfLiteContext * context,TfLiteNode * node,TfLiteEvalTensor * output)150 void EvalAddN(TfLiteContext* context, TfLiteNode* node,
151               TfLiteEvalTensor* output) {
152   int num_inputs = NumInputs(node);
153 
154   int scratch_index =
155       static_cast<int>(reinterpret_cast<intptr_t>(node->user_data));
156   const T** all_inputs =
157       CopyInputsToScratchBuffer<T>(context, node, scratch_index);
158 
159   reference_ops::AddN<T>(tflite::micro::GetTensorShape(output), num_inputs,
160                          all_inputs, tflite::micro::GetTensorData<T>(output));
161 }
162 
163 template <typename T>
EvalAddNQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteEvalTensor * output)164 void EvalAddNQuantized(TfLiteContext* context, TfLiteNode* node,
165                        TfLiteEvalTensor* output) {
166   int num_inputs = NumInputs(node);
167 
168   OpData* data = static_cast<OpData*>(node->user_data);
169   const T** all_inputs =
170       CopyInputsToScratchBuffer<T>(context, node, data->scratch_index);
171 
172   ArithmeticParams params;
173   params.left_shift = data->left_shift;
174   params.input1_offset = data->input_offset;
175   params.input1_multiplier = data->input_multiplier;
176   params.input1_shift = data->input_shift;
177   params.output_offset = data->output_offset;
178   params.output_multiplier = data->output_multiplier;
179   params.output_shift = data->output_shift;
180   SetActivationParams(data->output_activation_min, data->output_activation_max,
181                       &params);
182 
183   reference_ops::AddN(params, tflite::micro::GetTensorShape(output), num_inputs,
184                       all_inputs, tflite::micro::GetTensorData<T>(output));
185 }
186 
Eval(TfLiteContext * context,TfLiteNode * node)187 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
188   TfLiteEvalTensor* output =
189       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
190   if (output->type == kTfLiteFloat32) {
191     EvalAddN<float>(context, node, output);
192   } else if (output->type == kTfLiteInt8) {
193     EvalAddNQuantized<int8_t>(context, node, output);
194   } else {
195     TF_LITE_KERNEL_LOG(context, "ADD_N only supports FLOAT32 and INT8, got %s.",
196                        TfLiteTypeGetName(output->type));
197     return kTfLiteError;
198   }
199   return kTfLiteOk;
200 }
201 
202 }  // namespace
203 
Register_ADD_N()204 TfLiteRegistration Register_ADD_N() {
205   return {/*init=*/nullptr,
206           /*free=*/nullptr,
207           /*prepare=*/Prepare,
208           /*invoke=*/Eval,
209           /*profiling_string=*/nullptr,
210           /*builtin_code=*/0,
211           /*custom_name=*/nullptr,
212           /*version=*/0};
213 }
214 
215 }  // namespace tflite
216