1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/kernels/internal/reference/add_n.h"
17
18 #include <cstdint>
19
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/kernels/internal/quantization_util.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 #include "tensorflow/lite/micro/kernels/kernel_util.h"
25
26 namespace tflite {
27 namespace {
28
29 constexpr int kInputTensor0 = 0;
30 constexpr int kOutputTensor = 0;
31
32 constexpr int kAddNIntegerShift = 20;
33
34 // only used with INT8 tensors
35 struct OpData {
36 int32_t output_activation_min;
37 int32_t output_activation_max;
38 int32_t input_offset;
39 int32_t output_offset;
40 int32_t input_multiplier;
41 int32_t output_multiplier;
42 int input_shift;
43 int output_shift;
44 int left_shift;
45 int scratch_index;
46 };
47
CalculateOpData(TfLiteContext * context,TfLiteNode * node)48 TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
49 int num_inputs = NumInputs(node);
50 TF_LITE_ENSURE(context, num_inputs >= 2);
51 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
52
53 const TfLiteTensor* input_tensor_first;
54 TF_LITE_ENSURE_OK(
55 context, GetInputSafe(context, node, kInputTensor0, &input_tensor_first));
56 TfLiteTensor* output;
57 TF_LITE_ENSURE_OK(context,
58 GetOutputSafe(context, node, kOutputTensor, &output));
59
60 // Check that all tensors have the same shape and type.
61 TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_tensor_first->type);
62 for (int i = kInputTensor0 + 1; i < num_inputs; ++i) {
63 const TfLiteTensor* input;
64 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
65 TF_LITE_ENSURE(context, HaveSameShapes(input_tensor_first, input));
66 TF_LITE_ENSURE_TYPES_EQ(context, input_tensor_first->type, input->type);
67
68 // Check that all INT8 input tensors have the same zero-point and scale.
69 if (input_tensor_first->type == kTfLiteInt8) {
70 TF_LITE_ENSURE(context, input_tensor_first->params.zero_point ==
71 input->params.zero_point);
72 TF_LITE_ENSURE(context,
73 input_tensor_first->params.scale == input->params.scale);
74 }
75 }
76
77 if (output->type == kTfLiteFloat32) {
78 // Allocate scratch buffer space for pointer to each tensor's data
79 // and store the scratch buffer index in the node's user_data
80 int scratch_index;
81 size_t scratch_size = sizeof(float*) * num_inputs;
82 TF_LITE_ENSURE_OK(context, context->RequestScratchBufferInArena(
83 context, scratch_size, &scratch_index));
84 node->user_data =
85 reinterpret_cast<decltype(node->user_data)>(scratch_index);
86 } else if (output->type == kTfLiteInt8) {
87 node->user_data =
88 context->AllocatePersistentBuffer(context, sizeof(OpData));
89 OpData* data = static_cast<OpData*>(node->user_data);
90
91 // Allocate scratch buffer space for pointer to each tensor's data
92 // and store the scratch buffer index in OpData
93 size_t scratch_size = sizeof(int8_t*) * num_inputs;
94 TF_LITE_ENSURE_OK(
95 context, context->RequestScratchBufferInArena(context, scratch_size,
96 &data->scratch_index));
97
98 // 8bit -> 8bit general quantized path, with general rescalings
99 data->input_offset = -input_tensor_first->params.zero_point;
100 data->output_offset = output->params.zero_point;
101 data->left_shift = kAddNIntegerShift;
102 const double twice_max_input_scale =
103 2 * static_cast<double>(input_tensor_first->params.scale);
104 const double real_input_multiplier =
105 static_cast<double>(input_tensor_first->params.scale) /
106 twice_max_input_scale;
107 const double real_output_multiplier =
108 twice_max_input_scale /
109 ((1 << data->left_shift) * static_cast<double>(output->params.scale));
110
111 QuantizeMultiplierSmallerThanOneExp(
112 real_input_multiplier, &data->input_multiplier, &data->input_shift);
113
114 QuantizeMultiplierSmallerThanOneExp(
115 real_output_multiplier, &data->output_multiplier, &data->output_shift);
116
117 TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
118 context, kTfLiteActNone, output, &data->output_activation_min,
119 &data->output_activation_max));
120 } else {
121 TF_LITE_KERNEL_LOG(context, "ADD_N only supports FLOAT32 and INT8, got %s.",
122 TfLiteTypeGetName(output->type));
123 return kTfLiteError;
124 }
125
126 return kTfLiteOk;
127 }
128
Prepare(TfLiteContext * context,TfLiteNode * node)129 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
130 return CalculateOpData(context, node);
131 }
132
133 template <typename T>
CopyInputsToScratchBuffer(TfLiteContext * context,TfLiteNode * node,const int scratch_index)134 inline const T** CopyInputsToScratchBuffer(TfLiteContext* context,
135 TfLiteNode* node,
136 const int scratch_index) {
137 int num_inputs = NumInputs(node);
138 void* scratch_buffer = context->GetScratchBuffer(context, scratch_index);
139 const T** all_inputs = static_cast<decltype(all_inputs)>(scratch_buffer);
140 for (int i = 0; i < num_inputs; i++) {
141 const TfLiteEvalTensor* next_input =
142 tflite::micro::GetEvalInput(context, node, kInputTensor0 + i);
143 all_inputs[i] = tflite::micro::GetTensorData<T>(next_input);
144 }
145
146 return all_inputs;
147 }
148
149 template <typename T>
EvalAddN(TfLiteContext * context,TfLiteNode * node,TfLiteEvalTensor * output)150 void EvalAddN(TfLiteContext* context, TfLiteNode* node,
151 TfLiteEvalTensor* output) {
152 int num_inputs = NumInputs(node);
153
154 int scratch_index =
155 static_cast<int>(reinterpret_cast<intptr_t>(node->user_data));
156 const T** all_inputs =
157 CopyInputsToScratchBuffer<T>(context, node, scratch_index);
158
159 reference_ops::AddN<T>(tflite::micro::GetTensorShape(output), num_inputs,
160 all_inputs, tflite::micro::GetTensorData<T>(output));
161 }
162
163 template <typename T>
EvalAddNQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteEvalTensor * output)164 void EvalAddNQuantized(TfLiteContext* context, TfLiteNode* node,
165 TfLiteEvalTensor* output) {
166 int num_inputs = NumInputs(node);
167
168 OpData* data = static_cast<OpData*>(node->user_data);
169 const T** all_inputs =
170 CopyInputsToScratchBuffer<T>(context, node, data->scratch_index);
171
172 ArithmeticParams params;
173 params.left_shift = data->left_shift;
174 params.input1_offset = data->input_offset;
175 params.input1_multiplier = data->input_multiplier;
176 params.input1_shift = data->input_shift;
177 params.output_offset = data->output_offset;
178 params.output_multiplier = data->output_multiplier;
179 params.output_shift = data->output_shift;
180 SetActivationParams(data->output_activation_min, data->output_activation_max,
181 ¶ms);
182
183 reference_ops::AddN(params, tflite::micro::GetTensorShape(output), num_inputs,
184 all_inputs, tflite::micro::GetTensorData<T>(output));
185 }
186
Eval(TfLiteContext * context,TfLiteNode * node)187 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
188 TfLiteEvalTensor* output =
189 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
190 if (output->type == kTfLiteFloat32) {
191 EvalAddN<float>(context, node, output);
192 } else if (output->type == kTfLiteInt8) {
193 EvalAddNQuantized<int8_t>(context, node, output);
194 } else {
195 TF_LITE_KERNEL_LOG(context, "ADD_N only supports FLOAT32 and INT8, got %s.",
196 TfLiteTypeGetName(output->type));
197 return kTfLiteError;
198 }
199 return kTfLiteOk;
200 }
201
202 } // namespace
203
Register_ADD_N()204 TfLiteRegistration Register_ADD_N() {
205 return {/*init=*/nullptr,
206 /*free=*/nullptr,
207 /*prepare=*/Prepare,
208 /*invoke=*/Eval,
209 /*profiling_string=*/nullptr,
210 /*builtin_code=*/0,
211 /*custom_name=*/nullptr,
212 /*version=*/0};
213 }
214
215 } // namespace tflite
216