1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/micro/all_ops_resolver.h"
19 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
20 #include "tensorflow/lite/micro/micro_interpreter.h"
21 #include "tensorflow/lite/micro/mock_micro_graph.h"
22 #include "tensorflow/lite/micro/test_helpers.h"
23 #include "tensorflow/lite/micro/testing/micro_test.h"
24 
25 namespace tflite {
26 namespace testing {
27 namespace {
28 
TestIf(int * input1_dims_data,const bool * input1_data,int * input2_dims_data,const float * input2_data,int * output_dims_data,const float * expected_output_data,const int subgraph1_invoke_count_golden,const int subgraph2_invoke_count_golden,float * output_data)29 void TestIf(int* input1_dims_data, const bool* input1_data,
30             int* input2_dims_data, const float* input2_data,
31             int* output_dims_data, const float* expected_output_data,
32             const int subgraph1_invoke_count_golden,
33             const int subgraph2_invoke_count_golden, float* output_data) {
34   TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data);
35   TfLiteIntArray* input2_dims = IntArrayFromInts(input2_dims_data);
36   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
37   const int output_dims_count = ElementCount(*output_dims);
38 
39   constexpr int inputs_size = 2;
40   constexpr int outputs_size = 1;
41   constexpr int tensors_size = inputs_size + outputs_size;
42   TfLiteTensor tensors[tensors_size] = {
43       CreateTensor(input1_data, input1_dims),
44       CreateTensor(input2_data, input2_dims),
45       CreateTensor(output_data, output_dims),
46   };
47 
48   int inputs_array_data[] = {2, 0, 1};
49   TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
50   int outputs_array_data[] = {1, 2};
51   TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
52 
53   TfLiteIfParams params;
54   params.then_subgraph_index = 1;
55   params.else_subgraph_index = 2;
56 
57   const TfLiteRegistration registration = tflite::Register_IF();
58   micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
59                              outputs_array, &params);
60 
61   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
62   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
63 
64   TF_LITE_MICRO_EXPECT_EQ(output_dims_count, 2);
65   for (int i = 0; i < output_dims_count; ++i) {
66     TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]);
67   }
68 
69   TF_LITE_MICRO_EXPECT_EQ(subgraph1_invoke_count_golden,
70                           runner.GetMockGraph()->get_invoke_count(1));
71   TF_LITE_MICRO_EXPECT_EQ(subgraph2_invoke_count_golden,
72                           runner.GetMockGraph()->get_invoke_count(2));
73 }
74 
75 }  // namespace
76 }  // namespace testing
77 }  // namespace tflite
78 
79 TF_LITE_MICRO_TESTS_BEGIN
80 
TF_LITE_MICRO_TEST(IfShouldInvokeSubgraphWithMockModelConditionTrue)81 TF_LITE_MICRO_TEST(IfShouldInvokeSubgraphWithMockModelConditionTrue) {
82   int shape[] = {2, 1, 2};
83   int condition_shape[] = {1, 1};
84   const bool condition[] = {true};
85   const float input[] = {5.0, 2.0};
86   const float golden[] = {5.0, 2.0};
87   float output_data[2] = {0};
88   tflite::testing::TestIf(condition_shape, condition, shape, input, shape,
89                           golden, 1, 0, output_data);
90 }
91 
TF_LITE_MICRO_TEST(IfShouldInvokeSubgraphWithMockModelConditionFalse)92 TF_LITE_MICRO_TEST(IfShouldInvokeSubgraphWithMockModelConditionFalse) {
93   int shape[] = {2, 1, 2};
94   int condition_shape[] = {1, 1};
95   const bool condition[] = {false};
96   const float input[] = {5.0, 2.0};
97   const float golden[] = {5.0, 2.0};
98   float output_data[2] = {0};
99   tflite::testing::TestIf(condition_shape, condition, shape, input, shape,
100                           golden, 0, 1, output_data);
101 }
102 
TF_LITE_MICRO_TEST(IfShouldInvokeSubgraphConditionTrue)103 TF_LITE_MICRO_TEST(IfShouldInvokeSubgraphConditionTrue) {
104   constexpr int kArenaSize = 5000;
105   uint8_t arena[kArenaSize];
106 
107   const tflite::Model* model =
108       tflite::testing::GetSimpleModelWithSubgraphsAndIf();
109   tflite::MicroMutableOpResolver<3> resolver;
110   tflite::MicroErrorReporter reporter;
111   resolver.AddIf();
112   resolver.AddAdd();
113   resolver.AddMul();
114   tflite::MicroInterpreter interpreter(model, resolver, arena, kArenaSize,
115                                        &reporter);
116   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.AllocateTensors());
117   TfLiteTensor* condition = interpreter.input(0);
118   TfLiteTensor* input1 = interpreter.input(1);
119   TfLiteTensor* input2 = interpreter.input(2);
120   TfLiteTensor* output = interpreter.output(0);
121   float input1_data[] = {2.0, 5.0};
122   float input2_data[] = {3.0, 7.0};
123   memcpy(input1->data.f, input1_data, 2 * sizeof(float));
124   memcpy(input2->data.f, input2_data, 2 * sizeof(float));
125   condition->data.b[0] = true;
126 
127   interpreter.Invoke();
128 
129   TF_LITE_MICRO_EXPECT_EQ(output->data.f[0], 5.0f);
130   TF_LITE_MICRO_EXPECT_EQ(output->data.f[1], 12.0f);
131 }
132 
TF_LITE_MICRO_TEST(IfShouldInvokeSubgraphConditionFalse)133 TF_LITE_MICRO_TEST(IfShouldInvokeSubgraphConditionFalse) {
134   constexpr int kArenaSize = 5000;
135   uint8_t arena[kArenaSize];
136 
137   const tflite::Model* model =
138       tflite::testing::GetSimpleModelWithSubgraphsAndIf();
139   tflite::MicroMutableOpResolver<3> resolver;
140   tflite::MicroErrorReporter reporter;
141   resolver.AddIf();
142   resolver.AddAdd();
143   resolver.AddMul();
144   tflite::MicroInterpreter interpreter(model, resolver, arena, kArenaSize,
145                                        &reporter);
146   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.AllocateTensors());
147   TfLiteTensor* condition = interpreter.input(0);
148   TfLiteTensor* input1 = interpreter.input(1);
149   TfLiteTensor* input2 = interpreter.input(2);
150   TfLiteTensor* output = interpreter.output(0);
151   float input1_data[] = {2.0, 5.0};
152   float input2_data[] = {3.0, 7.0};
153   memcpy(input1->data.f, input1_data, 2 * sizeof(float));
154   memcpy(input2->data.f, input2_data, 2 * sizeof(float));
155   condition->data.b[0] = false;
156 
157   interpreter.Invoke();
158 
159   TF_LITE_MICRO_EXPECT_EQ(output->data.f[0], 6.0f);
160   TF_LITE_MICRO_EXPECT_EQ(output->data.f[1], 35.0f);
161 }
162 
163 TF_LITE_MICRO_TESTS_END
164