1 /*
2  * Copyright (c) 2020 Stephanos Ioannidis <root@stephanos.io>
3  * Copyright (C) 2010-2020 ARM Limited or its affiliates. All rights reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math.h>
12 #include "../../common/test_common.h"
13 
14 #include "u32.pat"
15 
16 #define REL_ERROR_THRESH	(1e-8)
17 
18 #define DIMS_IN			(dims[0])
19 #define DIMS_VEC		(dims[1])
20 #define DIMS_BIT_VEC		(dims[2])
21 
22 #define OP_DICE			(0)
23 #define OP_HAMMING		(1)
24 #define OP_JACCARD		(2)
25 #define OP_KULSINSKI		(3)
26 #define OP_ROGERSTANIMOTO	(4)
27 #define OP_RUSSELLRAO		(5)
28 #define OP_SOKALMICHENER	(6)
29 #define OP_SOKALSNEATH		(7)
30 #define OP_YULE			(8)
31 
32 ZTEST_SUITE(distance_u32, NULL, NULL, NULL, NULL, NULL);
33 
test_arm_distance(int op,const uint16_t * dims,const uint32_t * input1,const uint32_t * input2,const uint32_t * ref)34 static void test_arm_distance(int op, const uint16_t *dims,
35 	const uint32_t *input1, const uint32_t *input2, const uint32_t *ref)
36 {
37 	size_t index;
38 	const size_t length = DIMS_IN;
39 	float32_t *output;
40 
41 	/* Allocate output buffer */
42 	output = malloc(length * sizeof(float32_t));
43 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
44 
45 	/* Enumerate input */
46 	for (index = 0; index < length; index++) {
47 		float32_t val;
48 
49 		/* Run test function */
50 		switch (op) {
51 		case OP_DICE:
52 			val = arm_dice_distance(input1, input2, DIMS_VEC);
53 			break;
54 		case OP_HAMMING:
55 			val = arm_hamming_distance(input1, input2, DIMS_VEC);
56 			break;
57 		case OP_JACCARD:
58 			val = arm_jaccard_distance(input1, input2, DIMS_VEC);
59 			break;
60 		case OP_KULSINSKI:
61 			val = arm_kulsinski_distance(input1, input2, DIMS_VEC);
62 			break;
63 		case OP_ROGERSTANIMOTO:
64 			val = arm_rogerstanimoto_distance(
65 					input1, input2, DIMS_VEC);
66 			break;
67 		case OP_RUSSELLRAO:
68 			val = arm_russellrao_distance(
69 					input1, input2, DIMS_VEC);
70 			break;
71 		case OP_SOKALMICHENER:
72 			val = arm_sokalmichener_distance(
73 					input1, input2, DIMS_VEC);
74 			break;
75 		case OP_SOKALSNEATH:
76 			val = arm_sokalsneath_distance(
77 					input1, input2, DIMS_VEC);
78 			break;
79 		case OP_YULE:
80 			val = arm_yule_distance(input1, input2, DIMS_VEC);
81 			break;
82 		default:
83 			zassert_unreachable("invalid operation");
84 		}
85 
86 		/* Store output value */
87 		output[index] = val;
88 
89 		/* Increment pointers */
90 		input1 += DIMS_BIT_VEC;
91 		input2 += DIMS_BIT_VEC;
92 	}
93 
94 	/* Validate output */
95 	zassert_true(
96 		test_rel_error_f32(length, output, (float32_t *)ref,
97 			REL_ERROR_THRESH),
98 		ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
99 
100 	/* Free output buffer */
101 	free(output);
102 }
103 
104 DEFINE_TEST_VARIANT5(distance_u32,
105 	arm_distance, dice, OP_DICE, in_dims,
106 	in_com1, in_com2, ref_dice);
107 
108 DEFINE_TEST_VARIANT5(distance_u32,
109 	arm_distance, hamming, OP_HAMMING, in_dims,
110 	in_com1, in_com2, ref_hamming);
111 
112 DEFINE_TEST_VARIANT5(distance_u32,
113 	arm_distance, jaccard, OP_JACCARD, in_dims,
114 	in_com1, in_com2, ref_jaccard);
115 
116 DEFINE_TEST_VARIANT5(distance_u32,
117 	arm_distance, kulsinski, OP_KULSINSKI, in_dims,
118 	in_com1, in_com2, ref_kulsinski);
119 
120 DEFINE_TEST_VARIANT5(distance_u32,
121 	arm_distance, rogerstanimoto, OP_ROGERSTANIMOTO, in_dims,
122 	in_com1, in_com2, ref_rogerstanimoto);
123 
124 DEFINE_TEST_VARIANT5(distance_u32,
125 	arm_distance, russellrao, OP_RUSSELLRAO, in_dims,
126 	in_com1, in_com2, ref_russellrao);
127 
128 DEFINE_TEST_VARIANT5(distance_u32,
129 	arm_distance, sokalmichener, OP_SOKALMICHENER, in_dims,
130 	in_com1, in_com2, ref_sokalmichener);
131 
132 DEFINE_TEST_VARIANT5(distance_u32,
133 	arm_distance, sokalsneath, OP_SOKALSNEATH, in_dims,
134 	in_com1, in_com2, ref_sokalsneath);
135 
136 DEFINE_TEST_VARIANT5(distance_u32,
137 	arm_distance, yule, OP_YULE, in_dims,
138 	in_com1, in_com2, ref_yule);
139