1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/micro/all_ops_resolver.h"
19 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
20 #include "tensorflow/lite/micro/micro_interpreter.h"
21 #include "tensorflow/lite/micro/mock_micro_graph.h"
22 #include "tensorflow/lite/micro/test_helpers.h"
23 #include "tensorflow/lite/micro/testing/micro_test.h"
24
25 namespace tflite {
26 namespace testing {
27 namespace {
28
TestIf(int * input1_dims_data,const bool * input1_data,int * input2_dims_data,const float * input2_data,int * output_dims_data,const float * expected_output_data,const int subgraph1_invoke_count_golden,const int subgraph2_invoke_count_golden,float * output_data)29 void TestIf(int* input1_dims_data, const bool* input1_data,
30 int* input2_dims_data, const float* input2_data,
31 int* output_dims_data, const float* expected_output_data,
32 const int subgraph1_invoke_count_golden,
33 const int subgraph2_invoke_count_golden, float* output_data) {
34 TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data);
35 TfLiteIntArray* input2_dims = IntArrayFromInts(input2_dims_data);
36 TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
37 const int output_dims_count = ElementCount(*output_dims);
38
39 constexpr int inputs_size = 2;
40 constexpr int outputs_size = 1;
41 constexpr int tensors_size = inputs_size + outputs_size;
42 TfLiteTensor tensors[tensors_size] = {
43 CreateTensor(input1_data, input1_dims),
44 CreateTensor(input2_data, input2_dims),
45 CreateTensor(output_data, output_dims),
46 };
47
48 int inputs_array_data[] = {2, 0, 1};
49 TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
50 int outputs_array_data[] = {1, 2};
51 TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
52
53 TfLiteIfParams params;
54 params.then_subgraph_index = 1;
55 params.else_subgraph_index = 2;
56
57 const TfLiteRegistration registration = tflite::Register_IF();
58 micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
59 outputs_array, ¶ms);
60
61 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
62 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
63
64 TF_LITE_MICRO_EXPECT_EQ(output_dims_count, 2);
65 for (int i = 0; i < output_dims_count; ++i) {
66 TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]);
67 }
68
69 TF_LITE_MICRO_EXPECT_EQ(subgraph1_invoke_count_golden,
70 runner.GetMockGraph()->get_invoke_count(1));
71 TF_LITE_MICRO_EXPECT_EQ(subgraph2_invoke_count_golden,
72 runner.GetMockGraph()->get_invoke_count(2));
73 }
74
75 } // namespace
76 } // namespace testing
77 } // namespace tflite
78
79 TF_LITE_MICRO_TESTS_BEGIN
80
TF_LITE_MICRO_TEST(IfShouldInvokeSubgraphWithMockModelConditionTrue)81 TF_LITE_MICRO_TEST(IfShouldInvokeSubgraphWithMockModelConditionTrue) {
82 int shape[] = {2, 1, 2};
83 int condition_shape[] = {1, 1};
84 const bool condition[] = {true};
85 const float input[] = {5.0, 2.0};
86 const float golden[] = {5.0, 2.0};
87 float output_data[2] = {0};
88 tflite::testing::TestIf(condition_shape, condition, shape, input, shape,
89 golden, 1, 0, output_data);
90 }
91
TF_LITE_MICRO_TEST(IfShouldInvokeSubgraphWithMockModelConditionFalse)92 TF_LITE_MICRO_TEST(IfShouldInvokeSubgraphWithMockModelConditionFalse) {
93 int shape[] = {2, 1, 2};
94 int condition_shape[] = {1, 1};
95 const bool condition[] = {false};
96 const float input[] = {5.0, 2.0};
97 const float golden[] = {5.0, 2.0};
98 float output_data[2] = {0};
99 tflite::testing::TestIf(condition_shape, condition, shape, input, shape,
100 golden, 0, 1, output_data);
101 }
102
TF_LITE_MICRO_TEST(IfShouldInvokeSubgraphConditionTrue)103 TF_LITE_MICRO_TEST(IfShouldInvokeSubgraphConditionTrue) {
104 constexpr int kArenaSize = 5000;
105 uint8_t arena[kArenaSize];
106
107 const tflite::Model* model =
108 tflite::testing::GetSimpleModelWithSubgraphsAndIf();
109 tflite::MicroMutableOpResolver<3> resolver;
110 tflite::MicroErrorReporter reporter;
111 resolver.AddIf();
112 resolver.AddAdd();
113 resolver.AddMul();
114 tflite::MicroInterpreter interpreter(model, resolver, arena, kArenaSize,
115 &reporter);
116 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.AllocateTensors());
117 TfLiteTensor* condition = interpreter.input(0);
118 TfLiteTensor* input1 = interpreter.input(1);
119 TfLiteTensor* input2 = interpreter.input(2);
120 TfLiteTensor* output = interpreter.output(0);
121 float input1_data[] = {2.0, 5.0};
122 float input2_data[] = {3.0, 7.0};
123 memcpy(input1->data.f, input1_data, 2 * sizeof(float));
124 memcpy(input2->data.f, input2_data, 2 * sizeof(float));
125 condition->data.b[0] = true;
126
127 interpreter.Invoke();
128
129 TF_LITE_MICRO_EXPECT_EQ(output->data.f[0], 5.0f);
130 TF_LITE_MICRO_EXPECT_EQ(output->data.f[1], 12.0f);
131 }
132
TF_LITE_MICRO_TEST(IfShouldInvokeSubgraphConditionFalse)133 TF_LITE_MICRO_TEST(IfShouldInvokeSubgraphConditionFalse) {
134 constexpr int kArenaSize = 5000;
135 uint8_t arena[kArenaSize];
136
137 const tflite::Model* model =
138 tflite::testing::GetSimpleModelWithSubgraphsAndIf();
139 tflite::MicroMutableOpResolver<3> resolver;
140 tflite::MicroErrorReporter reporter;
141 resolver.AddIf();
142 resolver.AddAdd();
143 resolver.AddMul();
144 tflite::MicroInterpreter interpreter(model, resolver, arena, kArenaSize,
145 &reporter);
146 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.AllocateTensors());
147 TfLiteTensor* condition = interpreter.input(0);
148 TfLiteTensor* input1 = interpreter.input(1);
149 TfLiteTensor* input2 = interpreter.input(2);
150 TfLiteTensor* output = interpreter.output(0);
151 float input1_data[] = {2.0, 5.0};
152 float input2_data[] = {3.0, 7.0};
153 memcpy(input1->data.f, input1_data, 2 * sizeof(float));
154 memcpy(input2->data.f, input2_data, 2 * sizeof(float));
155 condition->data.b[0] = false;
156
157 interpreter.Invoke();
158
159 TF_LITE_MICRO_EXPECT_EQ(output->data.f[0], 6.0f);
160 TF_LITE_MICRO_EXPECT_EQ(output->data.f[1], 35.0f);
161 }
162
163 TF_LITE_MICRO_TESTS_END
164