1 // -*- C++ -*-
2 /** @file */
3 #pragma once
4
5 #ifdef DOXYGEN
6 #define ARM_MATH_MVEI
7 #define ARM_MATH_MVEF
8 #define ARM_MATH_MVE_FLOAT16
9 #endif
10
11 /** \addtogroup HELIUMALG
12 * @{
13 */
14
15 template<typename M,
16 typename V,
17 typename RES>
_dot_m_v(RES & res,const M & m,const V & v,const Helium * =nullptr)18 inline void _dot_m_v(RES &res,
19 const M&m,const V&v,
20 const Helium* = nullptr)
21 {
22
23 const vector_length_t nb_rows=m.rows();
24 constexpr int U = 4;
25
26 index_t row=0;
27
28 DISABLE_LOOP_UNROLL
29 for(; row<=nb_rows-U; row += U)
30 {
31 results<U>([&res,&row](index_t k){return &res[row+k];}) =
32 inner::from_accumulator(dot(unroll<U>(
33 [&row,&m](index_t k){return m.row(row+k);}),
34 replicate<U>(v)
35 ));
36 }
37
38 switch (nb_rows-row)
39 {
40 case 3:
41 results<3>([&res,row](index_t k){return &res[row+k];}) =
42 inner::from_accumulator(dot(unroll<3>(
43 [row,&m](index_t k){return m.row(row+k);}),
44 replicate<3>(v)
45 ));
46 break;
47 case 2:
48 results<2>([&res,row](index_t k){return &res[row+k];}) =
49 inner::from_accumulator(dot(unroll<2>(
50 [row,&m](index_t k){return m.row(row+k);}),
51 replicate<2>(v)
52 ));
53 break;
54 case 1:
55 res[row] = inner::from_accumulator(dot(m.row(row),v));
56 break;
57 }
58
59 }
60
61 #define MATRIX_DIM2 2
62 #define MATRIX_DIM3 3
63 #define MATRIX_DIM4 4
64
65 #if defined(ARM_MATH_MVEI)
66
67 /* Fixed point specific cases*/
68 #include "matrix_multiply_fixed.hpp"
69
70 #endif
71
72 #if defined(ARM_MATH_MVEF)
73
74 /* Datatype specific cases*/
75 #include "matrix_multiply_f16.hpp"
76 #include "matrix_multiply_f32.hpp"
77
78 /* Generic float */
79 template<typename MA,
80 typename MB,
81 typename RES,
82 typename std::enable_if<
83 has_vector_inst<MA>() &&
84 number_traits<typename traits<MA>::Scalar>::is_float,bool>::type = true>
85 __STATIC_INLINE void _dot_m_m(const MA&pSrcA,const MB&pSrcB,
86 RES &&pDst,
87 const Helium* = nullptr)
88 {
89 using T = typename traits<MA>::Scalar;
90 using ACC = typename vector_traits<T>::temp_accumulator;
91 using VEC = typename vector_traits<T>::vector;
92 constexpr int nb_lanes = vector_traits<T>::nb_lanes;
93
94 T *pInB = pSrcB.ptr(); /* input data matrix pointer B */
95 T *pInA = pSrcA.ptr(); /* input data matrix pointer A */
96 T *pOut = pDst.ptr(); /* output data matrix pointer */
97 int numRowsA = pSrcA.rows(); /* number of rows of input matrix A */
98 int numColsB = pSrcB.columns(); /* number of columns of input matrix B */
99 int numColsA = pSrcA.columns(); /* number of columns of input matrix A */
100 uint32_t blkCnt; /* loop counters */
101 uint32_t i;
102
103 {
104 /* small squared matrix specialized routines */
105 if(numRowsA == numColsB && numColsB == numColsA) {
106 if (numRowsA == 1)
107 {
108 pDst(0,0)= pSrcA(0,0) * pSrcB(0,0);
109 return;
110 }
111 else if(numRowsA == 2)
112 return _arm_mat_mult_2x2_mve(pSrcA, pSrcB, std::forward<RES>(pDst));
113 else if(numRowsA == 3)
114 return _arm_mat_mult_3x3_mve(pSrcA, pSrcB, std::forward<RES>(pDst));
115 else if(numRowsA == 4)
116 return _arm_mat_mult_4x4_mve(pSrcA, pSrcB, std::forward<RES>(pDst));
117 }
118
119 /* main loop process 4 rows */
120 i = numRowsA >> 2;
121 while (i > 0U)
122 {
123 T *pInA0, *pInA1, *pInA2, *pInA3;
124 T *pInB0;
125 T *pOut0, *pOut1, *pOut2, *pOut3;
126 ACC vecMac0, vecMac1, vecMac2, vecMac3;
127 VEC vecInB;
128
129 /* pointers to 4 consecutive output rows */
130 pOut0 = pOut;
131 pOut1 = pOut0 + pDst.stride();
132 pOut2 = pOut1 + pDst.stride();
133 pOut3 = pOut2 + pDst.stride();
134 pInB0 = pInB;
135
136 uint32_t k = numColsB / nb_lanes;
137 while (k > 0U)
138 {
139 /* pointers to 4 consecutive Matrix A rows */
140 pInA0 = pInA;
141 pInA1 = pInA0 + pSrcA.stride();
142 pInA2 = pInA1 + pSrcA.stride();
143 pInA3 = pInA2 + pSrcA.stride();
144
145 vecMac0 = vector_traits<T>::temp_acc_zero();
146 vecMac1 = vector_traits<T>::temp_acc_zero();
147 vecMac2 = vector_traits<T>::temp_acc_zero();
148 vecMac3 = vector_traits<T>::temp_acc_zero();
149
150 blkCnt = numColsA;
151
152 while (blkCnt > 0U)
153 {
154 /*
155 * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
156 */
157 vecInB = inner::vload1<1>(pInB0); /* vldrwq_f32(pInB0, 0); */
158
159 vecMac0 = inner::vmacc(vecMac0, vecInB, *pInA0++);
160 vecMac1 = inner::vmacc(vecMac1, vecInB, *pInA1++);
161 vecMac2 = inner::vmacc(vecMac2, vecInB, *pInA2++);
162 vecMac3 = inner::vmacc(vecMac3, vecInB, *pInA3++);
163
164 pInB0 = pInB0 + pSrcB.stride();
165 /*
166 * Decrement the blockSize loop counter
167 */
168 blkCnt--;
169 }
170
171 /* Store the results (4 x 4 block) in the destination buffer */
172 inner::vstore1<1>(pOut0, vecMac0);
173 pOut0 += nb_lanes;
174 inner::vstore1<1>(pOut1, vecMac1);
175 pOut1 += nb_lanes;
176 inner::vstore1<1>(pOut2, vecMac2);
177 pOut2 += nb_lanes;
178 inner::vstore1<1>(pOut3, vecMac3);
179 pOut3 += nb_lanes;
180
181 /*
182 * rewind
183 */
184 pInB0 -= (pSrcB.stride() * numColsA) - nb_lanes;
185 k--;
186 }
187
188 int colBLeft = numColsB & (nb_lanes - 1);
189 if (colBLeft)
190 {
191 pInA0 = pInA;
192 pInA1 = pInA0 + pSrcA.stride();
193 pInA2 = pInA1 + pSrcA.stride();
194 pInA3 = pInA2 + pSrcA.stride();
195
196 mve_pred16_t p0 = inner::vctpq<T>::mk(colBLeft);
197
198 vecMac0 = vector_traits<T>::temp_acc_zero();
199 vecMac1 = vector_traits<T>::temp_acc_zero();
200 vecMac2 = vector_traits<T>::temp_acc_zero();
201 vecMac3 = vector_traits<T>::temp_acc_zero();
202
203 blkCnt = numColsA;
204
205 while (blkCnt > 0U)
206 {
207 /*
208 * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
209 */
210 vecInB = inner::vload1_z<1>(pInB0, colBLeft,p0);
211
212 vecMac0 = inner::vmacc(vecMac0, vecInB, *pInA0++);
213 vecMac1 = inner::vmacc(vecMac1, vecInB, *pInA1++);
214 vecMac2 = inner::vmacc(vecMac2, vecInB, *pInA2++);
215 vecMac3 = inner::vmacc(vecMac3, vecInB, *pInA3++);
216
217 pInB0 = pInB0 + pSrcB.stride();
218 /*
219 * Decrement the blockSize loop counter
220 */
221 blkCnt--;
222 }
223
224 /* Store the results (4 x colBLeft block) in the destination buffer */
225 inner::vstore1_z<1>(pOut0, vecMac0, colBLeft,p0);
226 inner::vstore1_z<1>(pOut1, vecMac1, colBLeft,p0);
227 inner::vstore1_z<1>(pOut2, vecMac2, colBLeft,p0);
228 inner::vstore1_z<1>(pOut3, vecMac3, colBLeft,p0);
229 }
230
231 /* move to next rows */
232 pInA += 4 * pSrcA.stride();
233 pOut += 4 * pDst.stride();
234 i--;
235 }
236
237 /*
238 * non multiple of 4 rows for Matrix A
239 * process single row
240 */
241 if (numRowsA & 3)
242 {
243 i = numRowsA & 3;
244 while (i > 0U)
245 {
246 T *pInA0;
247 T *pInB0;
248 T *pOut0;
249 VEC vecInB;
250 ACC vecMac0;
251
252 pOut0 = pOut;
253 pInB0 = pInB;
254
255 uint32_t k = numColsB / nb_lanes;
256 while (k > 0U)
257 {
258 pInA0 = pInA;
259
260 vecMac0 = vector_traits<T>::temp_acc_zero();
261 blkCnt = numColsA;
262 while (blkCnt > 0U)
263 {
264 /*
265 * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
266 */
267 vecInB = inner::vload1<1>(pInB0); /* vldrwq_f32(pInB0, 0); */
268
269 vecMac0 = inner::vmacc(vecMac0, vecInB, *pInA0++);
270
271 pInB0 = pInB0 + pSrcB.stride();
272 /*
273 * Decrement the blockSize loop counter
274 */
275 blkCnt--;
276 }
277
278 /* Store the results (1 x 4 block) in the destination buffer */
279 inner::vstore1<1>(pOut0, vecMac0);
280 pOut0 += nb_lanes;
281
282 /*
283 * rewind
284 */
285 pInB0 -= (pSrcB.stride() * numColsA) - nb_lanes;
286 k--;
287 }
288
289 int colBLeft = numColsB & (nb_lanes-1);
290 if (colBLeft)
291 {
292 pInA0 = pInA;
293 mve_pred16_t p0 = inner::vctpq<T>::mk(colBLeft);
294
295 vecMac0 = vector_traits<T>::temp_acc_zero();
296 blkCnt = numColsA;
297 while (blkCnt > 0U)
298 {
299 /*
300 * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
301 */
302 vecInB = inner::vload1_z<1>(pInB0, colBLeft,p0);
303
304 vecMac0 = inner::vmacc(vecMac0, vecInB, *pInA0++);
305
306 pInB0 = pInB0 + pSrcB.stride();
307 /*
308 * Decrement the blockSize loop counter
309 */
310 blkCnt--;
311 }
312 /* Store the results (1 x colBLeft block) in the destination buffer */
313 inner::vstore1_z<1>(pOut0, vecMac0, colBLeft,p0);
314 }
315
316 /* move to next row */
317 pInA += 1 * pSrcA.stride();
318 pOut += 1 * pDst.stride();
319 i--;
320 }
321
322 }
323
324 }
325
326 }
327
328
329 #undef MATRIX_DIM2
330 #undef MATRIX_DIM3
331 #undef MATRIX_DIM4
332
333 #endif
334
335 /*! @} */