1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstddef>
17 #include <cstdint>
18 
19 #include "tensorflow/lite/c/builtin_op_data.h"
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/micro/all_ops_resolver.h"
22 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
23 #include "tensorflow/lite/micro/micro_utils.h"
24 #include "tensorflow/lite/micro/test_helpers.h"
25 #include "tensorflow/lite/micro/testing/micro_test.h"
26 
27 namespace tflite {
28 namespace testing {
29 namespace {
30 
31 // Simple test data for 2x2x10 input 2x3x10 weights.
32 const int simple_input_size = 20;
33 int simple_input_dims[] = {2, 2, 10};
34 const float simple_input_data[] = {
35     1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
36     1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
37 };
38 const int simple_weights_size = 30;
39 int simple_weights_dims[] = {2, 3, 10};
40 const float simple_weights_data[] = {
41     1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
42     1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
43     1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
44 };
45 int simple_bias_dims[] = {1, 3};
46 const float simple_bias_data[] = {1, 2, 3};
47 const float simple_golden[] = {
48     24, 25, 26, 58, 59, 60,
49 };
50 const int simple_output_size = 6;
51 int simple_output_dims[] = {2, 2, 3};
52 
53 // Test data for 2x2x10 input 2x3x10 weights with negative outputs to test relu.
54 const int relu_input_size = 20;
55 int relu_input_dims[] = {2, 2, 10};
56 const float relu_input_data[] = {
57     1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
58     1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
59 };
60 const int relu_weights_size = 30;
61 int relu_weights_dims[] = {2, 3, 10};
62 const float relu_weights_data[] = {
63     1,  2,  3,  4,  5,  6,  7,  8,  9,  10,   // u = 0
64     -1, -2, -3, -4, -5, -6, -7, -8, -9, -10,  // u = 1
65     1,  2,  3,  4,  5,  6,  7,  8,  9,  10,   // u = 2
66 };
67 int relu_bias_dims[] = {1, 3};
68 const float relu_bias_data[] = {1, -2, 3};
69 const float relu_golden[] = {
70     24, 0, 26, 58, 0, 60,
71 };
72 const int relu_output_size = 6;
73 int relu_output_dims[] = {2, 2, 3};
74 
75 // Input and filter similar to real model. Input shape is 1x64 and output is
76 // 1x16.
77 const int representative_64x16_input_size = 64;
78 int representative_64x16_input_dims[] = {2, 1, 64};
79 const float representative_64x16_input_data[] = {
80     0.0000, 0.1543, 0.0000, 0.0000, 1.8520, 0.0000, 4.7844, 1.1832,
81     0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.5948, 0.0000,
82     1.5948, 1.9549, 0.0000, 1.2347, 0.0000, 1.5948, 1.5948, 0.5145,
83     0.0000, 0.0000, 0.0000, 0.0000, 2.6237, 0.0000, 0.0000, 0.0000,
84     1.3890, 5.3503, 2.3665, 2.9838, 0.0000, 1.2861, 0.0000, 3.0867,
85     0.9775, 0.0000, 5.9676, 0.0000, 0.0000, 1.4405, 0.5145, 2.5723,
86     3.1896, 4.4757, 0.0000, 0.0000, 0.0000, 0.0000, 4.1671, 0.0000,
87     2.8295, 3.0353, 0.0000, 2.7780, 0.0000, 0.0000, 0.0000, 0.0000};
88 const int representative_64x16_weights_size = 64 * 16;
89 int representative_64x16_weights_dims[] = {2, 16, 64};
90 const float representative_64x16_weights_data[] = {
91     -0.1075, 0.1245,  0.1811,  -0.1302, -0.1868, 0.0679,  0.1245,  0.2321,
92     -0.1981, -0.2094, 0.1358,  -0.1698, 0.0113,  0.0566,  0.1358,  -0.2490,
93     0.0000,  -0.1189, -0.0170, -0.0396, -0.3113, 0.1641,  -0.4188, 0.0566,
94     -0.4471, 0.4754,  -0.0396, 0.0113,  -0.0340, 0.0170,  0.0170,  0.1811,
95     -0.0792, 0.4981,  0.2490,  -0.1924, 0.0792,  0.1868,  -0.1075, -0.3962,
96     0.1358,  0.2547,  -0.1245, -0.0962, -0.0283, 0.4132,  -0.0057, -0.5150,
97     0.1019,  0.1585,  -0.0962, -0.2207, -0.2377, 0.2830,  0.4471,  0.0170,
98     0.0566,  0.2038,  0.1019,  -0.0226, 0.2830,  0.1415,  0.0283,  -0.0792,
99     0.4301,  0.3226,  -0.1132, 0.4981,  -0.3849, -0.2943, -0.2547, -0.2264,
100     0.0453,  -0.0170, 0.0396,  0.1415,  0.3000,  0.2547,  0.0962,  0.2151,
101     -0.1585, -0.1302, -0.0057, -0.2773, 0.0283,  -0.0906, 0.1302,  -0.1075,
102     -0.0566, 0.1755,  0.2773,  0.0283,  0.0566,  0.1528,  -0.0736, -0.2830,
103     0.0792,  0.0962,  -0.2321, -0.0113, 0.2660,  -0.2887, -0.0566, 0.0057,
104     -0.2547, -0.0679, -0.2321, 0.0340,  0.1868,  0.2490,  0.2264,  -0.3509,
105     0.1585,  -0.0849, -0.0623, 0.1132,  0.3396,  -0.2490, 0.1528,  0.0679,
106     0.1755,  0.4754,  -0.0057, -0.2151, -0.1415, -0.1302, -0.2717, 0.1641,
107     0.5037,  -0.2321, 0.0170,  -0.1755, -0.1075, -0.0226, 0.2038,  -0.0340,
108     -0.5150, -0.3113, 0.1472,  -0.0226, 0.1528,  0.1189,  -0.1472, 0.0396,
109     -0.3000, -0.1924, -0.0283, 0.0283,  0.1641,  0.0736,  0.1472,  -0.1755,
110     -0.1132, 0.0113,  -0.1868, -0.2604, -0.3283, -0.0509, 0.0283,  -0.0679,
111     0.0623,  0.0792,  -0.0283, -0.0962, 0.0396,  0.1641,  0.4584,  0.3226,
112     0.0226,  -0.1811, 0.2377,  -0.1019, 0.2321,  0.1811,  -0.1924, -0.0057,
113     0.0736,  0.0113,  0.2547,  -0.2264, -0.0170, -0.0396, 0.1245,  -0.1415,
114     0.1755,  0.3679,  -0.2377, -0.0396, -0.1585, -0.3000, -0.1641, -0.1302,
115     -0.0396, -0.1698, 0.1189,  0.2434,  0.1132,  -0.1245, -0.1415, 0.0453,
116     0.1868,  -0.0906, -0.1189, -0.0509, 0.0057,  -0.1189, -0.0057, 0.0170,
117     -0.1924, 0.2207,  0.0792,  -0.4641, -0.2660, 0.2943,  0.1358,  -0.0340,
118     -0.3339, -0.1189, 0.0906,  -0.4358, 0.0453,  -0.1755, 0.1415,  0.0340,
119     0.1924,  -0.0057, 0.2321,  -0.2094, -0.1132, 0.0000,  0.1924,  -0.3000,
120     0.0340,  -0.3396, -0.0906, -0.0340, 0.1641,  -0.0226, -0.1472, -0.1019,
121     0.2377,  -0.0962, -0.3396, -0.5433, 0.0906,  0.2151,  -0.0679, 0.1755,
122     0.1528,  0.0283,  -0.4188, -0.0340, -0.0057, -0.0679, 0.0509,  0.1472,
123     -0.3849, -0.0113, 0.3962,  0.0849,  0.1472,  0.0340,  -0.1358, 0.1641,
124     -0.2038, 0.2151,  -0.1189, -0.3679, 0.0906,  -0.0679, 0.5716,  -0.0057,
125     -0.0736, 0.0113,  0.2830,  -0.2887, 0.0396,  0.0849,  -0.0736, -0.0736,
126     -0.3679, 0.2264,  0.0113,  -0.1641, 0.0396,  -0.1132, -0.0623, 0.3113,
127     0.5999,  -0.1415, 0.1472,  -0.2038, -0.1132, -0.2377, 0.0566,  0.1755,
128     -0.0057, -0.0453, 0.0226,  0.1132,  0.1698,  0.0340,  -0.0226, 0.0226,
129     0.4415,  -0.3792, 0.0792,  0.3736,  -0.5999, -0.3056, -0.1924, -0.1132,
130     -0.0962, 0.0283,  0.0000,  -0.3339, -0.3226, 0.3679,  -0.0453, -0.1641,
131     0.0170,  0.1302,  -0.0170, -0.0509, 0.1755,  -0.0283, -0.1302, -0.2887,
132     -0.0679, 0.0340,  0.4641,  0.2321,  0.7188,  0.3339,  -0.1075, 0.4754,
133     -0.0226, 0.3226,  -0.1528, -0.0849, 0.0509,  -0.1981, 0.0113,  0.2321,
134     0.2773,  -0.1019, 0.4075,  0.0396,  0.0792,  0.1132,  -0.0906, -0.4188,
135     0.1924,  -0.3679, -0.6396, 0.1358,  0.4981,  0.4132,  -0.0283, 0.3849,
136     -0.3509, -0.0566, -0.0962, 0.3113,  -0.1811, 0.4019,  0.0453,  -0.0057,
137     -0.1868, -0.2490, -0.0792, -0.3622, 0.1924,  -0.0453, -0.1528, -0.1811,
138     0.5943,  -0.1302, 0.3170,  -0.0170, 0.0509,  -0.1528, -0.1755, 0.5547,
139     0.2490,  -0.0906, 0.0000,  0.1698,  0.0000,  0.0340,  -0.1132, -0.0509,
140     -0.1755, -0.2943, 0.1472,  0.0849,  0.0000,  0.1528,  -0.0566, 0.1528,
141     -0.5264, -0.5320, -0.0736, 0.0566,  0.2604,  -0.4075, 0.0962,  -0.3453,
142     -0.1415, 0.0057,  0.3905,  0.2830,  0.3679,  0.5320,  -0.2660, 0.0340,
143     0.0736,  0.0057,  0.2207,  0.4471,  0.0849,  0.3000,  -0.0057, -0.0623,
144     0.1415,  -0.0566, 0.5264,  -0.0340, 0.0226,  -0.0623, -0.0113, -0.5037,
145     -0.4471, 0.0170,  -0.0396, -0.1358, -0.1698, 0.1924,  0.0057,  -0.1585,
146     0.0849,  -0.1698, 0.0057,  -0.1245, -0.0170, -0.1755, -0.0792, 0.5264,
147     0.1358,  0.2434,  0.1585,  -0.4188, -0.1472, -0.1358, -0.0849, -0.1189,
148     0.5037,  0.0736,  -0.0453, -0.2434, 0.1868,  -0.0679, 0.1415,  -0.2717,
149     0.2604,  0.0057,  -0.1528, -0.1811, 0.0226,  -0.1641, 0.3170,  -0.1981,
150     0.1245,  0.0226,  0.0566,  0.2830,  -0.1755, 0.0396,  -0.2094, 0.1924,
151     0.1698,  0.0283,  0.1641,  0.0849,  0.0000,  -0.1698, -0.1415, -0.3000,
152     0.4471,  0.3056,  -0.0283, -0.4245, -0.0453, 0.0226,  0.0000,  -0.1075,
153     -0.1528, -0.3226, 0.2773,  -0.2264, -0.1811, 0.1755,  -0.3566, -0.4188,
154     0.1755,  -0.0057, 0.2038,  0.1075,  0.3679,  -0.0792, 0.2207,  -0.0453,
155     0.3736,  0.2943,  -0.0113, -0.0623, 0.2264,  0.0113,  -0.0396, -0.2207,
156     0.0453,  -0.2830, -0.1302, 0.0623,  -0.1924, -0.1811, -0.2717, 0.2830,
157     0.2094,  0.0170,  -0.3170, -0.0283, -0.1189, -0.0509, -0.0566, -0.3622,
158     0.1132,  -0.0906, 0.1132,  0.4019,  -0.4698, -0.1019, -0.1075, -0.2094,
159     -0.2207, -0.0509, 0.0057,  0.1019,  -0.0509, 0.2264,  -0.5716, 0.0226,
160     -0.4019, 0.1641,  -0.3000, 0.3849,  0.1245,  0.0679,  0.3056,  0.2377,
161     0.0679,  -0.0170, -0.5377, -0.0170, 0.0057,  0.1358,  -0.1132, -0.2038,
162     0.0679,  0.1075,  -0.2773, 0.5943,  0.0623,  -0.1472, 0.3566,  0.0396,
163     -0.2377, 0.2604,  0.0849,  0.1358,  -0.3792, -0.0340, -0.1415, 0.3566,
164     -0.3736, 0.1245,  0.0566,  0.3396,  0.0736,  0.4019,  -0.1528, 0.1075,
165     0.0792,  -0.2547, 0.0453,  -0.1755, 0.1868,  -0.2547, 0.1075,  0.0623,
166     0.1698,  -0.0170, 0.1585,  -0.0736, -0.4358, -0.0113, -0.6792, -0.0849,
167     -0.0396, -0.6056, 0.1358,  0.1189,  0.2547,  0.1528,  0.2887,  0.0453,
168     -0.1075, -0.3283, -0.0453, -0.0509, 0.2038,  0.2547,  0.0849,  -0.0566,
169     -0.1698, 0.0509,  -0.0113, -0.1585, 0.1924,  -0.0792, -0.1868, 0.0509,
170     -0.1698, -0.0849, -0.0170, 0.0453,  0.3170,  0.0906,  -0.5943, -0.1245,
171     0.1585,  -0.1755, -0.2151, 0.0906,  0.1924,  0.3170,  -0.2490, -0.5660,
172     -0.0283, 0.0962,  -0.1358, 0.1585,  0.0057,  -0.2604, 0.1189,  -0.0170,
173     0.3509,  0.0623,  0.0679,  -0.1302, -0.0792, 0.0906,  -0.0792, 0.0849,
174     -0.1924, 0.2604,  -0.1245, -0.3679, 0.0340,  0.0113,  -0.1698, 0.2490,
175     0.0283,  0.1019,  -0.3736, 0.1019,  -0.2207, -0.0340, 0.3170,  0.1755,
176     0.0962,  0.3226,  -0.0113, -0.1189, -0.2321, -0.0226, -0.2434, -0.0170,
177     -0.1585, -0.0283, -0.1132, 0.0679,  -0.4188, -0.0453, 0.1528,  -0.1302,
178     -0.3792, 0.1415,  -0.1358, -0.1811, 0.1302,  0.1415,  0.5207,  0.0509,
179     -0.1358, -0.0396, -0.2434, 0.0396,  0.0792,  -0.2264, -0.1415, 0.0906,
180     0.1245,  0.0170,  0.0623,  -0.1415, 0.2773,  -0.3566, -0.0396, 0.2887,
181     0.4188,  0.1698,  -0.2547, 0.1132,  -0.0453, -0.0113, -0.1358, 0.1075,
182     0.0566,  0.1075,  0.2604,  -0.0849, -0.2490, 0.1415,  0.0509,  -0.2151,
183     0.0340,  0.1698,  0.0509,  -0.0906, 0.0566,  -0.1075, -0.2151, 0.2038,
184     -0.1924, -0.0113, 0.2830,  0.1358,  -0.1189, 0.0113,  -0.5603, -0.2830,
185     -0.2943, 0.0453,  -0.0396, 0.1358,  0.0566,  0.2038,  -0.3283, -0.0509,
186     0.0509,  0.1641,  0.2094,  -0.2038, -0.1868, -0.1585, -0.2207, -0.1302,
187     0.0396,  -0.1019, -0.0679, 0.1075,  -0.4584, -0.2207, 0.2434,  -0.0113,
188     0.0849,  0.1755,  -0.3056, 0.1585,  -0.2547, 0.0453,  0.0906,  -0.1358,
189     -0.0679, -0.0509, 0.0679,  -0.3509, 0.0057,  0.0453,  0.4132,  -0.1981,
190     0.2264,  -0.0736, 0.1075,  0.0679,  -0.0906, -0.3113, 0.0509,  0.0849,
191     0.2604,  0.0623,  -0.3113, 0.3849,  0.0000,  0.6396,  -0.2038, -0.1019,
192     0.1245,  -0.0453, 0.1641,  0.1075,  -0.1075, -0.2660, -0.4528, -0.0566,
193     -0.0170, 0.0453,  0.0340,  0.1189,  -0.2434, -0.0283, -0.1811, 0.2547,
194     0.0000,  -0.0226, 0.4471,  0.1019,  -0.1472, 0.0849,  0.1075,  0.1075,
195     0.0283,  -0.2773, 0.4415,  -0.1811, 0.2717,  0.3170,  0.0509,  0.0623,
196     -0.0962, 0.1585,  -0.0792, -0.1811, -0.0792, -0.3283, 0.0962,  -0.1698,
197     -0.0736, 0.0453,  0.0962,  -0.3566, -0.4584, 0.3396,  -0.4811, 0.3056,
198     -0.1755, 0.2490,  -0.1698, -0.2377, -0.3339, -0.0453, 0.1811,  0.0736,
199     0.0340,  -0.0962, -0.0113, -0.3056, -0.3339, 0.2038,  0.2038,  -0.1924,
200     0.2547,  -0.4471, -0.0849, -0.2038, 0.3566,  -0.4811, 0.3453,  0.0849,
201     0.1189,  0.3170,  -0.1358, 0.2717,  0.0113,  -0.4754, -0.1924, 0.4245,
202     -0.2773, 0.3453,  0.2264,  0.2943,  0.5320,  0.2773,  -0.2264, -0.1019,
203     -0.1132, -0.3962, 0.3679,  0.0509,  -0.0623, -0.0906, -0.5603, -0.1641,
204     -0.3170, -0.2377, 0.1415,  -0.0509, 0.0792,  0.0170,  -0.0226, -0.0057,
205     -0.1358, -0.4245, 0.3905,  0.3113,  0.0340,  -0.1189, 0.2887,  -0.2943,
206     -0.3056, 0.2434,  0.1019,  -0.0170, 0.3849,  0.1528,  -0.0736, -0.0170,
207     0.0792,  0.1755,  0.0509,  0.3509,  0.1472,  0.1528,  0.1472,  0.0057,
208     0.0113,  -0.0113, -0.3283, -0.3962, -0.0792, -0.1245, -0.0283, -0.1868,
209     0.4019,  0.2943,  -0.0906, -0.2321, 0.6056,  0.1189,  0.0340,  -0.2207,
210     -0.0453, 0.3339,  0.2377,  -0.1641, 0.3736,  0.2151,  -0.2547, 0.0453,
211     0.1924,  -0.1019, -0.0340, -0.2207, 0.3962,  -0.4471, -0.2547, -0.2151,
212     -0.3736, 0.0283,  0.1189,  0.0283,  0.0736,  0.0396,  0.1019,  0.0283,
213     0.0170,  0.2321,  0.3509,  -0.0226, -0.0226, 0.0736,  0.0283,  0.1641,
214     -0.0906, 0.1811,  0.0226,  0.5716,  -0.0396, -0.0509, -0.1641, -0.0509,
215     0.4132,  -0.2604, 0.1019,  -0.0283, -0.0340, 0.0453,  0.1472,  -0.0057,
216     0.2717,  -0.2094, 0.3396,  0.0340,  0.1245,  0.2547,  -0.5886, 0.2717,
217     -0.0906, 0.1641,  0.0962,  -0.0792, -0.0113, 0.2264,  -0.0736, 0.3170,
218     0.0623,  0.0679,  0.0623,  -0.0792, -0.2207, 0.1924,  0.1245,  -0.2773};
219 int representative_64x16_bias_dims[] = {1, 16};
220 const float representative_64x16_bias_data[] = {
221     -0.0084, 0.0006,  0.0000,  0.0000,  -0.0087, -0.0006, -0.0003, -0.0003,
222     0.0006,  -0.0003, -0.0003, -0.0003, -0.0253, 0.0012,  0.0000,  0.0000};
223 const float representative_64x16_golden[] = {
224     3.8624,  -2.9580, 4.3043,  -1.2844, -1.5769, -2.7998, -0.1011, -3.4029,
225     -1.0557, -7.1931, -1.4852, -0.4163, 1.7186,  -0.6965, 0.3580,  2.7378};
226 const int representative_64x16_output_size = 16;
227 int representative_64x16_output_dims[] = {2, 1, 16};
228 
229 template <typename T>
ValidateFullyConnectedGoldens(TfLiteTensor * tensors,const int tensors_size,const TfLiteFusedActivation activation,const float tolerance,const int output_len,const T * golden,T * output_data)230 TfLiteStatus ValidateFullyConnectedGoldens(
231     TfLiteTensor* tensors, const int tensors_size,
232     const TfLiteFusedActivation activation, const float tolerance,
233     const int output_len, const T* golden, T* output_data) {
234   TfLiteFullyConnectedParams builtin_data = {
235       activation, kTfLiteFullyConnectedWeightsFormatDefault, false, false};
236 
237   int inputs_array_data[] = {3, 0, 1, 2};
238   TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
239   int outputs_array_data[] = {1, 3};
240   TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
241 
242   const TfLiteRegistration registration = Register_FULLY_CONNECTED();
243   micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
244                              outputs_array,
245                              reinterpret_cast<void*>(&builtin_data));
246 
247   TfLiteStatus status = runner.InitAndPrepare();
248   if (status != kTfLiteOk) {
249     return status;
250   }
251 
252   status = runner.Invoke();
253   if (status != kTfLiteOk) {
254     return status;
255   }
256 
257   for (int i = 0; i < output_len; ++i) {
258     TF_LITE_MICRO_EXPECT_NEAR(golden[i], output_data[i], tolerance);
259   }
260   return kTfLiteOk;
261 }
262 
263 #if !defined(XTENSA)  // Needed to avoid build error from unused functions.
TestFullyConnectedFloat(int * input_dims_data,const float * input_data,int * weights_dims_data,const float * weights_data,int * bias_dims_data,const float * bias_data,const float * golden,int * output_dims_data,TfLiteFusedActivation activation,float * output_data)264 TfLiteStatus TestFullyConnectedFloat(
265     int* input_dims_data, const float* input_data, int* weights_dims_data,
266     const float* weights_data, int* bias_dims_data, const float* bias_data,
267     const float* golden, int* output_dims_data,
268     TfLiteFusedActivation activation, float* output_data) {
269   TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
270   TfLiteIntArray* weights_dims = IntArrayFromInts(weights_dims_data);
271   TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data);
272   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
273   const int output_dims_count = ElementCount(*output_dims);
274 
275   constexpr int inputs_size = 3;
276   constexpr int outputs_size = 1;
277   constexpr int tensors_size = inputs_size + outputs_size;
278   TfLiteTensor tensors[tensors_size] = {
279       CreateTensor(input_data, input_dims),
280       CreateTensor(weights_data, weights_dims),
281       CreateTensor(bias_data, bias_dims),
282       CreateTensor(output_data, output_dims),
283   };
284 
285   return ValidateFullyConnectedGoldens(tensors, tensors_size, activation, 1e-4f,
286                                        output_dims_count, golden, output_data);
287 }
288 #endif
289 
290 template <typename T>
TestFullyConnectedQuantized(int * input_dims_data,const float * input_data,T * input_quantized,const float input_scale,const int input_zero_point,int * weights_dims_data,const float * weights_data,T * weights_quantized,const float weights_scale,const int weights_zero_point,int * bias_dims_data,const float * bias_data,int32_t * bias_quantized,const float * golden,T * golden_quantized,int * output_dims_data,const float output_scale,const int output_zero_point,TfLiteFusedActivation activation,T * output_data)291 TfLiteStatus TestFullyConnectedQuantized(
292     int* input_dims_data, const float* input_data, T* input_quantized,
293     const float input_scale, const int input_zero_point, int* weights_dims_data,
294     const float* weights_data, T* weights_quantized, const float weights_scale,
295     const int weights_zero_point, int* bias_dims_data, const float* bias_data,
296     int32_t* bias_quantized, const float* golden, T* golden_quantized,
297     int* output_dims_data, const float output_scale,
298     const int output_zero_point, TfLiteFusedActivation activation,
299     T* output_data) {
300   TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
301   TfLiteIntArray* weights_dims = IntArrayFromInts(weights_dims_data);
302   TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data);
303   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
304   const int output_dims_count = ElementCount(*output_dims);
305 
306   constexpr int inputs_size = 3;
307   constexpr int outputs_size = 1;
308   constexpr int tensors_size = inputs_size + outputs_size;
309   TfLiteTensor tensors[tensors_size] = {
310       CreateQuantizedTensor(input_data, input_quantized, input_dims,
311                             input_scale, input_zero_point),
312       CreateQuantizedTensor(weights_data, weights_quantized, weights_dims,
313                             weights_scale, weights_zero_point),
314       CreateQuantizedBiasTensor(bias_data, bias_quantized, bias_dims,
315                                 input_scale, weights_scale),
316       CreateQuantizedTensor(output_data, output_dims, output_scale,
317                             output_zero_point),
318   };
319 
320   Quantize(golden, golden_quantized, output_dims_count, output_scale,
321            output_zero_point);
322 
323   return ValidateFullyConnectedGoldens(tensors, tensors_size, activation, 0.0f,
324                                        output_dims_count, golden_quantized,
325                                        output_data);
326 }
327 
328 }  // namespace
329 }  // namespace testing
330 }  // namespace tflite
331 
332 TF_LITE_MICRO_TESTS_BEGIN
333 
334 #if !defined(XTENSA) && !defined(CEVA_BX1) && !defined(CEVA_SP500)
335 // TODO(b/170503075): xtensa kernels are less general
336 // than reference kernels and we ifdef out test cases that are currently known
337 // to fail.
338 
339 // CEVA's fully connected implementation assumes weights_zero_point=0 as
340 // described in TFLite's quantization specification. tests which use a different
341 // zero point will so ifdefed out.
342 // See tflite quantization spec:
343 // https://www.tensorflow.org/lite/performance/quantization_spec
TF_LITE_MICRO_TEST(SimpleTest)344 TF_LITE_MICRO_TEST(SimpleTest) {
345   float output_data[tflite::testing::simple_output_size];
346   TF_LITE_MICRO_EXPECT_EQ(
347       tflite::testing::TestFullyConnectedFloat(
348           tflite::testing::simple_input_dims,
349           tflite::testing::simple_input_data,
350           tflite::testing::simple_weights_dims,
351           tflite::testing::simple_weights_data,
352           tflite::testing::simple_bias_dims, tflite::testing::simple_bias_data,
353           tflite::testing::simple_golden, tflite::testing::simple_output_dims,
354           kTfLiteActNone, output_data),
355       kTfLiteOk);
356 }
357 
358 #endif
359 
TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8)360 TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8) {
361   const float input_scale = 1.0f;
362   const int input_zero_point = -1;
363   const float weights_scale = 1.0f;
364   const int weights_zero_point = 0;
365   const float output_scale = 0.5f;
366   const int output_zero_point = -1;
367 
368   int8_t input_quantized[tflite::testing::simple_input_size];
369   int8_t weights_quantized[tflite::testing::simple_weights_size];
370   int32_t bias_quantized[tflite::testing::simple_output_size];
371   int8_t golden_quantized[tflite::testing::simple_output_size];
372   int8_t output_data[tflite::testing::simple_output_size];
373 
374   TF_LITE_MICRO_EXPECT_EQ(
375       tflite::testing::TestFullyConnectedQuantized(
376           tflite::testing::simple_input_dims,
377           tflite::testing::simple_input_data, input_quantized, input_scale,
378           input_zero_point, tflite::testing::simple_weights_dims,
379           tflite::testing::simple_weights_data, weights_quantized,
380           weights_scale, weights_zero_point, tflite::testing::simple_bias_dims,
381           tflite::testing::simple_bias_data, bias_quantized,
382           tflite::testing::simple_golden, golden_quantized,
383           tflite::testing::simple_output_dims, output_scale, output_zero_point,
384           kTfLiteActNone, output_data),
385       kTfLiteOk);
386 }
387 
TF_LITE_MICRO_TEST(SimpleTest4DInputQuantizedInt8)388 TF_LITE_MICRO_TEST(SimpleTest4DInputQuantizedInt8) {
389   const float input_scale = 1.0f;
390   const int input_zero_point = -1;
391   const float weights_scale = 1.0f;
392   const int weights_zero_point = 0;
393 
394   const float output_scale = 0.5f;
395   const int output_zero_point = -1;
396 
397   int input_dims_4d[] = {4, 1, 1, 2, 10};
398 
399   int8_t input_quantized[tflite::testing::simple_input_size];
400   int8_t weights_quantized[tflite::testing::simple_weights_size];
401   int32_t bias_quantized[tflite::testing::simple_output_size];
402   int8_t golden_quantized[tflite::testing::simple_output_size];
403   int8_t output_data[tflite::testing::simple_output_size];
404 
405   TF_LITE_MICRO_EXPECT_EQ(
406       tflite::testing::TestFullyConnectedQuantized(
407           input_dims_4d, tflite::testing::simple_input_data, input_quantized,
408           input_scale, input_zero_point, tflite::testing::simple_weights_dims,
409           tflite::testing::simple_weights_data, weights_quantized,
410           weights_scale, weights_zero_point, tflite::testing::simple_bias_dims,
411           tflite::testing::simple_bias_data, bias_quantized,
412           tflite::testing::simple_golden, golden_quantized,
413           tflite::testing::simple_output_dims, output_scale, output_zero_point,
414           kTfLiteActNone, output_data),
415       kTfLiteOk);
416 }
417 
TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8Relu)418 TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8Relu) {
419   const float input_scale = 1.0f;
420   const int input_zero_point = -1;
421   const float weights_scale = 1.0f;
422   const int weights_zero_point = 0;
423 
424   const float output_scale = 0.5f;
425   const int output_zero_point = -128;
426 
427   int8_t input_quantized[tflite::testing::relu_input_size];
428   int8_t weights_quantized[tflite::testing::relu_weights_size];
429   int32_t bias_quantized[tflite::testing::relu_output_size];
430   int8_t golden_quantized[tflite::testing::relu_output_size];
431   int8_t output_data[tflite::testing::relu_output_size];
432 
433   TF_LITE_MICRO_EXPECT_EQ(
434       tflite::testing::TestFullyConnectedQuantized(
435           tflite::testing::relu_input_dims, tflite::testing::relu_input_data,
436           input_quantized, input_scale, input_zero_point,
437           tflite::testing::relu_weights_dims,
438           tflite::testing::relu_weights_data, weights_quantized, weights_scale,
439           weights_zero_point, tflite::testing::relu_bias_dims,
440           tflite::testing::relu_bias_data, bias_quantized,
441           tflite::testing::relu_golden, golden_quantized,
442           tflite::testing::relu_output_dims, output_scale, output_zero_point,
443           kTfLiteActRelu, output_data),
444       kTfLiteOk);
445 }
446 
447 #if !defined(XTENSA)  // TODO(b/170503075): xtensa kernels are less general than
448                       // reference kernels and we ifdef out test cases that are
449                       // currently known to fail.
TF_LITE_MICRO_TEST(SimpleTest4DInput)450 TF_LITE_MICRO_TEST(SimpleTest4DInput) {
451   int input_dims_4d[] = {4, 1, 1, 2, 10};
452 
453   float output_data[tflite::testing::simple_output_size];
454 
455   TF_LITE_MICRO_EXPECT_EQ(
456       tflite::testing::TestFullyConnectedFloat(
457           input_dims_4d, tflite::testing::simple_input_data,
458           tflite::testing::simple_weights_dims,
459           tflite::testing::simple_weights_data,
460           tflite::testing::simple_bias_dims, tflite::testing::simple_bias_data,
461           tflite::testing::simple_golden, tflite::testing::simple_output_dims,
462           kTfLiteActNone, output_data),
463       kTfLiteOk);
464 }
465 
TF_LITE_MICRO_TEST(Representative1x64Input1x16Output)466 TF_LITE_MICRO_TEST(Representative1x64Input1x16Output) {
467   float output_data[tflite::testing::representative_64x16_output_size];
468 
469   TF_LITE_MICRO_EXPECT_EQ(
470       tflite::testing::TestFullyConnectedFloat(
471           tflite::testing::representative_64x16_input_dims,
472           tflite::testing::representative_64x16_input_data,
473           tflite::testing::representative_64x16_weights_dims,
474           tflite::testing::representative_64x16_weights_data,
475           tflite::testing::representative_64x16_bias_dims,
476           tflite::testing::representative_64x16_bias_data,
477           tflite::testing::representative_64x16_golden,
478           tflite::testing::representative_64x16_output_dims, kTfLiteActNone,
479           output_data),
480       kTfLiteOk);
481 }
482 
483 #endif
484 
TF_LITE_MICRO_TEST(Representative1x64Input1x16OutputQuantizedInt8)485 TF_LITE_MICRO_TEST(Representative1x64Input1x16OutputQuantizedInt8) {
486   const float input_scale = 0.051445;
487   const int input_zero_point = -128;
488   const float weights_scale = 0.005660;
489   const int weights_zero_point = 0;
490 
491   const float output_scale = 0.069785;
492   const int output_zero_point = -9;
493 
494   int8_t input_quantized[tflite::testing::representative_64x16_input_size];
495   int8_t weights_quantized[tflite::testing::representative_64x16_weights_size];
496   int32_t bias_quantized[tflite::testing::representative_64x16_output_size];
497   int8_t golden_quantized[tflite::testing::representative_64x16_output_size];
498   int8_t output_data[tflite::testing::representative_64x16_output_size];
499 
500   TF_LITE_MICRO_EXPECT_EQ(
501       tflite::testing::TestFullyConnectedQuantized(
502           tflite::testing::representative_64x16_input_dims,
503           tflite::testing::representative_64x16_input_data, input_quantized,
504           input_scale, input_zero_point,
505           tflite::testing::representative_64x16_weights_dims,
506           tflite::testing::representative_64x16_weights_data, weights_quantized,
507           weights_scale, weights_zero_point,
508           tflite::testing::representative_64x16_bias_dims,
509           tflite::testing::representative_64x16_bias_data, bias_quantized,
510           tflite::testing::representative_64x16_golden, golden_quantized,
511           tflite::testing::representative_64x16_output_dims, output_scale,
512           output_zero_point, kTfLiteActNone, output_data),
513       kTfLiteOk);
514 }
515 
516 TF_LITE_MICRO_TESTS_END
517