1 // -*- C++ -*-
2 /** @file */
3 #pragma once
4 
5 #ifdef DOXYGEN
6 #define ARM_MATH_MVEI
7 #define ARM_MATH_MVEF
8 #define ARM_MATH_MVE_FLOAT16
9 #endif
10 
11 /** \addtogroup HELIUMALG
12  *  @{
13  */
14 
15 template<typename M,
16          typename V,
17          typename RES>
_dot_m_v(RES & res,const M & m,const V & v,const Helium * =nullptr)18 inline void _dot_m_v(RES &res,
19                     const M&m,const V&v,
20                     const Helium* = nullptr)
21 {
22 
23    const vector_length_t nb_rows=m.rows();
24    constexpr int U = 4;
25 
26    index_t row=0;
27 
28    DISABLE_LOOP_UNROLL
29    for(; row<=nb_rows-U; row += U)
30    {
31       results<U>([&res,&row](index_t k){return &res[row+k];}) =
32            inner::from_accumulator(dot(unroll<U>(
33                                  [&row,&m](index_t k){return m.row(row+k);}),
34                                  replicate<U>(v)
35               ));
36    }
37 
38    switch (nb_rows-row)
39    {
40       case 3:
41          results<3>([&res,row](index_t k){return &res[row+k];}) =
42            inner::from_accumulator(dot(unroll<3>(
43                                  [row,&m](index_t k){return m.row(row+k);}),
44                                  replicate<3>(v)
45                            ));
46       break;
47       case 2:
48          results<2>([&res,row](index_t k){return &res[row+k];}) =
49            inner::from_accumulator(dot(unroll<2>(
50                                  [row,&m](index_t k){return m.row(row+k);}),
51                                  replicate<2>(v)
52                                ));
53       break;
54       case 1:
55          res[row] = inner::from_accumulator(dot(m.row(row),v));
56       break;
57    }
58 
59 }
60 
61 #define MATRIX_DIM2 2
62 #define MATRIX_DIM3 3
63 #define MATRIX_DIM4 4
64 
65 #if defined(ARM_MATH_MVEI)
66 
67 /* Fixed point specific cases*/
68 #include "matrix_multiply_fixed.hpp"
69 
70 #endif
71 
72 #if defined(ARM_MATH_MVEF)
73 
74 /* Datatype specific cases*/
75 #include "matrix_multiply_f16.hpp"
76 #include "matrix_multiply_f32.hpp"
77 
78 /* Generic float */
79 template<typename MA,
80          typename MB,
81          typename RES,
82          typename std::enable_if<
83          has_vector_inst<MA>() &&
84          number_traits<typename traits<MA>::Scalar>::is_float,bool>::type = true>
85 __STATIC_INLINE void _dot_m_m(const MA&pSrcA,const MB&pSrcB,
86                      RES &&pDst,
87                      const Helium* = nullptr)
88    {
89     using T = typename traits<MA>::Scalar;
90     using ACC = typename vector_traits<T>::temp_accumulator;
91     using VEC = typename vector_traits<T>::vector;
92     constexpr int nb_lanes = vector_traits<T>::nb_lanes;
93 
94     T  *pInB = pSrcB.ptr();        /* input data matrix pointer B */
95     T  *pInA = pSrcA.ptr();        /* input data matrix pointer A  */
96     T  *pOut = pDst.ptr();         /* output data matrix pointer */
97     int         numRowsA = pSrcA.rows();  /* number of rows of input matrix A */
98     int         numColsB = pSrcB.columns();  /* number of columns of input matrix B */
99     int         numColsA = pSrcA.columns();  /* number of columns of input matrix A */
100     uint32_t    blkCnt;                     /* loop counters */
101     uint32_t    i;
102 
103   {
104       /* small squared matrix specialized routines */
105     if(numRowsA == numColsB && numColsB == numColsA) {
106         if (numRowsA == 1)
107         {
108            pDst(0,0)= pSrcA(0,0) * pSrcB(0,0);
109            return;
110         }
111         else if(numRowsA == 2)
112             return _arm_mat_mult_2x2_mve(pSrcA, pSrcB, std::forward<RES>(pDst));
113         else if(numRowsA == 3)
114             return _arm_mat_mult_3x3_mve(pSrcA, pSrcB, std::forward<RES>(pDst));
115         else if(numRowsA == 4)
116             return _arm_mat_mult_4x4_mve(pSrcA, pSrcB, std::forward<RES>(pDst));
117     }
118 
119     /* main loop process 4 rows */
120     i = numRowsA >> 2;
121     while (i > 0U)
122     {
123         T *pInA0, *pInA1, *pInA2, *pInA3;
124         T *pInB0;
125         T *pOut0, *pOut1, *pOut2, *pOut3;
126         ACC vecMac0, vecMac1, vecMac2, vecMac3;
127         VEC vecInB;
128 
129         /* pointers to 4 consecutive output rows */
130         pOut0 = pOut;
131         pOut1 = pOut0 + pDst.stride();
132         pOut2 = pOut1 + pDst.stride();
133         pOut3 = pOut2 + pDst.stride();
134         pInB0 = pInB;
135 
136         uint32_t  k = numColsB / nb_lanes;
137         while (k > 0U)
138         {
139             /* pointers to 4 consecutive Matrix A rows */
140             pInA0 = pInA;
141             pInA1 = pInA0 + pSrcA.stride();
142             pInA2 = pInA1 + pSrcA.stride();
143             pInA3 = pInA2 + pSrcA.stride();
144 
145             vecMac0 = vector_traits<T>::temp_acc_zero();
146             vecMac1 = vector_traits<T>::temp_acc_zero();
147             vecMac2 = vector_traits<T>::temp_acc_zero();
148             vecMac3 = vector_traits<T>::temp_acc_zero();
149 
150             blkCnt = numColsA;
151 
152             while (blkCnt > 0U)
153             {
154                 /*
155                  * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
156                  */
157                 vecInB = inner::vload1<1>(pInB0); /* vldrwq_f32(pInB0, 0); */
158 
159                 vecMac0 = inner::vmacc(vecMac0, vecInB, *pInA0++);
160                 vecMac1 = inner::vmacc(vecMac1, vecInB, *pInA1++);
161                 vecMac2 = inner::vmacc(vecMac2, vecInB, *pInA2++);
162                 vecMac3 = inner::vmacc(vecMac3, vecInB, *pInA3++);
163 
164                 pInB0 = pInB0 + pSrcB.stride();
165                 /*
166                  * Decrement the blockSize loop counter
167                  */
168                 blkCnt--;
169             }
170 
171             /* Store the results (4 x 4 block) in the destination buffer */
172             inner::vstore1<1>(pOut0, vecMac0);
173             pOut0 += nb_lanes;
174             inner::vstore1<1>(pOut1, vecMac1);
175             pOut1 += nb_lanes;
176             inner::vstore1<1>(pOut2, vecMac2);
177             pOut2 += nb_lanes;
178             inner::vstore1<1>(pOut3, vecMac3);
179             pOut3 += nb_lanes;
180 
181             /*
182              * rewind
183              */
184             pInB0 -= (pSrcB.stride() * numColsA) - nb_lanes;
185             k--;
186         }
187 
188         int       colBLeft = numColsB & (nb_lanes - 1);
189         if (colBLeft)
190         {
191             pInA0 = pInA;
192             pInA1 = pInA0 + pSrcA.stride();
193             pInA2 = pInA1 + pSrcA.stride();
194             pInA3 = pInA2 + pSrcA.stride();
195 
196             mve_pred16_t    p0 = inner::vctpq<T>::mk(colBLeft);
197 
198             vecMac0 = vector_traits<T>::temp_acc_zero();
199             vecMac1 = vector_traits<T>::temp_acc_zero();
200             vecMac2 = vector_traits<T>::temp_acc_zero();
201             vecMac3 = vector_traits<T>::temp_acc_zero();
202 
203             blkCnt = numColsA;
204 
205             while (blkCnt > 0U)
206             {
207                 /*
208                  * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
209                  */
210                 vecInB = inner::vload1_z<1>(pInB0, colBLeft,p0);
211 
212                 vecMac0 = inner::vmacc(vecMac0, vecInB, *pInA0++);
213                 vecMac1 = inner::vmacc(vecMac1, vecInB, *pInA1++);
214                 vecMac2 = inner::vmacc(vecMac2, vecInB, *pInA2++);
215                 vecMac3 = inner::vmacc(vecMac3, vecInB, *pInA3++);
216 
217                 pInB0 = pInB0 + pSrcB.stride();
218                 /*
219                  * Decrement the blockSize loop counter
220                  */
221                 blkCnt--;
222             }
223 
224             /* Store the results (4 x colBLeft block) in the destination buffer */
225             inner::vstore1_z<1>(pOut0, vecMac0, colBLeft,p0);
226             inner::vstore1_z<1>(pOut1, vecMac1, colBLeft,p0);
227             inner::vstore1_z<1>(pOut2, vecMac2, colBLeft,p0);
228             inner::vstore1_z<1>(pOut3, vecMac3, colBLeft,p0);
229         }
230 
231         /* move to next rows */
232         pInA += 4 * pSrcA.stride();
233         pOut += 4 * pDst.stride();
234         i--;
235     }
236 
237     /*
238      * non multiple of 4 rows for Matrix A
239      * process single row
240      */
241     if (numRowsA & 3)
242     {
243         i = numRowsA & 3;
244         while (i > 0U)
245         {
246             T   *pInA0;
247             T   *pInB0;
248             T   *pOut0;
249             VEC    vecInB;
250             ACC    vecMac0;
251 
252             pOut0 = pOut;
253             pInB0 = pInB;
254 
255             uint32_t       k = numColsB / nb_lanes;
256             while (k > 0U)
257             {
258                 pInA0 = pInA;
259 
260                 vecMac0 = vector_traits<T>::temp_acc_zero();
261                 blkCnt = numColsA;
262                 while (blkCnt > 0U)
263                 {
264                     /*
265                      * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
266                      */
267                     vecInB = inner::vload1<1>(pInB0); /* vldrwq_f32(pInB0, 0); */
268 
269                     vecMac0 = inner::vmacc(vecMac0, vecInB, *pInA0++);
270 
271                     pInB0 = pInB0 + pSrcB.stride();
272                     /*
273                      * Decrement the blockSize loop counter
274                      */
275                     blkCnt--;
276                 }
277 
278                 /* Store the results (1 x 4 block) in the destination buffer */
279                 inner::vstore1<1>(pOut0, vecMac0);
280                 pOut0 += nb_lanes;
281 
282                 /*
283                  * rewind
284                  */
285                 pInB0 -= (pSrcB.stride() * numColsA) - nb_lanes;
286                 k--;
287             }
288 
289             int       colBLeft = numColsB & (nb_lanes-1);
290             if (colBLeft)
291             {
292                 pInA0 = pInA;
293                 mve_pred16_t    p0 = inner::vctpq<T>::mk(colBLeft);
294 
295                 vecMac0 = vector_traits<T>::temp_acc_zero();
296                 blkCnt = numColsA;
297                 while (blkCnt > 0U)
298                 {
299                     /*
300                      * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
301                      */
302                     vecInB = inner::vload1_z<1>(pInB0, colBLeft,p0);
303 
304                     vecMac0 = inner::vmacc(vecMac0, vecInB, *pInA0++);
305 
306                     pInB0 = pInB0 + pSrcB.stride();
307                     /*
308                      * Decrement the blockSize loop counter
309                      */
310                     blkCnt--;
311                 }
312                 /* Store the results (1 x colBLeft block) in the destination buffer */
313                 inner::vstore1_z<1>(pOut0, vecMac0, colBLeft,p0);
314             }
315 
316             /* move to next row */
317             pInA += 1 * pSrcA.stride();
318             pOut += 1 * pDst.stride();
319             i--;
320         }
321 
322       }
323 
324 }
325 
326 }
327 
328 
329 #undef MATRIX_DIM2
330 #undef MATRIX_DIM3
331 #undef MATRIX_DIM4
332 
333 #endif
334 
335 /*! @} */