1 /*
2 * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 */
7
8 #include <ztest.h>
9 #include <zephyr.h>
10 #include <stdlib.h>
11 #include <arm_math.h>
12 #include "../../common/test_common.h"
13
14 #include "binary_q7.pat"
15
16 #define SNR_ERROR_THRESH ((float32_t)20)
17 #define ABS_ERROR_THRESH_Q7 ((q7_t)5)
18
19 #define NUM_MATRICES (ARRAY_SIZE(in_dims) / 3)
20 #define MAX_MATRIX_DIM (47)
21
22 #define OP2_MULT (0)
23
test_op2(int op,const q7_t * input1,const q7_t * input2,const q7_t * ref,size_t length)24 static void test_op2(int op, const q7_t *input1, const q7_t *input2,
25 const q7_t *ref, size_t length)
26 {
27 size_t index;
28 uint16_t *dims = (uint16_t *)in_dims;
29 q7_t *tmp1, *tmp2, *scratch, *output;
30 uint16_t rows, internal, columns;
31 arm_status status;
32
33 arm_matrix_instance_q7 mat_in1;
34 arm_matrix_instance_q7 mat_in2;
35 arm_matrix_instance_q7 mat_out;
36
37 /* Allocate buffers */
38 output = malloc(length * sizeof(q7_t));
39 zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
40
41 tmp1 = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q7_t));
42 zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
43
44 tmp2 = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q7_t));
45 zassert_not_null(tmp2, ASSERT_MSG_BUFFER_ALLOC_FAILED);
46
47 scratch = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q7_t));
48 zassert_not_null(scratch, ASSERT_MSG_BUFFER_ALLOC_FAILED);
49
50 /* Initialise contexts */
51 mat_in1.pData = tmp1;
52 mat_in2.pData = tmp2;
53 mat_out.pData = output;
54
55 /* Iterate matrices */
56 for (index = 0; index < NUM_MATRICES; index++) {
57 rows = *dims++;
58 internal = *dims++;
59 columns = *dims++;
60
61 /* Initialise matrix dimensions */
62 mat_in1.numRows = rows;
63 mat_in1.numCols = internal;
64
65 mat_in2.numRows = internal;
66 mat_in2.numCols = columns;
67
68 mat_out.numRows = rows;
69 mat_out.numCols = columns;
70
71 /* Load matrix data */
72 memcpy(mat_in1.pData, input1,
73 rows * internal * sizeof(q7_t));
74
75 memcpy(mat_in2.pData, input2,
76 internal * columns * sizeof(q7_t));
77
78 /* Run test function */
79 switch (op) {
80 case OP2_MULT:
81 status = arm_mat_mult_q7(
82 &mat_in1, &mat_in2, &mat_out,
83 scratch);
84 break;
85 default:
86 zassert_unreachable("invalid operation");
87 }
88
89 /* Validate status */
90 zassert_equal(status, ARM_MATH_SUCCESS,
91 ASSERT_MSG_INCORRECT_COMP_RESULT);
92
93 /* Increment output pointer */
94 mat_out.pData += (rows * columns);
95 }
96
97 /* Validate output */
98 zassert_true(
99 test_snr_error_q7(length, output, ref, SNR_ERROR_THRESH),
100 ASSERT_MSG_SNR_LIMIT_EXCEED);
101
102 zassert_true(
103 test_near_equal_q7(length, output, ref,
104 ABS_ERROR_THRESH_Q7),
105 ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
106
107 /* Free buffers */
108 free(tmp1);
109 free(tmp2);
110 free(scratch);
111 free(output);
112 }
113
114 DEFINE_TEST_VARIANT5(
115 op2, arm_mat_mult_q7, OP2_MULT,
116 in_mult1, in_mult2, ref_mult,
117 ARRAY_SIZE(ref_mult));
118
test_matrix_binary_q7(void)119 void test_matrix_binary_q7(void)
120 {
121 ztest_test_suite(matrix_binary_q7,
122 ztest_unit_test(test_op2_arm_mat_mult_q7)
123 );
124
125 ztest_run_test_suite(matrix_binary_q7);
126 }
127