1 /*
2  * Copyright (c) 2021 Stephanos Ioannidis <root@stephanos.io>
3  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <ztest.h>
9 #include <zephyr.h>
10 #include <stdlib.h>
11 #include <arm_math.h>
12 #include "../../common/test_common.h"
13 
14 #include "binary_q7.pat"
15 
16 #define SNR_ERROR_THRESH		((float32_t)20)
17 #define ABS_ERROR_THRESH_Q7		((q7_t)5)
18 
19 #define NUM_MATRICES			(ARRAY_SIZE(in_dims) / 3)
20 #define MAX_MATRIX_DIM			(47)
21 
22 #define OP2_MULT			(0)
23 
test_op2(int op,const q7_t * input1,const q7_t * input2,const q7_t * ref,size_t length)24 static void test_op2(int op, const q7_t *input1, const q7_t *input2,
25 	const q7_t *ref, size_t length)
26 {
27 	size_t index;
28 	uint16_t *dims = (uint16_t *)in_dims;
29 	q7_t *tmp1, *tmp2, *scratch, *output;
30 	uint16_t rows, internal, columns;
31 	arm_status status;
32 
33 	arm_matrix_instance_q7 mat_in1;
34 	arm_matrix_instance_q7 mat_in2;
35 	arm_matrix_instance_q7 mat_out;
36 
37 	/* Allocate buffers */
38 	output = malloc(length * sizeof(q7_t));
39 	zassert_not_null(output, ASSERT_MSG_BUFFER_ALLOC_FAILED);
40 
41 	tmp1 = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q7_t));
42 	zassert_not_null(tmp1, ASSERT_MSG_BUFFER_ALLOC_FAILED);
43 
44 	tmp2 = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q7_t));
45 	zassert_not_null(tmp2, ASSERT_MSG_BUFFER_ALLOC_FAILED);
46 
47 	scratch = malloc(MAX_MATRIX_DIM * MAX_MATRIX_DIM * sizeof(q7_t));
48 	zassert_not_null(scratch, ASSERT_MSG_BUFFER_ALLOC_FAILED);
49 
50 	/* Initialise contexts */
51 	mat_in1.pData = tmp1;
52 	mat_in2.pData = tmp2;
53 	mat_out.pData = output;
54 
55 	/* Iterate matrices */
56 	for (index = 0; index < NUM_MATRICES; index++) {
57 		rows = *dims++;
58 		internal = *dims++;
59 		columns = *dims++;
60 
61 		/* Initialise matrix dimensions */
62 		mat_in1.numRows = rows;
63 		mat_in1.numCols = internal;
64 
65 		mat_in2.numRows = internal;
66 		mat_in2.numCols = columns;
67 
68 		mat_out.numRows = rows;
69 		mat_out.numCols = columns;
70 
71 		/* Load matrix data */
72 		memcpy(mat_in1.pData, input1,
73 		       rows * internal * sizeof(q7_t));
74 
75 		memcpy(mat_in2.pData, input2,
76 		       internal * columns * sizeof(q7_t));
77 
78 		/* Run test function */
79 		switch (op) {
80 		case OP2_MULT:
81 			status = arm_mat_mult_q7(
82 					&mat_in1, &mat_in2, &mat_out,
83 					scratch);
84 			break;
85 		default:
86 			zassert_unreachable("invalid operation");
87 		}
88 
89 		/* Validate status */
90 		zassert_equal(status, ARM_MATH_SUCCESS,
91 			      ASSERT_MSG_INCORRECT_COMP_RESULT);
92 
93 		/* Increment output pointer */
94 		mat_out.pData += (rows * columns);
95 	}
96 
97 	/* Validate output */
98 	zassert_true(
99 		test_snr_error_q7(length, output, ref, SNR_ERROR_THRESH),
100 		ASSERT_MSG_SNR_LIMIT_EXCEED);
101 
102 	zassert_true(
103 		test_near_equal_q7(length, output, ref,
104 			ABS_ERROR_THRESH_Q7),
105 		ASSERT_MSG_ABS_ERROR_LIMIT_EXCEED);
106 
107 	/* Free buffers */
108 	free(tmp1);
109 	free(tmp2);
110 	free(scratch);
111 	free(output);
112 }
113 
114 DEFINE_TEST_VARIANT5(
115 	op2, arm_mat_mult_q7, OP2_MULT,
116 	in_mult1, in_mult2, ref_mult,
117 	ARRAY_SIZE(ref_mult));
118 
test_matrix_binary_q7(void)119 void test_matrix_binary_q7(void)
120 {
121 	ztest_test_suite(matrix_binary_q7,
122 		ztest_unit_test(test_op2_arm_mat_mult_q7)
123 		);
124 
125 	ztest_run_test_suite(matrix_binary_q7);
126 }
127