1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/kernels/internal/reference/cumsum.h"
17
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/quantization_util.h"
20 #include "tensorflow/lite/kernels/internal/types.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22 #include "tensorflow/lite/micro/kernels/kernel_util.h"
23
24 namespace tflite {
25 namespace {
26
27 constexpr int kInputTensor = 0;
28 constexpr int kAxisTensor = 1;
29 constexpr int kOutputTensor = 0;
30
31 constexpr int kCumSumIntegerShift = 20;
32
33 // only used with INT8 tensors
34 struct OpData {
35 int32_t output_activation_min;
36 int32_t output_activation_max;
37 int32_t input_offset;
38 int32_t output_offset;
39 int32_t input_multiplier;
40 int32_t output_multiplier;
41 int input_shift;
42 int output_shift;
43 int left_shift;
44 };
45
CalculateOpData(TfLiteContext * context,TfLiteNode * node)46 TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
47 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
48 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
49
50 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
51 const TfLiteTensor* axis = GetInput(context, node, kAxisTensor);
52
53 TF_LITE_ENSURE(context,
54 input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
55 TF_LITE_ENSURE_EQ(context, axis->type, kTfLiteInt32);
56
57 TF_LITE_ENSURE_EQ(context, NumElements(axis), 1);
58
59 TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
60
61 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
62
63 TF_LITE_ENSURE_EQ(context, input->type, output->type);
64 TF_LITE_ENSURE(context, HaveSameShapes(input, output));
65
66 if (output->type == kTfLiteInt8) {
67 node->user_data =
68 context->AllocatePersistentBuffer(context, sizeof(OpData));
69 OpData* data = static_cast<OpData*>(node->user_data);
70
71 // 8bit -> 8bit general quantized path, with general rescalings
72 data->input_offset = -input->params.zero_point;
73 data->output_offset = output->params.zero_point;
74 data->left_shift = kCumSumIntegerShift;
75 const double twice_max_input_scale =
76 2 * static_cast<double>(input->params.scale);
77 const double real_input_multiplier =
78 static_cast<double>(input->params.scale) / twice_max_input_scale;
79 const double real_output_multiplier =
80 twice_max_input_scale /
81 ((1 << data->left_shift) * static_cast<double>(output->params.scale));
82
83 QuantizeMultiplierSmallerThanOneExp(
84 real_input_multiplier, &data->input_multiplier, &data->input_shift);
85
86 QuantizeMultiplierSmallerThanOneExp(
87 real_output_multiplier, &data->output_multiplier, &data->output_shift);
88
89 TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
90 context, kTfLiteActNone, output, &data->output_activation_min,
91 &data->output_activation_max));
92 }
93
94 return kTfLiteOk;
95 }
96
Prepare(TfLiteContext * context,TfLiteNode * node)97 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
98 return CalculateOpData(context, node);
99 }
100
Eval(TfLiteContext * context,TfLiteNode * node)101 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
102 const TfLiteEvalTensor* input =
103 tflite::micro::GetEvalInput(context, node, kInputTensor);
104 const TfLiteEvalTensor* axis_tensor =
105 tflite::micro::GetEvalInput(context, node, kAxisTensor);
106
107 TfLiteEvalTensor* output =
108 tflite::micro::GetEvalOutput(context, node, kOutputTensor);
109
110 auto* cs_params = static_cast<TfLiteCumsumParams*>(node->builtin_data);
111 auto input_shape = tflite::micro::GetTensorShape(input);
112
113 int32_t axis = *tflite::micro::GetTensorData<int32_t>(axis_tensor);
114 if (axis < 0) axis += input_shape.DimensionsCount();
115
116 if (axis < 0 || axis >= input_shape.DimensionsCount()) {
117 TF_LITE_KERNEL_LOG(context, "CUMSUM Invalid axis: %d", axis);
118 return kTfLiteError;
119 }
120
121 switch (input->type) {
122 case kTfLiteFloat32: {
123 reference_ops::CumSum(tflite::micro::GetTensorData<float>(input),
124 input_shape, axis, cs_params->exclusive,
125 cs_params->reverse,
126 tflite::micro::GetTensorData<float>(output));
127 return kTfLiteOk;
128 } break;
129
130 case kTfLiteInt8: {
131 auto* data = static_cast<OpData*>(node->user_data);
132 ArithmeticParams params;
133 params.left_shift = data->left_shift;
134 params.input1_offset = data->input_offset;
135 params.input1_multiplier = data->input_multiplier;
136 params.input1_shift = data->input_shift;
137 params.output_offset = data->output_offset;
138 params.output_multiplier = data->output_multiplier;
139 params.output_shift = data->output_shift;
140 SetActivationParams(data->output_activation_min,
141 data->output_activation_max, ¶ms);
142 reference_ops::CumSum(params, tflite::micro::GetTensorData<int8_t>(input),
143 input_shape, axis, cs_params->exclusive,
144 cs_params->reverse,
145 tflite::micro::GetTensorData<int8_t>(output));
146 return kTfLiteOk;
147 } break;
148
149 default: {
150 TF_LITE_KERNEL_LOG(context,
151 "CUMSUM only supports FLOAT32 and INT8, got %s.",
152 TfLiteTypeGetName(output->type));
153 return kTfLiteError;
154 }
155 }
156
157 return kTfLiteError;
158 }
159
160 } // namespace
161
Register_CUMSUM()162 TfLiteRegistration Register_CUMSUM() {
163 return {/*init=*/nullptr,
164 /*free=*/nullptr,
165 /*prepare=*/Prepare,
166 /*invoke=*/Eval,
167 /*profiling_string=*/nullptr,
168 /*builtin_code=*/0,
169 /*custom_name=*/nullptr,
170 /*version=*/0};
171 }
172
173 } // namespace tflite
174