1 extern "C" {
2 extern void row_test();
3 }
4
5 #include "allocator.h"
6
7 #include <dsppp/arch.hpp>
8 #include <dsppp/fixed_point.hpp>
9 #include <dsppp/matrix.hpp>
10
11 #include <iostream>
12
13 #include <cmsis_tests.h>
14
15
16
17 #include "dsp/matrix_functions.h"
18 #include "matrix_utils.h"
19
20 template<typename T,int R,int C>
test()21 static void test()
22 {
23 constexpr int NBOUT = C-1;
24 std::cout << "----\r\n";
25 std::cout << R << " x " << C << "\r\n";
26 std::cout << "NBOUT = " << NBOUT << "\r\n";
27
28 #if defined(STATIC_TEST)
29 PMat<T,R,C> a;
30 #else
31 PMat<T> a(R,C);
32 #endif
33
34 init_array(a,R*C);
35
36
37 INIT_SYSTICK
38 START_CYCLE_MEASUREMENT;
39 startSectionNB(1);
40 #if defined(STATIC_TEST)
41 PVector<T,NBOUT> res = a.row(0,1) + a.row(1,1);
42 #else
43 PVector<T> res = a.row(0,1) + a.row(1,1);
44 #endif
45 stopSectionNB(1);
46 STOP_CYCLE_MEASUREMENT;
47
48 INIT_SYSTICK;
49 START_CYCLE_MEASUREMENT;
50 PVector<T,NBOUT> da (a.row(0,1));
51 PVector<T,NBOUT> db (a.row(1,1));
52 PVector<T,NBOUT> ref;
53
54 cmsisdsp_add(da.const_ptr(),db.const_ptr(),ref.ptr(),NBOUT);
55 STOP_CYCLE_MEASUREMENT;
56
57 if (!validate(res.const_ptr(),ref.const_ptr(),NBOUT))
58 {
59 printf("row add failed \r\n");
60 }
61
62 std::cout << "=====\r\n";
63 }
64
65 template<typename T,int R,int C>
swaptest()66 static void swaptest()
67 {
68 constexpr int NBOUT = C-2;
69 std::cout << "----\r\n";
70 std::cout << R << " x " << C << "\r\n";
71 std::cout << "NBOUT = " << NBOUT << "\r\n";
72
73 #if defined(STATIC_TEST)
74 PMat<T,R,C> a;
75 PMat<T,R,C> b;
76 #else
77 PMat<T> a(R,C);
78 PMat<T> b(R,C);
79 #endif
80
81 init_array(a,R*C);
82 init_array(b,R*C);
83
84
85 INIT_SYSTICK;
86 START_CYCLE_MEASUREMENT;
87 startSectionNB(1);
88 swap(a.row(0,2) , a.row(1,2));
89 stopSectionNB(1);
90 STOP_CYCLE_MEASUREMENT;
91
92 typename CMSISMatrixType<T>::type mat;
93 mat.numCols = C;
94 mat.numRows = R;
95 mat.pData = b.ptr();
96
97 INIT_SYSTICK;
98 START_CYCLE_MEASUREMENT;
99 SWAP_ROWS_F32(&mat,2,0,1);
100 STOP_CYCLE_MEASUREMENT;
101
102 if (!validate(a.const_ptr(),(const float32_t*)mat.pData,R*C))
103 {
104 printf("row add failed \r\n");
105 }
106
107 std::cout << "=====\r\n";
108
109
110
111 }
112
113
114
115
116 template<typename T>
all_row_test()117 void all_row_test()
118 {
119 const int nb_tails = TailForTests<T>::tail;
120 const int nb_loops = TailForTests<T>::loop;
121
122
123 title<T>("Row test");
124
125 test<T,2,NBVEC_4>();
126 test<T,4,NBVEC_4>();
127 test<T,5,NBVEC_4>();
128 test<T,9,NBVEC_4>();
129
130 test<T,2,NBVEC_8>();
131 test<T,4,NBVEC_8>();
132 test<T,5,NBVEC_8>();
133 test<T,9,NBVEC_8>();
134
135 test<T,2,NBVEC_16>();
136 test<T,4,NBVEC_16>();
137 test<T,5,NBVEC_16>();
138 test<T,9,NBVEC_16>();
139
140 test<T,2,nb_loops>();
141 test<T,4,nb_loops>();
142 test<T,5,nb_loops>();
143 test<T,9,nb_loops>();
144
145 test<T,2,nb_loops+1>();
146 test<T,4,nb_loops+1>();
147 test<T,5,nb_loops+1>();
148 test<T,9,nb_loops+1>();
149
150 test<T,2,nb_loops+nb_tails>();
151 test<T,4,nb_loops+nb_tails>();
152 test<T,5,nb_loops+nb_tails>();
153 test<T,9,nb_loops+nb_tails>();
154
155 if constexpr (std::is_same<T,float>::value)
156 {
157 title<T>("Swap test");
158
159 swaptest<T,2,NBVEC_32>();
160 swaptest<T,4,NBVEC_32>();
161 swaptest<T,5,NBVEC_32>();
162 swaptest<T,9,NBVEC_32>();
163
164 swaptest<T,2,nb_loops>();
165 swaptest<T,4,nb_loops>();
166 swaptest<T,5,nb_loops>();
167 swaptest<T,9,nb_loops>();
168
169 swaptest<T,2,nb_loops+1>();
170 swaptest<T,4,nb_loops+1>();
171 swaptest<T,5,nb_loops+1>();
172 swaptest<T,9,nb_loops+1>();
173
174 swaptest<T,2,nb_loops+nb_tails>();
175 swaptest<T,4,nb_loops+nb_tails>();
176 swaptest<T,5,nb_loops+nb_tails>();
177 swaptest<T,9,nb_loops+nb_tails>();
178 }
179 //print_map("Stats",max_stats);
180 }
181
row_test()182 void row_test()
183 {
184 #if defined(ROW_TEST)
185 #if defined(F64_DT)
186 all_row_test<double>();
187 #endif
188 #if defined(F32_DT)
189 all_row_test<float>();
190 #endif
191 #if defined(F16_DT) && !defined(DISABLEFLOAT16)
192 all_row_test<float16_t>();
193 #endif
194 #if defined(Q31_DT)
195 all_row_test<Q31>();
196 #endif
197 #if defined(Q15_DT)
198 all_row_test<Q15>();
199 #endif
200 #if defined(Q7_DT)
201 all_row_test<Q7>();
202 #endif
203 #endif
204 }
205