1 /*
2  * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math_f16.h>
12 #include "../../common/test_common.h"
13 
14 #include "f16.pat"
15 
16 #define REL_ERROR_THRESH	(5e-3)
17 #define REL_JS_ERROR_THRESH	(3e-2)
18 #define REL_MK_ERROR_THRESH	(1e-2)
19 
20 #define DIMS_IN			(dims[0])
21 #define DIMS_VEC		(dims[1])
22 
23 #define OP_BRAYCURTIS		(0)
24 #define OP_CANBERRA		(1)
25 #define OP_CHEBYSHEV		(2)
26 #define OP_CITYBLOCK		(3)
27 #define OP_CORRELATION		(4)
28 #define OP_COSINE		(5)
29 #define OP_EUCLIDEAN		(6)
30 #define OP_JENSENSHANNON	(7)
31 #define OP_MINKOWSKI		(8)
32 
33 ZTEST_SUITE(distance_f16, NULL, NULL, NULL, NULL, NULL);
34 
test_arm_distance_f16(int op,bool scratchy,const uint16_t * dims,const uint16_t * dinput1,const uint16_t * dinput2,const uint16_t * ref)35 static void test_arm_distance_f16(int op, bool scratchy, const uint16_t *dims,
36 	const uint16_t *dinput1, const uint16_t *dinput2, const uint16_t *ref)
37 {
38 	size_t index;
39 	const size_t length = DIMS_IN;
40 	const float16_t *input1 = (const float16_t *)dinput1;
41 	const float16_t *input2 = (const float16_t *)dinput2;
42 	float16_t *output, *tmp1 = NULL, *tmp2 = NULL;
43 
44 	/* Allocate output buffer */
45 	output = malloc(length * sizeof(float16_t));
46 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
47 
48 	/* Allocate scratch buffers */
49 	if (scratchy) {
50 		tmp1 = malloc(DIMS_VEC * sizeof(float16_t));
51 		zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
52 
53 		tmp2 = malloc(DIMS_VEC * sizeof(float16_t));
54 		zassert_not_null(tmp2, ASSERT_MSG_BUFFER_ALLOC_FAILED);
55 	}
56 
57 	/* Enumerate input */
58 	for (index = 0; index < length; index++) {
59 		float16_t val;
60 
61 		/* Load input values into the scratch buffers */
62 		if (scratchy) {
63 			memcpy(tmp1, input1, DIMS_VEC * sizeof(float16_t));
64 			memcpy(tmp2, input2, DIMS_VEC * sizeof(float16_t));
65 		}
66 
67 		/* Run test function */
68 		switch (op) {
69 		case OP_BRAYCURTIS:
70 			val = arm_braycurtis_distance_f16(
71 					input1, input2, DIMS_VEC);
72 			break;
73 		case OP_CANBERRA:
74 			val = arm_canberra_distance_f16(
75 					input1, input2, DIMS_VEC);
76 			break;
77 		case OP_CHEBYSHEV:
78 			val = arm_chebyshev_distance_f16(
79 					input1, input2, DIMS_VEC);
80 			break;
81 		case OP_CITYBLOCK:
82 			val = arm_cityblock_distance_f16(
83 					input1, input2, DIMS_VEC);
84 			break;
85 		case OP_CORRELATION:
86 			val = arm_correlation_distance_f16(
87 					tmp1, tmp2, DIMS_VEC);
88 			break;
89 		case OP_COSINE:
90 			val = arm_cosine_distance_f16(
91 					input1, input2, DIMS_VEC);
92 			break;
93 		case OP_EUCLIDEAN:
94 			val = arm_euclidean_distance_f16(
95 					input1, input2, DIMS_VEC);
96 			break;
97 		case OP_JENSENSHANNON:
98 			val = arm_jensenshannon_distance_f16(
99 					input1, input2, DIMS_VEC);
100 			break;
101 		default:
102 			zassert_unreachable("invalid operation");
103 		}
104 
105 		/* Store output value */
106 		output[index] = val;
107 
108 		/* Increment pointers */
109 		input1 += DIMS_VEC;
110 		input2 += DIMS_VEC;
111 	}
112 
113 	/* Validate output */
114 	switch (op) {
115 	case OP_JENSENSHANNON:
116 		zassert_true(
117 			test_rel_error_f16(
118 				length, output, (float16_t *)ref,
119 				REL_JS_ERROR_THRESH),
120 			ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
121 		break;
122 	default:
123 		zassert_true(
124 			test_rel_error_f16(
125 				length, output, (float16_t *)ref,
126 				REL_ERROR_THRESH),
127 			ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
128 		break;
129 	}
130 
131 	/* Free buffers */
132 	free(output);
133 	free(tmp1);
134 	free(tmp2);
135 }
136 
137 DEFINE_TEST_VARIANT6(distance_f16,
138 	arm_distance_f16, braycurtis, OP_BRAYCURTIS, false, in_dims,
139 	in_com1, in_com2, ref_braycurtis);
140 
141 DEFINE_TEST_VARIANT6(distance_f16,
142 	arm_distance_f16, canberra, OP_CANBERRA, false, in_dims,
143 	in_com1, in_com2, ref_canberra);
144 
145 DEFINE_TEST_VARIANT6(distance_f16,
146 	arm_distance_f16, chebyshev, OP_CHEBYSHEV, false, in_dims,
147 	in_com1, in_com2, ref_chebyshev);
148 
149 DEFINE_TEST_VARIANT6(distance_f16,
150 	arm_distance_f16, cityblock, OP_CITYBLOCK, false, in_dims,
151 	in_com1, in_com2, ref_cityblock);
152 
153 DEFINE_TEST_VARIANT6(distance_f16,
154 	arm_distance_f16, correlation, OP_CORRELATION, true, in_dims,
155 	in_com1, in_com2, ref_correlation);
156 
157 DEFINE_TEST_VARIANT6(distance_f16,
158 	arm_distance_f16, cosine, OP_COSINE, false, in_dims,
159 	in_com1, in_com2, ref_cosine);
160 
161 DEFINE_TEST_VARIANT6(distance_f16,
162 	arm_distance_f16, euclidean, OP_EUCLIDEAN, false, in_dims,
163 	in_com1, in_com2, ref_euclidean);
164 
165 DEFINE_TEST_VARIANT6(distance_f16,
166 	arm_distance_f16, jensenshannon, OP_JENSENSHANNON, false, in_dims,
167 	in_jen1, in_jen2, ref_jensenshannon);
168 
ZTEST(distance_f16,test_arm_distance_f16_minkowski)169 ZTEST(distance_f16, test_arm_distance_f16_minkowski)
170 {
171 	size_t index;
172 	const size_t length = in_dims_minkowski[0];
173 	const size_t dims_vec = in_dims_minkowski[1];
174 	const uint16_t *dims = in_dims_minkowski + 2;
175 	const float16_t *input1 = (const float16_t *)in_com1;
176 	const float16_t *input2 = (const float16_t *)in_com2;
177 	float16_t *output;
178 
179 	/* Allocate output buffer */
180 	output = malloc(length * sizeof(float16_t));
181 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
182 
183 	/* Enumerate input */
184 	for (index = 0; index < length; index++) {
185 		/* Run test function */
186 		output[index] =
187 			arm_minkowski_distance_f16(
188 				input1, input2, dims[index], dims_vec);
189 
190 		/* Increment pointers */
191 		input1 += dims_vec;
192 		input2 += dims_vec;
193 	}
194 
195 	/* Validate output */
196 	zassert_true(
197 		test_rel_error_f16(length, output, (float16_t *)ref_minkowski,
198 			REL_MK_ERROR_THRESH),
199 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
200 
201 	/* Free buffers */
202 	free(output);
203 }
204