1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/micro/kernels/kernel_runner.h"
19 #include "tensorflow/lite/micro/test_helpers.h"
20 #include "tensorflow/lite/micro/testing/micro_test.h"
21 namespace tflite {
22 namespace testing {
23 namespace {
24
25 // naming as follows: <tensor name>_<input size>x<batch size>x<batch count>
26
27 // 10 inputs each with shape {2, 2}.
28 const float input_data_2x2x10[] = {
29 0.12609188, -0.46347019, 0.35867718, 0.36897406,
30
31 0.14278367, -1.64410412, -0.57290924, 0.12729003,
32
33 0.49837467, 0.19278903, 0.17660543, 0.52949083,
34
35 -0.11186574, 0.13164264, -0.72674477, -0.5683046,
36
37 -0.68892461, 0.37783599, -0.63690937, 0.44483393,
38
39 -0.81299269, -0.86831826, -0.95760226, 1.82078898,
40
41 -1.45006323, -0.82251364, -1.65087092, -1.89238167,
42
43 0.03966608, -0.24936394, 2.06740379, -1.51439476,
44
45 0.11771342, -0.23761693, 0.31088525, -1.55601168,
46
47 -0.89477462, 1.67204106, -0.6230064, 0.29819036,
48 };
49
50 // Feature filter of shape {8, 2}.
51 const float feature_weights_data_2x2x10[] = {
52 -0.31930989, 0.0079667, 0.39296314, 0.37613347, 0.12416199, 0.15785322,
53 0.27901134, 0.3905206, 0.21931258, -0.36137494, -0.10640851, 0.31053296,
54 -0.36118156, -0.0976817, -0.36916667, 0.22197971};
55
56 // Time filter of shape {8, 10}.
57 const float time_weights_data_2x2x10[] = {
58 -0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
59 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
60
61 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
62 -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
63
64 -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
65 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
66
67 -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
68 -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657,
69
70 -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486,
71 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187,
72
73 -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589,
74 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836,
75
76 -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277,
77 -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214,
78
79 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326,
80 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763};
81
82 // Activation state with shape {2, 80}. These initial values must be copied into
83 // a mutable activation state tensor.
84
85 const float initial_activation_state_data_2x2x10[] = {
86 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
87 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
88 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
89 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
90 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
91 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
92 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
93
94 // Bias with shape {8}
95 const float bias_data_2x2x10[] = {0, 0, 0, 0, 0, 0, 0, 0};
96
97 // 10 outputs each of shape {2, 4}
98 const float golden_output_2x2x10[] = {
99 -0.044205, -0.013757, 0.050369, -0.018447,
100 0.073010, 0.025142, -0.021154, 0.013551,
101
102 -0.209613, -0.062421, 0.150209, -0.108334,
103 0.028256, -0.006950, -0.030885, 0.009603,
104
105 -0.076800, -0.037075, -0.087198, -0.155183,
106 0.091069, 0.098446, -0.016083, 0.106475,
107
108 -0.082123, -0.162238, -0.084434, -0.141074,
109 -0.029340, -0.090685, 0.053302, -0.030604,
110
111 -0.201440, 0.088424, 0.139877, 0.012416,
112 -0.113212, 0.103893, -0.100842, 0.122780,
113
114 -0.166632, -0.116705, 0.175298, -0.047163,
115 0.313077, -0.166485, -0.285860, 0.129069,
116
117 -0.625911, 0.046134, 0.138081, -0.129581,
118 -0.521455, -0.061579, 0.230289, 0.114963,
119
120 -0.216693, -0.161643, -0.179177, -0.052599,
121 -0.213239, 0.029502, 0.260858, 0.275045,
122
123 -0.213689, -0.323608, -0.285635, -0.317687,
124 -0.324092, -0.317972, -0.208450, -0.462504,
125
126 -0.255126, -0.218576, -0.041528, 0.179421,
127 -0.440583, 0.072127, -0.284136, 0.241570};
128
129 // Simulated real-world inputs, weights and expected outputs.
130
131 // Input of shape {1x16}
132 const float input_data_16x1x1[] = {
133 -0.488494, 2.023762, -2.233117, -0.488494, 3.559030, 9.490748,
134 -3.210106, -1.953977, -0.279140, 0.907204, 1.674838, 0.000000,
135 -0.279140, -0.628064, -0.069785, -0.628064,
136 };
137
138 // Feature filter of shape {64, 16}.
139 const float feature_weights_data_16x1x1[] = {
140 0.173588, 0.173588, -0.024798, 0.193426, -0.099193, 0.044637, 0.183507,
141 0.183507, 0.044637, 0.198386, -0.069435, 0.084314, 0.312458, 0.024798,
142 0.173588, -0.049596, -0.352135, -0.550521, -0.009919, -0.099193, -0.074395,
143 -0.128951, 0.193426, 0.357095, -0.317418, -0.119032, -0.218225, -0.004960,
144 -0.386853, -0.133911, 0.252942, -0.019839, -0.024798, -0.054556, -0.069435,
145 -0.128951, 0.029758, -0.099193, -0.312458, -0.029758, 0.064475, 0.183507,
146 0.114072, -0.178547, -0.247982, -0.119032, 0.243023, -0.119032, -0.034718,
147 -0.178547, 0.019839, 0.128951, -0.223184, -0.009919, -0.213265, 0.168628,
148 -0.143830, -0.322377, -0.218225, -0.193426, -0.252942, -0.049596, 0.064475,
149 -0.267821, -0.580279, -0.099193, 0.213265, 0.119032, -0.119032, -0.178547,
150 0.610037, 0.109112, 0.049596, -0.014879, -0.049596, -0.193426, 0.039677,
151 -0.148789, -0.114072, -0.158709, -0.158709, 0.094233, 0.099193, -0.114072,
152 0.104153, -0.123991, 0.198386, -0.173588, 0.089274, -0.247982, -0.054556,
153 0.123991, 0.183507, 0.114072, 0.188467, 0.302539, 0.044637, 0.039677,
154 -0.099193, 0.168628, -0.024798, -0.054556, -0.109112, 0.014879, -0.009919,
155 0.069435, -0.396772, -0.287660, -0.079354, -0.104153, 0.054556, 0.089274,
156 -0.099193, 0.114072, 0.034718, 0.119032, 0.282700, -0.119032, -0.505884,
157 -0.233104, -0.114072, -0.257902, -0.233104, -0.178547, 0.153749, 0.128951,
158 0.143830, -0.188467, -0.183507, 0.104153, -0.024798, 0.193426, -0.287660,
159 0.168628, -0.009919, 0.119032, -0.024798, -0.099193, -0.203346, 0.099193,
160 0.084314, -0.168628, 0.123991, -0.148789, 0.114072, -0.029758, 0.228144,
161 -0.238063, 0.089274, -0.064475, 0.307498, -0.188467, -0.004960, -0.252942,
162 -0.173588, -0.158709, -0.044637, -0.009919, 0.312458, -0.262861, 0.059516,
163 0.158709, 0.069435, -0.282700, 0.074395, -0.322377, -0.183507, -0.123991,
164 -0.233104, 0.009919, 0.252942, -0.243023, 0.555481, -0.099193, -0.119032,
165 -0.441409, 0.148789, 0.084314, -0.168628, -0.183507, 0.188467, 0.024798,
166 -0.302539, 0.223184, 0.143830, -0.193426, -0.054556, -0.218225, -0.297579,
167 0.104153, 0.272781, -0.034718, 0.114072, -0.059516, 0.044637, 0.342216,
168 0.421570, 0.138870, -0.024798, -0.039677, -0.163668, -0.034718, 0.396772,
169 -0.128951, -0.044637, -0.173588, 0.302539, 0.079354, 0.049596, 0.133911,
170 -0.029758, -0.312458, -0.029758, 0.079354, 0.128951, 0.252942, 0.213265,
171 0.014879, 0.287660, 0.178547, 0.297579, 0.352135, 0.401732, 0.024798,
172 -0.277740, -0.411651, -0.069435, 0.342216, -0.158709, -0.104153, -0.009919,
173 0.223184, 0.228144, -0.019839, 0.059516, -0.104153, -0.510844, 0.029758,
174 -0.406691, 0.089274, 0.421570, 0.163668, -0.143830, -0.019839, -0.039677,
175 0.104153, -0.044637, -0.128951, 0.203346, 0.079354, -0.069435, 0.094233,
176 -0.138870, 0.466207, -0.163668, 0.049596, 0.029758, 0.267821, 0.029758,
177 -0.049596, 0.009919, 0.004960, -0.099193, 0.094233, -0.262861, 0.089274,
178 -0.302539, 0.332297, -0.307498, -0.014879, 0.168628, -0.094233, -0.272781,
179 0.034718, -0.133911, -0.228144, 0.094233, 0.257902, -0.228144, 0.153749,
180 -0.054556, -0.252942, 0.054556, 0.218225, -0.054556, 0.302539, 0.282700,
181 0.054556, -0.044637, -0.133911, 0.233104, -0.049596, 0.411651, 0.044637,
182 -0.297579, -0.029758, -0.114072, 0.114072, -0.580279, 0.079354, -0.024798,
183 -0.347175, -0.128951, -0.099193, 0.238063, -0.104153, -0.009919, 0.158709,
184 -0.034718, 0.123991, -0.163668, 0.059516, 0.342216, 0.009919, 0.064475,
185 -0.307498, -0.520763, -0.238063, 0.163668, 0.362054, 0.034718, -0.178547,
186 -0.104153, -0.257902, 0.322377, 0.054556, 0.148789, -0.178547, 0.084314,
187 0.004960, 0.257902, 0.029758, 0.079354, -0.223184, -0.193426, 0.282700,
188 0.000000, -0.019839, -0.114072, 0.491005, -0.193426, -0.029758, -0.243023,
189 0.009919, 0.089274, -0.277740, -0.089274, 0.104153, 0.337256, 0.138870,
190 -0.307498, -0.054556, 0.352135, 0.133911, -0.044637, 0.133911, -0.089274,
191 -0.357095, -0.272781, 0.069435, 0.059516, -0.109112, 0.148789, -0.044637,
192 -0.019839, -0.153749, 0.123991, -0.223184, 0.322377, 0.074395, -0.312458,
193 0.024798, -0.223184, 0.109112, -0.138870, 0.218225, -0.074395, -0.406691,
194 0.009919, -0.198386, -0.009919, 0.416611, 0.178547, 0.148789, 0.133911,
195 -0.004960, 0.069435, -0.054556, -0.044637, 0.297579, 0.059516, -0.456288,
196 -0.148789, -0.004960, 0.054556, 0.094233, -0.104153, 0.198386, -0.302539,
197 0.133911, 0.411651, 0.054556, 0.525723, -0.089274, 0.079354, 0.238063,
198 0.079354, -0.039677, 0.039677, 0.029758, 0.332297, -0.014879, -0.367014,
199 -0.143830, -0.123991, -0.064475, 0.014879, 0.173588, -0.168628, 0.386853,
200 0.009919, 0.173588, 0.163668, 0.123991, 0.163668, 0.198386, 0.203346,
201 -0.401732, -0.009919, 0.272781, -0.173588, 0.044637, 0.238063, 0.133911,
202 0.049596, 0.208305, -0.024798, 0.049596, -0.049596, 0.034718, -0.446368,
203 0.466207, -0.089274, -0.099193, -0.128951, -0.228144, 0.014879, -0.252942,
204 0.074395, -0.223184, -0.168628, -0.292619, 0.178547, 0.153749, -0.014879,
205 0.054556, 0.000000, 0.193426, 0.158709, 0.178547, -0.327337, -0.138870,
206 -0.114072, 0.168628, 0.297579, -0.109112, -0.029758, -0.029758, -0.416611,
207 0.059516, 0.000000, -0.168628, -0.322377, 0.238063, -0.128951, -0.029758,
208 0.500925, 0.292619, 0.123991, -0.099193, 0.074395, 0.317418, -0.148789,
209 0.064475, -0.104153, -0.044637, -0.094233, 0.188467, -0.044637, 0.213265,
210 -0.233104, -0.049596, 0.004960, -0.198386, 0.287660, -0.148789, -0.257902,
211 0.004960, -0.218225, -0.044637, -0.386853, -0.243023, -0.163668, 0.094233,
212 0.029758, -0.019839, -0.009919, -0.143830, -0.158709, 0.158709, -0.243023,
213 -0.039677, -0.297579, 0.069435, 0.049596, 0.302539, 0.059516, 0.074395,
214 -0.019839, 0.352135, -0.019839, -0.138870, -0.178547, -0.243023, 0.233104,
215 0.252942, -0.228144, -0.049596, 0.173588, 0.173588, -0.074395, -0.034718,
216 -0.292619, 0.362054, 0.183507, 0.243023, -0.203346, -0.044637, 0.054556,
217 0.059516, -0.158709, -0.158709, 0.000000, 0.327337, 0.119032, 0.034718,
218 -0.044637, -0.089274, 0.089274, -0.233104, 0.000000, -0.317418, 0.371974,
219 0.213265, 0.307498, -0.178547, -0.367014, 0.039677, -0.059516, 0.168628,
220 -0.014879, 0.143830, 0.123991, -0.084314, -0.332297, -0.416611, 0.183507,
221 0.109112, -0.039677, 0.014879, 0.292619, -0.213265, -0.054556, 0.004960,
222 0.123991, 0.119032, 0.000000, -0.332297, -0.312458, -0.198386, -0.213265,
223 0.119032, 0.322377, 0.168628, 0.104153, -0.262861, 0.327337, -0.049596,
224 -0.228144, -0.074395, 0.168628, 0.123991, 0.396772, 0.044637, 0.322377,
225 0.193426, 0.267821, -0.178547, 0.297579, 0.148789, -0.218225, -0.138870,
226 0.044637, 0.049596, 0.133911, 0.064475, 0.069435, 0.064475, -0.158709,
227 -0.044637, -0.173588, 0.267821, 0.327337, 0.079354, -0.228144, 0.029758,
228 0.014879, 0.198386, -0.109112, -0.133911, 0.431490, 0.099193, 0.421570,
229 0.233104, -0.054556, 0.054556, -0.317418, -0.133911, -0.123991, -0.287660,
230 0.342216, -0.049596, -0.153749, 0.228144, -0.213265, 0.262861, 0.406691,
231 -0.084314, -0.004960, 0.193426, 0.188467, -0.099193, -0.223184, 0.163668,
232 -0.257902, -0.153749, 0.441409, 0.099193, 0.128951, -0.089274, -0.208305,
233 -0.009919, -0.004960, -0.109112, 0.024798, -0.119032, 0.019839, 0.391812,
234 -0.024798, 0.198386, 0.327337, -0.505884, -0.099193, 0.510844, -0.148789,
235 0.094233, -0.153749, -0.039677, 0.352135, 0.272781, -0.228144, -0.287660,
236 -0.272781, 0.148789, 0.277740, 0.074395, 0.109112, -0.064475, 0.044637,
237 0.074395, -0.292619, 0.153749, -0.064475, -0.114072, 0.198386, -0.039677,
238 -0.128951, -0.004960, 0.257902, -0.228144, -0.094233, 0.064475, 0.014879,
239 0.188467, -0.416611, 0.099193, 0.362054, -0.208305, 0.198386, -0.079354,
240 0.009919, 0.119032, 0.332297, 0.243023, -0.168628, 0.158709, 0.039677,
241 0.143830, 0.277740, -0.168628, 0.009919, 0.099193, -0.004960, -0.257902,
242 -0.297579, 0.208305, -0.104153, 0.119032, 0.247982, 0.381893, -0.223184,
243 -0.367014, -0.327337, -0.168628, -0.094233, 0.208305, -0.019839, 0.183507,
244 0.084314, 0.133911, 0.109112, -0.148789, -0.183507, -0.411651, -0.024798,
245 -0.114072, -0.029758, -0.009919, 0.173588, -0.059516, -0.049596, 0.039677,
246 0.317418, 0.138870, -0.247982, -0.084314, 0.158709, 0.054556, -0.084314,
247 -0.049596, 0.074395, 0.019839, -0.282700, -0.119032, -0.262861, 0.163668,
248 -0.069435, -0.064475, -0.059516, 0.094233, 0.123991, -0.079354, -0.272781,
249 -0.267821, 0.233104, 0.114072, -0.218225, 0.540602, 0.089274, 0.262861,
250 0.079354, 0.267821, -0.119032, -0.109112, -0.128951, 0.128951, -0.044637,
251 -0.272781, 0.277740, 0.297579, -0.054556, -0.084314, -0.049596, 0.123991,
252 0.059516, 0.238063, -0.168628, -0.009919, 0.163668, -0.307498, 0.109112,
253 -0.064475, 0.218225, -0.168628, -0.004960, -0.168628, 0.119032, 0.094233,
254 -0.183507, -0.089274, -0.292619, -0.094233, 0.064475, -0.183507, -0.168628,
255 0.089274, 0.074395, -0.367014, -0.024798, -0.069435, 0.119032, -0.302539,
256 -0.376933, -0.123991, -0.009919, -0.069435, -0.208305, -0.119032, 0.014879,
257 -0.183507, -0.238063, 0.163668, -0.332297, -0.148789, -0.391812, -0.024798,
258 -0.133911, -0.059516, -0.123991, 0.123991, -0.292619, -0.044637, 0.059516,
259 -0.069435, 0.049596, -0.069435, 0.034718, 0.158709, -0.347175, -0.044637,
260 0.352135, -0.347175, -0.282700, -0.054556, 0.307498, 0.029758, 0.357095,
261 -0.148789, 0.208305, -0.317418, 0.009919, 0.004960, -0.243023, 0.049596,
262 -0.099193, 0.213265, -0.342216, 0.158709, 0.123991, -0.332297, 0.386853,
263 -0.262861, -0.208305, 0.123991, -0.044637, 0.148789, 0.084314, -0.297579,
264 -0.307498, -0.163668, 0.337256, -0.014879, 0.074395, 0.178547, -0.004960,
265 -0.257902, -0.019839, -0.228144, -0.034718, -0.277740, -0.158709, -0.119032,
266 -0.153749, 0.629876, 0.277740, 0.178547, -0.267821, -0.004960, 0.247982,
267 0.084314, -0.094233, 0.000000, -0.039677, 0.332297, 0.178547, 0.009919,
268 -0.213265, -0.208305, -0.044637, 0.019839, 0.218225, -0.297579, 0.014879,
269 -0.247982, -0.004960, -0.128951, 0.421570, -0.059516, 0.362054, -0.203346,
270 -0.143830, -0.099193, -0.024798, 0.094233, -0.123991, 0.163668, 0.109112,
271 -0.104153, -0.233104, 0.009919, -0.218225, 0.376933, 0.104153, -0.059516,
272 0.049596, -0.054556, 0.019839, -0.044637, -0.019839, 0.371974, -0.019839,
273 0.104153, 0.168628, -0.024798, -0.272781, -0.158709, 0.223184, 0.044637,
274 0.039677, -0.168628, -0.287660, -0.109112, 0.094233, -0.089274, -0.148789,
275 0.178547, -0.039677, -0.089274, -0.049596, -0.024798, 0.064475, -0.158709,
276 0.089274, 0.029758, -0.247982, 0.362054, 0.024798, -0.004960, -0.099193,
277 0.173588, -0.059516, 0.188467, -0.629876, 0.094233, 0.371974, 0.069435,
278 0.252942, -0.357095, -0.272781, -0.367014, 0.014879, -0.049596, -0.262861,
279 0.009919, -0.094233, -0.094233, 0.059516, 0.223184, 0.133911, 0.411651,
280 -0.044637, -0.044637, 0.109112, 0.228144, 0.386853, -0.233104, 0.069435,
281 0.228144, -0.302539, 0.029758, 0.089274, 0.044637, -0.238063, -0.138870,
282 -0.158709, -0.019839, 0.049596, 0.039677, 0.000000, -0.069435, 0.109112,
283 -0.213265, -0.188467, -0.262861, -0.267821, -0.094233, 0.133911, 0.391812,
284 0.123991, -0.317418, 0.233104, -0.029758, -0.099193, -0.193426, 0.074395,
285 -0.009919, 0.252942, 0.322377, -0.530683, 0.208305, 0.252942, 0.203346,
286 -0.069435, -0.262861};
287
288 // Time filter of shape {64, 8}.
289 const float time_weights_data_16x1x1[] = {
290 -0.052026, 0.043107, 0.053512, 0.013378, 0.011892, -0.182834, -0.108511,
291 0.153105, 0.050539, -0.173915, 0.145672, 0.208103, -0.221481, 0.108511,
292 -0.496475, 0.181347, -0.016351, -0.132294, -0.234859, -0.243778, 0.028243,
293 -0.228914, -0.130808, -0.167969, -0.041621, -0.306209, -0.193239, -0.028243,
294 -0.057972, -0.057972, -0.497962, 0.054999, 0.181347, 0.047566, -0.099592,
295 -0.111484, -0.130808, -0.071350, 0.380532, 0.010405, 0.041621, 0.052026,
296 0.022297, 0.081755, 0.098106, 0.099592, -0.584176, -0.023783, 0.062431,
297 -0.090674, -0.279453, -0.486070, -0.273507, 0.004459, -0.062431, 0.095133,
298 0.056485, 0.022297, -0.105538, -0.184320, 0.358235, 0.254183, 0.049053,
299 0.084728, 0.218508, 0.078782, -0.136754, -0.017837, -0.124862, -0.118916,
300 -0.001486, 0.043107, 0.254183, 0.087701, 0.261616, 0.309182, -0.404315,
301 -0.040134, -0.046080, -0.052026, -0.034188, -0.475665, -0.025270, -0.049053,
302 -0.046080, -0.062431, 0.020810, 0.040134, -0.135267, -0.169456, -0.050539,
303 -0.576743, 0.034188, 0.075809, 0.101079, 0.136754, 0.083241, 0.077296,
304 -0.050539, 0.761064, -0.335938, -0.080268, 0.025270, 0.257156, 0.227427,
305 0.252697, 0.065404, 0.115943, 0.222968, -0.026756, -0.054999, 0.107025,
306 -0.093646, 0.041621, -0.092160, -0.474178, -0.016351, 0.004459, 0.049053,
307 0.019324, 0.019324, 0.074323, 0.038648, -0.613905, 0.182834, 0.075809,
308 0.028243, 0.019324, 0.010405, -0.011892, 0.001486, -0.492016, -0.224454,
309 -0.474178, -0.147159, 0.002973, 0.102565, 0.136754, -0.267561, -0.001486,
310 -0.095133, -0.040134, 0.066890, 0.074323, 0.104052, 0.532150, 0.090674,
311 0.072836, -0.053512, -0.004459, 0.020810, 0.046080, 0.062431, 0.477151,
312 0.133781, -0.029729, -0.026756, 0.031215, 0.156077, 0.096619, 0.251210,
313 0.352289, 0.657012, 0.047566, -0.014865, -0.072836, -0.016351, 0.008919,
314 -0.053512, 0.016351, 0.300263, 0.047566, 0.020810, 0.169456, 0.001486,
315 0.007432, 0.111484, 0.044594, -0.188779, -0.096619, 0.074323, -0.040134,
316 0.160537, 0.138240, 0.184320, 0.377559, -0.092160, -0.049053, 0.056485,
317 -0.032702, 0.001486, -0.083241, -0.472692, -0.114457, -0.117430, -0.075809,
318 0.026756, 0.163510, 0.172428, 0.127835, -0.199185, -0.218508, -0.057972,
319 -0.132294, -0.162023, -0.019324, -0.245265, -0.395396, -0.254183, 0.084728,
320 0.248238, 0.191752, 0.221481, 0.173915, 0.173915, -0.208103, -0.077296,
321 0.384991, -0.313641, -0.313641, -0.147159, -0.090674, 0.035675, 0.059458,
322 -0.010405, 0.019324, 0.087701, 0.016351, 0.037161, 0.469719, -0.074323,
323 0.092160, 0.026756, 0.090674, 0.098106, 0.004459, -0.034188, 0.492016,
324 -0.367154, -0.093646, -0.063917, 0.041621, 0.017837, 0.026756, -0.062431,
325 -0.350803, 0.425125, 0.002973, 0.083241, 0.075809, 0.016351, 0.047566,
326 -0.185807, -0.107025, -0.098106, -0.144186, 0.255670, 0.020810, 0.105538,
327 0.029729, 0.129321, 0.156077, 0.141213, 0.334452, 0.147159, -0.066890,
328 0.035675, 0.115943, 0.240805, 0.328506, 0.162023, -0.237832, 0.218508,
329 0.233373, 0.214049, 0.099592, 0.026756, -0.322560, -0.236346, -0.166483,
330 0.225941, 0.109997, -0.147159, 0.147159, -0.266075, 0.111484, 0.078782,
331 -0.120403, 0.022297, -0.075809, -0.148645, -0.251210, -0.176888, -0.044594,
332 -0.023783, 0.016351, 0.026756, -0.013378, -0.069863, -0.112970, 0.013378,
333 0.086214, 0.014865, 0.352289, -0.240805, -0.135267, -0.114457, -0.472692,
334 0.334452, 0.095133, 0.047566, 0.130808, -0.068377, -0.007432, -0.130808,
335 -0.121889, -0.053512, -0.245265, -0.371613, -0.083241, 0.000000, -0.028243,
336 0.029729, -0.093646, -0.004459, -0.038648, -0.108511, -0.475665, -0.169456,
337 -0.047566, -0.010405, -0.114457, -0.353776, -0.034188, -0.044594, 0.041621,
338 -0.047566, -0.107025, 0.004459, 0.053512, 0.047566, -0.358235, -0.193239,
339 0.040134, -0.096619, -0.054999, 0.099592, 0.032702, 0.205130, -0.170942,
340 -0.237832, -0.405801, -0.126348, -0.072836, -0.203644, -0.169456, -0.093646,
341 -0.074323, 0.078782, 0.607959, -0.437017, -0.164996, -0.166483, 0.043107,
342 -0.016351, 0.258643, 0.065404, -0.057972, 0.017837, 0.080268, 0.050539,
343 -0.013378, -0.215536, -0.524718, 0.260129, 0.040134, -0.002973, -0.046080,
344 0.020810, 0.025270, 0.145672, 0.515799, 0.233373, 0.011892, 0.139727,
345 0.126348, 0.065404, -0.007432, -0.008919, 0.035675, 0.083241, 0.040134,
346 -0.005946, 0.503907, -0.490529, -0.181347, -0.092160, -0.038648, 0.019324,
347 0.133781, -0.011892, 0.041621, 0.062431, -0.062431, -0.040134, -0.092160,
348 -0.111484, -0.133781, -0.130808, -0.484583, -0.248238, 0.037161, -0.092160,
349 -0.056485, -0.041621, 0.112970, 0.248238, 0.438503, 0.258643, -0.013378,
350 0.004459, 0.043107, 0.040134, 0.017837, 0.101079, 0.264589, 0.212563,
351 0.014865, 0.285399, 0.153105, 0.170942, 0.358235, 0.334452, 0.086214,
352 0.132294, 0.098106, -0.001486, 0.107025, 0.200671, -0.026756, 0.344857,
353 0.227427, -0.041621, 0.098106, 0.063917, -0.093646, 0.130808, 0.285399,
354 -0.319587, 0.035675, -0.017837, -0.319587, 0.016351, -0.098106, -0.017837,
355 0.083241, 0.074323, -0.054999, 0.276480, 0.316614, -0.099592, -0.059458,
356 0.156077, -0.043107, 0.035675, 0.056485, -0.022297, 0.017837, -0.001486,
357 0.340398, 0.492016, 0.004459, 0.057972, -0.150132, -0.206617, -0.257156,
358 -0.248238, -0.080268, -0.164996, 0.352289, -0.054999, -0.056485, 0.010405,
359 -0.049053, -0.041621, -0.099592, 0.013378, -0.089187, 0.057972, -0.413234,
360 0.217022, 0.013378, -0.080268, -0.035675, 0.035675, 0.007432, 0.002973,
361 -0.469719, 0.141213, 0.136754, 0.153105, 0.130808, -0.104052, -0.508367,
362 -0.291345, -0.072836, -0.019324, -0.252697, -0.214049, -0.214049, 0.130808,
363 0.484583};
364
365 // Bias of shape {64}
366 const float bias_data_16x1x1[] = {
367 -0.245395, -0.083545, -0.262522, -0.407912, -0.560898, -0.364789, -0.037964,
368 -0.378594, 0.178152, 0.400380, -0.301349, -0.240913, -0.159454, -0.158757,
369 -0.073665, 0.455906, -0.061232, 0.318907, -0.226993, -0.344644, 0.140316,
370 0.559608, 0.109774, 0.437391, 0.113849, -0.162068, 0.039572, 0.569472,
371 0.460205, 0.113459, 0.370469, 0.176811, 0.203063, -0.296975, -0.271655,
372 0.059862, -0.159912, -0.077310, -0.338314, -0.195477, -0.256762, 0.233834,
373 0.083172, 0.029040, -0.236288, -0.267054, -0.166627, 0.188319, -0.271391,
374 -0.222920, 0.106463, 0.263614, 0.384986, -0.125957, -0.095890, 0.363686,
375 -0.036990, -0.358884, -0.178254, 0.305596, 0.390088, -0.189437, 0.613409,
376 0.399639};
377
378 // Activation state with shape {64, 8}. These initial values must be copied into
379 // a mutable activation state tensor.
380 const float initial_activation_state_data_16x1x1[] = {
381 -0.582275, -0.586623, -1.262373, -1.277279, -1.542175, -1.271999, -1.429757,
382 -1.184425, -0.462094, -1.443421, 0.230736, -0.494701, -0.354955, -2.534061,
383 -4.277471, -4.218467, 0.403711, -0.248748, -0.330111, -0.467683, 0.549047,
384 0.733511, -0.230115, 0.793136, -1.126353, -0.984123, -0.081984, -0.222351,
385 0.692830, 0.517060, 1.367958, 2.118860, -0.116766, -0.826365, -2.402700,
386 -2.313884, -2.898954, -2.076005, -2.405185, -2.755481, 0.329490, 0.085400,
387 -1.485966, -2.034702, -2.161405, -1.269515, -1.151818, -1.823841, 0.561469,
388 1.109273, 1.693411, -0.082605, -0.069252, -1.225107, -1.330693, -1.411435,
389 0.253406, -0.357439, -1.593415, -0.879779, -1.111136, 1.821357, 2.471952,
390 1.236908, -4.014127, -2.810448, -2.944604, -1.930980, -1.566398, -0.838166,
391 -0.319242, 0.749349, 1.156476, 0.658670, 1.997437, 2.080663, 2.912618,
392 2.677224, 2.642442, 2.796163, -0.272349, -0.473273, 3.120063, 2.747097,
393 3.595510, 1.874150, 2.049919, 2.093396, -1.049959, 0.277939, -1.255541,
394 -1.052443, -1.810177, -0.883505, -0.538178, 0.524203, -1.017662, -0.269244,
395 0.039129, -0.227941, -0.114592, -2.018243, -2.548968, -0.706804, 0.890959,
396 0.102480, 0.349986, 0.405885, 1.287216, 0.756181, 0.319242, -0.641590,
397 -3.841774, -2.716042, -4.342065, -3.826557, -2.924729, -1.643724, -1.237839,
398 -0.597492, -1.954892, -1.215169, -1.528201, -1.018904, -0.863941, -0.293467,
399 0.039439, 0.672023, 1.408019, 1.362679, 1.467644, 1.006171, 0.310236,
400 -0.249990, -1.048406, -0.752144, -1.831605, -1.058033, -1.096541, -0.293467,
401 0.051551, 0.232600, 0.088816, 2.570395, 0.704009, 2.465120, 3.010751,
402 2.139357, 0.630410, 1.006171, 1.545281, 1.486898, -1.162998, -2.344317,
403 -4.593918, -3.522842, -2.872247, -1.416714, -0.642521, -0.230115, 0.315205,
404 -0.368930, -0.162726, 0.396879, 0.505570, 0.534451, 0.554947, 1.270447,
405 0.388805, 0.531967, -1.243119, -0.671713, -1.214859, -0.238189, 0.016459,
406 -1.164550, 0.609603, 3.293348, 2.600208, 1.454290, -1.034121, -1.760179,
407 -1.192500, -0.613951, 3.449553, 2.912618, 1.917937, 1.435968, 0.879158,
408 1.118279, 0.102791, -0.502465, -0.239121, -0.092853, 1.786265, 1.943091,
409 2.547104, 2.630641, 2.585302, 2.965411, -0.945615, -2.538720, -2.474126,
410 -1.088156, 0.056209, 0.864873, 0.170490, 0.457435, 0.545941, 0.752765,
411 1.569503, 1.129459, 0.662086, -0.527929, -0.810838, -1.662978, 1.285042,
412 1.653040, 4.130893, 2.961995, 4.147041, 3.256393, 3.881524, 2.522571,
413 -0.875431, -1.112378, 2.105817, 2.180970, 3.121926, 1.577577, 1.639376,
414 2.906407, -0.142230, 0.421101, 2.212335, 2.311399, 3.993321, 3.651719,
415 4.206666, 4.678387, -1.304917, -1.130701, -2.543067, -2.500212, -2.197118,
416 -1.197158, -0.949652, -0.282908, 0.320795, -1.543728, 1.290322, 1.788128,
417 3.957297, 3.205774, 2.892432, 2.297114, 0.138814, -0.139435, 0.936920,
418 0.344707, 0.723263, -1.772290, -3.138385, -2.287177, -2.405806, -1.859864,
419 -4.572801, -3.410424, -3.855748, -2.239663, -2.269786, -1.582857, 4.238342,
420 3.858543, 2.499901, 1.087535, 0.290051, -0.026086, -0.880400, -2.602692,
421 -1.404292, 0.253096, -0.665502, -1.443421, -0.925119, -0.096580, 1.115484,
422 1.846200, -1.604284, -1.244671, -0.464888, 0.326385, 0.168006, -0.262723,
423 -0.744691, 0.953379, -0.407127, -0.349986, -1.154302, 0.831023, 1.590931,
424 2.538720, 2.063583, 3.697680, -0.752455, -1.293117, -1.330693, -1.869802,
425 -0.592523, 0.631652, 1.198089, -0.481347, 3.738983, 4.153252, 2.782499,
426 2.244321, 0.709289, 1.650245, 1.700865, 0.385078, 2.192460, 2.610456,
427 4.009780, 3.492719, 2.574743, 2.116687, 1.856138, 1.205853, 2.722563,
428 4.075305, 5.415935, 3.009198, 2.715421, 1.571056, 0.897170, -2.430339,
429 0.749970, 0.425760, -0.302783, 0.817359, 1.031636, 1.913589, 2.686229,
430 1.631923, -1.459259, -1.793097, -1.187531, -1.553355, -0.844998, -1.296843,
431 -1.805519, -0.486627, 0.909591, 2.082837, -1.473855, -2.456735, -3.851401,
432 -2.760139, -3.060438, -2.605487, -2.138735, -2.441519, -1.333177, -1.353984,
433 -0.245642, -0.588486, 0.033850, 2.084700, 0.076084, 0.690035, 0.747797,
434 0.594697, -1.016109, -1.348083, -1.201195, -1.088466, 2.045571, 2.460772,
435 0.717984, 0.041613, -0.721711, 1.134738, 2.322269, 1.112378, -0.307441,
436 -0.581033, -0.868599, -0.018633, 0.856488, 0.919839, 0.303094, -0.433213,
437 0.811148, -0.508986, -1.060828, -1.227591, -1.566087, -1.117968, -1.385038,
438 -2.011101, -0.490353, -1.849616, -0.594697, -1.055859, 1.110205, 0.622646,
439 0.145957, 0.359303, 1.012072, 0.774814, -0.400295, -1.484103, -2.007374,
440 -1.441247, -0.997787, -0.581033, -0.545941, -0.306510, 0.693451, 0.087264,
441 -0.227320, -1.211753, -1.532859, -1.688753, 0.065215, 0.134777, 0.608051,
442 -0.393152, -0.214588, -0.635689, -1.499320, 0.069562, -1.555839, -2.633126,
443 -2.966032, -1.550870, -0.101549, 0.874189, 0.436318, 0.299367, 2.289972,
444 2.339659, 2.602071, 1.564535, 0.019254, -0.583207, -1.295912, -2.424749,
445 -1.221070, -1.175109, -0.577306, -0.102791, 1.877876, 2.568222, 2.173827,
446 3.131243, 2.637784, 2.088737, 3.679047, 3.218506, 2.483442, 1.650556,
447 1.363611, -0.027328, 1.486898, -0.721711, -3.684327, -3.006093, -3.777491,
448 -2.327548, -2.737470, -4.549510, -0.060867, 0.127635, 0.680408, 0.581344,
449 0.320174, -0.403090, -0.838166, 0.293777, -0.995613, -0.165521, -0.419859,
450 1.110515, 1.203679, 1.749931, 2.467294, 4.276539, 0.031055, -0.967664,
451 1.167035, 1.865144, 3.221923, 3.248630, 4.121266, 4.187723, 0.749039,
452 -1.571056, 0.785994, 1.568572, 3.759479, 3.588678, 4.116608, 3.864444,
453 -0.290051, -0.271107, 0.375140, 0.537556, 0.536314, 0.095959, 0.054656,
454 0.088816};
455
456 // One output with shape {1, 64}
457 const float golden_output_16x1x1[] = {
458 -0.087914, 1.145864, -0.418088, -1.556392, -0.925298, 0.205252, 0.289119,
459 1.331180, -0.218010, 0.963057, -2.225886, 1.248478, 1.448983, 0.355467,
460 1.682174, 0.803739, 0.449738, 0.543566, 1.916269, -2.975136, 0.222774,
461 0.241589, -0.104216, 1.561748, 0.936818, -0.089907, -0.520117, -0.870353,
462 1.606074, 0.895770, 0.521297, -0.369994, -0.889351, -2.809309, 2.404628,
463 1.069754, -0.195456, -1.105652, 1.272715, -1.233177, 1.271416, -1.691805,
464 -1.058125, -0.716227, 0.052540, 1.262483, 0.540555, 1.735760, -0.539197,
465 -0.014367, -0.243002, 1.072254, 0.528985, -0.731151, -1.262649, 2.338702,
466 -0.603093, 0.970736, -3.567897, 0.035085, -0.201711, -0.550400, 1.545573,
467 -1.805005};
468
469 // One output with shape {1, 64}
470 const float golden_output_relu_16x1x1[] = {
471 0.000000, 1.145864, 0.000000, 0.000000, 0.000000, 0.205252, 0.289119,
472 1.331180, 0.000000, 0.963057, 0.000000, 1.248478, 1.448983, 0.355467,
473 1.682174, 0.803739, 0.449738, 0.543566, 1.916269, 0.000000, 0.222774,
474 0.241589, 0.000000, 1.561748, 0.936818, 0.000000, 0.000000, 0.000000,
475 1.606074, 0.895770, 0.521297, 0.000000, 0.000000, 0.000000, 2.404628,
476 1.069754, 0.000000, 0.000000, 1.272715, 0.000000, 1.271416, 0.000000,
477 0.000000, 0.000000, 0.052540, 1.262483, 0.540555, 1.735760, 0.000000,
478 0.000000, 0.000000, 1.072254, 0.528985, 0.000000, 0.000000, 2.338702,
479 0.000000, 0.970736, 0.000000, 0.035085, 0.000000, 0.000000, 1.545573,
480 0.000000};
481
482 template <typename T>
ValidateSVDFGoldens(const int batch_size,const int num_units,const int input_size,const int rank,TfLiteTensor * tensors,const int tensor_count,TfLiteFusedActivation activaiton,const T * input_sequences_data,const int input_sequences_len,T * output_data,const T * expected_output,float tolerance=1e-5f)483 void ValidateSVDFGoldens(const int batch_size, const int num_units,
484 const int input_size, const int rank,
485 TfLiteTensor* tensors, const int tensor_count,
486 TfLiteFusedActivation activaiton,
487 const T* input_sequences_data,
488 const int input_sequences_len, T* output_data,
489 const T* expected_output, float tolerance = 1e-5f) {
490 TfLiteSVDFParams params;
491 params.rank = rank;
492 params.activation = activaiton;
493
494 int inputs_array_data[] = {5, 0, 1, 2, 3, 4};
495 TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
496
497 int outputs_array_data[] = {1, 5};
498 TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
499
500 const TfLiteRegistration registration = Register_SVDF();
501 micro::KernelRunner runner(registration, tensors, tensor_count, inputs_array,
502 outputs_array, ¶ms);
503
504 TfLiteStatus init_and_prepare_status = runner.InitAndPrepare();
505 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, init_and_prepare_status);
506
507 // Abort early to make it clear init and prepare failed.
508 if (init_and_prepare_status != kTfLiteOk) {
509 return;
510 }
511
512 int num_inputs = input_sequences_len / (input_size * batch_size);
513
514 for (int i = 0; i < num_inputs; ++i) {
515 const T* input_batch_start =
516 input_sequences_data + i * input_size * batch_size;
517
518 memcpy(tensors[0].data.raw, input_batch_start, tensors[0].bytes);
519 TfLiteStatus status = runner.Invoke();
520 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, status);
521
522 // Only validate outputs when invoke has succeeded.
523 if (status == kTfLiteOk) {
524 int output_idx = 0;
525 int golden_idx = i * batch_size * num_units;
526 for (int j = golden_idx; j < golden_idx + batch_size * num_units; ++j) {
527 TF_LITE_MICRO_EXPECT_NEAR(expected_output[j], output_data[output_idx],
528 tolerance);
529 output_idx++;
530 }
531 }
532 }
533 }
534
535 #if !defined(XTENSA) // Needed to avoid build errors from unused functions.
TestSVDF(const int batch_size,const int num_units,const int input_size,const int memory_size,const int rank,TfLiteFusedActivation activation,float * input_data,const float * feature_weights_data,const float * time_weights_data,float * activation_state_data,const float * bias_data,float * scratch_data,float * output_data,const float * input_sequences_data,int input_sequences_len,const float * expected_output,float tolerance=1e-5f)536 void TestSVDF(const int batch_size, const int num_units, const int input_size,
537 const int memory_size, const int rank,
538 TfLiteFusedActivation activation, float* input_data,
539 const float* feature_weights_data, const float* time_weights_data,
540 float* activation_state_data, const float* bias_data,
541 float* scratch_data, float* output_data,
542 const float* input_sequences_data, int input_sequences_len,
543 const float* expected_output, float tolerance = 1e-5f) {
544 const int num_filters = num_units * rank;
545
546 int input_dims_arg[] = {2, batch_size, input_size};
547 TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_arg);
548
549 int feature_weights_dims_args[] = {2, num_filters, input_size};
550 TfLiteIntArray* feature_weights_dims =
551 IntArrayFromInts(feature_weights_dims_args);
552
553 int time_weights_dims_args[] = {2, num_filters, memory_size};
554 TfLiteIntArray* time_weights_dims = IntArrayFromInts(time_weights_dims_args);
555
556 int activation_state_dims_args[] = {2, batch_size, memory_size * num_filters};
557 TfLiteIntArray* activation_state_dims =
558 IntArrayFromInts(activation_state_dims_args);
559
560 int bias_dims_args[] = {1, num_units};
561 TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_args);
562
563 int output_dims_args[] = {2, batch_size, num_units};
564 TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_args);
565
566 const int tensor_count = 6; // 5 inputs, 1 output
567 TfLiteTensor tensors[] = {
568 CreateTensor(input_data, input_dims),
569 CreateTensor(feature_weights_data, feature_weights_dims),
570 CreateTensor(time_weights_data, time_weights_dims),
571 CreateTensor(bias_data, bias_dims),
572 CreateTensor(activation_state_data, activation_state_dims,
573 /*is_variable=*/true),
574 CreateTensor(output_data, output_dims),
575 };
576
577 ValidateSVDFGoldens(batch_size, num_units, input_size, rank, tensors,
578 tensor_count, activation, input_sequences_data,
579 input_sequences_len, output_data, expected_output,
580 tolerance);
581 }
582 #endif
583
584 // The pattern to this method's arguemnts is:
585 // <kernel metadata>
586 // for each tensor in
587 // {input, feature weights, time weights, bias, activation state, output}:
588 // <tensor float values> <tensor quantized buffer> <tensor quantization data>
TestIntegerSVDF(const int batch_size,const int num_units,const int input_size,const int memory_size,const int rank,TfLiteFusedActivation activation,int8_t * input_quantized,float input_scale,int input_zero_point,const float * feature_weights_data,int8_t * feature_weights_quantized,const float feature_weights_scale,const float * time_weights_data,int16_t * time_weights_quantized,float time_weights_scale,const float * bias_data,int32_t * bias_quantized,const float * initial_activation_state_data,int16_t * activation_state_quantized,float activation_state_scale,int8_t * output_data,float output_scale,int output_zero_point,const float * input_sequences_data,int8_t * input_sequences_quantized,const int input_sequences_len,const float * golden_output,int8_t * golden_output_quantized,int golden_output_len)589 inline void TestIntegerSVDF(
590 const int batch_size, const int num_units, const int input_size,
591 const int memory_size, const int rank, TfLiteFusedActivation activation,
592 int8_t* input_quantized, float input_scale, int input_zero_point,
593 const float* feature_weights_data, int8_t* feature_weights_quantized,
594 const float feature_weights_scale, const float* time_weights_data,
595 int16_t* time_weights_quantized, float time_weights_scale,
596 const float* bias_data, int32_t* bias_quantized,
597 const float* initial_activation_state_data,
598 int16_t* activation_state_quantized, float activation_state_scale,
599 int8_t* output_data, float output_scale, int output_zero_point,
600 const float* input_sequences_data, int8_t* input_sequences_quantized,
601 const int input_sequences_len, const float* golden_output,
602 int8_t* golden_output_quantized, int golden_output_len) {
603 const int num_filters = num_units * rank;
604
605 int input_dims_arg[] = {2, batch_size, input_size};
606 TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_arg);
607
608 int feature_weights_dims_args[] = {2, num_filters, input_size};
609 TfLiteIntArray* feature_weights_dims =
610 IntArrayFromInts(feature_weights_dims_args);
611
612 int time_weights_dims_args[] = {2, num_filters, memory_size};
613 TfLiteIntArray* time_weights_dims = IntArrayFromInts(time_weights_dims_args);
614
615 int bias_dims_data[] = {1, num_units};
616 TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data);
617
618 int activation_state_dims_args[] = {2, batch_size, memory_size * num_filters};
619 TfLiteIntArray* activation_state_dims =
620 IntArrayFromInts(activation_state_dims_args);
621
622 int output_dims_args[] = {2, batch_size, num_units};
623 TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_args);
624
625 const int tensor_count = 6; // 5 inputs, 1 output
626
627 TfLiteTensor tensors[] = {
628 CreateQuantizedTensor(input_quantized, input_dims, input_scale,
629 input_zero_point),
630 CreateQuantizedTensor(feature_weights_data, feature_weights_quantized,
631 feature_weights_dims, feature_weights_scale, 0),
632 CreateQuantizedTensor(time_weights_data, time_weights_quantized,
633 time_weights_dims, time_weights_scale, 0),
634 CreateQuantizedBiasTensor(bias_data, bias_quantized, bias_dims,
635 time_weights_scale, activation_state_scale),
636 CreateQuantizedTensor(initial_activation_state_data,
637 activation_state_quantized, activation_state_dims,
638 activation_state_scale, 0,
639 /*is_variable=*/true),
640 CreateQuantizedTensor(output_data, output_dims, output_scale,
641 output_zero_point)};
642
643 tflite::Quantize(golden_output, golden_output_quantized, golden_output_len,
644 output_scale, output_zero_point);
645 tflite::Quantize(input_sequences_data, input_sequences_quantized,
646 input_sequences_len, input_scale, input_zero_point);
647
648 ValidateSVDFGoldens(batch_size, num_units, input_size, rank, tensors,
649 tensor_count, activation, input_sequences_quantized,
650 input_sequences_len, output_data, golden_output_quantized,
651 /*tolerance*/ 1);
652 }
653
654 } // namespace
655 } // namespace testing
656 } // namespace tflite
657
658 TF_LITE_MICRO_TESTS_BEGIN
659
660 #if !defined(XTENSA) // TODO(b/170332589): xtensa kernels are less general than
661 // reference kernels and we ifdef out test cases that are
662 // currently known to fail.
TF_LITE_MICRO_TEST(SvdfFloat2x2Input2x4OutputShouldMatchGolden)663 TF_LITE_MICRO_TEST(SvdfFloat2x2Input2x4OutputShouldMatchGolden) {
664 constexpr int batch_size = 2;
665 constexpr int num_units = 4;
666 constexpr int input_size = 2;
667 constexpr int memory_size = 10;
668 constexpr int rank = 2;
669 constexpr int num_filters = num_units * rank;
670
671 const int input_size_dims_count = batch_size * input_size;
672 float input_data[input_size_dims_count];
673
674 const int activation_state_dims_count =
675 batch_size * memory_size * num_filters;
676 float activation_state_data[activation_state_dims_count];
677
678 memcpy(activation_state_data,
679 tflite::testing::initial_activation_state_data_2x2x10,
680 sizeof(tflite::testing::initial_activation_state_data_2x2x10));
681
682 const int scratch_dims_count = batch_size * num_filters;
683 float scratch_data[scratch_dims_count];
684
685 const int output_dims_count = batch_size * num_units;
686 float output_data[output_dims_count];
687
688 tflite::testing::TestSVDF(
689 batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone,
690 input_data, tflite::testing::feature_weights_data_2x2x10,
691 tflite::testing::time_weights_data_2x2x10, activation_state_data,
692 tflite::testing::bias_data_2x2x10, scratch_data, output_data,
693 tflite::testing::input_data_2x2x10,
694 sizeof(tflite::testing::input_data_2x2x10) / sizeof(float),
695 tflite::testing::golden_output_2x2x10);
696 }
697 #endif
698
TF_LITE_MICRO_TEST(SvdfQuantized2x2Input2x4OutputShouldMatchGolden)699 TF_LITE_MICRO_TEST(SvdfQuantized2x2Input2x4OutputShouldMatchGolden) {
700 constexpr int batch_size = 2;
701 constexpr int num_units = 4;
702 constexpr int input_size = 2;
703 constexpr int memory_size = 10;
704 constexpr int rank = 2;
705 constexpr int num_filters = num_units * rank;
706
707 const int input_size_dims_count = batch_size * input_size;
708
709 const int activation_state_dims_count =
710 batch_size * memory_size * num_filters;
711
712 const int output_dims_count = batch_size * num_units;
713 int8_t output_data[output_dims_count];
714
715 float input_scale = 2.5f / INT8_MAX; // Range is [-2.5, 2.5]
716 float feature_weights_scale = 1.f / INT8_MAX; // Range is [-1, 1]
717 float time_weights_scale = 1.f / INT16_MAX; // Range is [-1, 1]
718 float activation_state_scale = 16.f / INT16_MAX; // Range is [-16, 16]
719 float output_scale = 1.f / INT8_MAX; // Range is [-1, 1]
720
721 int input_zero_point = 0;
722 int output_zero_point = 0;
723
724 int8_t input_quantized[input_size_dims_count];
725 int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_2x2x10) /
726 sizeof(float)];
727 int8_t feature_weights_quantized
728 [sizeof(tflite::testing::feature_weights_data_2x2x10) / sizeof(float)];
729 int16_t
730 time_weights_quantized[sizeof(tflite::testing::time_weights_data_2x2x10) /
731 sizeof(float)];
732 int16_t activation_state_quantized[activation_state_dims_count];
733 int32_t
734 bias_quantized[sizeof(tflite::testing::bias_data_2x2x10) / sizeof(float)];
735 int8_t golden_quantized[sizeof(tflite::testing::golden_output_2x2x10) /
736 sizeof(float)];
737
738 tflite::testing::TestIntegerSVDF(
739 batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu,
740 input_quantized, input_scale, input_zero_point,
741 tflite::testing::feature_weights_data_2x2x10, feature_weights_quantized,
742 feature_weights_scale, tflite::testing::time_weights_data_2x2x10,
743 time_weights_quantized, time_weights_scale,
744 tflite::testing::bias_data_2x2x10, bias_quantized,
745 tflite::testing::initial_activation_state_data_2x2x10,
746 activation_state_quantized, activation_state_scale, output_data,
747 output_scale, output_zero_point, tflite::testing::input_data_2x2x10,
748 input_sequences_quantized,
749 sizeof(tflite::testing::input_data_2x2x10) / sizeof(float),
750 tflite::testing::golden_output_2x2x10, golden_quantized,
751 sizeof(tflite::testing::golden_output_2x2x10) / sizeof(float));
752 }
753
754 #if !defined(XTENSA) // TODO(b/170332589): xtensa kernels are less general than
755 // reference kernels and we ifdef out test cases that are
756 // currently known to fail.
TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputShouldMatchGolden)757 TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputShouldMatchGolden) {
758 constexpr int batch_size = 1;
759 constexpr int num_units = 64;
760 constexpr int input_size = 16;
761 constexpr int memory_size = 8;
762 constexpr int rank = 1;
763 constexpr int num_filters = num_units * rank;
764 constexpr int activation_state_dims_count =
765 batch_size * memory_size * num_filters;
766 constexpr int output_dims_count = batch_size * num_units;
767 constexpr int input_dims_count = batch_size * input_size;
768
769 float input_data[input_dims_count];
770 float output_data[output_dims_count];
771 float scratch_buffer[batch_size * num_filters];
772 float activation_state_data_mutable[activation_state_dims_count];
773
774 // Initialize activation state to starting values.
775 memcpy(activation_state_data_mutable,
776 tflite::testing::initial_activation_state_data_16x1x1,
777 sizeof(tflite::testing::initial_activation_state_data_16x1x1));
778
779 tflite::testing::TestSVDF(
780 batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone,
781 input_data, tflite::testing::feature_weights_data_16x1x1,
782 tflite::testing::time_weights_data_16x1x1, activation_state_data_mutable,
783 tflite::testing::bias_data_16x1x1, scratch_buffer, output_data,
784 tflite::testing::input_data_16x1x1, input_size,
785 tflite::testing::golden_output_16x1x1);
786 }
787
TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputReluShouldMatchGolden)788 TF_LITE_MICRO_TEST(SvdfFloat1x16Input64x1OutputReluShouldMatchGolden) {
789 constexpr int batch_size = 1;
790 constexpr int num_units = 64;
791 constexpr int input_size = 16;
792 constexpr int memory_size = 8;
793 constexpr int rank = 1;
794 constexpr int num_filters = num_units * rank;
795 constexpr int activation_state_dims_count =
796 batch_size * memory_size * num_filters;
797 constexpr int output_dims_count = batch_size * num_units;
798 constexpr int input_dims_count = batch_size * input_size;
799
800 float input_data[input_dims_count];
801 float output_data[output_dims_count];
802 float scratch_buffer[batch_size * num_filters];
803 float activation_state_data_mutable[activation_state_dims_count];
804
805 // Initialize activation state to starting values.
806 memcpy(activation_state_data_mutable,
807 tflite::testing::initial_activation_state_data_16x1x1,
808 sizeof(tflite::testing::initial_activation_state_data_16x1x1));
809
810 tflite::testing::TestSVDF(
811 batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu,
812 input_data, tflite::testing::feature_weights_data_16x1x1,
813 tflite::testing::time_weights_data_16x1x1, activation_state_data_mutable,
814 tflite::testing::bias_data_16x1x1, scratch_buffer, output_data,
815 tflite::testing::input_data_16x1x1, input_size,
816 tflite::testing::golden_output_relu_16x1x1);
817 }
818 #endif
819
TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputShouldMatchGolden)820 TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputShouldMatchGolden) {
821 constexpr int batch_size = 1;
822 constexpr int num_units = 64;
823 constexpr int input_size = 16;
824 constexpr int memory_size = 8;
825 constexpr int rank = 1;
826 constexpr int num_filters = num_units * rank;
827 constexpr int activation_state_dims_count =
828 batch_size * memory_size * num_filters;
829 constexpr int output_dims_count = batch_size * num_units;
830 constexpr int input_dims_count = batch_size * input_size;
831
832 int8_t output_data[output_dims_count];
833
834 float input_scale = 0.10075444;
835 float feature_weights_scale = 0.00649388;
836 float time_weights_scale = 0.001571355;
837 float activation_state_scale = 0.00045896982;
838 float output_scale = 0.051445257;
839
840 int input_zero_point = 2;
841 int output_zero_point = 0;
842
843 int8_t input_quantized[input_dims_count];
844 int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_16x1x1) /
845 sizeof(float)];
846 int8_t feature_weights_quantized
847 [sizeof(tflite::testing::feature_weights_data_16x1x1) / sizeof(float)];
848 int16_t
849 time_weights_quantized[sizeof(tflite::testing::time_weights_data_16x1x1) /
850 sizeof(float)];
851 int16_t activation_state_quantized[activation_state_dims_count];
852 int32_t
853 bias_quantized[sizeof(tflite::testing::bias_data_16x1x1) / sizeof(float)];
854 int8_t golden_quantized[sizeof(tflite::testing::golden_output_16x1x1) /
855 sizeof(float)];
856
857 tflite::testing::TestIntegerSVDF(
858 batch_size, num_units, input_size, memory_size, rank, kTfLiteActNone,
859 input_quantized, input_scale, input_zero_point,
860 tflite::testing::feature_weights_data_16x1x1, feature_weights_quantized,
861 feature_weights_scale, tflite::testing::time_weights_data_16x1x1,
862 time_weights_quantized, time_weights_scale,
863 tflite::testing::bias_data_16x1x1, bias_quantized,
864 tflite::testing::initial_activation_state_data_16x1x1,
865 activation_state_quantized, activation_state_scale, output_data,
866 output_scale, output_zero_point, tflite::testing::input_data_16x1x1,
867 input_sequences_quantized,
868 sizeof(tflite::testing::input_data_16x1x1) / sizeof(float),
869 tflite::testing::golden_output_16x1x1, golden_quantized,
870 sizeof(tflite::testing::golden_output_16x1x1) / sizeof(float));
871 }
872
TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputReluShouldMatchGolden)873 TF_LITE_MICRO_TEST(SvdfQuantized1x16Input64x1OutputReluShouldMatchGolden) {
874 constexpr int batch_size = 1;
875 constexpr int num_units = 64;
876 constexpr int input_size = 16;
877 constexpr int memory_size = 8;
878 constexpr int rank = 1;
879 constexpr int num_filters = num_units * rank;
880 constexpr int activation_state_dims_count =
881 batch_size * memory_size * num_filters;
882 constexpr int output_dims_count = batch_size * num_units;
883 constexpr int input_dims_count = batch_size * input_size;
884
885 int8_t output_data[output_dims_count];
886
887 float input_scale = 0.10075444;
888 float feature_weights_scale = 0.00649388;
889 float time_weights_scale = 0.001571355;
890 float activation_state_scale = 0.00045896982;
891 float output_scale = 0.051445257;
892
893 int input_zero_point = 2;
894 int output_zero_point = -128;
895
896 int8_t input_quantized[input_dims_count];
897 int8_t input_sequences_quantized[sizeof(tflite::testing::input_data_16x1x1) /
898 sizeof(float)];
899 int8_t feature_weights_quantized
900 [sizeof(tflite::testing::feature_weights_data_16x1x1) / sizeof(float)];
901 int16_t
902 time_weights_quantized[sizeof(tflite::testing::time_weights_data_16x1x1) /
903 sizeof(float)];
904 int16_t activation_state_quantized[activation_state_dims_count];
905 int32_t
906 bias_quantized[sizeof(tflite::testing::bias_data_16x1x1) / sizeof(float)];
907 int8_t golden_quantized[sizeof(tflite::testing::golden_output_relu_16x1x1) /
908 sizeof(float)];
909
910 tflite::testing::TestIntegerSVDF(
911 batch_size, num_units, input_size, memory_size, rank, kTfLiteActRelu,
912 input_quantized, input_scale, input_zero_point,
913 tflite::testing::feature_weights_data_16x1x1, feature_weights_quantized,
914 feature_weights_scale, tflite::testing::time_weights_data_16x1x1,
915 time_weights_quantized, time_weights_scale,
916 tflite::testing::bias_data_16x1x1, bias_quantized,
917 tflite::testing::initial_activation_state_data_16x1x1,
918 activation_state_quantized, activation_state_scale, output_data,
919 output_scale, output_zero_point, tflite::testing::input_data_16x1x1,
920 input_sequences_quantized,
921 sizeof(tflite::testing::input_data_16x1x1) / sizeof(float),
922 tflite::testing::golden_output_relu_16x1x1, golden_quantized,
923 sizeof(tflite::testing::golden_output_relu_16x1x1) / sizeof(float));
924 }
925
926 TF_LITE_MICRO_TESTS_END
927