1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
19 #include "tensorflow/lite/micro/test_helpers.h"
20 #include "tensorflow/lite/micro/testing/micro_test.h"
21 namespace tflite {
22 namespace testing {
23 namespace {
24 
25 // naming as follows: <tensor name>_<input size>x<batch size>x<batch count>
26 
27 // 10 inputs each with shape {2, 2}.
28 const float input_data_2x2x10[] = {
29     0.12609188,  -0.46347019, 0.35867718,  0.36897406,
30 
31     0.14278367,  -1.64410412, -0.57290924, 0.12729003,
32 
33     0.49837467,  0.19278903,  0.17660543,  0.52949083,
34 
35     -0.11186574, 0.13164264,  -0.72674477, -0.5683046,
36 
37     -0.68892461, 0.37783599,  -0.63690937, 0.44483393,
38 
39     -0.81299269, -0.86831826, -0.95760226, 1.82078898,
40 
41     -1.45006323, -0.82251364, -1.65087092, -1.89238167,
42 
43     0.03966608,  -0.24936394, 2.06740379,  -1.51439476,
44 
45     0.11771342,  -0.23761693, 0.31088525,  -1.55601168,
46 
47     -0.89477462, 1.67204106,  -0.6230064,  0.29819036,
48 };
49 
50 // Feature filter of shape {8, 2}.
51 const float feature_weights_data_2x2x10[] = {
52     -0.31930989, 0.0079667,  0.39296314,  0.37613347,  0.12416199,  0.15785322,
53     0.27901134,  0.3905206,  0.21931258,  -0.36137494, -0.10640851, 0.31053296,
54     -0.36118156, -0.0976817, -0.36916667, 0.22197971};
55 
56 // Time filter of shape {8, 10}.
57 const float time_weights_data_2x2x10[] = {
58     -0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
59     0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
60 
61     0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
62     -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
63 
64     -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
65     0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
66 
67     -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
68     -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657,
69 
70     -0.14884081, 0.19931212,  -0.36002168, 0.34663299,  -0.11405486,
71     0.12672701,  0.39463779,  -0.07886535, -0.06384811, 0.08249187,
72 
73     -0.26816407, -0.19905911, 0.29211238,  0.31264046,  -0.28664589,
74     0.05698794,  0.11613581,  0.14078894,  0.02187902,  -0.21781836,
75 
76     -0.15567942, 0.08693647,  -0.38256618, 0.36580828,  -0.22922277,
77     -0.0226903,  0.12878349,  -0.28122205, -0.10850525, -0.11955214,
78 
79     0.27179423,  -0.04710215, 0.31069002,  0.22672787,  0.09580326,
80     0.08682203,  0.1258215,   0.1851041,   0.29228821,  0.12366763};
81 
82 // Activation state with shape {2, 80}. These initial values must be copied into
83 // a mutable activation state tensor.
84 
85 const float initial_activation_state_data_2x2x10[] = {
86     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
87     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
88     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
89     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
90     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
91     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
92     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
93 
94 // Bias with shape {8}
95 const float bias_data_2x2x10[] = {0, 0, 0, 0, 0, 0, 0, 0};
96 
97 // 10 outputs each of shape {2, 4}
98 const float golden_output_2x2x10[] = {
99     -0.044205, -0.013757, 0.050369,  -0.018447,
100     0.073010,  0.025142,  -0.021154, 0.013551,
101 
102     -0.209613, -0.062421, 0.150209,  -0.108334,
103     0.028256,  -0.006950, -0.030885, 0.009603,
104 
105     -0.076800, -0.037075, -0.087198, -0.155183,
106     0.091069,  0.098446,  -0.016083, 0.106475,
107 
108     -0.082123, -0.162238, -0.084434, -0.141074,
109     -0.029340, -0.090685, 0.053302,  -0.030604,
110 
111     -0.201440, 0.088424,  0.139877,  0.012416,
112     -0.113212, 0.103893,  -0.100842, 0.122780,
113 
114     -0.166632, -0.116705, 0.175298,  -0.047163,
115     0.313077,  -0.166485, -0.285860, 0.129069,
116 
117     -0.625911, 0.046134,  0.138081,  -0.129581,
118     -0.521455, -0.061579, 0.230289,  0.114963,
119 
120     -0.216693, -0.161643, -0.179177, -0.052599,
121     -0.213239, 0.029502,  0.260858,  0.275045,
122 
123     -0.213689, -0.323608, -0.285635, -0.317687,
124     -0.324092, -0.317972, -0.208450, -0.462504,
125 
126     -0.255126, -0.218576, -0.041528, 0.179421,
127     -0.440583, 0.072127,  -0.284136, 0.241570};
128 
129 // Simulated real-world inputs, weights and expected outputs.
130 
131 // Input of shape {1x16}
132 const float input_data_16x1x1[] = {
133     -0.488494, 2.023762,  -2.233117, -0.488494, 3.559030, 9.490748,
134     -3.210106, -1.953977, -0.279140, 0.907204,  1.674838, 0.000000,
135     -0.279140, -0.628064, -0.069785, -0.628064,
136 };
137 
138 // Feature filter of shape {64, 16}.
139 const float feature_weights_data_16x1x1[] = {
140     0.173588,  0.173588,  -0.024798, 0.193426,  -0.099193, 0.044637,  0.183507,
141     0.183507,  0.044637,  0.198386,  -0.069435, 0.084314,  0.312458,  0.024798,
142     0.173588,  -0.049596, -0.352135, -0.550521, -0.009919, -0.099193, -0.074395,
143     -0.128951, 0.193426,  0.357095,  -0.317418, -0.119032, -0.218225, -0.004960,
144     -0.386853, -0.133911, 0.252942,  -0.019839, -0.024798, -0.054556, -0.069435,
145     -0.128951, 0.029758,  -0.099193, -0.312458, -0.029758, 0.064475,  0.183507,
146     0.114072,  -0.178547, -0.247982, -0.119032, 0.243023,  -0.119032, -0.034718,
147     -0.178547, 0.019839,  0.128951,  -0.223184, -0.009919, -0.213265, 0.168628,
148     -0.143830, -0.322377, -0.218225, -0.193426, -0.252942, -0.049596, 0.064475,
149     -0.267821, -0.580279, -0.099193, 0.213265,  0.119032,  -0.119032, -0.178547,
150     0.610037,  0.109112,  0.049596,  -0.014879, -0.049596, -0.193426, 0.039677,
151     -0.148789, -0.114072, -0.158709, -0.158709, 0.094233,  0.099193,  -0.114072,
152     0.104153,  -0.123991, 0.198386,  -0.173588, 0.089274,  -0.247982, -0.054556,
153     0.123991,  0.183507,  0.114072,  0.188467,  0.302539,  0.044637,  0.039677,
154     -0.099193, 0.168628,  -0.024798, -0.054556, -0.109112, 0.014879,  -0.009919,
155     0.069435,  -0.396772, -0.287660, -0.079354, -0.104153, 0.054556,  0.089274,
156     -0.099193, 0.114072,  0.034718,  0.119032,  0.282700,  -0.119032, -0.505884,
157     -0.233104, -0.114072, -0.257902, -0.233104, -0.178547, 0.153749,  0.128951,
158     0.143830,  -0.188467, -0.183507, 0.104153,  -0.024798, 0.193426,  -0.287660,
159     0.168628,  -0.009919, 0.119032,  -0.024798, -0.099193, -0.203346, 0.099193,
160     0.084314,  -0.168628, 0.123991,  -0.148789, 0.114072,  -0.029758, 0.228144,
161     -0.238063, 0.089274,  -0.064475, 0.307498,  -0.188467, -0.004960, -0.252942,
162     -0.173588, -0.158709, -0.044637, -0.009919, 0.312458,  -0.262861, 0.059516,
163     0.158709,  0.069435,  -0.282700, 0.074395,  -0.322377, -0.183507, -0.123991,
164     -0.233104, 0.009919,  0.252942,  -0.243023, 0.555481,  -0.099193, -0.119032,
165     -0.441409, 0.148789,  0.084314,  -0.168628, -0.183507, 0.188467,  0.024798,
166     -0.302539, 0.223184,  0.143830,  -0.193426, -0.054556, -0.218225, -0.297579,
167     0.104153,  0.272781,  -0.034718, 0.114072,  -0.059516, 0.044637,  0.342216,
168     0.421570,  0.138870,  -0.024798, -0.039677, -0.163668, -0.034718, 0.396772,
169     -0.128951, -0.044637, -0.173588, 0.302539,  0.079354,  0.049596,  0.133911,
170     -0.029758, -0.312458, -0.029758, 0.079354,  0.128951,  0.252942,  0.213265,
171     0.014879,  0.287660,  0.178547,  0.297579,  0.352135,  0.401732,  0.024798,
172     -0.277740, -0.411651, -0.069435, 0.342216,  -0.158709, -0.104153, -0.009919,
173     0.223184,  0.228144,  -0.019839, 0.059516,  -0.104153, -0.510844, 0.029758,
174     -0.406691, 0.089274,  0.421570,  0.163668,  -0.143830, -0.019839, -0.039677,
175     0.104153,  -0.044637, -0.128951, 0.203346,  0.079354,  -0.069435, 0.094233,
176     -0.138870, 0.466207,  -0.163668, 0.049596,  0.029758,  0.267821,  0.029758,
177     -0.049596, 0.009919,  0.004960,  -0.099193, 0.094233,  -0.262861, 0.089274,
178     -0.302539, 0.332297,  -0.307498, -0.014879, 0.168628,  -0.094233, -0.272781,
179     0.034718,  -0.133911, -0.228144, 0.094233,  0.257902,  -0.228144, 0.153749,
180     -0.054556, -0.252942, 0.054556,  0.218225,  -0.054556, 0.302539,  0.282700,
181     0.054556,  -0.044637, -0.133911, 0.233104,  -0.049596, 0.411651,  0.044637,
182     -0.297579, -0.029758, -0.114072, 0.114072,  -0.580279, 0.079354,  -0.024798,
183     -0.347175, -0.128951, -0.099193, 0.238063,  -0.104153, -0.009919, 0.158709,
184     -0.034718, 0.123991,  -0.163668, 0.059516,  0.342216,  0.009919,  0.064475,
185     -0.307498, -0.520763, -0.238063, 0.163668,  0.362054,  0.034718,  -0.178547,
186     -0.104153, -0.257902, 0.322377,  0.054556,  0.148789,  -0.178547, 0.084314,
187     0.004960,  0.257902,  0.029758,  0.079354,  -0.223184, -0.193426, 0.282700,
188     0.000000,  -0.019839, -0.114072, 0.491005,  -0.193426, -0.029758, -0.243023,
189     0.009919,  0.089274,  -0.277740, -0.089274, 0.104153,  0.337256,  0.138870,
190     -0.307498, -0.054556, 0.352135,  0.133911,  -0.044637, 0.133911,  -0.089274,
191     -0.357095, -0.272781, 0.069435,  0.059516,  -0.109112, 0.148789,  -0.044637,
192     -0.019839, -0.153749, 0.123991,  -0.223184, 0.322377,  0.074395,  -0.312458,
193     0.024798,  -0.223184, 0.109112,  -0.138870, 0.218225,  -0.074395, -0.406691,
194     0.009919,  -0.198386, -0.009919, 0.416611,  0.178547,  0.148789,  0.133911,
195     -0.004960, 0.069435,  -0.054556, -0.044637, 0.297579,  0.059516,  -0.456288,
196     -0.148789, -0.004960, 0.054556,  0.094233,  -0.104153, 0.198386,  -0.302539,
197     0.133911,  0.411651,  0.054556,  0.525723,  -0.089274, 0.079354,  0.238063,
198     0.079354,  -0.039677, 0.039677,  0.029758,  0.332297,  -0.014879, -0.367014,
199     -0.143830, -0.123991, -0.064475, 0.014879,  0.173588,  -0.168628, 0.386853,
200     0.009919,  0.173588,  0.163668,  0.123991,  0.163668,  0.198386,  0.203346,
201     -0.401732, -0.009919, 0.272781,  -0.173588, 0.044637,  0.238063,  0.133911,
202     0.049596,  0.208305,  -0.024798, 0.049596,  -0.049596, 0.034718,  -0.446368,
203     0.466207,  -0.089274, -0.099193, -0.128951, -0.228144, 0.014879,  -0.252942,
204     0.074395,  -0.223184, -0.168628, -0.292619, 0.178547,  0.153749,  -0.014879,
205     0.054556,  0.000000,  0.193426,  0.158709,  0.178547,  -0.327337, -0.138870,
206     -0.114072, 0.168628,  0.297579,  -0.109112, -0.029758, -0.029758, -0.416611,
207     0.059516,  0.000000,  -0.168628, -0.322377, 0.238063,  -0.128951, -0.029758,
208     0.500925,  0.292619,  0.123991,  -0.099193, 0.074395,  0.317418,  -0.148789,
209     0.064475,  -0.104153, -0.044637, -0.094233, 0.188467,  -0.044637, 0.213265,
210     -0.233104, -0.049596, 0.004960,  -0.198386, 0.287660,  -0.148789, -0.257902,
211     0.004960,  -0.218225, -0.044637, -0.386853, -0.243023, -0.163668, 0.094233,
212     0.029758,  -0.019839, -0.009919, -0.143830, -0.158709, 0.158709,  -0.243023,
213     -0.039677, -0.297579, 0.069435,  0.049596,  0.302539,  0.059516,  0.074395,
214     -0.019839, 0.352135,  -0.019839, -0.138870, -0.178547, -0.243023, 0.233104,
215     0.252942,  -0.228144, -0.049596, 0.173588,  0.173588,  -0.074395, -0.034718,
216     -0.292619, 0.362054,  0.183507,  0.243023,  -0.203346, -0.044637, 0.054556,
217     0.059516,  -0.158709, -0.158709, 0.000000,  0.327337,  0.119032,  0.034718,
218     -0.044637, -0.089274, 0.089274,  -0.233104, 0.000000,  -0.317418, 0.371974,
219     0.213265,  0.307498,  -0.178547, -0.367014, 0.039677,  -0.059516, 0.168628,
220     -0.014879, 0.143830,  0.123991,  -0.084314, -0.332297, -0.416611, 0.183507,
221     0.109112,  -0.039677, 0.014879,  0.292619,  -0.213265, -0.054556, 0.004960,
222     0.123991,  0.119032,  0.000000,  -0.332297, -0.312458, -0.198386, -0.213265,
223     0.119032,  0.322377,  0.168628,  0.104153,  -0.262861, 0.327337,  -0.049596,
224     -0.228144, -0.074395, 0.168628,  0.123991,  0.396772,  0.044637,  0.322377,
225     0.193426,  0.267821,  -0.178547, 0.297579,  0.148789,  -0.218225, -0.138870,
226     0.044637,  0.049596,  0.133911,  0.064475,  0.069435,  0.064475,  -0.158709,
227     -0.044637, -0.173588, 0.267821,  0.327337,  0.079354,  -0.228144, 0.029758,
228     0.014879,  0.198386,  -0.109112, -0.133911, 0.431490,  0.099193,  0.421570,
229     0.233104,  -0.054556, 0.054556,  -0.317418, -0.133911, -0.123991, -0.287660,
230     0.342216,  -0.049596, -0.153749, 0.228144,  -0.213265, 0.262861,  0.406691,
231     -0.084314, -0.004960, 0.193426,  0.188467,  -0.099193, -0.223184, 0.163668,
232     -0.257902, -0.153749, 0.441409,  0.099193,  0.128951,  -0.089274, -0.208305,
233     -0.009919, -0.004960, -0.109112, 0.024798,  -0.119032, 0.019839,  0.391812,
234     -0.024798, 0.198386,  0.327337,  -0.505884, -0.099193, 0.510844,  -0.148789,
235     0.094233,  -0.153749, -0.039677, 0.352135,  0.272781,  -0.228144, -0.287660,
236     -0.272781, 0.148789,  0.277740,  0.074395,  0.109112,  -0.064475, 0.044637,
237     0.074395,  -0.292619, 0.153749,  -0.064475, -0.114072, 0.198386,  -0.039677,
238     -0.128951, -0.004960, 0.257902,  -0.228144, -0.094233, 0.064475,  0.014879,
239     0.188467,  -0.416611, 0.099193,  0.362054,  -0.208305, 0.198386,  -0.079354,
240     0.009919,  0.119032,  0.332297,  0.243023,  -0.168628, 0.158709,  0.039677,
241     0.143830,  0.277740,  -0.168628, 0.009919,  0.099193,  -0.004960, -0.257902,
242     -0.297579, 0.208305,  -0.104153, 0.119032,  0.247982,  0.381893,  -0.223184,
243     -0.367014, -0.327337, -0.168628, -0.094233, 0.208305,  -0.019839, 0.183507,
244     0.084314,  0.133911,  0.109112,  -0.148789, -0.183507, -0.411651, -0.024798,
245     -0.114072, -0.029758, -0.009919, 0.173588,  -0.059516, -0.049596, 0.039677,
246     0.317418,  0.138870,  -0.247982, -0.084314, 0.158709,  0.054556,  -0.084314,
247     -0.049596, 0.074395,  0.019839,  -0.282700, -0.119032, -0.262861, 0.163668,
248     -0.069435, -0.064475, -0.059516, 0.094233,  0.123991,  -0.079354, -0.272781,
249     -0.267821, 0.233104,  0.114072,  -0.218225, 0.540602,  0.089274,  0.262861,
250     0.079354,  0.267821,  -0.119032, -0.109112, -0.128951, 0.128951,  -0.044637,
251     -0.272781, 0.277740,  0.297579,  -0.054556, -0.084314, -0.049596, 0.123991,
252     0.059516,  0.238063,  -0.168628, -0.009919, 0.163668,  -0.307498, 0.109112,
253     -0.064475, 0.218225,  -0.168628, -0.004960, -0.168628, 0.119032,  0.094233,
254     -0.183507, -0.089274, -0.292619, -0.094233, 0.064475,  -0.183507, -0.168628,
255     0.089274,  0.074395,  -0.367014, -0.024798, -0.069435, 0.119032,  -0.302539,
256     -0.376933, -0.123991, -0.009919, -0.069435, -0.208305, -0.119032, 0.014879,
257     -0.183507, -0.238063, 0.163668,  -0.332297, -0.148789, -0.391812, -0.024798,
258     -0.133911, -0.059516, -0.123991, 0.123991,  -0.292619, -0.044637, 0.059516,
259     -0.069435, 0.049596,  -0.069435, 0.034718,  0.158709,  -0.347175, -0.044637,
260     0.352135,  -0.347175, -0.282700, -0.054556, 0.307498,  0.029758,  0.357095,
261     -0.148789, 0.208305,  -0.317418, 0.009919,  0.004960,  -0.243023, 0.049596,
262     -0.099193, 0.213265,  -0.342216, 0.158709,  0.123991,  -0.332297, 0.386853,
263     -0.262861, -0.208305, 0.123991,  -0.044637, 0.148789,  0.084314,  -0.297579,
264     -0.307498, -0.163668, 0.337256,  -0.014879, 0.074395,  0.178547,  -0.004960,
265     -0.257902, -0.019839, -0.228144, -0.034718, -0.277740, -0.158709, -0.119032,
266     -0.153749, 0.629876,  0.277740,  0.178547,  -0.267821, -0.004960, 0.247982,
267     0.084314,  -0.094233, 0.000000,  -0.039677, 0.332297,  0.178547,  0.009919,
268     -0.213265, -0.208305, -0.044637, 0.019839,  0.218225,  -0.297579, 0.014879,
269     -0.247982, -0.004960, -0.128951, 0.421570,  -0.059516, 0.362054,  -0.203346,
270     -0.143830, -0.099193, -0.024798, 0.094233,  -0.123991, 0.163668,  0.109112,
271     -0.104153, -0.233104, 0.009919,  -0.218225, 0.376933,  0.104153,  -0.059516,
272     0.049596,  -0.054556, 0.019839,  -0.044637, -0.019839, 0.371974,  -0.019839,
273     0.104153,  0.168628,  -0.024798, -0.272781, -0.158709, 0.223184,  0.044637,
274     0.039677,  -0.168628, -0.287660, -0.109112, 0.094233,  -0.089274, -0.148789,
275     0.178547,  -0.039677, -0.089274, -0.049596, -0.024798, 0.064475,  -0.158709,
276     0.089274,  0.029758,  -0.247982, 0.362054,  0.024798,  -0.004960, -0.099193,
277     0.173588,  -0.059516, 0.188467,  -0.629876, 0.094233,  0.371974,  0.069435,
278     0.252942,  -0.357095, -0.272781, -0.367014, 0.014879,  -0.049596, -0.262861,
279     0.009919,  -0.094233, -0.094233, 0.059516,  0.223184,  0.133911,  0.411651,
280     -0.044637, -0.044637, 0.109112,  0.228144,  0.386853,  -0.233104, 0.069435,
281     0.228144,  -0.302539, 0.029758,  0.089274,  0.044637,  -0.238063, -0.138870,
282     -0.158709, -0.019839, 0.049596,  0.039677,  0.000000,  -0.069435, 0.109112,
283     -0.213265, -0.188467, -0.262861, -0.267821, -0.094233, 0.133911,  0.391812,
284     0.123991,  -0.317418, 0.233104,  -0.029758, -0.099193, -0.193426, 0.074395,
285     -0.009919, 0.252942,  0.322377,  -0.530683, 0.208305,  0.252942,  0.203346,
286     -0.069435, -0.262861};
287 
288 // Time filter of shape {64, 8}.
289 const float time_weights_data_16x1x1[] = {
290     -0.052026, 0.043107,  0.053512,  0.013378,  0.011892,  -0.182834, -0.108511,
291     0.153105,  0.050539,  -0.173915, 0.145672,  0.208103,  -0.221481, 0.108511,
292     -0.496475, 0.181347,  -0.016351, -0.132294, -0.234859, -0.243778, 0.028243,
293     -0.228914, -0.130808, -0.167969, -0.041621, -0.306209, -0.193239, -0.028243,
294     -0.057972, -0.057972, -0.497962, 0.054999,  0.181347,  0.047566,  -0.099592,
295     -0.111484, -0.130808, -0.071350, 0.380532,  0.010405,  0.041621,  0.052026,
296     0.022297,  0.081755,  0.098106,  0.099592,  -0.584176, -0.023783, 0.062431,
297     -0.090674, -0.279453, -0.486070, -0.273507, 0.004459,  -0.062431, 0.095133,
298     0.056485,  0.022297,  -0.105538, -0.184320, 0.358235,  0.254183,  0.049053,
299     0.084728,  0.218508,  0.078782,  -0.136754, -0.017837, -0.124862, -0.118916,
300     -0.001486, 0.043107,  0.254183,  0.087701,  0.261616,  0.309182,  -0.404315,
301     -0.040134, -0.046080, -0.052026, -0.034188, -0.475665, -0.025270, -0.049053,
302     -0.046080, -0.062431, 0.020810,  0.040134,  -0.135267, -0.169456, -0.050539,
303     -0.576743, 0.034188,  0.075809,  0.101079,  0.136754,  0.083241,  0.077296,
304     -0.050539, 0.761064,  -0.335938, -0.080268, 0.025270,  0.257156,  0.227427,
305     0.252697,  0.065404,  0.115943,  0.222968,  -0.026756, -0.054999, 0.107025,
306     -0.093646, 0.041621,  -0.092160, -0.474178, -0.016351, 0.004459,  0.049053,
307     0.019324,  0.019324,  0.074323,  0.038648,  -0.613905, 0.182834,  0.075809,
308     0.028243,  0.019324,  0.010405,  -0.011892, 0.001486,  -0.492016, -0.224454,
309     -0.474178, -0.147159, 0.002973,  0.102565,  0.136754,  -0.267561, -0.001486,
310     -0.095133, -0.040134, 0.066890,  0.074323,  0.104052,  0.532150,  0.090674,
311     0.072836,  -0.053512, -0.004459, 0.020810,  0.046080,  0.062431,  0.477151,
312     0.133781,  -0.029729, -0.026756, 0.031215,  0.156077,  0.096619,  0.251210,
313     0.352289,  0.657012,  0.047566,  -0.014865, -0.072836, -0.016351, 0.008919,
314     -0.053512, 0.016351,  0.300263,  0.047566,  0.020810,  0.169456,  0.001486,
315     0.007432,  0.111484,  0.044594,  -0.188779, -0.096619, 0.074323,  -0.040134,
316     0.160537,  0.138240,  0.184320,  0.377559,  -0.092160, -0.049053, 0.056485,
317     -0.032702, 0.001486,  -0.083241, -0.472692, -0.114457, -0.117430, -0.075809,
318     0.026756,  0.163510,  0.172428,  0.127835,  -0.199185, -0.218508, -0.057972,
319     -0.132294, -0.162023, -0.019324, -0.245265, -0.395396, -0.254183, 0.084728,
320     0.248238,  0.191752,  0.221481,  0.173915,  0.173915,  -0.208103, -0.077296,
321     0.384991,  -0.313641, -0.313641, -0.147159, -0.090674, 0.035675,  0.059458,
322     -0.010405, 0.019324,  0.087701,  0.016351,  0.037161,  0.469719,  -0.074323,
323     0.092160,  0.026756,  0.090674,  0.098106,  0.004459,  -0.034188, 0.492016,
324     -0.367154, -0.093646, -0.063917, 0.041621,  0.017837,  0.026756,  -0.062431,
325     -0.350803, 0.425125,  0.002973,  0.083241,  0.075809,  0.016351,  0.047566,
326     -0.185807, -0.107025, -0.098106, -0.144186, 0.255670,  0.020810,  0.105538,
327     0.029729,  0.129321,  0.156077,  0.141213,  0.334452,  0.147159,  -0.066890,
328     0.035675,  0.115943,  0.240805,  0.328506,  0.162023,  -0.237832, 0.218508,
329     0.233373,  0.214049,  0.099592,  0.026756,  -0.322560, -0.236346, -0.166483,
330     0.225941,  0.109997,  -0.147159, 0.147159,  -0.266075, 0.111484,  0.078782,
331     -0.120403, 0.022297,  -0.075809, -0.148645, -0.251210, -0.176888, -0.044594,
332     -0.023783, 0.016351,  0.026756,  -0.013378, -0.069863, -0.112970, 0.013378,
333     0.086214,  0.014865,  0.352289,  -0.240805, -0.135267, -0.114457, -0.472692,
334     0.334452,  0.095133,  0.047566,  0.130808,  -0.068377, -0.007432, -0.130808,
335     -0.121889, -0.053512, -0.245265, -0.371613, -0.083241, 0.000000,  -0.028243,
336     0.029729,  -0.093646, -0.004459, -0.038648, -0.108511, -0.475665, -0.169456,
337     -0.047566, -0.010405, -0.114457, -0.353776, -0.034188, -0.044594, 0.041621,
338     -0.047566, -0.107025, 0.004459,  0.053512,  0.047566,  -0.358235, -0.193239,
339     0.040134,  -0.096619, -0.054999, 0.099592,  0.032702,  0.205130,  -0.170942,
340     -0.237832, -0.405801, -0.126348, -0.072836, -0.203644, -0.169456, -0.093646,
341     -0.074323, 0.078782,  0.607959,  -0.437017, -0.164996, -0.166483, 0.043107,
342     -0.016351, 0.258643,  0.065404,  -0.057972, 0.017837,  0.080268,  0.050539,
343     -0.013378, -0.215536, -0.524718, 0.260129,  0.040134,  -0.002973, -0.046080,
344     0.020810,  0.025270,  0.145672,  0.515799,  0.233373,  0.011892,  0.139727,
345     0.126348,  0.065404,  -0.007432, -0.008919, 0.035675,  0.083241,  0.040134,
346     -0.005946, 0.503907,  -0.490529, -0.181347, -0.092160, -0.038648, 0.019324,
347     0.133781,  -0.011892, 0.041621,  0.062431,  -0.062431, -0.040134, -0.092160,
348     -0.111484, -0.133781, -0.130808, -0.484583, -0.248238, 0.037161,  -0.092160,
349     -0.056485, -0.041621, 0.112970,  0.248238,  0.438503,  0.258643,  -0.013378,
350     0.004459,  0.043107,  0.040134,  0.017837,  0.101079,  0.264589,  0.212563,
351     0.014865,  0.285399,  0.153105,  0.170942,  0.358235,  0.334452,  0.086214,
352     0.132294,  0.098106,  -0.001486, 0.107025,  0.200671,  -0.026756, 0.344857,
353     0.227427,  -0.041621, 0.098106,  0.063917,  -0.093646, 0.130808,  0.285399,
354     -0.319587, 0.035675,  -0.017837, -0.319587, 0.016351,  -0.098106, -0.017837,
355     0.083241,  0.074323,  -0.054999, 0.276480,  0.316614,  -0.099592, -0.059458,
356     0.156077,  -0.043107, 0.035675,  0.056485,  -0.022297, 0.017837,  -0.001486,
357     0.340398,  0.492016,  0.004459,  0.057972,  -0.150132, -0.206617, -0.257156,
358     -0.248238, -0.080268, -0.164996, 0.352289,  -0.054999, -0.056485, 0.010405,
359     -0.049053, -0.041621, -0.099592, 0.013378,  -0.089187, 0.057972,  -0.413234,
360     0.217022,  0.013378,  -0.080268, -0.035675, 0.035675,  0.007432,  0.002973,
361     -0.469719, 0.141213,  0.136754,  0.153105,  0.130808,  -0.104052, -0.508367,
362     -0.291345, -0.072836, -0.019324, -0.252697, -0.214049, -0.214049, 0.130808,
363     0.484583};
364 
365 // Bias of shape {64}
366 const float bias_data_16x1x1[] = {
367     -0.245395, -0.083545, -0.262522, -0.407912, -0.560898, -0.364789, -0.037964,
368     -0.378594, 0.178152,  0.400380,  -0.301349, -0.240913, -0.159454, -0.158757,
369     -0.073665, 0.455906,  -0.061232, 0.318907,  -0.226993, -0.344644, 0.140316,
370     0.559608,  0.109774,  0.437391,  0.113849,  -0.162068, 0.039572,  0.569472,
371     0.460205,  0.113459,  0.370469,  0.176811,  0.203063,  -0.296975, -0.271655,
372     0.059862,  -0.159912, -0.077310, -0.338314, -0.195477, -0.256762, 0.233834,
373     0.083172,  0.029040,  -0.236288, -0.267054, -0.166627, 0.188319,  -0.271391,
374     -0.222920, 0.106463,  0.263614,  0.384986,  -0.125957, -0.095890, 0.363686,
375     -0.036990, -0.358884, -0.178254, 0.305596,  0.390088,  -0.189437, 0.613409,
376     0.399639};
377 
378 // Activation state with shape {64, 8}. These initial values must be copied into
379 // a mutable activation state tensor.
380 const float initial_activation_state_data_16x1x1[] = {
381     -0.582275, -0.586623, -1.262373, -1.277279, -1.542175, -1.271999, -1.429757,
382     -1.184425, -0.462094, -1.443421, 0.230736,  -0.494701, -0.354955, -2.534061,
383     -4.277471, -4.218467, 0.403711,  -0.248748, -0.330111, -0.467683, 0.549047,
384     0.733511,  -0.230115, 0.793136,  -1.126353, -0.984123, -0.081984, -0.222351,
385     0.692830,  0.517060,  1.367958,  2.118860,  -0.116766, -0.826365, -2.402700,
386     -2.313884, -2.898954, -2.076005, -2.405185, -2.755481, 0.329490,  0.085400,
387     -1.485966, -2.034702, -2.161405, -1.269515, -1.151818, -1.823841, 0.561469,
388     1.109273,  1.693411,  -0.082605, -0.069252, -1.225107, -1.330693, -1.411435,
389     0.253406,  -0.357439, -1.593415, -0.879779, -1.111136, 1.821357,  2.471952,
390     1.236908,  -4.014127, -2.810448, -2.944604, -1.930980, -1.566398, -0.838166,
391     -0.319242, 0.749349,  1.156476,  0.658670,  1.997437,  2.080663,  2.912618,
392     2.677224,  2.642442,  2.796163,  -0.272349, -0.473273, 3.120063,  2.747097,
393     3.595510,  1.874150,  2.049919,  2.093396,  -1.049959, 0.277939,  -1.255541,
394     -1.052443, -1.810177, -0.883505, -0.538178, 0.524203,  -1.017662, -0.269244,
395     0.039129,  -0.227941, -0.114592, -2.018243, -2.548968, -0.706804, 0.890959,
396     0.102480,  0.349986,  0.405885,  1.287216,  0.756181,  0.319242,  -0.641590,
397     -3.841774, -2.716042, -4.342065, -3.826557, -2.924729, -1.643724, -1.237839,
398     -0.597492, -1.954892, -1.215169, -1.528201, -1.018904, -0.863941, -0.293467,
399     0.039439,  0.672023,  1.408019,  1.362679,  1.467644,  1.006171,  0.310236,
400     -0.249990, -1.048406, -0.752144, -1.831605, -1.058033, -1.096541, -0.293467,
401     0.051551,  0.232600,  0.088816,  2.570395,  0.704009,  2.465120,  3.010751,
402     2.139357,  0.630410,  1.006171,  1.545281,  1.486898,  -1.162998, -2.344317,
403     -4.593918, -3.522842, -2.872247, -1.416714, -0.642521, -0.230115, 0.315205,
404     -0.368930, -0.162726, 0.396879,  0.505570,  0.534451,  0.554947,  1.270447,
405     0.388805,  0.531967,  -1.243119, -0.671713, -1.214859, -0.238189, 0.016459,
406     -1.164550, 0.609603,  3.293348,  2.600208,  1.454290,  -1.034121, -1.760179,
407     -1.192500, -0.613951, 3.449553,  2.912618,  1.917937,  1.435968,  0.879158,
408     1.118279,  0.102791,  -0.502465, -0.239121, -0.092853, 1.786265,  1.943091,
409     2.547104,  2.630641,  2.585302,  2.965411,  -0.945615, -2.538720, -2.474126,
410     -1.088156, 0.056209,  0.864873,  0.170490,  0.457435,  0.545941,  0.752765,
411     1.569503,  1.129459,  0.662086,  -0.527929, -0.810838, -1.662978, 1.285042,
412     1.653040,  4.130893,  2.961995,  4.147041,  3.256393,  3.881524,  2.522571,
413     -0.875431, -1.112378, 2.105817,  2.180970,  3.121926,  1.577577,  1.639376,
414     2.906407,  -0.142230, 0.421101,  2.212335,  2.311399,  3.993321,  3.651719,
415     4.206666,  4.678387,  -1.304917, -1.130701, -2.543067, -2.500212, -2.197118,
416     -1.197158, -0.949652, -0.282908, 0.320795,  -1.543728, 1.290322,  1.788128,
417     3.957297,  3.205774,  2.892432,  2.297114,  0.138814,  -0.139435, 0.936920,
418     0.344707,  0.723263,  -1.772290, -3.138385, -2.287177, -2.405806, -1.859864,
419     -4.572801, -3.410424, -3.855748, -2.239663, -2.269786, -1.582857, 4.238342,
420     3.858543,  2.499901,  1.087535,  0.290051,  -0.026086, -0.880400, -2.602692,
421     -1.404292, 0.253096,  -0.665502, -1.443421, -0.925119, -0.096580, 1.115484,
422     1.846200,  -1.604284, -1.244671, -0.464888, 0.326385,  0.168006,  -0.262723,
423     -0.744691, 0.953379,  -0.407127, -0.349986, -1.154302, 0.831023,  1.590931,
424     2.538720,  2.063583,  3.697680,  -0.752455, -1.293117, -1.330693, -1.869802,
425     -0.592523, 0.631652,  1.198089,  -0.481347, 3.738983,  4.153252,  2.782499,
426     2.244321,  0.709289,  1.650245,  1.700865,  0.385078,  2.192460,  2.610456,
427     4.009780,  3.492719,  2.574743,  2.116687,  1.856138,  1.205853,  2.722563,
428     4.075305,  5.415935,  3.009198,  2.715421,  1.571056,  0.897170,  -2.430339,
429     0.749970,  0.425760,  -0.302783, 0.817359,  1.031636,  1.913589,  2.686229,
430     1.631923,  -1.459259, -1.793097, -1.187531, -1.553355, -0.844998, -1.296843,
431     -1.805519, -0.486627, 0.909591,  2.082837,  -1.473855, -2.456735, -3.851401,
432     -2.760139, -3.060438, -2.605487, -2.138735, -2.441519, -1.333177, -1.353984,
433     -0.245642, -0.588486, 0.033850,  2.084700,  0.076084,  0.690035,  0.747797,
434     0.594697,  -1.016109, -1.348083, -1.201195, -1.088466, 2.045571,  2.460772,
435     0.717984,  0.041613,  -0.721711, 1.134738,  2.322269,  1.112378,  -0.307441,
436     -0.581033, -0.868599, -0.018633, 0.856488,  0.919839,  0.303094,  -0.433213,
437     0.811148,  -0.508986, -1.060828, -1.227591, -1.566087, -1.117968, -1.385038,
438     -2.011101, -0.490353, -1.849616, -0.594697, -1.055859, 1.110205,  0.622646,
439     0.145957,  0.359303,  1.012072,  0.774814,  -0.400295, -1.484103, -2.007374,
440     -1.441247, -0.997787, -0.581033, -0.545941, -0.306510, 0.693451,  0.087264,
441     -0.227320, -1.211753, -1.532859, -1.688753, 0.065215,  0.134777,  0.608051,
442     -0.393152, -0.214588, -0.635689, -1.499320, 0.069562,  -1.555839, -2.633126,
443     -2.966032, -1.550870, -0.101549, 0.874189,  0.436318,  0.299367,  2.289972,
444     2.339659,  2.602071,  1.564535,  0.019254,  -0.583207, -1.295912, -2.424749,
445     -1.221070, -1.175109, -0.577306, -0.102791, 1.877876,  2.568222,  2.173827,
446     3.131243,  2.637784,  2.088737,  3.679047,  3.218506,  2.483442,  1.650556,
447     1.363611,  -0.027328, 1.486898,  -0.721711, -3.684327, -3.006093, -3.777491,
448     -2.327548, -2.737470, -4.549510, -0.060867, 0.127635,  0.680408,  0.581344,
449     0.320174,  -0.403090, -0.838166, 0.293777,  -0.995613, -0.165521, -0.419859,
450     1.110515,  1.203679,  1.749931,  2.467294,  4.276539,  0.031055,  -0.967664,
451     1.167035,  1.865144,  3.221923,  3.248630,  4.121266,  4.187723,  0.749039,
452     -1.571056, 0.785994,  1.568572,  3.759479,  3.588678,  4.116608,  3.864444,
453     -0.290051, -0.271107, 0.375140,  0.537556,  0.536314,  0.095959,  0.054656,
454     0.088816};
455 
456 // One output with shape {1, 64}
457 const float golden_output_16x1x1[] = {
458     -0.087914, 1.145864,  -0.418088, -1.556392, -0.925298, 0.205252,  0.289119,
459     1.331180,  -0.218010, 0.963057,  -2.225886, 1.248478,  1.448983,  0.355467,
460     1.682174,  0.803739,  0.449738,  0.543566,  1.916269,  -2.975136, 0.222774,
461     0.241589,  -0.104216, 1.561748,  0.936818,  -0.089907, -0.520117, -0.870353,
462     1.606074,  0.895770,  0.521297,  -0.369994, -0.889351, -2.809309, 2.404628,
463     1.069754,  -0.195456, -1.105652, 1.272715,  -1.233177, 1.271416,  -1.691805,
464     -1.058125, -0.716227, 0.052540,  1.262483,  0.540555,  1.735760,  -0.539197,
465     -0.014367, -0.243002, 1.072254,  0.528985,  -0.731151, -1.262649, 2.338702,
466     -0.603093, 0.970736,  -3.567897, 0.035085,  -0.201711, -0.550400, 1.545573,
467     -1.805005};
468 
469 // One output with shape {1, 64}
470 const float golden_output_relu_16x1x1[] = {
471     0.000000, 1.145864, 0.000000, 0.000000, 0.000000, 0.205252, 0.289119,
472     1.331180, 0.000000, 0.963057, 0.000000, 1.248478, 1.448983, 0.355467,
473     1.682174, 0.803739, 0.449738, 0.543566, 1.916269, 0.000000, 0.222774,
474     0.241589, 0.000000, 1.561748, 0.936818, 0.000000, 0.000000, 0.000000,
475     1.606074, 0.895770, 0.521297, 0.000000, 0.000000, 0.000000, 2.404628,
476     1.069754, 0.000000, 0.000000, 1.272715, 0.000000, 1.271416, 0.000000,
477     0.000000, 0.000000, 0.052540, 1.262483, 0.540555, 1.735760, 0.000000,
478     0.000000, 0.000000, 1.072254, 0.528985, 0.000000, 0.000000, 2.338702,
479     0.000000, 0.970736, 0.000000, 0.035085, 0.000000, 0.000000, 1.545573,
480     0.000000};
481 
482 template <typename T>
ValidateSVDFGoldens(const int batch_size,const int num_units,const int input_size,const int rank,TfLiteTensor * tensors,const int tensor_count,TfLiteFusedActivation activaiton,const T * input_sequences_data,const int input_sequences_len,T * output_data,const T * expected_output,float tolerance=1e-5f)483 void ValidateSVDFGoldens(const int batch_size, const int num_units,
484                          const int input_size, const int rank,
485                          TfLiteTensor* tensors, const int tensor_count,
486                          TfLiteFusedActivation activaiton,
487                          const T* input_sequences_data,
488                          const int input_sequences_len, T* output_data,
489                          const T* expected_output, float tolerance = 1e-5f) {
490   TfLiteSVDFParams params;
491   params.rank = rank;
492   params.activation = activaiton;
493 
494   int inputs_array_data[] = {5, 0, 1, 2, 3, 4};
495   TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
496 
497   int outputs_array_data[] = {1, 5};
498   TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
499 
500   const TfLiteRegistration registration = Register_SVDF();
501   micro::KernelRunner runner(registration, tensors, tensor_count, inputs_array,
502                              outputs_array, &params);
503 
504   TfLiteStatus init_and_prepare_status = runner.InitAndPrepare();
505   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, init_and_prepare_status);
506 
507   // Abort early to make it clear init and prepare failed.
508   if (init_and_prepare_status != kTfLiteOk) {
509     return;
510   }
511 
512   int num_inputs = input_sequences_len / (input_size * batch_size);
513 
514   for (int i = 0; i < num_inputs; ++i) {
515     const T* input_batch_start =
516         input_sequences_data + i * input_size * batch_size;
517 
518     memcpy(tensors[0].data.raw, input_batch_start, tensors[0].bytes);
519     TfLiteStatus status = runner.Invoke();
520     TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, status);
521 
522     // Only validate outputs when invoke has succeeded.
523     if (status == kTfLiteOk) {
524       int output_idx = 0;
525       int golden_idx = i * batch_size * num_units;
526       for (int j = golden_idx; j < golden_idx + batch_size * num_units; ++j) {
527         TF_LITE_MICRO_EXPECT_NEAR(expected_output[j], output_data[output_idx],
528                                   tolerance);
529         output_idx++;
530       }
531     }
532   }
533 }
534 
535 #if !defined(XTENSA)  // Needed to avoid build errors from unused functions.
TestSVDF(const int batch_size,const int num_units,const int input_size,const int memory_size,const int rank,TfLiteFusedActivation activation,float * input_data,const float * feature_weights_data,const float * time_weights_data,float * activation_state_data,const float * bias_data,float * scratch_data,float * output_data,const float * input_sequences_data,int input_sequences_len,const float * expected_output,float tolerance=1e-5f)536 void TestSVDF(const int batch_size, const int num_units, const int input_size,
537               const int memory_size, const int rank,
538               TfLiteFusedActivation activation, float* input_data,
539               const float* feature_weights_data, const float* time_weights_data,
540               float* activation_state_data, const float* bias_data,
541               float* scratch_data, float* output_data,
542               const float* input_sequences_data, int input_sequences_len,
543               const float* expected_output, float tolerance = 1e-5f) {
544   const int num_filters = num_units * rank;
545 
546   int input_dims_arg[] = {2, batch_size, input_size};
547   TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_arg);
548 
549   int feature_weights_dims_args[] = {2, num_filters, input_size};
550   TfLiteIntArray* feature_weights_dims =
551       IntArrayFromInts(feature_weights_dims_args);
552 
553   int time_weights_dims_args[] = {2, num_filters, memory_size};
554   TfLiteIntArray* time_weights_dims = IntArrayFromInts(time_weights_dims_args);
555 
556   int activation_state_dims_args[] = {2, batch_size, memory_size * num_filters};
557   TfLiteIntArray* activation_state_dims =
558       IntArrayFromInts(activation_state_dims_args);
559 
560   int bias_dims_args[] = {1, num_units};
561   TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_args);
562 
563   int output_dims_args[] = {2, batch_size, num_units};
564   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_args);
565 
566   const int tensor_count = 6;  // 5 inputs, 1 output
567   TfLiteTensor tensors[] = {
568       CreateTensor(input_data, input_dims),
569       CreateTensor(feature_weights_data, feature_weights_dims),
570       CreateTensor(time_weights_data, time_weights_dims),
571       CreateTensor(bias_data, bias_dims),
572       CreateTensor(activation_state_data, activation_state_dims,
573                    /*is_variable=*/true),
574       CreateTensor(output_data, output_dims),
575   };
576 
577   ValidateSVDFGoldens(batch_size, num_units, input_size, rank, tensors,
578                       tensor_count, activation, input_sequences_data,
579                       input_sequences_len, output_data, expected_output,
580                       tolerance);
581 }
582 #endif
583 
584 // The pattern to this method's arguemnts is:
585 // <kernel metadata>
586 // for each tensor in
587 //     {input, feature weights, time weights, bias, activation state, output}:
588 //   <tensor float values> <tensor quantized buffer> <tensor quantization data>
TestIntegerSVDF(const int batch_size,const int num_units,const int input_size,const int memory_size,const int rank,TfLiteFusedActivation activation,int8_t * input_quantized,float input_scale,int input_zero_point,const float * feature_weights_data,int8_t * feature_weights_quantized,const float feature_weights_scale,const float * time_weights_data,int16_t * time_weights_quantized,float time_weights_scale,const float * bias_data,int32_t * bias_quantized,const float * initial_activation_state_data,int16_t * activation_state_quantized,float activation_state_scale,int8_t * output_data,float output_scale,int output_zero_point,const float * input_sequences_data,int8_t * input_sequences_quantized,const int input_sequences_len,const float * golden_output,int8_t * golden_output_quantized,int golden_output_len)589 inline void TestIntegerSVDF(
590     const int batch_size, const int num_units, const int input_size,
591     const int memory_size, const int rank, TfLiteFusedActivation activation,
592     int8_t* input_quantized, float input_scale, int input_zero_point,
593     const float* feature_weights_data, int8_t* feature_weights_quantized,
594     const float feature_weights_scale, const float* time_weights_data,
595     int16_t* time_weights_quantized, float time_weights_scale,
596     const float* bias_data, int32_t* bias_quantized,
597     const float* initial_activation_state_data,
598     int16_t* activation_state_quantized, float activation_state_scale,
599     int8_t* output_data, float output_scale, int output_zero_point,
600     const float* input_sequences_data, int8_t* input_sequences_quantized,
601     const int input_sequences_len, const float* golden_output,
602     int8_t* golden_output_quantized, int golden_output_len) {
603   const int num_filters = num_units * rank;
604 
605   int input_dims_arg[] = {2, batch_size, input_size};
606   TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_arg);
607 
608   int feature_weights_dims_args[] = {2, num_filters, input_size};
609   TfLiteIntArray* feature_weights_dims =
610       IntArrayFromInts(feature_weights_dims_args);
611 
612   int time_weights_dims_args[] = {2, num_filters, memory_size};
613   TfLiteIntArray* time_weights_dims = IntArrayFromInts(time_weights_dims_args);
614 
615   int bias_dims_data[] = {1, num_units};
616   TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data);
617 
618   int activation_state_dims_args[] = {2, batch_size, memory_size * num_filters};
619   TfLiteIntArray* activation_state_dims =
620       IntArrayFromInts(activation_state_dims_args);
621 
622   int output_dims_args[] = {2, batch_size, num_units};
623   TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_args);
624 
625   const int tensor_count = 6;  // 5 inputs, 1 output
626 
627   TfLiteTensor tensors[] = {
628       CreateQuantizedTensor(input_quantized, input_dims, input_scale,
629                             input_zero_point),
630       CreateQuantizedTensor(feature_weights_data, feature_weights_quantized,
631                             feature_weights_dims, feature_weights_scale, 0),
632       CreateQuantizedTensor(time_weights_data, time_weights_quantized,
633                             time_weights_dims, time_weights_scale, 0),
634       CreateQuantizedBiasTensor(bias_data, bias_quantized, bias_dims,
635                                 time_weights_scale, activation_state_scale),
636       CreateQuantizedTensor(initial_activation_state_data,
637                             activation_state_quantized, activation_state_dims,
638                             activation_state_scale, 0,
639                             /*is_variable=*/true),
640       CreateQuantizedTensor(output_data, output_dims, output_scale,
641                             output_zero_point)};
642 
643   tflite::Quantize(golden_output, golden_output_quantized, golden_output_len,
644                    output_scale, output_zero_point);
645   tflite::Quantize(input_sequences_data, input_sequences_quantized,
646                    input_sequences_len, input_scale, input_zero_point);
647 
648   ValidateSVDFGoldens(batch_size, num_units, input_size, rank, tensors,
649                       tensor_count, activation, input_sequences_quantized,
650                       input_sequences_len, output_data, golden_output_quantized,
651                       /*tolerance*/ 1);
652 }
653 
654 }  // namespace
655 }  // namespace testing
656 }  // namespace tflite
657 
658 TF_LITE_MICRO_TESTS_BEGIN
659 
660 #if !defined(XTENSA)  // TODO(b/170332589): xtensa kernels are less general than
661                       // reference kernels and we ifdef out test cases that are
662                       // currently known to fail.
TF_LITE_MICRO_TEST(SvdfFloat2x2Input2x4OutputShouldMatchGolden)663 TF_LITE_MICRO_TEST(SvdfFloat2x2Input2x4OutputShouldMatchGolden) {
664   constexpr int batch_size = 2;
665   constexpr int num_units = 4;
666   constexpr int input_size = 2;
667   constexpr int memory_size = 10;
668   constexpr int rank = 2;
669   constexpr int num_filters = num_units * rank;
670 
671   const int input_size_dims_count = batch_size * input_size;
672   float input_data[input_size_dims_count];
673 
674   const int activation_state_dims_count =
675       batch_size * memory_size * num_filters;
676   float activation_state_data[activation_state_dims_count];
677 
678   memcpy(activation_state_data,
679          tflite::testing::initial_activation_state_data_2x2x10,
680          sizeof(tflite::testing::initial_activation_state_data_2x2x10));
681 
682   const int scratch_dims_count = batch_size * num_filters;
683   float scratch_data[scratch_dims_count];
684 
685   const int output_dims_count = batch_size * num_units;
686   float output_data[output_dims_count];
687 
688   tflite::testing::TestSVDF(
689       batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone,
690       input_data, tflite::testing::feature_weights_data_2x2x10,
691       tflite::testing::time_weights_data_2x2x10, activation_state_data,
692       tflite::testing::bias_data_2x2x10, scratch_data, output_data,
693       tflite::testing::input_data_2x2x10,
694       sizeof(tflite::testing::input_data_2x2x10) / sizeof(float),
695       tflite::testing::golden_output_2x2x10);
696 }
697 #endif
698 
TF_LITE_MICRO_TEST(SvdfQuantized2x2Input2x4OutputShouldMatchGolden)699 TF_LITE_MICRO_TEST(SvdfQuantized2x2Input2x4OutputShouldMatchGolden) {
700   constexpr int batch_size = 2;
701   constexpr int num_units = 4;
702   constexpr int input_size = 2;
703   constexpr int memory_size = 10;
704   constexpr int rank = 2;
705   constexpr int num_filters = num_units * rank;
706 
707   const int input_size_dims_count = batch_size * input_size;
708 
709   const int activation_state_dims_count =
710       batch_size * memory_size * num_filters;
711 
712   const int output_dims_count = batch_size * num_units;
713   int8_t output_data[output_dims_count];
714 
715   float input_scale = 2.5f / INT8_MAX;              // Range is [-2.5, 2.5]
716   float feature_weights_scale = 1.f / INT8_MAX;     // Range is [-1, 1]
717   float time_weights_scale = 1.f / INT16_MAX;       // Range is [-1, 1]
718   float activation_state_scale = 16.f / INT16_MAX;  // Range is [-16, 16]
719   float output_scale = 1.f / INT8_MAX;              // Range is [-1, 1]
720 
721   int input_zero_point = 0;
722   int output_zero_point = 0;
723 
724   int8_t input_quantized[input_size_dims_count];
725   int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_2x2x10) /
726                                    sizeof(float)];
727   int8_t feature_weights_quantized
728       [sizeof(tflite::testing::feature_weights_data_2x2x10) / sizeof(float)];
729   int16_t
730       time_weights_quantized[sizeof(tflite::testing::time_weights_data_2x2x10) /
731                              sizeof(float)];
732   int16_t activation_state_quantized[activation_state_dims_count];
733   int32_t
734       bias_quantized[sizeof(tflite::testing::bias_data_2x2x10) / sizeof(float)];
735   int8_t golden_quantized[sizeof(tflite::testing::golden_output_2x2x10) /
736                           sizeof(float)];
737 
738   tflite::testing::TestIntegerSVDF(
739       batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu,
740       input_quantized, input_scale, input_zero_point,
741       tflite::testing::feature_weights_data_2x2x10, feature_weights_quantized,
742       feature_weights_scale, tflite::testing::time_weights_data_2x2x10,
743       time_weights_quantized, time_weights_scale,
744       tflite::testing::bias_data_2x2x10, bias_quantized,
745       tflite::testing::initial_activation_state_data_2x2x10,
746       activation_state_quantized, activation_state_scale, output_data,
747       output_scale, output_zero_point, tflite::testing::input_data_2x2x10,
748       input_sequences_quantized,
749       sizeof(tflite::testing::input_data_2x2x10) / sizeof(float),
750       tflite::testing::golden_output_2x2x10, golden_quantized,
751       sizeof(tflite::testing::golden_output_2x2x10) / sizeof(float));
752 }
753 
754 #if !defined(XTENSA)  // TODO(b/170332589): xtensa kernels are less general than
755                       // reference kernels and we ifdef out test cases that are
756                       // currently known to fail.
TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputShouldMatchGolden)757 TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputShouldMatchGolden) {
758   constexpr int batch_size = 1;
759   constexpr int num_units = 64;
760   constexpr int input_size = 16;
761   constexpr int memory_size = 8;
762   constexpr int rank = 1;
763   constexpr int num_filters = num_units * rank;
764   constexpr int activation_state_dims_count =
765       batch_size * memory_size * num_filters;
766   constexpr int output_dims_count = batch_size * num_units;
767   constexpr int input_dims_count = batch_size * input_size;
768 
769   float input_data[input_dims_count];
770   float output_data[output_dims_count];
771   float scratch_buffer[batch_size * num_filters];
772   float activation_state_data_mutable[activation_state_dims_count];
773 
774   // Initialize activation state to starting values.
775   memcpy(activation_state_data_mutable,
776          tflite::testing::initial_activation_state_data_16x1x1,
777          sizeof(tflite::testing::initial_activation_state_data_16x1x1));
778 
779   tflite::testing::TestSVDF(
780       batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone,
781       input_data, tflite::testing::feature_weights_data_16x1x1,
782       tflite::testing::time_weights_data_16x1x1, activation_state_data_mutable,
783       tflite::testing::bias_data_16x1x1, scratch_buffer, output_data,
784       tflite::testing::input_data_16x1x1, input_size,
785       tflite::testing::golden_output_16x1x1);
786 }
787 
TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputReluShouldMatchGolden)788 TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputReluShouldMatchGolden) {
789   constexpr int batch_size = 1;
790   constexpr int num_units = 64;
791   constexpr int input_size = 16;
792   constexpr int memory_size = 8;
793   constexpr int rank = 1;
794   constexpr int num_filters = num_units * rank;
795   constexpr int activation_state_dims_count =
796       batch_size * memory_size * num_filters;
797   constexpr int output_dims_count = batch_size * num_units;
798   constexpr int input_dims_count = batch_size * input_size;
799 
800   float input_data[input_dims_count];
801   float output_data[output_dims_count];
802   float scratch_buffer[batch_size * num_filters];
803   float activation_state_data_mutable[activation_state_dims_count];
804 
805   // Initialize activation state to starting values.
806   memcpy(activation_state_data_mutable,
807          tflite::testing::initial_activation_state_data_16x1x1,
808          sizeof(tflite::testing::initial_activation_state_data_16x1x1));
809 
810   tflite::testing::TestSVDF(
811       batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu,
812       input_data, tflite::testing::feature_weights_data_16x1x1,
813       tflite::testing::time_weights_data_16x1x1, activation_state_data_mutable,
814       tflite::testing::bias_data_16x1x1, scratch_buffer, output_data,
815       tflite::testing::input_data_16x1x1, input_size,
816       tflite::testing::golden_output_relu_16x1x1);
817 }
818 #endif
819 
TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputShouldMatchGolden)820 TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputShouldMatchGolden) {
821   constexpr int batch_size = 1;
822   constexpr int num_units = 64;
823   constexpr int input_size = 16;
824   constexpr int memory_size = 8;
825   constexpr int rank = 1;
826   constexpr int num_filters = num_units * rank;
827   constexpr int activation_state_dims_count =
828       batch_size * memory_size * num_filters;
829   constexpr int output_dims_count = batch_size * num_units;
830   constexpr int input_dims_count = batch_size * input_size;
831 
832   int8_t output_data[output_dims_count];
833 
834   float input_scale = 0.10075444;
835   float feature_weights_scale = 0.00649388;
836   float time_weights_scale = 0.001571355;
837   float activation_state_scale = 0.00045896982;
838   float output_scale = 0.051445257;
839 
840   int input_zero_point = 2;
841   int output_zero_point = 0;
842 
843   int8_t input_quantized[input_dims_count];
844   int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_16x1x1) /
845                                    sizeof(float)];
846   int8_t feature_weights_quantized
847       [sizeof(tflite::testing::feature_weights_data_16x1x1) / sizeof(float)];
848   int16_t
849       time_weights_quantized[sizeof(tflite::testing::time_weights_data_16x1x1) /
850                              sizeof(float)];
851   int16_t activation_state_quantized[activation_state_dims_count];
852   int32_t
853       bias_quantized[sizeof(tflite::testing::bias_data_16x1x1) / sizeof(float)];
854   int8_t golden_quantized[sizeof(tflite::testing::golden_output_16x1x1) /
855                           sizeof(float)];
856 
857   tflite::testing::TestIntegerSVDF(
858       batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone,
859       input_quantized, input_scale, input_zero_point,
860       tflite::testing::feature_weights_data_16x1x1, feature_weights_quantized,
861       feature_weights_scale, tflite::testing::time_weights_data_16x1x1,
862       time_weights_quantized, time_weights_scale,
863       tflite::testing::bias_data_16x1x1, bias_quantized,
864       tflite::testing::initial_activation_state_data_16x1x1,
865       activation_state_quantized, activation_state_scale, output_data,
866       output_scale, output_zero_point, tflite::testing::input_data_16x1x1,
867       input_sequences_quantized,
868       sizeof(tflite::testing::input_data_16x1x1) / sizeof(float),
869       tflite::testing::golden_output_16x1x1, golden_quantized,
870       sizeof(tflite::testing::golden_output_16x1x1) / sizeof(float));
871 }
872 
TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputReluShouldMatchGolden)873 TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputReluShouldMatchGolden) {
874   constexpr int batch_size = 1;
875   constexpr int num_units = 64;
876   constexpr int input_size = 16;
877   constexpr int memory_size = 8;
878   constexpr int rank = 1;
879   constexpr int num_filters = num_units * rank;
880   constexpr int activation_state_dims_count =
881       batch_size * memory_size * num_filters;
882   constexpr int output_dims_count = batch_size * num_units;
883   constexpr int input_dims_count = batch_size * input_size;
884 
885   int8_t output_data[output_dims_count];
886 
887   float input_scale = 0.10075444;
888   float feature_weights_scale = 0.00649388;
889   float time_weights_scale = 0.001571355;
890   float activation_state_scale = 0.00045896982;
891   float output_scale = 0.051445257;
892 
893   int input_zero_point = 2;
894   int output_zero_point = -128;
895 
896   int8_t input_quantized[input_dims_count];
897   int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_16x1x1) /
898                                    sizeof(float)];
899   int8_t feature_weights_quantized
900       [sizeof(tflite::testing::feature_weights_data_16x1x1) / sizeof(float)];
901   int16_t
902       time_weights_quantized[sizeof(tflite::testing::time_weights_data_16x1x1) /
903                              sizeof(float)];
904   int16_t activation_state_quantized[activation_state_dims_count];
905   int32_t
906       bias_quantized[sizeof(tflite::testing::bias_data_16x1x1) / sizeof(float)];
907   int8_t golden_quantized[sizeof(tflite::testing::golden_output_relu_16x1x1) /
908                           sizeof(float)];
909 
910   tflite::testing::TestIntegerSVDF(
911       batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu,
912       input_quantized, input_scale, input_zero_point,
913       tflite::testing::feature_weights_data_16x1x1, feature_weights_quantized,
914       feature_weights_scale, tflite::testing::time_weights_data_16x1x1,
915       time_weights_quantized, time_weights_scale,
916       tflite::testing::bias_data_16x1x1, bias_quantized,
917       tflite::testing::initial_activation_state_data_16x1x1,
918       activation_state_quantized, activation_state_scale, output_data,
919       output_scale, output_zero_point, tflite::testing::input_data_16x1x1,
920       input_sequences_quantized,
921       sizeof(tflite::testing::input_data_16x1x1) / sizeof(float),
922       tflite::testing::golden_output_relu_16x1x1, golden_quantized,
923       sizeof(tflite::testing::golden_output_relu_16x1x1) / sizeof(float));
924 }
925 
926 TF_LITE_MICRO_TESTS_END
927