1 // -*- C++ -*-
2 /** @file */
3 #pragma once
4 
5 #ifdef DOXYGEN
6 #define ARM_MATH_MVEI
7 #define ARM_MATH_MVEF
8 #define ARM_MATH_MVE_FLOAT16
9 #endif
10 
11 /** \addtogroup HELIUMALG
12  *  @{
13  */
14 
15 #if defined(ARM_MATH_MVE_FLOAT16)
16 
17 /*
18 
19 This can't be used with stride bigger than 21845
20 which for embedded is acceptable.
21 
22 No check is done at runtime or build time that the stride is not
23 too big.
24 
25 */
26 
27 template<typename MA,
28          typename MB,
29          typename RES,
30 typename std::enable_if<
31          has_vector_inst<MA>() &&
32          SameElementType<MA,float16_t>::value,bool>::type = true>
33 __STATIC_INLINE  void _arm_mat_mult_2x2_mve(
34     const MA &pSrcA,
35     const MB &pSrcB,
36     RES &&pDst)
37 {
38     using T = typename traits<MA>::Scalar;
39     //using ACC = typename vector_traits<T>::temp_accumulator;
40     using VEC = typename vector_traits<T>::vector;
41 
42     const uint16_t offsetA[8] = { 0, 0, (uint16_t)pSrcA.stride(), (uint16_t)pSrcA.stride(),
43                                   0, 0, (uint16_t)pSrcA.stride(), (uint16_t)pSrcA.stride() };
44     /* offsetB allows to read and duplicate 1 row of B */
45     const uint16_t offsetB[8] = { 0, 1, 0, 1, 0, 1, 0, 1 };
46 
47     /* {d00, d01, d10, d11} */
48     const uint16_t  offsetD[8] = { 0, 1, (uint16_t)pDst.stride(), (uint16_t)(pDst.stride()+1),
49                                    0,0,0,0 };
50 
51     uint16x8_t    vecOffsA, vecOffsB,vecOffsD;
52     VEC       vecInA, vecInB, vecDst;
53     T      *pOut = pDst.ptr();  /* output data matrix pointer */
54 
55     /*
56      * load initial offsets
57      */
58     vecOffsA = vldrhq_u16((uint16_t const *) offsetA);
59     vecOffsB = vldrhq_u16((uint16_t const *) offsetB);
60     /*
61      * load {a00 a00 a10 a10 x x x x }
62      */
63     vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA);
64     /*
65      * load {b00 b01 b00 b01 x x x x }
66      */
67     vecInB = vldrhq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB);
68     /*
69      *  { a00 b00       a00 b01
70      *    a10 b00       a10 b01
71      *       x             x
72      *       x             x   }
73      */
74     vecDst = vmulq(vecInA, vecInB);
75     /*
76      * move to 2nd column of matrix A
77      */
78     vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 1);
79     /*
80      * load {a01 a01 a11 a11 x x x x}
81      */
82     vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA);
83     /*
84      * move to next B row
85      */
86     vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) pSrcB.stride());
87     /*
88      * load {b10, b11, b10, b11, x x x x }
89      */
90     vecInB = vldrhq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB);
91     /*
92      *  { a00 b00 + a01 b10   a00 b01 + a01 b11
93      *    a10 b00 + a11 b10     a10 b01 + a11 b11
94      *             x                    x
95      *             x                    x       }
96      */
97     vecDst = vfmaq(vecDst, vecInA, vecInB);
98 
99     mve_pred16_t p0 = vctp16q(2*2);
100     /*
101      * Store the result in the destination buffer
102      * (lower half of the vector)
103      */
104 
105     vecOffsD = vldrhq_u16((uint16_t const *) offsetD);
106 
107     vstrhq_scatter_shifted_offset_p(pOut,vecOffsD,vecDst,p0);
108 
109 }
110 
111 
112 template<typename MA,
113          typename MB,
114          typename RES,
115 typename std::enable_if<
116          has_vector_inst<MA>() &&
117          SameElementType<MA,float16_t>::value,bool>::type = true>
118 __STATIC_INLINE  void _arm_mat_mult_3x3_mve(
119     const MA &pSrcA,
120     const MB &pSrcB,
121     RES &&pDst)
122 {
123     const uint16_t offsetA[8] = { 0, 0, 0,
124                                          (uint16_t)pSrcA.stride(), (uint16_t)pSrcA.stride(), (uint16_t)pSrcA.stride(),
125                                          (uint16_t)(2U*pSrcA.stride()), (uint16_t)(2U*pSrcA.stride()) };
126     /* offsetB allows to read and duplicate 1 row of B */
127     const uint16_t offsetB[8] = { 0, 1, 2, 0, 1, 2, 0, 1 };
128     const uint16_t offsetD[8] = { 0, 1, 2,
129                                 (uint16_t)(0+pDst.stride()), (uint16_t)(1+pDst.stride()),
130                                 (uint16_t)(2+pDst.stride()),
131                                 (uint16_t)(0+2*pDst.stride()),
132                                 (uint16_t)(1+2*pDst.stride()) };
133 
134     uint16x8_t    vecOffsA, vecOffsB,vecOffsD;
135     float16x8_t       vecInA, vecInB, vecDst;
136     float16_t      *pOut = pDst.ptr();  /* output data matrix pointer */
137 
138     /*
139      * load initial offsets
140      */
141     vecOffsA = vldrhq_u16((uint16_t const *) offsetA);
142     vecOffsB = vldrhq_u16((uint16_t const *) offsetB);
143 
144     /*
145      * load {a00 a00 a00 a10 a10 a10 a20 a20}
146      */
147     vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA);
148     /*
149      * load {b00 b01 b02 b00 b01 b02 b00 b01}
150      */
151     vecInB = vldrhq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB);
152     /*
153      *  { a00 b00       a00 b01     a00 b02
154      *    a10 b00       a10 b01     a10 b02
155      *    a20 b00       a20 b01}
156      */
157     vecDst = vmulq(vecInA, vecInB);
158 
159     /*
160      * move to 2nd column of matrix A
161      */
162     vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 1);
163     /*
164      * load {a01 a01 a01 a11 a11 a11 a21 a21}
165      */
166     vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA);
167     /*
168      * move to next B row
169      */
170     vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) pSrcB.stride());
171     /*
172      * load {b10, b11, b12, b10, b11, b12, b10, b11}
173      */
174     vecInB = vldrhq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB);
175     /*
176      *  { a00 b00 + a01 b10   a00 b01 + a01 b11     a00 b02 + a01 b12
177      *    a10 b00 + a11 b10     a10 b01 + a11 b11     a10 b02 + a11 b12
178      *    a20 b00 + a21 b10     a20 b01 + a21 b11   }
179      */
180     vecDst = vfmaq(vecDst, vecInA, vecInB);
181     /*
182      * move to 3rd column of matrix A
183      */
184     vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 1);
185     /*
186      * load {a02 a02 a02 a12 a12 a12 a22 a22}
187      */
188     vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA);
189     /*
190      * move to next B row
191      */
192     vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t)  pSrcB.stride());
193     /*
194      * load {b20, b21, b22, b20, b21, b22, b20, b21}
195      */
196     vecInB = vldrhq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB);
197     /*
198      *  {a00 b00 + a01 b10 + a02 b20  a00 b01 + a01 b11 + a02 b21     a00 b02 + a01 b12 + a02 b22},
199      *   a10 b00 + a11 b10 + a12 b20    a10 b01 + a11 b11 + a12 b21     a10 b02 + a11 b12 + a12 b22},
200      *   a20 b00 + a21 b10 + a22 b20    a20 b01 + a21 b11 + a22 b21   }
201      */
202     vecDst = vfmaq(vecDst, vecInA, vecInB);
203 
204     /*
205      * Store the result in the destination buffer
206      */
207     vecOffsD = vldrhq_u16((uint16_t const *) offsetD);
208 
209     vstrhq_scatter_shifted_offset(pOut,vecOffsD,vecDst);
210 
211     pOut += 2*pDst.stride()+2;
212 
213     /* last element computed in scalar mode
214      * a20 b02 + a21 b12 + a22 b22
215      */
216 
217     const _Float16 * pA = (const _Float16 *)pSrcA.const_ptr();
218     const _Float16 * pB = (const _Float16 *)pSrcB.const_ptr();
219     const index_t sa =pSrcA.stride();
220     const index_t sb =pSrcB.stride();
221     *pOut = pA[2*sa] * pB[2] + pA[1+2*sa] * pB[2+sb] + pA[2+2*sa] * pB[2+2*sb];
222 
223 }
224 
225 
226 
227 template<typename MA,
228          typename MB,
229          typename RES,
230 typename std::enable_if<
231          has_vector_inst<MA>() &&
232          SameElementType<MA,float16_t>::value,bool>::type = true>
233 __STATIC_INLINE  void _arm_mat_mult_4x4_mve(
234     const MA &pSrcA,
235     const MB &pSrcB,
236     RES &&pDst)
237 {
238      /* offsetA allows to read and duplicate 2 successive column elements of A */
239     const uint16_t offsetA[8] = { 0, 0, 0, 0,
240                                          (uint16_t)pSrcA.stride(), (uint16_t)pSrcA.stride(), (uint16_t)pSrcA.stride(), (uint16_t)pSrcA.stride() };
241     /* offsetB allows to read and duplicate 1 row of B */
242     const uint16_t offsetB[8] = { 0, 1, 2, 3, 0, 1, 2, 3 };
243 
244     const uint16_t offsetD[8] = { 0, 1, 2, 3,
245                                   (uint16_t)(0+pDst.stride()), (uint16_t)(1+pDst.stride()),
246                                   (uint16_t)(2+pDst.stride()), (uint16_t)(3+pDst.stride()) };
247 
248     uint16x8_t    vecOffsA, vecOffsB,vecOffsD;
249     float16x8_t       vecInA, vecInB, vecDst0, vecDst1;
250     float16_t      *pOut = pDst.ptr();  /* output data matrix pointer */
251 
252     /*
253      * load initial offsets
254      */
255     vecOffsA = vldrhq_u16((uint16_t const *) offsetA);
256     vecOffsB = vldrhq_u16((uint16_t const *) offsetB);
257 
258     /*
259      * load {a00 a00 a00 a00 a10 a10 a10 a10}
260      */
261     vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA);
262     /*
263      * load {b00 b01 b02 b03 b00 b01 b02 b03}
264      */
265     vecInB = vldrhq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB);
266 
267     /*
268      *  { a00 b00       a00 b01     a00 b02     a00 b03
269      *    a10 b00       a10 b01     a10 b02     a10 b03 }
270      */
271     vecDst0 = vmulq(vecInA, vecInB);
272     /*
273      * jump 2 x A rows (2nd half of matrix)
274      */
275     vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) pSrcA.stride()*2);
276     /*
277      * load {a20 a20 a20 a20 a30 a30 a30 a30}
278      */
279     vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA);
280     /*
281      *  { a20 b00       a20 b01     a20 b02     a20 b03
282      *    a30 b00       a30 b01     a30 b02 +   a31 b12 }
283      */
284     vecDst1 = vmulq(vecInA, vecInB);
285     /*
286      * rewind back to top half of the A matrix (2nd column)
287      */
288     vecOffsA = vsubq(vecOffsA, (uint16_t) (2*pSrcA.stride()-1));
289     /*
290      * load {a01 a01 a01 a01 a11 a11 a11 a11}
291      */
292     vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA);
293 
294     /*
295      * move to next B row
296      */
297     vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) pSrcB.stride());
298     /*
299      * load {b10, b11, b12, b13, b10, b11, b12, b13}
300      */
301     vecInB = vldrhq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB);
302     /*
303      *  { a00 b00 + a01 b10         a00 b01 + a01 b11       a00 b02 + a01 b12       a00 b03 + a01 b13
304      *    a10 b00 + a11 b10         a10 b01 + a11 b11       a10 b02 + a11 b12       a10 b03 + a11 b13 }
305      */
306     vecDst0 = vfmaq(vecDst0, vecInA, vecInB);
307     /*
308      * jump 2 x A rows (2nd half of matrix)
309      */
310     vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) pSrcA.stride()*2);
311     /*
312      * load {a21 a21 a21 a21 a31 a31 a31 a31}
313      */
314     vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA);
315     /*
316      *  {a20 b00 + a21 b10      a20 b01 + a21 b11       a20 b02 + a21 b12       a20 b03 + a21 b13
317      *   a30 b00 + a31 b10      a30 b01 + a31 b11       a30 b02 + a31 b12       a30 b03 + a31 b13 }
318      */
319     vecDst1 = vfmaq(vecDst1, vecInA, vecInB);
320 
321     /*
322      * rewind back to top half of the A matrix (3rd column)
323      */
324     vecOffsA = vsubq(vecOffsA, (uint16_t) (2*pSrcA.stride()-1));
325     /*
326      * load {a02 a02 a02 a02 a12 a12 a12 a12}
327      */
328     vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA);
329     /*
330      * move to next B row
331      */
332     vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) pSrcB.stride());
333     /*
334      * load {b20, b21, b22, b23, b20, b21, b22, b23}
335      */
336     vecInB = vldrhq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB);
337     /*
338      *  { a00 b00 + a01 b10 + a02 b20    a00 b01 + a01 b11 + a02 b21    a00 b02 + a01 b12 + a02 b22   a00 b03 + a01 b13 + a02 b23
339      *    a10 b00 + a11 b10 + a12 b20    a10 b01 + a11 b11 + a12 b21    a10 b02 + a11 b12 + a12 b22   a10 b03 + a11 b13 + a12 b23 }
340      */
341     vecDst0 = vfmaq(vecDst0, vecInA, vecInB);
342     /*
343      * jump 2 x A rows
344      */
345     vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 2*pSrcA.stride());
346 
347     /*
348      * load {a22 a22 a22 a22 a32 a32 a32 a32}
349      */
350     vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA);
351     /*
352      *  {a20 b00 + a21 b10 + a22 b20   a20 b01 + a21 b11 + a22 b21  a20 b02 + a21 b12 + a22 b22    a20 b03 + a21 b13 + a22 b23
353      *   a30 b00 + a31 b10 + a32 b20   a30 b01 + a31 b11 + a32 b21  a30 b02 + a31 b12 + a32 b22    a30 b03 + a31 b13 + a32 b23 }
354      */
355     vecDst1 = vfmaq(vecDst1, vecInA, vecInB);
356 
357     /*
358      * rewind back to top half of the A matrix (4th column)
359      */
360     vecOffsA = vsubq(vecOffsA, (uint16_t) (2*pSrcA.stride()-1));
361     /*
362      * load {a03 a03 a03 a03 a13 a13 a13 a13}
363      */
364     vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA);
365     /*
366      * move to next B row
367      */
368     vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) pSrcB.stride());
369     /*
370      * load {b30, b31, b32, b33, b30, b31, b32, b33}
371      */
372     vecInB = vldrhq_gather_shifted_offset(pSrcB.const_ptr(), vecOffsB);
373     /*
374      * { a00 b00 +...+ a03 b30,    a00 b01 +...+ a03 b31,   a00 b02 +...+ a03 b32,   a00 b03 +...+ a03 b33
375      *   a10 b00 +...+ a13 b30,    a10 b01 +...+ a13 b31,   a10 b02 +...+ a13 b32,   a10 b03 +...+ a13 b33 }
376      */
377     vecDst0 = vfmaq(vecDst0, vecInA, vecInB);
378     /*
379      * jump 2 x A rows
380      */
381     vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) pSrcA.stride()*2);
382     /*
383      * load {a23 a23 a23 a23 a33 a33 a33 a33}
384      */
385     vecInA = vldrhq_gather_shifted_offset(pSrcA.const_ptr(), vecOffsA);
386     /*
387      *  {a20 b00 +...+ a23 b30,   a20 b01 +...+ a23 b31,   a20 b02 +...+ a23 b32,   a20 b03 +...+ a23 b33
388      *   a30 b00 +...+ a33 b30,   a30 b01 +...+ a33 b31,   a30 b02 +...+ a33 b32,   a30 b03 +...+ a33 b33 }
389      */
390     vecDst1 = vfmaq(vecDst1, vecInA, vecInB);
391 
392     /*
393      * Store the result in the destination buffer
394      */
395     vecOffsD = vldrhq_u16((uint16_t const *) offsetD);
396     vstrhq_scatter_shifted_offset(pOut,vecOffsD,vecDst0);
397     pOut += 2*pDst.stride();
398     vstrhq_scatter_shifted_offset(pOut,vecOffsD,vecDst1);
399 
400 }
401 
402 #endif
403 
404 /*! @} */