1 extern "C" {
2 extern void dot_test();
3 }
4
5 #include "allocator.h"
6
7 #include <dsppp/arch.hpp>
8 #include <dsppp/fixed_point.hpp>
9 #include <dsppp/matrix.hpp>
10
11 #include <iostream>
12
13 #include <cmsis_tests.h>
14
15 #include "dsp/basic_math_functions.h"
16 #include "dsp/basic_math_functions_f16.h"
17
18
19
20
21
22 template<typename T,int NB,typename O>
complex_test(const T scale)23 static void complex_test(const T scale)
24 {
25 std::cout << "----\r\n" << "N = " << NB << "\r\n";
26 #if defined(STATIC_TEST)
27 PVector<T,NB> a;
28 PVector<T,NB> b;
29 PVector<T,NB> c;
30 PVector<T,NB> d;
31
32 PVector<T,NB> res;
33 #else
34 PVector<T> a(NB);
35 PVector<T> b(NB);
36 PVector<T> c(NB);
37 PVector<T> d(NB);
38
39 PVector<T> res(NB);
40 #endif
41
42
43 init_array(a,NB);
44 init_array(b,NB);
45 init_array(c,NB);
46 init_array(d,NB);
47
48 INIT_SYSTICK;
49 START_CYCLE_MEASUREMENT;
50 startSectionNB(1);
51 O result = dot(scale*(a+b),c*d);
52 stopSectionNB(1);
53 STOP_CYCLE_MEASUREMENT;
54
55 O ref;
56 PVector<T,NB> tmp1;
57 PVector<T,NB> tmp2;
58 INIT_SYSTICK;
59 START_CYCLE_MEASUREMENT;
60 cmsisdsp_dot_expr(a.const_ptr(),
61 b.const_ptr(),
62 c.const_ptr(),
63 d.const_ptr(),
64 tmp1.ptr(),
65 tmp2.ptr(),
66 scale,
67 ref,NB);
68 STOP_CYCLE_MEASUREMENT;
69
70 if (!validate(result,ref))
71 {
72 printf("dot expr failed \r\n");
73
74 }
75
76 std::cout << "=====\r\n";
77
78 }
79
80
81 template<typename T,int NB,typename O>
test()82 static void test()
83 {
84 std::cout << "----\r\n" << "N = " << NB << "\r\n";
85 #if defined(STATIC_TEST)
86 PVector<T,NB> a;
87 PVector<T,NB> b;
88
89 PVector<T,NB> res;
90 #else
91 PVector<T> a(NB);
92 PVector<T> b(NB);
93
94 PVector<T> res(NB);
95 #endif
96
97 init_array(a,NB);
98 init_array(b,NB);
99
100 INIT_SYSTICK;
101 START_CYCLE_MEASUREMENT;
102 startSectionNB(1);
103 O result = dot(a,b);
104 stopSectionNB(1);
105 STOP_CYCLE_MEASUREMENT;
106
107
108 O ref;
109 INIT_SYSTICK;
110 START_CYCLE_MEASUREMENT;
111 cmsisdsp_dot(a.const_ptr(),b.const_ptr(),ref,NB);
112 STOP_CYCLE_MEASUREMENT;
113
114 if (!validate(result,ref))
115 {
116 printf("dot failed \r\n");
117
118 }
119
120 std::cout << "=====\r\n";
121
122 }
123
124
125 template<typename T>
all_dot_test()126 void all_dot_test()
127 {
128
129 const int nb_tails = TailForTests<T>::tail;
130 const int nb_loops = TailForTests<T>::loop;
131
132 using ACC = typename number_traits<T>::accumulator;
133 constexpr auto v = TestConstant<T>::v;
134
135 title<T>("Dot product");
136
137
138 test<T,NBVEC_4,ACC>();
139 test<T,NBVEC_8,ACC>();
140 test<T,NBVEC_9,ACC>();
141 test<T,NBVEC_16,ACC>();
142 test<T,NBVEC_32,ACC>();
143 test<T,NBVEC_64,ACC>();
144 test<T,NBVEC_128,ACC>();
145 test<T,NBVEC_256,ACC>();
146 test<T,NBVEC_258,ACC>();
147 test<T,NBVEC_512,ACC>();
148 test<T,NBVEC_1024,ACC>();
149 if constexpr (!std::is_same<T,double>::value)
150 {
151 test<T,NBVEC_2048,ACC>();
152 }
153
154 test<T,1,ACC>();
155 test<T,nb_tails,ACC>();
156 test<T,nb_loops,ACC>();
157 test<T,nb_loops+1,ACC>();
158 test<T,nb_loops+nb_tails,ACC>();
159
160
161 title<T>("Dot product with expressions");
162
163
164 complex_test<T,NBVEC_4,ACC>(v);
165 complex_test<T,NBVEC_8,ACC>(v);
166 complex_test<T,NBVEC_9,ACC>(v);
167 complex_test<T,NBVEC_32,ACC>(v);
168 complex_test<T,NBVEC_64,ACC>(v);
169 complex_test<T,NBVEC_128,ACC>(v);
170
171 complex_test<T,NBVEC_256,ACC>(v);
172
173 complex_test<T,NBVEC_258,ACC>(v);
174 complex_test<T,NBVEC_512,ACC>(v);
175 complex_test<T,NBVEC_1024,ACC>(v);
176 if constexpr (!std::is_same<T,double>::value)
177 {
178 complex_test<T,NBVEC_2048,ACC>(v);
179 }
180
181 complex_test<T,1,ACC>(v);
182 complex_test<T,nb_tails,ACC>(v);
183 complex_test<T,nb_loops,ACC>(v);
184 complex_test<T,nb_loops+1,ACC>(v);
185 complex_test<T,nb_loops+nb_tails,ACC>(v);
186
187 //print_map("Stats",max_stats);
188
189 }
190
dot_test()191 void dot_test()
192 {
193 #if defined(DOT_TEST)
194 #if defined(F64_DT)
195 all_dot_test<double>();
196 #endif
197 #if defined(F32_DT)
198 all_dot_test<float>();
199 #endif
200 #if defined(F16_DT) && !defined(DISABLEFLOAT16)
201 all_dot_test<float16_t>();
202 #endif
203 #if defined(Q31_DT)
204 all_dot_test<Q31>();
205 #endif
206 #if defined(Q15_DT)
207 all_dot_test<Q15>();
208 #endif
209 #if defined(Q7_DT)
210 all_dot_test<Q7>();
211 #endif
212 #endif
213 }
214