1 /*
2  * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math_f16.h>
12 #include "../../common/test_common.h"
13 
14 #include "f16.pat"
15 
16 #define DECLARE_COMMON_VARS(in_dims, in_param) \
17 	const int16_t *dims = in_dims; \
18 	const float16_t *params = (const float16_t *)in_param; \
19 	const int32_t classes[2] = { dims[1], dims[2] }; \
20 	const uint16_t sample_count = dims[3]; \
21 	const uint16_t vec_dims = dims[4]; \
22 	const uint16_t svec_count = dims[5]; \
23 	const float16_t intercept = \
24 		params[svec_count + (vec_dims * svec_count)]; \
25 	const float16_t *svec = params; \
26 	const float16_t *dual_coeff = params + (vec_dims * svec_count)
27 
28 #define DECLARE_POLY_VARS() \
29 	const uint16_t degree = dims[6]; \
30 	const float16_t coeff0 = \
31 		params[svec_count + (vec_dims * svec_count) + 1]; \
32 	const float16_t gamma = \
33 		params[svec_count + (vec_dims * svec_count) + 2]
34 
35 #define DECLARE_RBF_VARS() \
36 	const float16_t gamma = \
37 		params[svec_count + (vec_dims * svec_count) + 1]
38 
39 #define DECLARE_SIGMOID_VARS() \
40 	const float16_t coeff0 = \
41 		params[svec_count + (vec_dims * svec_count) + 1]; \
42 	const float16_t gamma = \
43 		params[svec_count + (vec_dims * svec_count) + 2]
44 
ZTEST(svm_f16,test_arm_svm_linear_predict_f16)45 ZTEST(svm_f16, test_arm_svm_linear_predict_f16)
46 {
47 	DECLARE_COMMON_VARS(in_linear_dims, in_linear_param);
48 
49 	arm_svm_linear_instance_f16 inst;
50 	size_t index;
51 	const size_t length = ARRAY_SIZE(ref_linear);
52 	const float16_t *input = (const float16_t *)in_linear_val;
53 	int32_t *output, *output_buf;
54 
55 	/* Initialise instance */
56 	arm_svm_linear_init_f16(&inst, svec_count, vec_dims,
57 		intercept, dual_coeff, svec, classes);
58 
59 	/* Allocate output buffer */
60 	output_buf = malloc(length * sizeof(int32_t));
61 	zassert_not_null(output_buf, ASSERT_MSG_BUFFER_ALLOC_FAILED);
62 
63 	output = output_buf;
64 
65 	/* Enumerate samples */
66 	for (index = 0; index < sample_count; index++) {
67 		/* Run test function */
68 		arm_svm_linear_predict_f16(&inst, input, output);
69 
70 		/* Increment pointers */
71 		input += vec_dims;
72 		output++;
73 	}
74 
75 	/* Validate output */
76 	zassert_true(
77 		test_equal_q31(length, output_buf, ref_linear),
78 		ASSERT_MSG_INCORRECT_COMP_RESULT);
79 
80 	/* Free output buffer */
81 	free(output_buf);
82 }
83 
ZTEST(svm_f16,test_arm_svm_polynomial_predict_f16)84 ZTEST(svm_f16, test_arm_svm_polynomial_predict_f16)
85 {
86 	DECLARE_COMMON_VARS(in_polynomial_dims, in_polynomial_param);
87 	DECLARE_POLY_VARS();
88 
89 	arm_svm_polynomial_instance_f16 inst;
90 	size_t index;
91 	const size_t length = ARRAY_SIZE(ref_polynomial);
92 	const float16_t *input = (const float16_t *)in_polynomial_val;
93 	int32_t *output, *output_buf;
94 
95 	/* Initialise instance */
96 	arm_svm_polynomial_init_f16(
97 		&inst, svec_count, vec_dims,
98 		intercept, dual_coeff, svec, classes,
99 		degree, coeff0, gamma);
100 
101 	/* Allocate output buffer */
102 	output_buf = malloc(length * sizeof(int32_t));
103 	zassert_not_null(output_buf, ASSERT_MSG_BUFFER_ALLOC_FAILED);
104 
105 	output = output_buf;
106 
107 	/* Enumerate samples */
108 	for (index = 0; index < sample_count; index++) {
109 		/* Run test function */
110 		arm_svm_polynomial_predict_f16(&inst, input, output);
111 
112 		/* Increment pointers */
113 		input += vec_dims;
114 		output++;
115 	}
116 
117 	/* Validate output */
118 	zassert_true(
119 		test_equal_q31(length, output_buf, ref_polynomial),
120 		ASSERT_MSG_INCORRECT_COMP_RESULT);
121 
122 	/* Free output buffer */
123 	free(output_buf);
124 }
125 
ZTEST(svm_f16,test_arm_svm_rbf_predict_f16)126 ZTEST(svm_f16, test_arm_svm_rbf_predict_f16)
127 {
128 	DECLARE_COMMON_VARS(in_rbf_dims, in_rbf_param);
129 	DECLARE_RBF_VARS();
130 
131 	arm_svm_rbf_instance_f16 inst;
132 	size_t index;
133 	const size_t length = ARRAY_SIZE(ref_rbf);
134 	const float16_t *input = (const float16_t *)in_rbf_val;
135 	int32_t *output, *output_buf;
136 
137 	/* Initialise instance */
138 	arm_svm_rbf_init_f16(
139 		&inst, svec_count, vec_dims,
140 		intercept, dual_coeff, svec, classes, gamma);
141 
142 	/* Allocate output buffer */
143 	output_buf = malloc(length * sizeof(int32_t));
144 	zassert_not_null(output_buf, ASSERT_MSG_BUFFER_ALLOC_FAILED);
145 
146 	output = output_buf;
147 
148 	/* Enumerate samples */
149 	for (index = 0; index < sample_count; index++) {
150 		/* Run test function */
151 		arm_svm_rbf_predict_f16(&inst, input, output);
152 
153 		/* Increment pointers */
154 		input += vec_dims;
155 		output++;
156 	}
157 
158 	/* Validate output */
159 	zassert_true(
160 		test_equal_q31(length, output_buf, ref_rbf),
161 		ASSERT_MSG_INCORRECT_COMP_RESULT);
162 
163 	/* Free output buffer */
164 	free(output_buf);
165 }
166 
ZTEST(svm_f16,test_arm_svm_sigmoid_predict_f16)167 ZTEST(svm_f16, test_arm_svm_sigmoid_predict_f16)
168 {
169 	DECLARE_COMMON_VARS(in_sigmoid_dims, in_sigmoid_param);
170 	DECLARE_SIGMOID_VARS();
171 
172 	arm_svm_sigmoid_instance_f16 inst;
173 	size_t index;
174 	const size_t length = ARRAY_SIZE(ref_sigmoid);
175 	const float16_t *input = (const float16_t *)in_sigmoid_val;
176 	int32_t *output, *output_buf;
177 
178 	/* Initialise instance */
179 	arm_svm_sigmoid_init_f16(
180 		&inst, svec_count, vec_dims,
181 		intercept, dual_coeff, svec, classes, coeff0, gamma);
182 
183 	/* Allocate output buffer */
184 	output_buf = malloc(length * sizeof(int32_t));
185 	zassert_not_null(output_buf, ASSERT_MSG_BUFFER_ALLOC_FAILED);
186 
187 	output = output_buf;
188 
189 	/* Enumerate samples */
190 	for (index = 0; index < sample_count; index++) {
191 		/* Run test function */
192 		arm_svm_sigmoid_predict_f16(&inst, input, output);
193 
194 		/* Increment pointers */
195 		input += vec_dims;
196 		output++;
197 	}
198 
199 	/* Validate output */
200 	zassert_true(
201 		test_equal_q31(length, output_buf, ref_sigmoid),
202 		ASSERT_MSG_INCORRECT_COMP_RESULT);
203 
204 	/* Free output buffer */
205 	free(output_buf);
206 }
207 
ZTEST(svm_f16,test_arm_svm_oneclass_predict_f16)208 ZTEST(svm_f16, test_arm_svm_oneclass_predict_f16)
209 {
210 	DECLARE_COMMON_VARS(in_oneclass_dims, in_oneclass_param);
211 
212 	arm_svm_linear_instance_f16 inst;
213 	size_t index;
214 	const size_t length = ARRAY_SIZE(ref_oneclass);
215 	const float16_t *input = (const float16_t *)in_oneclass_val;
216 	int32_t *output, *output_buf;
217 
218 	/* Initialise instance */
219 	arm_svm_linear_init_f16(&inst, svec_count, vec_dims,
220 		intercept, dual_coeff, svec, classes);
221 
222 	/* Allocate output buffer */
223 	output_buf = malloc(length * sizeof(int32_t));
224 	zassert_not_null(output_buf, ASSERT_MSG_BUFFER_ALLOC_FAILED);
225 
226 	output = output_buf;
227 
228 	/* Enumerate samples */
229 	for (index = 0; index < sample_count; index++) {
230 		/* Run test function */
231 		arm_svm_linear_predict_f16(&inst, input, output);
232 
233 		/* Increment pointers */
234 		input += vec_dims;
235 		output++;
236 	}
237 
238 	/* Validate output */
239 	zassert_true(
240 		test_equal_q31(length, output_buf, ref_oneclass),
241 		ASSERT_MSG_INCORRECT_COMP_RESULT);
242 
243 	/* Free output buffer */
244 	free(output_buf);
245 }
246 
247 ZTEST_SUITE(svm_f16, NULL, NULL, NULL, NULL, NULL);
248