1 /*
2 * Copyright (c) 2020 Stephanos Ioannidis <root@stephanos.io>
3 * Copyright (C) 2010-2020 ARM Limited or its affiliates. All rights reserved.
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 */
7
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math.h>
12 #include "../../common/test_common.h"
13
14 #include "f32.pat"
15
16 #define ABS_ERROR_THRESH (1e-3)
17
18 #define DIMS_IN (dims[0])
19 #define DIMS_VEC (dims[1])
20
21 #define OP_BRAYCURTIS (0)
22 #define OP_CANBERRA (1)
23 #define OP_CHEBYSHEV (2)
24 #define OP_CITYBLOCK (3)
25 #define OP_CORRELATION (4)
26 #define OP_COSINE (5)
27 #define OP_EUCLIDEAN (6)
28 #define OP_JENSENSHANNON (7)
29 #define OP_MINKOWSKI (8)
30
31 ZTEST_SUITE(distance_f32, NULL, NULL, NULL, NULL, NULL);
32
test_arm_distance_f32(int op,bool scratchy,const uint16_t * dims,const uint32_t * dinput1,const uint32_t * dinput2,const uint32_t * ref)33 static void test_arm_distance_f32(int op, bool scratchy, const uint16_t *dims,
34 const uint32_t *dinput1, const uint32_t *dinput2, const uint32_t *ref)
35 {
36 size_t index;
37 const size_t length = DIMS_IN;
38 const float32_t *input1 = (const float32_t *)dinput1;
39 const float32_t *input2 = (const float32_t *)dinput2;
40 float32_t *output, *tmp1 = NULL, *tmp2 = NULL;
41
42 /* Allocate output buffer */
43 output = malloc(length * sizeof(float32_t));
44 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
45
46 /* Allocate scratch buffers */
47 if (scratchy) {
48 tmp1 = malloc(DIMS_VEC * sizeof(float32_t));
49 zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
50
51 tmp2 = malloc(DIMS_VEC * sizeof(float32_t));
52 zassert_not_null(tmp2, ASSERT_MSG_BUFFER_ALLOC_FAILED);
53 }
54
55 /* Enumerate input */
56 for (index = 0; index < length; index++) {
57 float32_t val;
58
59 /* Load input values into the scratch buffers */
60 if (scratchy) {
61 memcpy(tmp1, input1, DIMS_VEC * sizeof(float32_t));
62 memcpy(tmp2, input2, DIMS_VEC * sizeof(float32_t));
63 }
64
65 /* Run test function */
66 switch (op) {
67 case OP_BRAYCURTIS:
68 val = arm_braycurtis_distance_f32(
69 input1, input2, DIMS_VEC);
70 break;
71 case OP_CANBERRA:
72 val = arm_canberra_distance_f32(
73 input1, input2, DIMS_VEC);
74 break;
75 case OP_CHEBYSHEV:
76 val = arm_chebyshev_distance_f32(
77 input1, input2, DIMS_VEC);
78 break;
79 case OP_CITYBLOCK:
80 val = arm_cityblock_distance_f32(
81 input1, input2, DIMS_VEC);
82 break;
83 case OP_CORRELATION:
84 val = arm_correlation_distance_f32(
85 tmp1, tmp2, DIMS_VEC);
86 break;
87 case OP_COSINE:
88 val = arm_cosine_distance_f32(
89 input1, input2, DIMS_VEC);
90 break;
91 case OP_EUCLIDEAN:
92 val = arm_euclidean_distance_f32(
93 input1, input2, DIMS_VEC);
94 break;
95 case OP_JENSENSHANNON:
96 val = arm_jensenshannon_distance_f32(
97 input1, input2, DIMS_VEC);
98 break;
99 default:
100 zassert_unreachable("invalid operation");
101 }
102
103 /* Store output value */
104 output[index] = val;
105
106 /* Increment pointers */
107 input1 += DIMS_VEC;
108 input2 += DIMS_VEC;
109 }
110
111 /* Validate output */
112 zassert_true(
113 test_near_equal_f32(
114 length, output, (float32_t *)ref, ABS_ERROR_THRESH),
115 ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
116
117 /* Free buffers */
118 free(output);
119 free(tmp1);
120 free(tmp2);
121 }
122
123 DEFINE_TEST_VARIANT6(distance_f32,
124 arm_distance_f32, braycurtis, OP_BRAYCURTIS, false, in_dims,
125 in_com1, in_com2, ref_braycurtis);
126
127 DEFINE_TEST_VARIANT6(distance_f32,
128 arm_distance_f32, canberra, OP_CANBERRA, false, in_dims,
129 in_com1, in_com2, ref_canberra);
130
131 DEFINE_TEST_VARIANT6(distance_f32,
132 arm_distance_f32, chebyshev, OP_CHEBYSHEV, false, in_dims,
133 in_com1, in_com2, ref_chebyshev);
134
135 DEFINE_TEST_VARIANT6(distance_f32,
136 arm_distance_f32, cityblock, OP_CITYBLOCK, false, in_dims,
137 in_com1, in_com2, ref_cityblock);
138
139 DEFINE_TEST_VARIANT6(distance_f32,
140 arm_distance_f32, correlation, OP_CORRELATION, true, in_dims,
141 in_com1, in_com2, ref_correlation);
142
143 DEFINE_TEST_VARIANT6(distance_f32,
144 arm_distance_f32, cosine, OP_COSINE, false, in_dims,
145 in_com1, in_com2, ref_cosine);
146
147 DEFINE_TEST_VARIANT6(distance_f32,
148 arm_distance_f32, euclidean, OP_EUCLIDEAN, false, in_dims,
149 in_com1, in_com2, ref_euclidean);
150
151 DEFINE_TEST_VARIANT6(distance_f32,
152 arm_distance_f32, jensenshannon, OP_JENSENSHANNON, false, in_dims,
153 in_jen1, in_jen2, ref_jensenshannon);
154
ZTEST(distance_f32,test_arm_distance_f32_minkowski)155 ZTEST(distance_f32, test_arm_distance_f32_minkowski)
156 {
157 size_t index;
158 const size_t length = in_dims_minkowski[0];
159 const size_t dims_vec = in_dims_minkowski[1];
160 const uint16_t *dims = in_dims_minkowski + 2;
161 const float32_t *input1 = (const float32_t *)in_com1;
162 const float32_t *input2 = (const float32_t *)in_com2;
163 float32_t *output;
164
165 /* Allocate output buffer */
166 output = malloc(length * sizeof(float32_t));
167 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
168
169 /* Enumerate input */
170 for (index = 0; index < length; index++) {
171 /* Run test function */
172 output[index] =
173 arm_minkowski_distance_f32(
174 input1, input2, dims[index], dims_vec);
175
176 /* Increment pointers */
177 input1 += dims_vec;
178 input2 += dims_vec;
179 }
180
181 /* Validate output */
182 zassert_true(
183 test_near_equal_f32(length, output, (float32_t *)ref_minkowski,
184 ABS_ERROR_THRESH),
185 ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
186
187 /* Free buffers */
188 free(output);
189 }
190