1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstddef>
17 #include <cstdint>
18
19 #include "tensorflow/lite/c/builtin_op_data.h"
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/micro/all_ops_resolver.h"
22 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
23 #include "tensorflow/lite/micro/micro_utils.h"
24 #include "tensorflow/lite/micro/test_helpers.h"
25 #include "tensorflow/lite/micro/testing/micro_test.h"
26
27 namespace tflite {
28 namespace testing {
29 namespace {
30
31 // Simple test data for 2x2x10 input 2x3x10 weights.
32 const int simple_input_size = 20;
33 int simple_input_dims[] = {2, 2, 10};
34 const float simple_input_data[] = {
35 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
36 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
37 };
38 const int simple_weights_size = 30;
39 int simple_weights_dims[] = {2, 3, 10};
40 const float simple_weights_data[] = {
41 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
42 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
43 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
44 };
45 int simple_bias_dims[] = {1, 3};
46 const float simple_bias_data[] = {1, 2, 3};
47 const float simple_golden[] = {
48 24, 25, 26, 58, 59, 60,
49 };
50 const int simple_output_size = 6;
51 int simple_output_dims[] = {2, 2, 3};
52
53 // Test data for 2x2x10 input 2x3x10 weights with negative outputs to test relu.
54 const int relu_input_size = 20;
55 int relu_input_dims[] = {2, 2, 10};
56 const float relu_input_data[] = {
57 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
58 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
59 };
60 const int relu_weights_size = 30;
61 int relu_weights_dims[] = {2, 3, 10};
62 const float relu_weights_data[] = {
63 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
64 -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, // u = 1
65 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
66 };
67 int relu_bias_dims[] = {1, 3};
68 const float relu_bias_data[] = {1, -2, 3};
69 const float relu_golden[] = {
70 24, 0, 26, 58, 0, 60,
71 };
72 const int relu_output_size = 6;
73 int relu_output_dims[] = {2, 2, 3};
74
75 // Input and filter similar to real model. Input shape is 1x64 and output is
76 // 1x16.
77 const int representative_64x16_input_size = 64;
78 int representative_64x16_input_dims[] = {2, 1, 64};
79 const float representative_64x16_input_data[] = {
80 0.0000, 0.1543, 0.0000, 0.0000, 1.8520, 0.0000, 4.7844, 1.1832,
81 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.5948, 0.0000,
82 1.5948, 1.9549, 0.0000, 1.2347, 0.0000, 1.5948, 1.5948, 0.5145,
83 0.0000, 0.0000, 0.0000, 0.0000, 2.6237, 0.0000, 0.0000, 0.0000,
84 1.3890, 5.3503, 2.3665, 2.9838, 0.0000, 1.2861, 0.0000, 3.0867,
85 0.9775, 0.0000, 5.9676, 0.0000, 0.0000, 1.4405, 0.5145, 2.5723,
86 3.1896, 4.4757, 0.0000, 0.0000, 0.0000, 0.0000, 4.1671, 0.0000,
87 2.8295, 3.0353, 0.0000, 2.7780, 0.0000, 0.0000, 0.0000, 0.0000};
88 const int representative_64x16_weights_size = 64 * 16;
89 int representative_64x16_weights_dims[] = {2, 16, 64};
90 const float representative_64x16_weights_data[] = {
91 -0.1075, 0.1245, 0.1811, -0.1302, -0.1868, 0.0679, 0.1245, 0.2321,
92 -0.1981, -0.2094, 0.1358, -0.1698, 0.0113, 0.0566, 0.1358, -0.2490,
93 0.0000, -0.1189, -0.0170, -0.0396, -0.3113, 0.1641, -0.4188, 0.0566,
94 -0.4471, 0.4754, -0.0396, 0.0113, -0.0340, 0.0170, 0.0170, 0.1811,
95 -0.0792, 0.4981, 0.2490, -0.1924, 0.0792, 0.1868, -0.1075, -0.3962,
96 0.1358, 0.2547, -0.1245, -0.0962, -0.0283, 0.4132, -0.0057, -0.5150,
97 0.1019, 0.1585, -0.0962, -0.2207, -0.2377, 0.2830, 0.4471, 0.0170,
98 0.0566, 0.2038, 0.1019, -0.0226, 0.2830, 0.1415, 0.0283, -0.0792,
99 0.4301, 0.3226, -0.1132, 0.4981, -0.3849, -0.2943, -0.2547, -0.2264,
100 0.0453, -0.0170, 0.0396, 0.1415, 0.3000, 0.2547, 0.0962, 0.2151,
101 -0.1585, -0.1302, -0.0057, -0.2773, 0.0283, -0.0906, 0.1302, -0.1075,
102 -0.0566, 0.1755, 0.2773, 0.0283, 0.0566, 0.1528, -0.0736, -0.2830,
103 0.0792, 0.0962, -0.2321, -0.0113, 0.2660, -0.2887, -0.0566, 0.0057,
104 -0.2547, -0.0679, -0.2321, 0.0340, 0.1868, 0.2490, 0.2264, -0.3509,
105 0.1585, -0.0849, -0.0623, 0.1132, 0.3396, -0.2490, 0.1528, 0.0679,
106 0.1755, 0.4754, -0.0057, -0.2151, -0.1415, -0.1302, -0.2717, 0.1641,
107 0.5037, -0.2321, 0.0170, -0.1755, -0.1075, -0.0226, 0.2038, -0.0340,
108 -0.5150, -0.3113, 0.1472, -0.0226, 0.1528, 0.1189, -0.1472, 0.0396,
109 -0.3000, -0.1924, -0.0283, 0.0283, 0.1641, 0.0736, 0.1472, -0.1755,
110 -0.1132, 0.0113, -0.1868, -0.2604, -0.3283, -0.0509, 0.0283, -0.0679,
111 0.0623, 0.0792, -0.0283, -0.0962, 0.0396, 0.1641, 0.4584, 0.3226,
112 0.0226, -0.1811, 0.2377, -0.1019, 0.2321, 0.1811, -0.1924, -0.0057,
113 0.0736, 0.0113, 0.2547, -0.2264, -0.0170, -0.0396, 0.1245, -0.1415,
114 0.1755, 0.3679, -0.2377, -0.0396, -0.1585, -0.3000, -0.1641, -0.1302,
115 -0.0396, -0.1698, 0.1189, 0.2434, 0.1132, -0.1245, -0.1415, 0.0453,
116 0.1868, -0.0906, -0.1189, -0.0509, 0.0057, -0.1189, -0.0057, 0.0170,
117 -0.1924, 0.2207, 0.0792, -0.4641, -0.2660, 0.2943, 0.1358, -0.0340,
118 -0.3339, -0.1189, 0.0906, -0.4358, 0.0453, -0.1755, 0.1415, 0.0340,
119 0.1924, -0.0057, 0.2321, -0.2094, -0.1132, 0.0000, 0.1924, -0.3000,
120 0.0340, -0.3396, -0.0906, -0.0340, 0.1641, -0.0226, -0.1472, -0.1019,
121 0.2377, -0.0962, -0.3396, -0.5433, 0.0906, 0.2151, -0.0679, 0.1755,
122 0.1528, 0.0283, -0.4188, -0.0340, -0.0057, -0.0679, 0.0509, 0.1472,
123 -0.3849, -0.0113, 0.3962, 0.0849, 0.1472, 0.0340, -0.1358, 0.1641,
124 -0.2038, 0.2151, -0.1189, -0.3679, 0.0906, -0.0679, 0.5716, -0.0057,
125 -0.0736, 0.0113, 0.2830, -0.2887, 0.0396, 0.0849, -0.0736, -0.0736,
126 -0.3679, 0.2264, 0.0113, -0.1641, 0.0396, -0.1132, -0.0623, 0.3113,
127 0.5999, -0.1415, 0.1472, -0.2038, -0.1132, -0.2377, 0.0566, 0.1755,
128 -0.0057, -0.0453, 0.0226, 0.1132, 0.1698, 0.0340, -0.0226, 0.0226,
129 0.4415, -0.3792, 0.0792, 0.3736, -0.5999, -0.3056, -0.1924, -0.1132,
130 -0.0962, 0.0283, 0.0000, -0.3339, -0.3226, 0.3679, -0.0453, -0.1641,
131 0.0170, 0.1302, -0.0170, -0.0509, 0.1755, -0.0283, -0.1302, -0.2887,
132 -0.0679, 0.0340, 0.4641, 0.2321, 0.7188, 0.3339, -0.1075, 0.4754,
133 -0.0226, 0.3226, -0.1528, -0.0849, 0.0509, -0.1981, 0.0113, 0.2321,
134 0.2773, -0.1019, 0.4075, 0.0396, 0.0792, 0.1132, -0.0906, -0.4188,
135 0.1924, -0.3679, -0.6396, 0.1358, 0.4981, 0.4132, -0.0283, 0.3849,
136 -0.3509, -0.0566, -0.0962, 0.3113, -0.1811, 0.4019, 0.0453, -0.0057,
137 -0.1868, -0.2490, -0.0792, -0.3622, 0.1924, -0.0453, -0.1528, -0.1811,
138 0.5943, -0.1302, 0.3170, -0.0170, 0.0509, -0.1528, -0.1755, 0.5547,
139 0.2490, -0.0906, 0.0000, 0.1698, 0.0000, 0.0340, -0.1132, -0.0509,
140 -0.1755, -0.2943, 0.1472, 0.0849, 0.0000, 0.1528, -0.0566, 0.1528,
141 -0.5264, -0.5320, -0.0736, 0.0566, 0.2604, -0.4075, 0.0962, -0.3453,
142 -0.1415, 0.0057, 0.3905, 0.2830, 0.3679, 0.5320, -0.2660, 0.0340,
143 0.0736, 0.0057, 0.2207, 0.4471, 0.0849, 0.3000, -0.0057, -0.0623,
144 0.1415, -0.0566, 0.5264, -0.0340, 0.0226, -0.0623, -0.0113, -0.5037,
145 -0.4471, 0.0170, -0.0396, -0.1358, -0.1698, 0.1924, 0.0057, -0.1585,
146 0.0849, -0.1698, 0.0057, -0.1245, -0.0170, -0.1755, -0.0792, 0.5264,
147 0.1358, 0.2434, 0.1585, -0.4188, -0.1472, -0.1358, -0.0849, -0.1189,
148 0.5037, 0.0736, -0.0453, -0.2434, 0.1868, -0.0679, 0.1415, -0.2717,
149 0.2604, 0.0057, -0.1528, -0.1811, 0.0226, -0.1641, 0.3170, -0.1981,
150 0.1245, 0.0226, 0.0566, 0.2830, -0.1755, 0.0396, -0.2094, 0.1924,
151 0.1698, 0.0283, 0.1641, 0.0849, 0.0000, -0.1698, -0.1415, -0.3000,
152 0.4471, 0.3056, -0.0283, -0.4245, -0.0453, 0.0226, 0.0000, -0.1075,
153 -0.1528, -0.3226, 0.2773, -0.2264, -0.1811, 0.1755, -0.3566, -0.4188,
154 0.1755, -0.0057, 0.2038, 0.1075, 0.3679, -0.0792, 0.2207, -0.0453,
155 0.3736, 0.2943, -0.0113, -0.0623, 0.2264, 0.0113, -0.0396, -0.2207,
156 0.0453, -0.2830, -0.1302, 0.0623, -0.1924, -0.1811, -0.2717, 0.2830,
157 0.2094, 0.0170, -0.3170, -0.0283, -0.1189, -0.0509, -0.0566, -0.3622,
158 0.1132, -0.0906, 0.1132, 0.4019, -0.4698, -0.1019, -0.1075, -0.2094,
159 -0.2207, -0.0509, 0.0057, 0.1019, -0.0509, 0.2264, -0.5716, 0.0226,
160 -0.4019, 0.1641, -0.3000, 0.3849, 0.1245, 0.0679, 0.3056, 0.2377,
161 0.0679, -0.0170, -0.5377, -0.0170, 0.0057, 0.1358, -0.1132, -0.2038,
162 0.0679, 0.1075, -0.2773, 0.5943, 0.0623, -0.1472, 0.3566, 0.0396,
163 -0.2377, 0.2604, 0.0849, 0.1358, -0.3792, -0.0340, -0.1415, 0.3566,
164 -0.3736, 0.1245, 0.0566, 0.3396, 0.0736, 0.4019, -0.1528, 0.1075,
165 0.0792, -0.2547, 0.0453, -0.1755, 0.1868, -0.2547, 0.1075, 0.0623,
166 0.1698, -0.0170, 0.1585, -0.0736, -0.4358, -0.0113, -0.6792, -0.0849,
167 -0.0396, -0.6056, 0.1358, 0.1189, 0.2547, 0.1528, 0.2887, 0.0453,
168 -0.1075, -0.3283, -0.0453, -0.0509, 0.2038, 0.2547, 0.0849, -0.0566,
169 -0.1698, 0.0509, -0.0113, -0.1585, 0.1924, -0.0792, -0.1868, 0.0509,
170 -0.1698, -0.0849, -0.0170, 0.0453, 0.3170, 0.0906, -0.5943, -0.1245,
171 0.1585, -0.1755, -0.2151, 0.0906, 0.1924, 0.3170, -0.2490, -0.5660,
172 -0.0283, 0.0962, -0.1358, 0.1585, 0.0057, -0.2604, 0.1189, -0.0170,
173 0.3509, 0.0623, 0.0679, -0.1302, -0.0792, 0.0906, -0.0792, 0.0849,
174 -0.1924, 0.2604, -0.1245, -0.3679, 0.0340, 0.0113, -0.1698, 0.2490,
175 0.0283, 0.1019, -0.3736, 0.1019, -0.2207, -0.0340, 0.3170, 0.1755,
176 0.0962, 0.3226, -0.0113, -0.1189, -0.2321, -0.0226, -0.2434, -0.0170,
177 -0.1585, -0.0283, -0.1132, 0.0679, -0.4188, -0.0453, 0.1528, -0.1302,
178 -0.3792, 0.1415, -0.1358, -0.1811, 0.1302, 0.1415, 0.5207, 0.0509,
179 -0.1358, -0.0396, -0.2434, 0.0396, 0.0792, -0.2264, -0.1415, 0.0906,
180 0.1245, 0.0170, 0.0623, -0.1415, 0.2773, -0.3566, -0.0396, 0.2887,
181 0.4188, 0.1698, -0.2547, 0.1132, -0.0453, -0.0113, -0.1358, 0.1075,
182 0.0566, 0.1075, 0.2604, -0.0849, -0.2490, 0.1415, 0.0509, -0.2151,
183 0.0340, 0.1698, 0.0509, -0.0906, 0.0566, -0.1075, -0.2151, 0.2038,
184 -0.1924, -0.0113, 0.2830, 0.1358, -0.1189, 0.0113, -0.5603, -0.2830,
185 -0.2943, 0.0453, -0.0396, 0.1358, 0.0566, 0.2038, -0.3283, -0.0509,
186 0.0509, 0.1641, 0.2094, -0.2038, -0.1868, -0.1585, -0.2207, -0.1302,
187 0.0396, -0.1019, -0.0679, 0.1075, -0.4584, -0.2207, 0.2434, -0.0113,
188 0.0849, 0.1755, -0.3056, 0.1585, -0.2547, 0.0453, 0.0906, -0.1358,
189 -0.0679, -0.0509, 0.0679, -0.3509, 0.0057, 0.0453, 0.4132, -0.1981,
190 0.2264, -0.0736, 0.1075, 0.0679, -0.0906, -0.3113, 0.0509, 0.0849,
191 0.2604, 0.0623, -0.3113, 0.3849, 0.0000, 0.6396, -0.2038, -0.1019,
192 0.1245, -0.0453, 0.1641, 0.1075, -0.1075, -0.2660, -0.4528, -0.0566,
193 -0.0170, 0.0453, 0.0340, 0.1189, -0.2434, -0.0283, -0.1811, 0.2547,
194 0.0000, -0.0226, 0.4471, 0.1019, -0.1472, 0.0849, 0.1075, 0.1075,
195 0.0283, -0.2773, 0.4415, -0.1811, 0.2717, 0.3170, 0.0509, 0.0623,
196 -0.0962, 0.1585, -0.0792, -0.1811, -0.0792, -0.3283, 0.0962, -0.1698,
197 -0.0736, 0.0453, 0.0962, -0.3566, -0.4584, 0.3396, -0.4811, 0.3056,
198 -0.1755, 0.2490, -0.1698, -0.2377, -0.3339, -0.0453, 0.1811, 0.0736,
199 0.0340, -0.0962, -0.0113, -0.3056, -0.3339, 0.2038, 0.2038, -0.1924,
200 0.2547, -0.4471, -0.0849, -0.2038, 0.3566, -0.4811, 0.3453, 0.0849,
201 0.1189, 0.3170, -0.1358, 0.2717, 0.0113, -0.4754, -0.1924, 0.4245,
202 -0.2773, 0.3453, 0.2264, 0.2943, 0.5320, 0.2773, -0.2264, -0.1019,
203 -0.1132, -0.3962, 0.3679, 0.0509, -0.0623, -0.0906, -0.5603, -0.1641,
204 -0.3170, -0.2377, 0.1415, -0.0509, 0.0792, 0.0170, -0.0226, -0.0057,
205 -0.1358, -0.4245, 0.3905, 0.3113, 0.0340, -0.1189, 0.2887, -0.2943,
206 -0.3056, 0.2434, 0.1019, -0.0170, 0.3849, 0.1528, -0.0736, -0.0170,
207 0.0792, 0.1755, 0.0509, 0.3509, 0.1472, 0.1528, 0.1472, 0.0057,
208 0.0113, -0.0113, -0.3283, -0.3962, -0.0792, -0.1245, -0.0283, -0.1868,
209 0.4019, 0.2943, -0.0906, -0.2321, 0.6056, 0.1189, 0.0340, -0.2207,
210 -0.0453, 0.3339, 0.2377, -0.1641, 0.3736, 0.2151, -0.2547, 0.0453,
211 0.1924, -0.1019, -0.0340, -0.2207, 0.3962, -0.4471, -0.2547, -0.2151,
212 -0.3736, 0.0283, 0.1189, 0.0283, 0.0736, 0.0396, 0.1019, 0.0283,
213 0.0170, 0.2321, 0.3509, -0.0226, -0.0226, 0.0736, 0.0283, 0.1641,
214 -0.0906, 0.1811, 0.0226, 0.5716, -0.0396, -0.0509, -0.1641, -0.0509,
215 0.4132, -0.2604, 0.1019, -0.0283, -0.0340, 0.0453, 0.1472, -0.0057,
216 0.2717, -0.2094, 0.3396, 0.0340, 0.1245, 0.2547, -0.5886, 0.2717,
217 -0.0906, 0.1641, 0.0962, -0.0792, -0.0113, 0.2264, -0.0736, 0.3170,
218 0.0623, 0.0679, 0.0623, -0.0792, -0.2207, 0.1924, 0.1245, -0.2773};
219 int representative_64x16_bias_dims[] = {1, 16};
220 const float representative_64x16_bias_data[] = {
221 -0.0084, 0.0006, 0.0000, 0.0000, -0.0087, -0.0006, -0.0003, -0.0003,
222 0.0006, -0.0003, -0.0003, -0.0003, -0.0253, 0.0012, 0.0000, 0.0000};
223 const float representative_64x16_golden[] = {
224 3.8624, -2.9580, 4.3043, -1.2844, -1.5769, -2.7998, -0.1011, -3.4029,
225 -1.0557, -7.1931, -1.4852, -0.4163, 1.7186, -0.6965, 0.3580, 2.7378};
226 const int representative_64x16_output_size = 16;
227 int representative_64x16_output_dims[] = {2, 1, 16};
228
229 template <typename T>
ValidateFullyConnectedGoldens(TfLiteTensor * tensors,const int tensors_size,const TfLiteFusedActivation activation,const float tolerance,const int output_len,const T * golden,T * output_data)230 TfLiteStatus ValidateFullyConnectedGoldens(
231 TfLiteTensor* tensors, const int tensors_size,
232 const TfLiteFusedActivation activation, const float tolerance,
233 const int output_len, const T* golden, T* output_data) {
234 TfLiteFullyConnectedParams builtin_data = {
235 activation, kTfLiteFullyConnectedWeightsFormatDefault, false, false};
236
237 int inputs_array_data[] = {3, 0, 1, 2};
238 TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
239 int outputs_array_data[] = {1, 3};
240 TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
241
242 const TfLiteRegistration registration = Register_FULLY_CONNECTED();
243 micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
244 outputs_array,
245 reinterpret_cast<void*>(&builtin_data));
246
247 TfLiteStatus status = runner.InitAndPrepare();
248 if (status != kTfLiteOk) {
249 return status;
250 }
251
252 status = runner.Invoke();
253 if (status != kTfLiteOk) {
254 return status;
255 }
256
257 for (int i = 0; i < output_len; ++i) {
258 TF_LITE_MICRO_EXPECT_NEAR(golden[i], output_data[i], tolerance);
259 }
260 return kTfLiteOk;
261 }
262
263 #if !defined(XTENSA) // Needed to avoid build error from unused functions.
TestFullyConnectedFloat(int * input_dims_data,const float * input_data,int * weights_dims_data,const float * weights_data,int * bias_dims_data,const float * bias_data,const float * golden,int * output_dims_data,TfLiteFusedActivation activation,float * output_data)264 TfLiteStatus TestFullyConnectedFloat(
265 int* input_dims_data, const float* input_data, int* weights_dims_data,
266 const float* weights_data, int* bias_dims_data, const float* bias_data,
267 const float* golden, int* output_dims_data,
268 TfLiteFusedActivation activation, float* output_data) {
269 TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
270 TfLiteIntArray* weights_dims = IntArrayFromInts(weights_dims_data);
271 TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data);
272 TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
273 const int output_dims_count = ElementCount(*output_dims);
274
275 constexpr int inputs_size = 3;
276 constexpr int outputs_size = 1;
277 constexpr int tensors_size = inputs_size + outputs_size;
278 TfLiteTensor tensors[tensors_size] = {
279 CreateTensor(input_data, input_dims),
280 CreateTensor(weights_data, weights_dims),
281 CreateTensor(bias_data, bias_dims),
282 CreateTensor(output_data, output_dims),
283 };
284
285 return ValidateFullyConnectedGoldens(tensors, tensors_size, activation, 1e-4f,
286 output_dims_count, golden, output_data);
287 }
288 #endif
289
290 template <typename T>
TestFullyConnectedQuantized(int * input_dims_data,const float * input_data,T * input_quantized,const float input_scale,const int input_zero_point,int * weights_dims_data,const float * weights_data,T * weights_quantized,const float weights_scale,const int weights_zero_point,int * bias_dims_data,const float * bias_data,int32_t * bias_quantized,const float * golden,T * golden_quantized,int * output_dims_data,const float output_scale,const int output_zero_point,TfLiteFusedActivation activation,T * output_data)291 TfLiteStatus TestFullyConnectedQuantized(
292 int* input_dims_data, const float* input_data, T* input_quantized,
293 const float input_scale, const int input_zero_point, int* weights_dims_data,
294 const float* weights_data, T* weights_quantized, const float weights_scale,
295 const int weights_zero_point, int* bias_dims_data, const float* bias_data,
296 int32_t* bias_quantized, const float* golden, T* golden_quantized,
297 int* output_dims_data, const float output_scale,
298 const int output_zero_point, TfLiteFusedActivation activation,
299 T* output_data) {
300 TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
301 TfLiteIntArray* weights_dims = IntArrayFromInts(weights_dims_data);
302 TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data);
303 TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
304 const int output_dims_count = ElementCount(*output_dims);
305
306 constexpr int inputs_size = 3;
307 constexpr int outputs_size = 1;
308 constexpr int tensors_size = inputs_size + outputs_size;
309 TfLiteTensor tensors[tensors_size] = {
310 CreateQuantizedTensor(input_data, input_quantized, input_dims,
311 input_scale, input_zero_point),
312 CreateQuantizedTensor(weights_data, weights_quantized, weights_dims,
313 weights_scale, weights_zero_point),
314 CreateQuantizedBiasTensor(bias_data, bias_quantized, bias_dims,
315 input_scale, weights_scale),
316 CreateQuantizedTensor(output_data, output_dims, output_scale,
317 output_zero_point),
318 };
319
320 Quantize(golden, golden_quantized, output_dims_count, output_scale,
321 output_zero_point);
322
323 return ValidateFullyConnectedGoldens(tensors, tensors_size, activation, 0.0f,
324 output_dims_count, golden_quantized,
325 output_data);
326 }
327
328 } // namespace
329 } // namespace testing
330 } // namespace tflite
331
332 TF_LITE_MICRO_TESTS_BEGIN
333
334 #if !defined(XTENSA) && !defined(CEVA_BX1) && !defined(CEVA_SP500)
335 // TODO(b/170503075): xtensa kernels are less general
336 // than reference kernels and we ifdef out test cases that are currently known
337 // to fail.
338
339 // CEVA's fully connected implementation assumes weights_zero_point=0 as
340 // described in TFLite's quantization specification. tests which use a different
341 // zero point will so ifdefed out.
342 // See tflite quantization spec:
343 // https://www.tensorflow.org/lite/performance/quantization_spec
TF_LITE_MICRO_TEST(SimpleTest)344 TF_LITE_MICRO_TEST(SimpleTest) {
345 float output_data[tflite::testing::simple_output_size];
346 TF_LITE_MICRO_EXPECT_EQ(
347 tflite::testing::TestFullyConnectedFloat(
348 tflite::testing::simple_input_dims,
349 tflite::testing::simple_input_data,
350 tflite::testing::simple_weights_dims,
351 tflite::testing::simple_weights_data,
352 tflite::testing::simple_bias_dims, tflite::testing::simple_bias_data,
353 tflite::testing::simple_golden, tflite::testing::simple_output_dims,
354 kTfLiteActNone, output_data),
355 kTfLiteOk);
356 }
357
358 #endif
359
TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8)360 TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8) {
361 const float input_scale = 1.0f;
362 const int input_zero_point = -1;
363 const float weights_scale = 1.0f;
364 const int weights_zero_point = 0;
365 const float output_scale = 0.5f;
366 const int output_zero_point = -1;
367
368 int8_t input_quantized[tflite::testing::simple_input_size];
369 int8_t weights_quantized[tflite::testing::simple_weights_size];
370 int32_t bias_quantized[tflite::testing::simple_output_size];
371 int8_t golden_quantized[tflite::testing::simple_output_size];
372 int8_t output_data[tflite::testing::simple_output_size];
373
374 TF_LITE_MICRO_EXPECT_EQ(
375 tflite::testing::TestFullyConnectedQuantized(
376 tflite::testing::simple_input_dims,
377 tflite::testing::simple_input_data, input_quantized, input_scale,
378 input_zero_point, tflite::testing::simple_weights_dims,
379 tflite::testing::simple_weights_data, weights_quantized,
380 weights_scale, weights_zero_point, tflite::testing::simple_bias_dims,
381 tflite::testing::simple_bias_data, bias_quantized,
382 tflite::testing::simple_golden, golden_quantized,
383 tflite::testing::simple_output_dims, output_scale, output_zero_point,
384 kTfLiteActNone, output_data),
385 kTfLiteOk);
386 }
387
TF_LITE_MICRO_TEST(SimpleTest4DInputQuantizedInt8)388 TF_LITE_MICRO_TEST(SimpleTest4DInputQuantizedInt8) {
389 const float input_scale = 1.0f;
390 const int input_zero_point = -1;
391 const float weights_scale = 1.0f;
392 const int weights_zero_point = 0;
393
394 const float output_scale = 0.5f;
395 const int output_zero_point = -1;
396
397 int input_dims_4d[] = {4, 1, 1, 2, 10};
398
399 int8_t input_quantized[tflite::testing::simple_input_size];
400 int8_t weights_quantized[tflite::testing::simple_weights_size];
401 int32_t bias_quantized[tflite::testing::simple_output_size];
402 int8_t golden_quantized[tflite::testing::simple_output_size];
403 int8_t output_data[tflite::testing::simple_output_size];
404
405 TF_LITE_MICRO_EXPECT_EQ(
406 tflite::testing::TestFullyConnectedQuantized(
407 input_dims_4d, tflite::testing::simple_input_data, input_quantized,
408 input_scale, input_zero_point, tflite::testing::simple_weights_dims,
409 tflite::testing::simple_weights_data, weights_quantized,
410 weights_scale, weights_zero_point, tflite::testing::simple_bias_dims,
411 tflite::testing::simple_bias_data, bias_quantized,
412 tflite::testing::simple_golden, golden_quantized,
413 tflite::testing::simple_output_dims, output_scale, output_zero_point,
414 kTfLiteActNone, output_data),
415 kTfLiteOk);
416 }
417
TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8Relu)418 TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8Relu) {
419 const float input_scale = 1.0f;
420 const int input_zero_point = -1;
421 const float weights_scale = 1.0f;
422 const int weights_zero_point = 0;
423
424 const float output_scale = 0.5f;
425 const int output_zero_point = -128;
426
427 int8_t input_quantized[tflite::testing::relu_input_size];
428 int8_t weights_quantized[tflite::testing::relu_weights_size];
429 int32_t bias_quantized[tflite::testing::relu_output_size];
430 int8_t golden_quantized[tflite::testing::relu_output_size];
431 int8_t output_data[tflite::testing::relu_output_size];
432
433 TF_LITE_MICRO_EXPECT_EQ(
434 tflite::testing::TestFullyConnectedQuantized(
435 tflite::testing::relu_input_dims, tflite::testing::relu_input_data,
436 input_quantized, input_scale, input_zero_point,
437 tflite::testing::relu_weights_dims,
438 tflite::testing::relu_weights_data, weights_quantized, weights_scale,
439 weights_zero_point, tflite::testing::relu_bias_dims,
440 tflite::testing::relu_bias_data, bias_quantized,
441 tflite::testing::relu_golden, golden_quantized,
442 tflite::testing::relu_output_dims, output_scale, output_zero_point,
443 kTfLiteActRelu, output_data),
444 kTfLiteOk);
445 }
446
447 #if !defined(XTENSA) // TODO(b/170503075): xtensa kernels are less general than
448 // reference kernels and we ifdef out test cases that are
449 // currently known to fail.
TF_LITE_MICRO_TEST(SimpleTest4DInput)450 TF_LITE_MICRO_TEST(SimpleTest4DInput) {
451 int input_dims_4d[] = {4, 1, 1, 2, 10};
452
453 float output_data[tflite::testing::simple_output_size];
454
455 TF_LITE_MICRO_EXPECT_EQ(
456 tflite::testing::TestFullyConnectedFloat(
457 input_dims_4d, tflite::testing::simple_input_data,
458 tflite::testing::simple_weights_dims,
459 tflite::testing::simple_weights_data,
460 tflite::testing::simple_bias_dims, tflite::testing::simple_bias_data,
461 tflite::testing::simple_golden, tflite::testing::simple_output_dims,
462 kTfLiteActNone, output_data),
463 kTfLiteOk);
464 }
465
TF_LITE_MICRO_TEST(Representative1x64Input1x16Output)466 TF_LITE_MICRO_TEST(Representative1x64Input1x16Output) {
467 float output_data[tflite::testing::representative_64x16_output_size];
468
469 TF_LITE_MICRO_EXPECT_EQ(
470 tflite::testing::TestFullyConnectedFloat(
471 tflite::testing::representative_64x16_input_dims,
472 tflite::testing::representative_64x16_input_data,
473 tflite::testing::representative_64x16_weights_dims,
474 tflite::testing::representative_64x16_weights_data,
475 tflite::testing::representative_64x16_bias_dims,
476 tflite::testing::representative_64x16_bias_data,
477 tflite::testing::representative_64x16_golden,
478 tflite::testing::representative_64x16_output_dims, kTfLiteActNone,
479 output_data),
480 kTfLiteOk);
481 }
482
483 #endif
484
TF_LITE_MICRO_TEST(Representative1x64Input1x16OutputQuantizedInt8)485 TF_LITE_MICRO_TEST(Representative1x64Input1x16OutputQuantizedInt8) {
486 const float input_scale = 0.051445;
487 const int input_zero_point = -128;
488 const float weights_scale = 0.005660;
489 const int weights_zero_point = 0;
490
491 const float output_scale = 0.069785;
492 const int output_zero_point = -9;
493
494 int8_t input_quantized[tflite::testing::representative_64x16_input_size];
495 int8_t weights_quantized[tflite::testing::representative_64x16_weights_size];
496 int32_t bias_quantized[tflite::testing::representative_64x16_output_size];
497 int8_t golden_quantized[tflite::testing::representative_64x16_output_size];
498 int8_t output_data[tflite::testing::representative_64x16_output_size];
499
500 TF_LITE_MICRO_EXPECT_EQ(
501 tflite::testing::TestFullyConnectedQuantized(
502 tflite::testing::representative_64x16_input_dims,
503 tflite::testing::representative_64x16_input_data, input_quantized,
504 input_scale, input_zero_point,
505 tflite::testing::representative_64x16_weights_dims,
506 tflite::testing::representative_64x16_weights_data, weights_quantized,
507 weights_scale, weights_zero_point,
508 tflite::testing::representative_64x16_bias_dims,
509 tflite::testing::representative_64x16_bias_data, bias_quantized,
510 tflite::testing::representative_64x16_golden, golden_quantized,
511 tflite::testing::representative_64x16_output_dims, output_scale,
512 output_zero_point, kTfLiteActNone, output_data),
513 kTfLiteOk);
514 }
515
516 TF_LITE_MICRO_TESTS_END
517