1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h"
18 #include "tensorflow/lite/kernels/internal/reference/pooling.h"
19 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 #include "tensorflow/lite/kernels/padding.h"
22 #include "tensorflow/lite/micro/kernels/kernel_util.h"
23 #include "tensorflow/lite/micro/kernels/pooling.h"
24
25 namespace tflite {
26
27 const int kPoolingInputTensor = 0;
28 const int kPoolingOutputTensor = 0;
29
CalculateOpDataPooling(const TfLiteContext * context,const TfLitePoolParams * params,const TfLiteTensor * input,const TfLiteTensor * output,OpDataPooling * data)30 TfLiteStatus CalculateOpDataPooling(const TfLiteContext* context,
31 const TfLitePoolParams* params,
32 const TfLiteTensor* input,
33 const TfLiteTensor* output,
34 OpDataPooling* data) {
35 // input: batch, height, width, channel
36 int height = SizeOfDimension(input, 1);
37 int width = SizeOfDimension(input, 2);
38
39 int out_height, out_width;
40
41 data->padding = ComputePaddingHeightWidth(
42 params->stride_height, params->stride_width,
43 /*dilation_rate_height=*/1,
44 /*dilation_rate_width=*/1, height, width, params->filter_height,
45 params->filter_width, params->padding, &out_height, &out_width);
46
47 return kTfLiteOk;
48 }
49
PoolingPrepare(TfLiteContext * context,TfLiteNode * node)50 TfLiteStatus PoolingPrepare(TfLiteContext* context, TfLiteNode* node) {
51 TFLITE_DCHECK(node->builtin_data != nullptr);
52 auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
53
54 TFLITE_DCHECK(node->user_data != nullptr);
55 OpDataPooling* data = static_cast<OpDataPooling*>(node->user_data);
56
57 const TfLiteTensor* input = GetInput(context, node, kPoolingInputTensor);
58 TF_LITE_ENSURE(context, input != nullptr);
59 TfLiteTensor* output = GetOutput(context, node, kPoolingOutputTensor);
60 TF_LITE_ENSURE(context, output != nullptr);
61
62 TF_LITE_ENSURE_STATUS(
63 CalculateOpDataPooling(context, params, input, output, data));
64
65 if (input->type == kTfLiteFloat32) {
66 CalculateActivationRange(params->activation, &data->activation_min_f32,
67 &data->activation_max_f32);
68 } else if (input->type == kTfLiteInt8) {
69 CalculateActivationRangeQuantized(context, params->activation, output,
70 &data->activation_min,
71 &data->activation_max);
72 }
73
74 return kTfLiteOk;
75 }
76
AveragePoolingEvalFloat(const TfLiteContext * context,const TfLiteNode * node,const TfLitePoolParams * params,const OpDataPooling * data,const TfLiteEvalTensor * input,TfLiteEvalTensor * output)77 void AveragePoolingEvalFloat(const TfLiteContext* context,
78 const TfLiteNode* node,
79 const TfLitePoolParams* params,
80 const OpDataPooling* data,
81 const TfLiteEvalTensor* input,
82 TfLiteEvalTensor* output) {
83 PoolParams op_params;
84 op_params.stride_height = params->stride_height;
85 op_params.stride_width = params->stride_width;
86 op_params.filter_height = params->filter_height;
87 op_params.filter_width = params->filter_width;
88 op_params.padding_values.height = data->padding.height;
89 op_params.padding_values.width = data->padding.width;
90 op_params.float_activation_min = data->activation_min_f32;
91 op_params.float_activation_max = data->activation_max_f32;
92 reference_ops::AveragePool(op_params, tflite::micro::GetTensorShape(input),
93 tflite::micro::GetTensorData<float>(input),
94 tflite::micro::GetTensorShape(output),
95 tflite::micro::GetTensorData<float>(output));
96 }
97
AveragePoolingEvalQuantized(TfLiteContext * context,const TfLiteNode * node,const TfLitePoolParams * params,const OpDataPooling * data,const TfLiteEvalTensor * input,TfLiteEvalTensor * output)98 void AveragePoolingEvalQuantized(TfLiteContext* context, const TfLiteNode* node,
99 const TfLitePoolParams* params,
100 const OpDataPooling* data,
101 const TfLiteEvalTensor* input,
102 TfLiteEvalTensor* output) {
103 TFLITE_DCHECK(input->type == kTfLiteInt8);
104
105 PoolParams op_params;
106 op_params.stride_height = params->stride_height;
107 op_params.stride_width = params->stride_width;
108 op_params.filter_height = params->filter_height;
109 op_params.filter_width = params->filter_width;
110 op_params.padding_values.height = data->padding.height;
111 op_params.padding_values.width = data->padding.width;
112 op_params.quantized_activation_min = data->activation_min;
113 op_params.quantized_activation_max = data->activation_max;
114
115 reference_integer_ops::AveragePool(
116 op_params, tflite::micro::GetTensorShape(input),
117 tflite::micro::GetTensorData<int8_t>(input),
118 tflite::micro::GetTensorShape(output),
119 tflite::micro::GetTensorData<int8_t>(output));
120 }
121
MaxPoolingEvalFloat(TfLiteContext * context,TfLiteNode * node,TfLitePoolParams * params,const OpDataPooling * data,const TfLiteEvalTensor * input,TfLiteEvalTensor * output)122 void MaxPoolingEvalFloat(TfLiteContext* context, TfLiteNode* node,
123 TfLitePoolParams* params, const OpDataPooling* data,
124 const TfLiteEvalTensor* input,
125 TfLiteEvalTensor* output) {
126 tflite::PoolParams op_params;
127 op_params.stride_height = params->stride_height;
128 op_params.stride_width = params->stride_width;
129 op_params.filter_height = params->filter_height;
130 op_params.filter_width = params->filter_width;
131 op_params.padding_values.height = data->padding.height;
132 op_params.padding_values.width = data->padding.width;
133 op_params.float_activation_min = data->activation_min_f32;
134 op_params.float_activation_max = data->activation_max_f32;
135 reference_ops::MaxPool(op_params, tflite::micro::GetTensorShape(input),
136 tflite::micro::GetTensorData<float>(input),
137 tflite::micro::GetTensorShape(output),
138 tflite::micro::GetTensorData<float>(output));
139 }
140
MaxPoolingEvalQuantized(TfLiteContext * context,TfLiteNode * node,TfLitePoolParams * params,const OpDataPooling * data,const TfLiteEvalTensor * input,TfLiteEvalTensor * output)141 void MaxPoolingEvalQuantized(TfLiteContext* context, TfLiteNode* node,
142 TfLitePoolParams* params,
143 const OpDataPooling* data,
144 const TfLiteEvalTensor* input,
145 TfLiteEvalTensor* output) {
146 tflite::PoolParams op_params;
147 op_params.stride_height = params->stride_height;
148 op_params.stride_width = params->stride_width;
149 op_params.filter_height = params->filter_height;
150 op_params.filter_width = params->filter_width;
151 op_params.padding_values.height = data->padding.height;
152 op_params.padding_values.width = data->padding.width;
153 op_params.quantized_activation_min = data->activation_min;
154 op_params.quantized_activation_max = data->activation_max;
155
156 reference_integer_ops::MaxPool(op_params,
157 tflite::micro::GetTensorShape(input),
158 tflite::micro::GetTensorData<int8_t>(input),
159 tflite::micro::GetTensorShape(output),
160 tflite::micro::GetTensorData<int8_t>(output));
161 }
162
163 } // namespace tflite
164