1 extern "C" {
2     extern void dot_test();
3 }
4 
5 #include "allocator.h"
6 
7 #include <dsppp/arch.hpp>
8 #include <dsppp/fixed_point.hpp>
9 #include <dsppp/matrix.hpp>
10 
11 #include <iostream>
12 
13 #include <cmsis_tests.h>
14 
15 #include "dsp/basic_math_functions.h"
16 #include "dsp/basic_math_functions_f16.h"
17 
18 
19 
20 
21 
22 template<typename T,int NB,typename O>
complex_test(const T scale)23 static void complex_test(const T scale)
24 {
25    std::cout << "----\r\n" << "N = " << NB << "\r\n";
26    #if defined(STATIC_TEST)
27    PVector<T,NB> a;
28    PVector<T,NB> b;
29    PVector<T,NB> c;
30    PVector<T,NB> d;
31 
32    PVector<T,NB> res;
33    #else
34    PVector<T> a(NB);
35    PVector<T> b(NB);
36    PVector<T> c(NB);
37    PVector<T> d(NB);
38 
39    PVector<T> res(NB);
40    #endif
41 
42 
43    init_array(a,NB);
44    init_array(b,NB);
45    init_array(c,NB);
46    init_array(d,NB);
47 
48    INIT_SYSTICK;
49    START_CYCLE_MEASUREMENT;
50    startSectionNB(1);
51    O result = dot(scale*(a+b),c*d);
52    stopSectionNB(1);
53    STOP_CYCLE_MEASUREMENT;
54 
55    O ref;
56    PVector<T,NB> tmp1;
57    PVector<T,NB> tmp2;
58    INIT_SYSTICK;
59    START_CYCLE_MEASUREMENT;
60    cmsisdsp_dot_expr(a.const_ptr(),
61                      b.const_ptr(),
62                      c.const_ptr(),
63                      d.const_ptr(),
64                      tmp1.ptr(),
65                      tmp2.ptr(),
66                      scale,
67                      ref,NB);
68    STOP_CYCLE_MEASUREMENT;
69 
70    if (!validate(result,ref))
71    {
72       printf("dot expr failed \r\n");
73 
74    }
75 
76    std::cout << "=====\r\n";
77 
78 }
79 
80 
81 template<typename T,int NB,typename O>
test()82 static void test()
83 {
84    std::cout << "----\r\n" << "N = " << NB << "\r\n";
85    #if defined(STATIC_TEST)
86    PVector<T,NB> a;
87    PVector<T,NB> b;
88 
89    PVector<T,NB> res;
90    #else
91    PVector<T> a(NB);
92    PVector<T> b(NB);
93 
94    PVector<T> res(NB);
95    #endif
96 
97    init_array(a,NB);
98    init_array(b,NB);
99 
100    INIT_SYSTICK;
101    START_CYCLE_MEASUREMENT;
102    startSectionNB(1);
103    O result = dot(a,b);
104    stopSectionNB(1);
105    STOP_CYCLE_MEASUREMENT;
106 
107 
108    O ref;
109    INIT_SYSTICK;
110    START_CYCLE_MEASUREMENT;
111    cmsisdsp_dot(a.const_ptr(),b.const_ptr(),ref,NB);
112    STOP_CYCLE_MEASUREMENT;
113 
114    if (!validate(result,ref))
115    {
116       printf("dot failed \r\n");
117 
118    }
119 
120    std::cout << "=====\r\n";
121 
122 }
123 
124 
125 template<typename T>
all_dot_test()126 void all_dot_test()
127 {
128 
129    const int nb_tails = TailForTests<T>::tail;
130    const int nb_loops = TailForTests<T>::loop;
131 
132     using ACC = typename number_traits<T>::accumulator;
133     constexpr auto v = TestConstant<T>::v;
134 
135     title<T>("Dot product");
136 
137 
138     test<T,NBVEC_4,ACC>();
139     test<T,NBVEC_8,ACC>();
140     test<T,NBVEC_9,ACC>();
141     test<T,NBVEC_16,ACC>();
142     test<T,NBVEC_32,ACC>();
143     test<T,NBVEC_64,ACC>();
144     test<T,NBVEC_128,ACC>();
145     test<T,NBVEC_256,ACC>();
146     test<T,NBVEC_258,ACC>();
147     test<T,NBVEC_512,ACC>();
148     test<T,NBVEC_1024,ACC>();
149     if constexpr (!std::is_same<T,double>::value)
150     {
151        test<T,NBVEC_2048,ACC>();
152     }
153 
154     test<T,1,ACC>();
155     test<T,nb_tails,ACC>();
156     test<T,nb_loops,ACC>();
157     test<T,nb_loops+1,ACC>();
158     test<T,nb_loops+nb_tails,ACC>();
159 
160 
161     title<T>("Dot product with expressions");
162 
163 
164     complex_test<T,NBVEC_4,ACC>(v);
165     complex_test<T,NBVEC_8,ACC>(v);
166     complex_test<T,NBVEC_9,ACC>(v);
167     complex_test<T,NBVEC_32,ACC>(v);
168     complex_test<T,NBVEC_64,ACC>(v);
169     complex_test<T,NBVEC_128,ACC>(v);
170 
171     complex_test<T,NBVEC_256,ACC>(v);
172 
173     complex_test<T,NBVEC_258,ACC>(v);
174     complex_test<T,NBVEC_512,ACC>(v);
175     complex_test<T,NBVEC_1024,ACC>(v);
176     if constexpr (!std::is_same<T,double>::value)
177     {
178        complex_test<T,NBVEC_2048,ACC>(v);
179     }
180 
181     complex_test<T,1,ACC>(v);
182     complex_test<T,nb_tails,ACC>(v);
183     complex_test<T,nb_loops,ACC>(v);
184     complex_test<T,nb_loops+1,ACC>(v);
185     complex_test<T,nb_loops+nb_tails,ACC>(v);
186 
187     //print_map("Stats",max_stats);
188 
189 }
190 
dot_test()191 void dot_test()
192 {
193 #if defined(DOT_TEST)
194    #if defined(F64_DT)
195    all_dot_test<double>();
196    #endif
197    #if defined(F32_DT)
198    all_dot_test<float>();
199    #endif
200    #if defined(F16_DT) && !defined(DISABLEFLOAT16)
201    all_dot_test<float16_t>();
202    #endif
203    #if defined(Q31_DT)
204    all_dot_test<Q31>();
205    #endif
206    #if defined(Q15_DT)
207    all_dot_test<Q15>();
208    #endif
209    #if defined(Q7_DT)
210    all_dot_test<Q7>();
211    #endif
212 #endif
213 }
214