1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h"
18 #include "tensorflow/lite/kernels/internal/reference/pooling.h"
19 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 #include "tensorflow/lite/kernels/padding.h"
22 #include "tensorflow/lite/micro/kernels/kernel_util.h"
23 #include "tensorflow/lite/micro/kernels/pooling.h"
24 
25 namespace tflite {
26 
27 const int kPoolingInputTensor = 0;
28 const int kPoolingOutputTensor = 0;
29 
CalculateOpDataPooling(const TfLiteContext * context,const TfLitePoolParams * params,const TfLiteTensor * input,const TfLiteTensor * output,OpDataPooling * data)30 TfLiteStatus CalculateOpDataPooling(const TfLiteContext* context,
31                                     const TfLitePoolParams* params,
32                                     const TfLiteTensor* input,
33                                     const TfLiteTensor* output,
34                                     OpDataPooling* data) {
35   // input: batch, height, width, channel
36   int height = SizeOfDimension(input, 1);
37   int width = SizeOfDimension(input, 2);
38 
39   int out_height, out_width;
40 
41   data->padding = ComputePaddingHeightWidth(
42       params->stride_height, params->stride_width,
43       /*dilation_rate_height=*/1,
44       /*dilation_rate_width=*/1, height, width, params->filter_height,
45       params->filter_width, params->padding, &out_height, &out_width);
46 
47   return kTfLiteOk;
48 }
49 
PoolingPrepare(TfLiteContext * context,TfLiteNode * node)50 TfLiteStatus PoolingPrepare(TfLiteContext* context, TfLiteNode* node) {
51   TFLITE_DCHECK(node->builtin_data != nullptr);
52   auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
53 
54   TFLITE_DCHECK(node->user_data != nullptr);
55   OpDataPooling* data = static_cast<OpDataPooling*>(node->user_data);
56 
57   const TfLiteTensor* input = GetInput(context, node, kPoolingInputTensor);
58   TF_LITE_ENSURE(context, input != nullptr);
59   TfLiteTensor* output = GetOutput(context, node, kPoolingOutputTensor);
60   TF_LITE_ENSURE(context, output != nullptr);
61 
62   TF_LITE_ENSURE_STATUS(
63       CalculateOpDataPooling(context, params, input, output, data));
64 
65   if (input->type == kTfLiteFloat32) {
66     CalculateActivationRange(params->activation, &data->activation_min_f32,
67                              &data->activation_max_f32);
68   } else if (input->type == kTfLiteInt8) {
69     CalculateActivationRangeQuantized(context, params->activation, output,
70                                       &data->activation_min,
71                                       &data->activation_max);
72   }
73 
74   return kTfLiteOk;
75 }
76 
AveragePoolingEvalFloat(const TfLiteContext * context,const TfLiteNode * node,const TfLitePoolParams * params,const OpDataPooling * data,const TfLiteEvalTensor * input,TfLiteEvalTensor * output)77 void AveragePoolingEvalFloat(const TfLiteContext* context,
78                              const TfLiteNode* node,
79                              const TfLitePoolParams* params,
80                              const OpDataPooling* data,
81                              const TfLiteEvalTensor* input,
82                              TfLiteEvalTensor* output) {
83   PoolParams op_params;
84   op_params.stride_height = params->stride_height;
85   op_params.stride_width = params->stride_width;
86   op_params.filter_height = params->filter_height;
87   op_params.filter_width = params->filter_width;
88   op_params.padding_values.height = data->padding.height;
89   op_params.padding_values.width = data->padding.width;
90   op_params.float_activation_min = data->activation_min_f32;
91   op_params.float_activation_max = data->activation_max_f32;
92   reference_ops::AveragePool(op_params, tflite::micro::GetTensorShape(input),
93                              tflite::micro::GetTensorData<float>(input),
94                              tflite::micro::GetTensorShape(output),
95                              tflite::micro::GetTensorData<float>(output));
96 }
97 
AveragePoolingEvalQuantized(TfLiteContext * context,const TfLiteNode * node,const TfLitePoolParams * params,const OpDataPooling * data,const TfLiteEvalTensor * input,TfLiteEvalTensor * output)98 void AveragePoolingEvalQuantized(TfLiteContext* context, const TfLiteNode* node,
99                                  const TfLitePoolParams* params,
100                                  const OpDataPooling* data,
101                                  const TfLiteEvalTensor* input,
102                                  TfLiteEvalTensor* output) {
103   TFLITE_DCHECK(input->type == kTfLiteInt8);
104 
105   PoolParams op_params;
106   op_params.stride_height = params->stride_height;
107   op_params.stride_width = params->stride_width;
108   op_params.filter_height = params->filter_height;
109   op_params.filter_width = params->filter_width;
110   op_params.padding_values.height = data->padding.height;
111   op_params.padding_values.width = data->padding.width;
112   op_params.quantized_activation_min = data->activation_min;
113   op_params.quantized_activation_max = data->activation_max;
114 
115   reference_integer_ops::AveragePool(
116       op_params, tflite::micro::GetTensorShape(input),
117       tflite::micro::GetTensorData<int8_t>(input),
118       tflite::micro::GetTensorShape(output),
119       tflite::micro::GetTensorData<int8_t>(output));
120 }
121 
MaxPoolingEvalFloat(TfLiteContext * context,TfLiteNode * node,TfLitePoolParams * params,const OpDataPooling * data,const TfLiteEvalTensor * input,TfLiteEvalTensor * output)122 void MaxPoolingEvalFloat(TfLiteContext* context, TfLiteNode* node,
123                          TfLitePoolParams* params, const OpDataPooling* data,
124                          const TfLiteEvalTensor* input,
125                          TfLiteEvalTensor* output) {
126   tflite::PoolParams op_params;
127   op_params.stride_height = params->stride_height;
128   op_params.stride_width = params->stride_width;
129   op_params.filter_height = params->filter_height;
130   op_params.filter_width = params->filter_width;
131   op_params.padding_values.height = data->padding.height;
132   op_params.padding_values.width = data->padding.width;
133   op_params.float_activation_min = data->activation_min_f32;
134   op_params.float_activation_max = data->activation_max_f32;
135   reference_ops::MaxPool(op_params, tflite::micro::GetTensorShape(input),
136                          tflite::micro::GetTensorData<float>(input),
137                          tflite::micro::GetTensorShape(output),
138                          tflite::micro::GetTensorData<float>(output));
139 }
140 
MaxPoolingEvalQuantized(TfLiteContext * context,TfLiteNode * node,TfLitePoolParams * params,const OpDataPooling * data,const TfLiteEvalTensor * input,TfLiteEvalTensor * output)141 void MaxPoolingEvalQuantized(TfLiteContext* context, TfLiteNode* node,
142                              TfLitePoolParams* params,
143                              const OpDataPooling* data,
144                              const TfLiteEvalTensor* input,
145                              TfLiteEvalTensor* output) {
146   tflite::PoolParams op_params;
147   op_params.stride_height = params->stride_height;
148   op_params.stride_width = params->stride_width;
149   op_params.filter_height = params->filter_height;
150   op_params.filter_width = params->filter_width;
151   op_params.padding_values.height = data->padding.height;
152   op_params.padding_values.width = data->padding.width;
153   op_params.quantized_activation_min = data->activation_min;
154   op_params.quantized_activation_max = data->activation_max;
155 
156   reference_integer_ops::MaxPool(op_params,
157                                  tflite::micro::GetTensorShape(input),
158                                  tflite::micro::GetTensorData<int8_t>(input),
159                                  tflite::micro::GetTensorShape(output),
160                                  tflite::micro::GetTensorData<int8_t>(output));
161 }
162 
163 }  // namespace tflite
164