1 extern "C" {
2     extern void fusion_test();
3 }
4 
5 #include "allocator.h"
6 
7 #include <tuple>
8 #include <array>
9 
10 #include <dsppp/arch.hpp>
11 #include <dsppp/fixed_point.hpp>
12 #include <dsppp/matrix.hpp>
13 #include <dsppp/unroll.hpp>
14 
15 #include <iostream>
16 
17 #include <cmsis_tests.h>
18 
19 template<typename T,int NB>
test()20 static void test()
21 {
22    std::cout << "----\r\n" << "N = " << NB << "\r\n";
23 
24    #if defined(STATIC_TEST)
25    PVector<T,NB> a;
26    PVector<T,NB> b;
27    PVector<T,NB> c;
28    #else
29    PVector<T> a(NB);
30    PVector<T> b(NB);
31    PVector<T> c(NB);
32    #endif
33 
34 
35    init_array(a,NB);
36    init_array(b,NB);
37    init_array(c,NB);
38 
39    #if defined(STATIC_TEST)
40    PVector<T,NB> resa;
41    PVector<T,NB> resb;
42    #else
43    PVector<T> resa(NB);
44    PVector<T> resb(NB);
45    #endif
46 
47 
48    INIT_SYSTICK;
49    START_CYCLE_MEASUREMENT;
50    startSectionNB(1);
51    results(resa,resb) = Merged{a + b,a + c};
52    stopSectionNB(1);
53    STOP_CYCLE_MEASUREMENT;
54 
55    PVector<T,NB> refa;
56    PVector<T,NB> refb;
57 
58    INIT_SYSTICK;
59    START_CYCLE_MEASUREMENT;
60    cmsisdsp_add(a.const_ptr(),b.const_ptr(),refa.ptr(),NB);
61    cmsisdsp_add(a.const_ptr(),c.const_ptr(),refb.ptr(),NB);
62    STOP_CYCLE_MEASUREMENT;
63 
64    if (!validate(resa.const_ptr(),refa.const_ptr(),NB))
65    {
66       printf("add a failed \r\n");
67 
68    }
69 
70    if (!validate(resb.const_ptr(),refb.const_ptr(),NB))
71    {
72       printf("add b failed \r\n");
73 
74    }
75 
76    std::cout << "=====\r\n";
77 }
78 
79 
80 template<typename T,int NB>
test2()81 static void test2()
82 {
83    std::cout << "----\r\n" << "N = " << NB << "\r\n";
84    #if defined(STATIC_TEST)
85    PVector<T,NB> a;
86    PVector<T,NB> b;
87    PVector<T,NB> c;
88    #else
89    PVector<T> a(NB);
90    PVector<T> b(NB);
91    PVector<T> c(NB);
92    #endif
93    using Acc = typename number_traits<T>::accumulator;
94 
95 
96    init_array(a,NB);
97    init_array(b,NB);
98    init_array(c,NB);
99 
100    Acc resa,resb,refa,refb;
101 
102    INIT_SYSTICK;
103    START_CYCLE_MEASUREMENT;
104    startSectionNB(2);
105    std::tie(resa,resb) = dot(Merged{expr(a),expr(a)},
106                              Merged{expr(b),expr(c)});
107    stopSectionNB(2);
108    STOP_CYCLE_MEASUREMENT;
109 
110    INIT_SYSTICK;
111    START_CYCLE_MEASUREMENT;
112    cmsisdsp_dot(a.const_ptr(),b.const_ptr(),refa,NB);
113    cmsisdsp_dot(a.const_ptr(),c.const_ptr(),refb,NB);
114    STOP_CYCLE_MEASUREMENT;
115 
116    if (!validate(resa,refa))
117    {
118       printf("dot a failed \r\n");
119 
120    }
121 
122    if (!validate(resb,refb))
123    {
124       printf("dot b failed \r\n");
125 
126    }
127 
128    std::cout << "=====\r\n";
129 
130 
131 }
132 
133 template<typename T,int NB>
test3()134 static void test3()
135 {
136    std::cout << "----\r\n" << "N = " << NB << "\r\n";
137 
138    constexpr int U = 2;
139    #if defined(STATIC_TEST)
140    PVector<T,NB> a[U];
141    PVector<T,NB> b[U];
142    #else
143    PVector<T> a[U]={PVector<T>(NB),PVector<T>(NB)};
144    PVector<T> b[U]={PVector<T>(NB),PVector<T>(NB)};
145    #endif
146 
147    using Acc = typename number_traits<T>::accumulator;
148 
149    for(int i=0;i<U;i++)
150    {
151       init_array(a[i],NB);
152       init_array(b[i],NB);
153    }
154 
155    std::array<Acc,U> res;
156    Acc ref[U];
157 
158    INIT_SYSTICK;
159    START_CYCLE_MEASUREMENT;
160    startSectionNB(3);
161    results(res) = dot(unroll<U>(
162                        [&a](index_t k){return expr(a[k]);}),
163                       unroll<U>(
164                        [&b](index_t k){return expr(b[k]);})
165               );
166    stopSectionNB(3);
167    STOP_CYCLE_MEASUREMENT;
168 
169    INIT_SYSTICK;
170    START_CYCLE_MEASUREMENT;
171    for(int i=0;i<U;i++)
172    {
173       cmsisdsp_dot(a[i].const_ptr(),b[i].const_ptr(),ref[i],NB);
174    }
175    STOP_CYCLE_MEASUREMENT;
176 
177    for(int i=0;i<U;i++)
178    {
179       if (!validate(res[i],ref[i]))
180       {
181          printf("dot failed %d \r\n",i);
182 
183       }
184    }
185 
186    std::cout << "=====\r\n";
187 
188 }
189 
190 template<typename T>
all_fusion_test()191 void all_fusion_test()
192 {
193 
194     const int nb_tails = TailForTests<T>::tail;
195     const int nb_loops = TailForTests<T>::loop;
196 
197     title<T>("Vector Fusion");
198 
199     test<T,NBVEC_256>();
200     test<T,1>();
201     test<T,nb_tails>();
202     test<T,nb_loops>();
203     test<T,nb_loops+1>();
204     test<T,nb_loops+nb_tails>();
205 
206     title<T>("Dot Product Fusion");
207 
208     test2<T,NBVEC_256>();
209     test2<T,1>();
210     test2<T,nb_tails>();
211     test2<T,nb_loops>();
212     test2<T,nb_loops+1>();
213     test2<T,nb_loops+nb_tails>();
214 
215 
216 
217 
218     title<T>("Unroll Fusion");
219 
220     test3<T,NBVEC_256>();
221     test3<T,1>();
222     test3<T,nb_tails>();
223     test3<T,nb_loops>();
224     test3<T,nb_loops+1>();
225     test3<T,nb_loops+nb_tails>();
226 
227 }
228 
fusion_test()229 void fusion_test()
230 {
231    /*
232 
233 gcc has some issues with this code.
234 FVP is freezing when trying to run it.
235 Since those kind of fusion are not really used in the library
236 (because performance is not good) we can disable those tests
237 to at least be able to test other parts of the library with gcc.
238 
239    */
240    #if !defined(GCC_COMPILER)
241 #if defined(FUSION_TEST)
242    #if defined(F64_DT)
243    all_fusion_test<double>();
244    #endif
245    #if defined(F32_DT)
246    all_fusion_test<float>();
247    #endif
248    #if defined(F16_DT) && !defined(DISABLEFLOAT16)
249    all_fusion_test<float16_t>();
250    #endif
251    #if defined(Q31_DT)
252    all_fusion_test<Q31>();
253    #endif
254    #if defined(Q15_DT)
255    all_fusion_test<Q15>();
256    #endif
257    #if defined(Q7_DT)
258    all_fusion_test<Q7>();
259    #endif
260 #endif
261    #endif
262 
263 }