1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
17 
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22 #include "tensorflow/lite/micro/kernels/kernel_util.h"
23 #include "tensorflow/lite/micro/kernels/micro_utils.h"
24 
25 namespace tflite {
26 namespace ops {
27 namespace micro {
28 namespace arg_min_max {
29 
30 constexpr int kInputTensor = 0;
31 constexpr int kAxis = 1;
32 constexpr int kOutputTensor = 0;
33 
34 template <typename T1, typename T2, typename T3>
ArgMinMaxHelper(const RuntimeShape & input1_shape,const T1 * input1_data,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data,bool is_arg_max)35 inline void ArgMinMaxHelper(const RuntimeShape& input1_shape,
36                             const T1* input1_data, const T3* input2_data,
37                             const RuntimeShape& output_shape, T2* output_data,
38                             bool is_arg_max) {
39   if (is_arg_max) {
40     reference_ops::ArgMinMax(input1_shape, input1_data, input2_data,
41                              output_shape, output_data, micro::Greater());
42   } else {
43     reference_ops::ArgMinMax(input1_shape, input1_data, input2_data,
44                              output_shape, output_data, micro::Less());
45   }
46 }
47 
Eval(TfLiteContext * context,TfLiteNode * node,bool is_arg_max)48 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
49   const TfLiteEvalTensor* input =
50       tflite::micro::GetEvalInput(context, node, kInputTensor);
51   const TfLiteEvalTensor* axis =
52       tflite::micro::GetEvalInput(context, node, kAxis);
53   TfLiteEvalTensor* output =
54       tflite::micro::GetEvalOutput(context, node, kOutputTensor);
55 
56 #define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type)       \
57   ArgMinMaxHelper(tflite::micro::GetTensorShape(input),              \
58                   tflite::micro::GetTensorData<data_type>(input),    \
59                   tflite::micro::GetTensorData<axis_type>(axis),     \
60                   tflite::micro::GetTensorShape(output),             \
61                   tflite::micro::GetTensorData<output_type>(output), \
62                   is_arg_max)
63   if (axis->type == kTfLiteInt32) {
64     if (output->type == kTfLiteInt32) {
65       switch (input->type) {
66         case kTfLiteFloat32:
67           TF_LITE_ARG_MIN_MAX(float, int32_t, int32_t);
68           break;
69         case kTfLiteUInt8:
70           TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t);
71           break;
72         case kTfLiteInt8:
73           TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int32_t);
74           break;
75         default:
76           TF_LITE_KERNEL_LOG(context,
77                              "Only float32, uint8_t and int8_t are "
78                              "supported currently, got %s.",
79                              TfLiteTypeGetName(input->type));
80           return kTfLiteError;
81       }
82     } else {
83       TF_LITE_KERNEL_LOG(context,
84                          "Only int32_t are supported currently, got %s.",
85                          TfLiteTypeGetName(output->type));
86       return kTfLiteError;
87     }
88   } else {
89     TF_LITE_KERNEL_LOG(context, "Only int32_t are supported currently, got %s.",
90                        TfLiteTypeGetName(axis->type));
91     return kTfLiteError;
92   }
93 
94 #undef TF_LITE_ARG_MIN_MAX
95 
96   return kTfLiteOk;
97 }
98 
ArgMinEval(TfLiteContext * context,TfLiteNode * node)99 TfLiteStatus ArgMinEval(TfLiteContext* context, TfLiteNode* node) {
100   return Eval(context, node, false);
101 }
102 
ArgMaxEval(TfLiteContext * context,TfLiteNode * node)103 TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
104   return Eval(context, node, true);
105 }
106 
107 }  // namespace arg_min_max
108 
Register_ARG_MAX()109 TfLiteRegistration Register_ARG_MAX() {
110   return {/*init=*/nullptr,
111           /*free=*/nullptr,
112           /*prepare=*/nullptr,
113           /*invoke=*/arg_min_max::ArgMaxEval,
114           /*profiling_string=*/nullptr,
115           /*builtin_code=*/0,
116           /*custom_name=*/nullptr,
117           /*version=*/0};
118 }
119 
Register_ARG_MIN()120 TfLiteRegistration Register_ARG_MIN() {
121   return {/*init=*/nullptr,
122           /*free=*/nullptr,
123           /*prepare=*/nullptr,
124           /*invoke=*/arg_min_max::ArgMinEval,
125           /*profiling_string=*/nullptr,
126           /*builtin_code=*/0,
127           /*custom_name=*/nullptr,
128           /*version=*/0};
129 }
130 
131 }  // namespace micro
132 }  // namespace ops
133 }  // namespace tflite
134