1 /*
2 * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 */
7
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math.h>
12 #include "../../common/test_common.h"
13
14 #include "f32.pat"
15
16 #define DECLARE_COMMON_VARS(in_dims, in_param) \
17 const int16_t *dims = in_dims; \
18 const float32_t *params = (const float32_t *)in_param; \
19 const int32_t classes[2] = { dims[1], dims[2] }; \
20 const uint16_t sample_count = dims[3]; \
21 const uint16_t vec_dims = dims[4]; \
22 const uint16_t svec_count = dims[5]; \
23 const float32_t intercept = \
24 params[svec_count + (vec_dims * svec_count)]; \
25 const float32_t *svec = params; \
26 const float32_t *dual_coeff = params + (vec_dims * svec_count)
27
28 #define DECLARE_POLY_VARS() \
29 const uint16_t degree = dims[6]; \
30 const float32_t coeff0 = \
31 params[svec_count + (vec_dims * svec_count) + 1]; \
32 const float32_t gamma = \
33 params[svec_count + (vec_dims * svec_count) + 2]
34
35 #define DECLARE_RBF_VARS() \
36 const float32_t gamma = \
37 params[svec_count + (vec_dims * svec_count) + 1]
38
39 #define DECLARE_SIGMOID_VARS() \
40 const float32_t coeff0 = \
41 params[svec_count + (vec_dims * svec_count) + 1]; \
42 const float32_t gamma = \
43 params[svec_count + (vec_dims * svec_count) + 2]
44
ZTEST(svm_f32,test_arm_svm_linear_predict_f32)45 ZTEST(svm_f32, test_arm_svm_linear_predict_f32)
46 {
47 DECLARE_COMMON_VARS(in_linear_dims, in_linear_param);
48
49 arm_svm_linear_instance_f32 inst;
50 size_t index;
51 const size_t length = ARRAY_SIZE(ref_linear);
52 const float32_t *input = (const float32_t *)in_linear_val;
53 int32_t *output, *output_buf;
54
55 /* Initialise instance */
56 arm_svm_linear_init_f32(&inst, svec_count, vec_dims,
57 intercept, dual_coeff, svec, classes);
58
59 /* Allocate output buffer */
60 output_buf = malloc(length * sizeof(int32_t));
61 zassert_not_null(output_buf, ASSERT_MSG_BUFFER_ALLOC_FAILED);
62
63 output = output_buf;
64
65 /* Enumerate samples */
66 for (index = 0; index < sample_count; index++) {
67 /* Run test function */
68 arm_svm_linear_predict_f32(&inst, input, output);
69
70 /* Increment pointers */
71 input += vec_dims;
72 output++;
73 }
74
75 /* Validate output */
76 zassert_true(
77 test_equal_q31(length, output_buf, ref_linear),
78 ASSERT_MSG_INCORRECT_COMP_RESULT);
79
80 /* Free output buffer */
81 free(output_buf);
82 }
83
ZTEST(svm_f32,test_arm_svm_polynomial_predict_f32)84 ZTEST(svm_f32, test_arm_svm_polynomial_predict_f32)
85 {
86 DECLARE_COMMON_VARS(in_polynomial_dims, in_polynomial_param);
87 DECLARE_POLY_VARS();
88
89 arm_svm_polynomial_instance_f32 inst;
90 size_t index;
91 const size_t length = ARRAY_SIZE(ref_polynomial);
92 const float32_t *input = (const float32_t *)in_polynomial_val;
93 int32_t *output, *output_buf;
94
95 /* Initialise instance */
96 arm_svm_polynomial_init_f32(
97 &inst, svec_count, vec_dims,
98 intercept, dual_coeff, svec, classes,
99 degree, coeff0, gamma);
100
101 /* Allocate output buffer */
102 output_buf = malloc(length * sizeof(int32_t));
103 zassert_not_null(output_buf, ASSERT_MSG_BUFFER_ALLOC_FAILED);
104
105 output = output_buf;
106
107 /* Enumerate samples */
108 for (index = 0; index < sample_count; index++) {
109 /* Run test function */
110 arm_svm_polynomial_predict_f32(&inst, input, output);
111
112 /* Increment pointers */
113 input += vec_dims;
114 output++;
115 }
116
117 /* Validate output */
118 zassert_true(
119 test_equal_q31(length, output_buf, ref_polynomial),
120 ASSERT_MSG_INCORRECT_COMP_RESULT);
121
122 /* Free output buffer */
123 free(output_buf);
124 }
125
ZTEST(svm_f32,test_arm_svm_rbf_predict_f32)126 ZTEST(svm_f32, test_arm_svm_rbf_predict_f32)
127 {
128 DECLARE_COMMON_VARS(in_rbf_dims, in_rbf_param);
129 DECLARE_RBF_VARS();
130
131 arm_svm_rbf_instance_f32 inst;
132 size_t index;
133 const size_t length = ARRAY_SIZE(ref_rbf);
134 const float32_t *input = (const float32_t *)in_rbf_val;
135 int32_t *output, *output_buf;
136
137 /* Initialise instance */
138 arm_svm_rbf_init_f32(
139 &inst, svec_count, vec_dims,
140 intercept, dual_coeff, svec, classes, gamma);
141
142 /* Allocate output buffer */
143 output_buf = malloc(length * sizeof(int32_t));
144 zassert_not_null(output_buf, ASSERT_MSG_BUFFER_ALLOC_FAILED);
145
146 output = output_buf;
147
148 /* Enumerate samples */
149 for (index = 0; index < sample_count; index++) {
150 /* Run test function */
151 arm_svm_rbf_predict_f32(&inst, input, output);
152
153 /* Increment pointers */
154 input += vec_dims;
155 output++;
156 }
157
158 /* Validate output */
159 zassert_true(
160 test_equal_q31(length, output_buf, ref_rbf),
161 ASSERT_MSG_INCORRECT_COMP_RESULT);
162
163 /* Free output buffer */
164 free(output_buf);
165 }
166
ZTEST(svm_f32,test_arm_svm_sigmoid_predict_f32)167 ZTEST(svm_f32, test_arm_svm_sigmoid_predict_f32)
168 {
169 DECLARE_COMMON_VARS(in_sigmoid_dims, in_sigmoid_param);
170 DECLARE_SIGMOID_VARS();
171
172 arm_svm_sigmoid_instance_f32 inst;
173 size_t index;
174 const size_t length = ARRAY_SIZE(ref_sigmoid);
175 const float32_t *input = (const float32_t *)in_sigmoid_val;
176 int32_t *output, *output_buf;
177
178 /* Initialise instance */
179 arm_svm_sigmoid_init_f32(
180 &inst, svec_count, vec_dims,
181 intercept, dual_coeff, svec, classes, coeff0, gamma);
182
183 /* Allocate output buffer */
184 output_buf = malloc(length * sizeof(int32_t));
185 zassert_not_null(output_buf, ASSERT_MSG_BUFFER_ALLOC_FAILED);
186
187 output = output_buf;
188
189 /* Enumerate samples */
190 for (index = 0; index < sample_count; index++) {
191 /* Run test function */
192 arm_svm_sigmoid_predict_f32(&inst, input, output);
193
194 /* Increment pointers */
195 input += vec_dims;
196 output++;
197 }
198
199 /* Validate output */
200 zassert_true(
201 test_equal_q31(length, output_buf, ref_sigmoid),
202 ASSERT_MSG_INCORRECT_COMP_RESULT);
203
204 /* Free output buffer */
205 free(output_buf);
206 }
207
ZTEST(svm_f32,test_arm_svm_oneclass_predict_f32)208 ZTEST(svm_f32, test_arm_svm_oneclass_predict_f32)
209 {
210 DECLARE_COMMON_VARS(in_oneclass_dims, in_oneclass_param);
211
212 arm_svm_linear_instance_f32 inst;
213 size_t index;
214 const size_t length = ARRAY_SIZE(ref_oneclass);
215 const float32_t *input = (const float32_t *)in_oneclass_val;
216 int32_t *output, *output_buf;
217
218 /* Initialise instance */
219 arm_svm_linear_init_f32(&inst, svec_count, vec_dims,
220 intercept, dual_coeff, svec, classes);
221
222 /* Allocate output buffer */
223 output_buf = malloc(length * sizeof(int32_t));
224 zassert_not_null(output_buf, ASSERT_MSG_BUFFER_ALLOC_FAILED);
225
226 output = output_buf;
227
228 /* Enumerate samples */
229 for (index = 0; index < sample_count; index++) {
230 /* Run test function */
231 arm_svm_linear_predict_f32(&inst, input, output);
232
233 /* Increment pointers */
234 input += vec_dims;
235 output++;
236 }
237
238 /* Validate output */
239 zassert_true(
240 test_equal_q31(length, output_buf, ref_oneclass),
241 ASSERT_MSG_INCORRECT_COMP_RESULT);
242
243 /* Free output buffer */
244 free(output_buf);
245 }
246
247 ZTEST_SUITE(svm_f32, NULL, NULL, NULL, NULL, NULL);
248