1 /*
2 * Copyright (c) 2020 Stephanos Ioannidis <root@stephanos.io>
3 * Copyright (C) 2010-2020 ARM Limited or its affiliates. All rights reserved.
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 */
7
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <stdlib.h>
11 #include <arm_math.h>
12 #include "../../common/test_common.h"
13
14 #include "u32.pat"
15
16 #define REL_ERROR_THRESH (1e-8)
17
18 #define DIMS_IN (dims[0])
19 #define DIMS_VEC (dims[1])
20 #define DIMS_BIT_VEC (dims[2])
21
22 #define OP_DICE (0)
23 #define OP_HAMMING (1)
24 #define OP_JACCARD (2)
25 #define OP_KULSINSKI (3)
26 #define OP_ROGERSTANIMOTO (4)
27 #define OP_RUSSELLRAO (5)
28 #define OP_SOKALMICHENER (6)
29 #define OP_SOKALSNEATH (7)
30 #define OP_YULE (8)
31
32 ZTEST_SUITE(distance_u32, NULL, NULL, NULL, NULL, NULL);
33
test_arm_distance(int op,const uint16_t * dims,const uint32_t * input1,const uint32_t * input2,const uint32_t * ref)34 static void test_arm_distance(int op, const uint16_t *dims,
35 const uint32_t *input1, const uint32_t *input2, const uint32_t *ref)
36 {
37 size_t index;
38 const size_t length = DIMS_IN;
39 float32_t *output;
40
41 /* Allocate output buffer */
42 output = malloc(length * sizeof(float32_t));
43 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
44
45 /* Enumerate input */
46 for (index = 0; index < length; index++) {
47 float32_t val;
48
49 /* Run test function */
50 switch (op) {
51 case OP_DICE:
52 val = arm_dice_distance(input1, input2, DIMS_VEC);
53 break;
54 case OP_HAMMING:
55 val = arm_hamming_distance(input1, input2, DIMS_VEC);
56 break;
57 case OP_JACCARD:
58 val = arm_jaccard_distance(input1, input2, DIMS_VEC);
59 break;
60 case OP_KULSINSKI:
61 val = arm_kulsinski_distance(input1, input2, DIMS_VEC);
62 break;
63 case OP_ROGERSTANIMOTO:
64 val = arm_rogerstanimoto_distance(
65 input1, input2, DIMS_VEC);
66 break;
67 case OP_RUSSELLRAO:
68 val = arm_russellrao_distance(
69 input1, input2, DIMS_VEC);
70 break;
71 case OP_SOKALMICHENER:
72 val = arm_sokalmichener_distance(
73 input1, input2, DIMS_VEC);
74 break;
75 case OP_SOKALSNEATH:
76 val = arm_sokalsneath_distance(
77 input1, input2, DIMS_VEC);
78 break;
79 case OP_YULE:
80 val = arm_yule_distance(input1, input2, DIMS_VEC);
81 break;
82 default:
83 zassert_unreachable("invalid operation");
84 }
85
86 /* Store output value */
87 output[index] = val;
88
89 /* Increment pointers */
90 input1 += DIMS_BIT_VEC;
91 input2 += DIMS_BIT_VEC;
92 }
93
94 /* Validate output */
95 zassert_true(
96 test_rel_error_f32(length, output, (float32_t *)ref,
97 REL_ERROR_THRESH),
98 ASSERT_MSG_REL_ERROR_LIMIT_EXCEED);
99
100 /* Free output buffer */
101 free(output);
102 }
103
104 DEFINE_TEST_VARIANT5(distance_u32,
105 arm_distance, dice, OP_DICE, in_dims,
106 in_com1, in_com2, ref_dice);
107
108 DEFINE_TEST_VARIANT5(distance_u32,
109 arm_distance, hamming, OP_HAMMING, in_dims,
110 in_com1, in_com2, ref_hamming);
111
112 DEFINE_TEST_VARIANT5(distance_u32,
113 arm_distance, jaccard, OP_JACCARD, in_dims,
114 in_com1, in_com2, ref_jaccard);
115
116 DEFINE_TEST_VARIANT5(distance_u32,
117 arm_distance, kulsinski, OP_KULSINSKI, in_dims,
118 in_com1, in_com2, ref_kulsinski);
119
120 DEFINE_TEST_VARIANT5(distance_u32,
121 arm_distance, rogerstanimoto, OP_ROGERSTANIMOTO, in_dims,
122 in_com1, in_com2, ref_rogerstanimoto);
123
124 DEFINE_TEST_VARIANT5(distance_u32,
125 arm_distance, russellrao, OP_RUSSELLRAO, in_dims,
126 in_com1, in_com2, ref_russellrao);
127
128 DEFINE_TEST_VARIANT5(distance_u32,
129 arm_distance, sokalmichener, OP_SOKALMICHENER, in_dims,
130 in_com1, in_com2, ref_sokalmichener);
131
132 DEFINE_TEST_VARIANT5(distance_u32,
133 arm_distance, sokalsneath, OP_SOKALSNEATH, in_dims,
134 in_com1, in_com2, ref_sokalsneath);
135
136 DEFINE_TEST_VARIANT5(distance_u32,
137 arm_distance, yule, OP_YULE, in_dims,
138 in_com1, in_com2, ref_yule);
139