1 /*
2 * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 */
7
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math_f16.h>
12 #include "../../common/test_common.h"
13
14 #include "f16.pat"
15
16 #define REL_ERROR_THRESH (5e-3)
17 #define REL_JS_ERROR_THRESH (3e-2)
18 #define REL_MK_ERROR_THRESH (1e-2)
19
20 #define DIMS_IN (dims[0])
21 #define DIMS_VEC (dims[1])
22
23 #define OP_BRAYCURTIS (0)
24 #define OP_CANBERRA (1)
25 #define OP_CHEBYSHEV (2)
26 #define OP_CITYBLOCK (3)
27 #define OP_CORRELATION (4)
28 #define OP_COSINE (5)
29 #define OP_EUCLIDEAN (6)
30 #define OP_JENSENSHANNON (7)
31 #define OP_MINKOWSKI (8)
32
33 ZTEST_SUITE(distance_f16, NULL, NULL, NULL, NULL, NULL);
34
test_arm_distance_f16(int op,bool scratchy,const uint16_t * dims,const uint16_t * dinput1,const uint16_t * dinput2,const uint16_t * ref)35 static void test_arm_distance_f16(int op, bool scratchy, const uint16_t *dims,
36 const uint16_t *dinput1, const uint16_t *dinput2, const uint16_t *ref)
37 {
38 size_t index;
39 const size_t length = DIMS_IN;
40 const float16_t *input1 = (const float16_t *)dinput1;
41 const float16_t *input2 = (const float16_t *)dinput2;
42 float16_t *output, *tmp1 = NULL, *tmp2 = NULL;
43
44 /* Allocate output buffer */
45 output = malloc(length * sizeof(float16_t));
46 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
47
48 /* Allocate scratch buffers */
49 if (scratchy) {
50 tmp1 = malloc(DIMS_VEC * sizeof(float16_t));
51 zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
52
53 tmp2 = malloc(DIMS_VEC * sizeof(float16_t));
54 zassert_not_null(tmp2, ASSERT_MSG_BUFFER_ALLOC_FAILED);
55 }
56
57 /* Enumerate input */
58 for (index = 0; index < length; index++) {
59 float16_t val;
60
61 /* Load input values into the scratch buffers */
62 if (scratchy) {
63 memcpy(tmp1, input1, DIMS_VEC * sizeof(float16_t));
64 memcpy(tmp2, input2, DIMS_VEC * sizeof(float16_t));
65 }
66
67 /* Run test function */
68 switch (op) {
69 case OP_BRAYCURTIS:
70 val = arm_braycurtis_distance_f16(
71 input1, input2, DIMS_VEC);
72 break;
73 case OP_CANBERRA:
74 val = arm_canberra_distance_f16(
75 input1, input2, DIMS_VEC);
76 break;
77 case OP_CHEBYSHEV:
78 val = arm_chebyshev_distance_f16(
79 input1, input2, DIMS_VEC);
80 break;
81 case OP_CITYBLOCK:
82 val = arm_cityblock_distance_f16(
83 input1, input2, DIMS_VEC);
84 break;
85 case OP_CORRELATION:
86 val = arm_correlation_distance_f16(
87 tmp1, tmp2, DIMS_VEC);
88 break;
89 case OP_COSINE:
90 val = arm_cosine_distance_f16(
91 input1, input2, DIMS_VEC);
92 break;
93 case OP_EUCLIDEAN:
94 val = arm_euclidean_distance_f16(
95 input1, input2, DIMS_VEC);
96 break;
97 case OP_JENSENSHANNON:
98 val = arm_jensenshannon_distance_f16(
99 input1, input2, DIMS_VEC);
100 break;
101 default:
102 zassert_unreachable("invalid operation");
103 }
104
105 /* Store output value */
106 output[index] = val;
107
108 /* Increment pointers */
109 input1 += DIMS_VEC;
110 input2 += DIMS_VEC;
111 }
112
113 /* Validate output */
114 switch (op) {
115 case OP_JENSENSHANNON:
116 zassert_true(
117 test_rel_error_f16(
118 length, output, (float16_t *)ref,
119 REL_JS_ERROR_THRESH),
120 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
121 break;
122 default:
123 zassert_true(
124 test_rel_error_f16(
125 length, output, (float16_t *)ref,
126 REL_ERROR_THRESH),
127 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
128 break;
129 }
130
131 /* Free buffers */
132 free(output);
133 free(tmp1);
134 free(tmp2);
135 }
136
137 DEFINE_TEST_VARIANT6(distance_f16,
138 arm_distance_f16, braycurtis, OP_BRAYCURTIS, false, in_dims,
139 in_com1, in_com2, ref_braycurtis);
140
141 DEFINE_TEST_VARIANT6(distance_f16,
142 arm_distance_f16, canberra, OP_CANBERRA, false, in_dims,
143 in_com1, in_com2, ref_canberra);
144
145 DEFINE_TEST_VARIANT6(distance_f16,
146 arm_distance_f16, chebyshev, OP_CHEBYSHEV, false, in_dims,
147 in_com1, in_com2, ref_chebyshev);
148
149 DEFINE_TEST_VARIANT6(distance_f16,
150 arm_distance_f16, cityblock, OP_CITYBLOCK, false, in_dims,
151 in_com1, in_com2, ref_cityblock);
152
153 DEFINE_TEST_VARIANT6(distance_f16,
154 arm_distance_f16, correlation, OP_CORRELATION, true, in_dims,
155 in_com1, in_com2, ref_correlation);
156
157 DEFINE_TEST_VARIANT6(distance_f16,
158 arm_distance_f16, cosine, OP_COSINE, false, in_dims,
159 in_com1, in_com2, ref_cosine);
160
161 DEFINE_TEST_VARIANT6(distance_f16,
162 arm_distance_f16, euclidean, OP_EUCLIDEAN, false, in_dims,
163 in_com1, in_com2, ref_euclidean);
164
165 DEFINE_TEST_VARIANT6(distance_f16,
166 arm_distance_f16, jensenshannon, OP_JENSENSHANNON, false, in_dims,
167 in_jen1, in_jen2, ref_jensenshannon);
168
ZTEST(distance_f16,test_arm_distance_f16_minkowski)169 ZTEST(distance_f16, test_arm_distance_f16_minkowski)
170 {
171 size_t index;
172 const size_t length = in_dims_minkowski[0];
173 const size_t dims_vec = in_dims_minkowski[1];
174 const uint16_t *dims = in_dims_minkowski + 2;
175 const float16_t *input1 = (const float16_t *)in_com1;
176 const float16_t *input2 = (const float16_t *)in_com2;
177 float16_t *output;
178
179 /* Allocate output buffer */
180 output = malloc(length * sizeof(float16_t));
181 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
182
183 /* Enumerate input */
184 for (index = 0; index < length; index++) {
185 /* Run test function */
186 output[index] =
187 arm_minkowski_distance_f16(
188 input1, input2, dims[index], dims_vec);
189
190 /* Increment pointers */
191 input1 += dims_vec;
192 input2 += dims_vec;
193 }
194
195 /* Validate output */
196 zassert_true(
197 test_rel_error_f16(length, output, (float16_t *)ref_minkowski,
198 REL_MK_ERROR_THRESH),
199 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
200
201 /* Free buffers */
202 free(output);
203 }
204