1 /*
2  * Copyright (c) 2020 Stephanos Ioannidis <root@stephanos.io>
3  * Copyright (C) 2010-2020 ARM Limited or its affiliates. All rights reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math.h>
12 #include "../../common/test_common.h"
13 
14 #include "f32.pat"
15 
16 #define ABS_ERROR_THRESH	(1e-3)
17 
18 #define DIMS_IN			(dims[0])
19 #define DIMS_VEC		(dims[1])
20 
21 #define OP_BRAYCURTIS		(0)
22 #define OP_CANBERRA		(1)
23 #define OP_CHEBYSHEV		(2)
24 #define OP_CITYBLOCK		(3)
25 #define OP_CORRELATION		(4)
26 #define OP_COSINE		(5)
27 #define OP_EUCLIDEAN		(6)
28 #define OP_JENSENSHANNON	(7)
29 #define OP_MINKOWSKI		(8)
30 
31 ZTEST_SUITE(distance_f32, NULL, NULL, NULL, NULL, NULL);
32 
test_arm_distance_f32(int op,bool scratchy,const uint16_t * dims,const uint32_t * dinput1,const uint32_t * dinput2,const uint32_t * ref)33 static void test_arm_distance_f32(int op, bool scratchy, const uint16_t *dims,
34 	const uint32_t *dinput1, const uint32_t *dinput2, const uint32_t *ref)
35 {
36 	size_t index;
37 	const size_t length = DIMS_IN;
38 	const float32_t *input1 = (const float32_t *)dinput1;
39 	const float32_t *input2 = (const float32_t *)dinput2;
40 	float32_t *output, *tmp1 = NULL, *tmp2 = NULL;
41 
42 	/* Allocate output buffer */
43 	output = malloc(length * sizeof(float32_t));
44 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
45 
46 	/* Allocate scratch buffers */
47 	if (scratchy) {
48 		tmp1 = malloc(DIMS_VEC * sizeof(float32_t));
49 		zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
50 
51 		tmp2 = malloc(DIMS_VEC * sizeof(float32_t));
52 		zassert_not_null(tmp2, ASSERT_MSG_BUFFER_ALLOC_FAILED);
53 	}
54 
55 	/* Enumerate input */
56 	for (index = 0; index < length; index++) {
57 		float32_t val;
58 
59 		/* Load input values into the scratch buffers */
60 		if (scratchy) {
61 			memcpy(tmp1, input1, DIMS_VEC * sizeof(float32_t));
62 			memcpy(tmp2, input2, DIMS_VEC * sizeof(float32_t));
63 		}
64 
65 		/* Run test function */
66 		switch (op) {
67 		case OP_BRAYCURTIS:
68 			val = arm_braycurtis_distance_f32(
69 					input1, input2, DIMS_VEC);
70 			break;
71 		case OP_CANBERRA:
72 			val = arm_canberra_distance_f32(
73 					input1, input2, DIMS_VEC);
74 			break;
75 		case OP_CHEBYSHEV:
76 			val = arm_chebyshev_distance_f32(
77 					input1, input2, DIMS_VEC);
78 			break;
79 		case OP_CITYBLOCK:
80 			val = arm_cityblock_distance_f32(
81 					input1, input2, DIMS_VEC);
82 			break;
83 		case OP_CORRELATION:
84 			val = arm_correlation_distance_f32(
85 					tmp1, tmp2, DIMS_VEC);
86 			break;
87 		case OP_COSINE:
88 			val = arm_cosine_distance_f32(
89 					input1, input2, DIMS_VEC);
90 			break;
91 		case OP_EUCLIDEAN:
92 			val = arm_euclidean_distance_f32(
93 					input1, input2, DIMS_VEC);
94 			break;
95 		case OP_JENSENSHANNON:
96 			val = arm_jensenshannon_distance_f32(
97 					input1, input2, DIMS_VEC);
98 			break;
99 		default:
100 			zassert_unreachable("invalid operation");
101 		}
102 
103 		/* Store output value */
104 		output[index] = val;
105 
106 		/* Increment pointers */
107 		input1 += DIMS_VEC;
108 		input2 += DIMS_VEC;
109 	}
110 
111 	/* Validate output */
112 	zassert_true(
113 		test_near_equal_f32(
114 			length, output, (float32_t *)ref, ABS_ERROR_THRESH),
115 		ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
116 
117 	/* Free buffers */
118 	free(output);
119 	free(tmp1);
120 	free(tmp2);
121 }
122 
123 DEFINE_TEST_VARIANT6(distance_f32,
124 	arm_distance_f32, braycurtis, OP_BRAYCURTIS, false, in_dims,
125 	in_com1, in_com2, ref_braycurtis);
126 
127 DEFINE_TEST_VARIANT6(distance_f32,
128 	arm_distance_f32, canberra, OP_CANBERRA, false, in_dims,
129 	in_com1, in_com2, ref_canberra);
130 
131 DEFINE_TEST_VARIANT6(distance_f32,
132 	arm_distance_f32, chebyshev, OP_CHEBYSHEV, false, in_dims,
133 	in_com1, in_com2, ref_chebyshev);
134 
135 DEFINE_TEST_VARIANT6(distance_f32,
136 	arm_distance_f32, cityblock, OP_CITYBLOCK, false, in_dims,
137 	in_com1, in_com2, ref_cityblock);
138 
139 DEFINE_TEST_VARIANT6(distance_f32,
140 	arm_distance_f32, correlation, OP_CORRELATION, true, in_dims,
141 	in_com1, in_com2, ref_correlation);
142 
143 DEFINE_TEST_VARIANT6(distance_f32,
144 	arm_distance_f32, cosine, OP_COSINE, false, in_dims,
145 	in_com1, in_com2, ref_cosine);
146 
147 DEFINE_TEST_VARIANT6(distance_f32,
148 	arm_distance_f32, euclidean, OP_EUCLIDEAN, false, in_dims,
149 	in_com1, in_com2, ref_euclidean);
150 
151 DEFINE_TEST_VARIANT6(distance_f32,
152 	arm_distance_f32, jensenshannon, OP_JENSENSHANNON, false, in_dims,
153 	in_jen1, in_jen2, ref_jensenshannon);
154 
ZTEST(distance_f32,test_arm_distance_f32_minkowski)155 ZTEST(distance_f32, test_arm_distance_f32_minkowski)
156 {
157 	size_t index;
158 	const size_t length = in_dims_minkowski[0];
159 	const size_t dims_vec = in_dims_minkowski[1];
160 	const uint16_t *dims = in_dims_minkowski + 2;
161 	const float32_t *input1 = (const float32_t *)in_com1;
162 	const float32_t *input2 = (const float32_t *)in_com2;
163 	float32_t *output;
164 
165 	/* Allocate output buffer */
166 	output = malloc(length * sizeof(float32_t));
167 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
168 
169 	/* Enumerate input */
170 	for (index = 0; index < length; index++) {
171 		/* Run test function */
172 		output[index] =
173 			arm_minkowski_distance_f32(
174 				input1, input2, dims[index], dims_vec);
175 
176 		/* Increment pointers */
177 		input1 += dims_vec;
178 		input2 += dims_vec;
179 	}
180 
181 	/* Validate output */
182 	zassert_true(
183 		test_near_equal_f32(length, output, (float32_t *)ref_minkowski,
184 			ABS_ERROR_THRESH),
185 		ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
186 
187 	/* Free buffers */
188 	free(output);
189 }
190