1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_mat_mult_f32.c
4 * Description: Floating-point matrix multiplication
5 *
6 * $Date: 23 April 2021
7 * $Revision: V1.9.0
8 *
9 * Target Processor: Cortex-M and Cortex-A cores
10 * -------------------------------------------------------------------- */
11 /*
12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13 *
14 * SPDX-License-Identifier: Apache-2.0
15 *
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
19 *
20 * www.apache.org/licenses/LICENSE-2.0
21 *
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
27 */
28
29 #include "dsp/matrix_functions.h"
30
31 #if defined(ARM_MATH_NEON)
32 #define GROUPOFROWS 8
33 #endif
34
35 /**
36 * @ingroup groupMatrix
37 */
38
39 /**
40 * @defgroup MatrixMult Matrix Multiplication
41 *
42 * Multiplies two matrices.
43 *
44 * @par Multiplication of two 3x3 matrices:
45 *
46 * \f[
47 * \begin{pmatrix}
48 * a_{1,1} & a_{1,2} & a_{1,3} \\
49 * a_{2,1} & a_{2,2} & a_{2,3} \\
50 * a_{3,1} & a_{3,2} & a_{3,3} \\
51 * \end{pmatrix}
52 *
53 * \begin{pmatrix}
54 * b_{1,1} & b_{1,2} & b_{1,3} \\
55 * b_{2,1} & b_{2,2} & b_{2,3} \\
56 * b_{3,1} & b_{3,2} & b_{3,3} \\
57 * \end{pmatrix}
58 * =
59 * \begin{pmatrix}
60 * a_{1,1} b_{1,1}+a_{1,2} b_{2,1}+a_{1,3} b_{3,1} & a_{1,1} b_{1,2}+a_{1,2} b_{2,2}+a_{1,3} b_{3,2} & a_{1,1} b_{1,3}+a_{1,2} b_{2,3}+a_{1,3} b_{3,3} \\
61 * a_{2,1} b_{1,1}+a_{2,2} b_{2,1}+a_{2,3} b_{3,1} & a_{2,1} b_{1,2}+a_{2,2} b_{2,2}+a_{2,3} b_{3,2} & a_{2,1} b_{1,3}+a_{2,2} b_{2,3}+a_{2,3} b_{3,3} \\
62 * a_{3,1} b_{1,1}+a_{3,2} b_{2,1}+a_{3,3} b_{3,1} & a_{3,1} b_{1,2}+a_{3,2} b_{2,2}+a_{3,3} b_{3,2} & a_{3,1} b_{1,3}+a_{3,2} b_{2,3}+a_{3,3} b_{3,3} \\
63 * \end{pmatrix}
64 * \f]
65
66 * Matrix multiplication is only defined if the number of columns of the
67 * first matrix equals the number of rows of the second matrix.
68 * Multiplying an <code>M x N</code> matrix with an <code>N x P</code> matrix results
69 * in an <code>M x P</code> matrix.
70 * When matrix size checking is enabled, the functions check: (1) that the inner dimensions of
71 * <code>pSrcA</code> and <code>pSrcB</code> are equal; and (2) that the size of the output
72 * matrix equals the outer dimensions of <code>pSrcA</code> and <code>pSrcB</code>.
73 */
74
75
76 /**
77 * @addtogroup MatrixMult
78 * @{
79 */
80
81
82
83 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
84
85 #define MATRIX_DIM3 3
86 #define MATRIX_DIM4 4
87
arm_mat_mult_f32_2x2_mve(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)88 __STATIC_INLINE arm_status arm_mat_mult_f32_2x2_mve(
89 const arm_matrix_instance_f32 *pSrcA,
90 const arm_matrix_instance_f32 *pSrcB,
91 arm_matrix_instance_f32 *pDst)
92 {
93 /* {a00, a00, a10, a10} */
94 static const uint32_t offsetA0[4] = { 0, 0, 2, 2 };
95 /* {b00, b01, b00, b01} */
96 static const uint32_t offsetB0[4] = { 0, 1, 0, 1 };
97 /* {a01, a01, a11, a11} */
98 static const uint32_t offsetA1[4] = { 1, 1, 3, 3 };
99 /* {b10, b11, b10, b11} */
100 static const uint32_t offsetB1[4] = { 2, 3, 2, 3 };
101
102 uint32x4_t vecOffsA, vecOffsB;
103 f32x4_t vecInA, vecInB, vecDst;
104
105 vecOffsA = vldrwq_u32((uint32_t const *) offsetA0);
106 vecOffsB = vldrwq_u32((uint32_t const *) offsetB0);
107
108 vecInA = vldrwq_gather_shifted_offset((float32_t const *) pSrcA->pData, vecOffsA);
109 vecInB = vldrwq_gather_shifted_offset((float32_t const *) pSrcB->pData, vecOffsB);
110
111 vecDst = vmulq(vecInA, vecInB);
112
113 vecOffsA = vldrwq_u32((uint32_t const *) offsetA1);
114 vecOffsB = vldrwq_u32((uint32_t const *) offsetB1);
115
116 vecInA = vldrwq_gather_shifted_offset((float32_t const *) pSrcA->pData, vecOffsA);
117 vecInB = vldrwq_gather_shifted_offset((float32_t const *) pSrcB->pData, vecOffsB);
118
119 vecDst = vfmaq(vecDst, vecInA, vecInB);
120
121 vstrwq_f32(pDst->pData, vecDst);
122
123 return (ARM_MATH_SUCCESS);
124
125 }
126
127
128 /*
129 * A = {{a00, a01, a02},
130 * {a10, a11, a12},
131 * {a20, a21, a22}}
132 * B = {{b00, b01, b02},
133 * {b10, b11, b12},
134 * {b20, b21, b22}}
135 *
136 * Dst = {{a00 b00 + a01 b10 + a02 b20, a00 b01 + a01 b11 + a02 b21, a00 b02 + a01 b12 + a02 b22},
137 * {a10 b00 + a11 b10 + a12 b20, a10 b01 + a11 b11 + a12 b21, a10 b02 + a11 b12 + a12 b22},
138 * {a20 b00 + a21 b10 + a22 b20, a20 b01 + a21 b11 + a22 b21, a20 b02 + a21 b12 + a22 b22}}
139 */
arm_mat_mult_f32_3x3_mve(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)140 __STATIC_INLINE arm_status arm_mat_mult_f32_3x3_mve(
141 const arm_matrix_instance_f32 *pSrcA,
142 const arm_matrix_instance_f32 *pSrcB,
143 arm_matrix_instance_f32 *pDst)
144 {
145 float32_t *pInB = pSrcB->pData; /* input data matrix pointer B */
146 float32_t *pInA = pSrcA->pData; /* input data matrix pointer A */
147 float32_t *pOut = pDst->pData; /* output data matrix pointer */
148 float32_t *pInA0, *pInA1, *pInA2;
149 f32x4_t vecMac0, vecMac1, vecMac2;
150 f32x4_t vecInB;
151 float32_t const *pSrBVec;
152
153 pSrBVec = (float32_t const *) pInB;
154
155 pInA0 = pInA;
156 pInA1 = pInA0 + MATRIX_DIM3;
157 pInA2 = pInA1 + MATRIX_DIM3;
158 /* enable predication to disable last (4th) vector element */
159 mve_pred16_t p0 = vctp32q(MATRIX_DIM3);
160
161 /*
162 * load {b0,0, b0,1, b0,2, 0}
163 */
164 vecInB = vldrwq_z_f32(pSrBVec, p0);
165 pSrBVec += MATRIX_DIM3;
166
167 vecMac0 = vmulq(vecInB, *pInA0++);
168 vecMac1 = vmulq(vecInB, *pInA1++);
169 vecMac2 = vmulq(vecInB, *pInA2++);
170 /*
171 * load {b1,0, b1,1, b1,2, 0}
172 */
173 vecInB = vldrwq_z_f32(pSrBVec, p0);
174 pSrBVec += MATRIX_DIM3;
175
176 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
177 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
178 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
179 /*
180 * load {b2,0, b2,1 , b2,2, 0}
181 */
182 vecInB = vldrwq_z_f32(pSrBVec, p0);
183 pSrBVec += MATRIX_DIM3;
184
185 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
186 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
187 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
188
189 /* partial vector stores */
190 vstrwq_p_f32(pOut, vecMac0, p0);
191 pOut += MATRIX_DIM3;
192 vstrwq_p_f32(pOut, vecMac1, p0);
193 pOut += MATRIX_DIM3;
194 vstrwq_p_f32(pOut, vecMac2, p0);
195 /*
196 * Return to application
197 */
198 return (ARM_MATH_SUCCESS);
199 }
200
201
202
203
arm_mat_mult_f32_4x4_mve(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)204 __STATIC_INLINE arm_status arm_mat_mult_f32_4x4_mve(
205 const arm_matrix_instance_f32 *pSrcA,
206 const arm_matrix_instance_f32 *pSrcB,
207 arm_matrix_instance_f32 *pDst)
208 {
209 float32_t const *pSrBVec;
210 float32_t *pInB = pSrcB->pData; /* input data matrix pointer B */
211 float32_t *pInA = pSrcA->pData; /* input data matrix pointer A */
212 float32_t *pOut = pDst->pData; /* output data matrix pointer */
213 float32_t *pInA0, *pInA1, *pInA2, *pInA3;
214 f32x4_t vecMac0, vecMac1, vecMac2, vecMac3;
215 f32x4_t vecInB;
216
217 pSrBVec = (float32_t const *) pInB;
218
219 pInA0 = pInA;
220 pInA1 = pInA0 + MATRIX_DIM4;
221 pInA2 = pInA1 + MATRIX_DIM4;
222 pInA3 = pInA2 + MATRIX_DIM4;
223 /*
224 * load {b0,0, b0,1, b0,2, b0,3}
225 */
226 vecInB = vld1q(pSrBVec);
227 pSrBVec += MATRIX_DIM4;
228
229 vecMac0 = vmulq(vecInB, *pInA0++);
230 vecMac1 = vmulq(vecInB, *pInA1++);
231 vecMac2 = vmulq(vecInB, *pInA2++);
232 vecMac3 = vmulq(vecInB, *pInA3++);
233 /*
234 * load {b1,0, b1,1, b1,2, b1,3}
235 */
236 vecInB = vld1q(pSrBVec);
237 pSrBVec += MATRIX_DIM4;
238
239 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
240 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
241 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
242 vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
243 /*
244 * load {b2,0, b2,1, b2,2, b2,3}
245 */
246 vecInB = vld1q(pSrBVec);
247 pSrBVec += MATRIX_DIM4;
248
249 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
250 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
251 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
252 vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
253 /*
254 * load {b3,0, b3,1, b3,2, b3,3}
255 */
256 vecInB = vld1q(pSrBVec);
257 pSrBVec += MATRIX_DIM4;
258
259 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
260 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
261 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
262 vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
263
264 vst1q(pOut, vecMac0);
265 pOut += MATRIX_DIM4;
266 vst1q(pOut, vecMac1);
267 pOut += MATRIX_DIM4;
268 vst1q(pOut, vecMac2);
269 pOut += MATRIX_DIM4;
270 vst1q(pOut, vecMac3);
271 /*
272 * Return to application
273 */
274 return (ARM_MATH_SUCCESS);
275 }
276
277
278 /**
279 * @brief Floating-point matrix multiplication.
280 * @param[in] *pSrcA points to the first input matrix structure
281 * @param[in] *pSrcB points to the second input matrix structure
282 * @param[out] *pDst points to output matrix structure
283 * @return The function returns either
284 * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
285 */
arm_mat_mult_f32(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)286 arm_status arm_mat_mult_f32(
287 const arm_matrix_instance_f32 * pSrcA,
288 const arm_matrix_instance_f32 * pSrcB,
289 arm_matrix_instance_f32 * pDst)
290 {
291 float32_t *pInB = pSrcB->pData; /* input data matrix pointer B */
292 float32_t *pInA = pSrcA->pData; /* input data matrix pointer A */
293 float32_t *pOut = pDst->pData; /* output data matrix pointer */
294 int numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
295 int numColsB = pSrcB->numCols; /* number of columns of input matrix B */
296 int numColsA = pSrcA->numCols; /* number of columns of input matrix A */
297 uint32_t blkCnt; /* loop counters */
298 uint32_t i;
299 arm_status status;
300
301 #ifdef ARM_MATH_MATRIX_CHECK
302
303 /* Check for matrix mismatch condition */
304 if ((pSrcA->numCols != pSrcB->numRows) ||
305 (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
306 {
307 /* Set status as ARM_MATH_SIZE_MISMATCH */
308 status = ARM_MATH_SIZE_MISMATCH;
309 }
310 else
311 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
312 {
313 /* small squared matrix specialized routines */
314 if(numRowsA == numColsB && numColsB == numColsA) {
315 if (numRowsA == 1)
316 {
317 pOut[0] = pInA[0] * pInB[0];
318 return(ARM_MATH_SUCCESS);
319 }
320 else if(numRowsA == 2)
321 return arm_mat_mult_f32_2x2_mve(pSrcA, pSrcB, pDst);
322 else if(numRowsA == 3)
323 return arm_mat_mult_f32_3x3_mve(pSrcA, pSrcB, pDst);
324 else if(numRowsA == 4)
325 return arm_mat_mult_f32_4x4_mve(pSrcA, pSrcB, pDst);
326 }
327
328 /* main loop process 4 rows */
329 i = numRowsA >> 2;
330 while (i > 0U)
331 {
332 float32_t *pInA0, *pInA1, *pInA2, *pInA3;
333 float32_t *pInB0;
334 float32_t *pOut0, *pOut1, *pOut2, *pOut3;
335 f32x4_t vecMac0, vecMac1, vecMac2, vecMac3;
336 f32x4_t vecInB;
337
338 /* pointers to 4 consecutive output rows */
339 pOut0 = pOut;
340 pOut1 = pOut0 + numColsB;
341 pOut2 = pOut1 + numColsB;
342 pOut3 = pOut2 + numColsB;
343 pInB0 = pInB;
344
345 uint32_t k = numColsB >> 2;
346 while (k > 0U)
347 {
348 /* pointers to 4 consecutive Matrix A rows */
349 pInA0 = pInA;
350 pInA1 = pInA0 + numColsA;
351 pInA2 = pInA1 + numColsA;
352 pInA3 = pInA2 + numColsA;
353
354 vecMac0 = vdupq_n_f32(0.0f);
355 vecMac1 = vdupq_n_f32(0.0f);
356 vecMac2 = vdupq_n_f32(0.0f);
357 vecMac3 = vdupq_n_f32(0.0f);
358
359 blkCnt = numColsA;
360
361 while (blkCnt > 0U)
362 {
363 /*
364 * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
365 */
366 vecInB = *(f32x4_t *)pInB0; /* vldrwq_f32(pInB0, 0); */
367
368 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
369 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
370 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
371 vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
372
373 pInB0 = pInB0 + numColsB;
374 /*
375 * Decrement the blockSize loop counter
376 */
377 blkCnt--;
378 }
379
380 /* Store the results (4 x 4 block) in the destination buffer */
381 vst1q(pOut0, vecMac0);
382 pOut0 += 4;
383 vst1q(pOut1, vecMac1);
384 pOut1 += 4;
385 vst1q(pOut2, vecMac2);
386 pOut2 += 4;
387 vst1q(pOut3, vecMac3);
388 pOut3 += 4;
389
390 /*
391 * rewind
392 */
393 pInB0 -= (numColsB * numColsA) - 4;
394 k--;
395 }
396
397 int colBLeft = numColsB & 3;
398 if (colBLeft)
399 {
400 pInA0 = pInA;
401 pInA1 = pInA0 + numColsA;
402 pInA2 = pInA1 + numColsA;
403 pInA3 = pInA2 + numColsA;
404 mve_pred16_t p0 = vctp32q(colBLeft);
405
406 vecMac0 = vdupq_n_f32(0.0f);
407 vecMac1 = vdupq_n_f32(0.0f);
408 vecMac2 = vdupq_n_f32(0.0f);
409 vecMac3 = vdupq_n_f32(0.0f);
410
411 blkCnt = numColsA;
412
413 while (blkCnt > 0U)
414 {
415 /*
416 * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
417 */
418 vecInB = vldrwq_z_f32(pInB0, p0);
419
420 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
421 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
422 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
423 vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
424
425 pInB0 = pInB0 + numColsB;
426 /*
427 * Decrement the blockSize loop counter
428 */
429 blkCnt--;
430 }
431
432 /* Store the results (4 x colBLeft block) in the destination buffer */
433 vstrwq_p_f32(pOut0, vecMac0, p0);
434 vstrwq_p_f32(pOut1, vecMac1, p0);
435 vstrwq_p_f32(pOut2, vecMac2, p0);
436 vstrwq_p_f32(pOut3, vecMac3, p0);
437 }
438
439 /* move to next rows */
440 pInA += 4 * numColsA;
441 pOut += 4 * numColsB;
442 i--;
443 }
444
445 /*
446 * non multiple of 4 rows for Matrix A
447 * process single row
448 */
449 if (numRowsA & 3)
450 {
451 i = numRowsA & 3;
452 while (i > 0U)
453 {
454 float32_t *pInA0;
455 float32_t *pInB0;
456 float32_t *pOut0;
457 f32x4_t vecInB;
458 f32x4_t vecMac0;
459
460 pOut0 = pOut;
461 pInB0 = pInB;
462
463 uint32_t k = numColsB >> 2;
464 while (k > 0U)
465 {
466 pInA0 = pInA;
467
468 vecMac0 = vdupq_n_f32(0.0f);
469 blkCnt = numColsA;
470 while (blkCnt > 0U)
471 {
472 /*
473 * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
474 */
475 vecInB = *(f32x4_t *)pInB0; /* vldrwq_f32(pInB0, 0); */
476
477 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
478
479 pInB0 = pInB0 + numColsB;
480 /*
481 * Decrement the blockSize loop counter
482 */
483 blkCnt--;
484 }
485
486 /* Store the results (1 x 4 block) in the destination buffer */
487 vst1q(pOut0, vecMac0);
488 pOut0 += 4;
489
490 /*
491 * rewind
492 */
493 pInB0 -= (numColsB * numColsA) - 4;
494 k--;
495 }
496
497 int colBLeft = numColsB & 3;
498 if (colBLeft)
499 {
500 pInA0 = pInA;
501 mve_pred16_t p0 = vctp32q(colBLeft);
502
503 vecMac0 = vdupq_n_f32(0.0f);
504 blkCnt = numColsA;
505 while (blkCnt > 0U)
506 {
507 /*
508 * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
509 */
510 vecInB = vldrwq_z_f32(pInB0, p0);
511
512 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
513
514 pInB0 = pInB0 + numColsB;
515 /*
516 * Decrement the blockSize loop counter
517 */
518 blkCnt--;
519 }
520 /* Store the results (1 x colBLeft block) in the destination buffer */
521 vstrwq_p_f32(pOut0, vecMac0, p0);
522 }
523
524 /* move to next row */
525 pInA += 1 * numColsA;
526 pOut += 1 * numColsB;
527 i--;
528 }
529
530 }
531 status = ARM_MATH_SUCCESS;
532 }
533
534 /* Return to application */
535 return (status);
536 }
537 #else
538
539 #if defined(ARM_MATH_NEON)
540 /**
541 * @brief Floating-point matrix multiplication.
542 * @param[in] *pSrcA points to the first input matrix structure
543 * @param[in] *pSrcB points to the second input matrix structure
544 * @param[out] *pDst points to output matrix structure
545 * @return The function returns either
546 * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
547 */
arm_mat_mult_f32(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)548 arm_status arm_mat_mult_f32(
549 const arm_matrix_instance_f32 * pSrcA,
550 const arm_matrix_instance_f32 * pSrcB,
551 arm_matrix_instance_f32 * pDst)
552 {
553 float32_t *pIn1 = pSrcA->pData; /* input data matrix pointer A */
554 float32_t *pIn2 = pSrcB->pData; /* input data matrix pointer B */
555 float32_t *pInA = pSrcA->pData; /* input data matrix pointer A */
556 float32_t *pOut = pDst->pData; /* output data matrix pointer */
557 float32_t *px; /* Temporary output data matrix pointer */
558 float32_t sum; /* Accumulator */
559 uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
560 uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
561 uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
562
563
564 uint32_t col, i = 0U, j, row = numRowsA, rowCnt, colCnt; /* loop counters */
565 arm_status status; /* status of matrix multiplication */
566
567 float32x4_t a0V, a1V, a2V, a3V, a4V, a5V, a6V, a7V;
568 float32x4_t acc0,acc1,acc2,acc3,acc4,acc5,acc6,acc7,temp;
569 float32x2_t accum = vdup_n_f32(0);
570 float32_t *pIn1B = pSrcA->pData;
571 float32_t *pIn1C = pSrcA->pData;
572 float32_t *pIn1D = pSrcA->pData;
573 float32_t *pIn1E = pSrcA->pData;
574 float32_t *pIn1F = pSrcA->pData;
575 float32_t *pIn1G = pSrcA->pData;
576 float32_t *pIn1H = pSrcA->pData;
577
578 float32_t *pxB,*pxC, *pxD, *pxE, *pxF, *pxG, *pxH; /* Temporary output data matrix pointer */
579 float32_t sum0,sum1, sum2,sum3, sum4, sum5 , sum6, sum7;
580
581 #ifdef ARM_MATH_MATRIX_CHECK
582
583 /* Check for matrix mismatch condition */
584 if ((pSrcA->numCols != pSrcB->numRows) ||
585 (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
586 {
587 /* Set status as ARM_MATH_SIZE_MISMATCH */
588 status = ARM_MATH_SIZE_MISMATCH;
589 }
590 else
591 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
592 {
593 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
594 /* Row loop */
595 rowCnt = row >> 3;
596
597 while(rowCnt > 0)
598 {
599 /* Output pointer is set to starting address of the row being processed */
600 px = pOut + GROUPOFROWS*i;
601 pxB = px + numColsB;
602 pxC = px + 2*numColsB;
603 pxD = px + 3*numColsB;
604 pxE = px + 4*numColsB;
605 pxF = px + 5*numColsB;
606 pxG = px + 6*numColsB;
607 pxH = px + 7*numColsB;
608
609 /* For every row wise process, the column loop counter is to be initiated */
610 col = numColsB;
611
612 /* For every row wise process, the pIn2 pointer is set
613 ** to the starting address of the pSrcB data */
614 pIn2 = pSrcB->pData;
615
616 j = 0U;
617
618 /* Column loop */
619 do
620 {
621 /* Set the variable sum, that acts as accumulator, to zero */
622 sum0 = 0.0f;
623 sum1 = 0.0f;
624 sum2 = 0.0f;
625 sum3 = 0.0f;
626 sum4 = 0.0f;
627 sum5 = 0.0f;
628 sum6 = 0.0f;
629 sum7 = 0.0f;
630
631 /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
632 pIn1 = pInA;
633 pIn1B = pIn1 + numColsA;
634 pIn1C = pIn1 + 2*numColsA;
635 pIn1D = pIn1 + 3*numColsA;
636 pIn1E = pIn1 + 4*numColsA;
637 pIn1F = pIn1 + 5*numColsA;
638 pIn1G = pIn1 + 6*numColsA;
639 pIn1H = pIn1 + 7*numColsA;
640
641 acc0 = vdupq_n_f32(0.0);
642 acc1 = vdupq_n_f32(0.0);
643 acc2 = vdupq_n_f32(0.0);
644 acc3 = vdupq_n_f32(0.0);
645 acc4 = vdupq_n_f32(0.0);
646 acc5 = vdupq_n_f32(0.0);
647 acc6 = vdupq_n_f32(0.0);
648 acc7 = vdupq_n_f32(0.0);
649
650 /* Compute 4 MACs simultaneously. */
651 colCnt = numColsA >> 2U;
652
653 /* Matrix multiplication */
654 while (colCnt > 0U)
655 {
656 /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
657 a0V = vld1q_f32(pIn1);
658 a1V = vld1q_f32(pIn1B);
659 a2V = vld1q_f32(pIn1C);
660 a3V = vld1q_f32(pIn1D);
661 a4V = vld1q_f32(pIn1E);
662 a5V = vld1q_f32(pIn1F);
663 a6V = vld1q_f32(pIn1G);
664 a7V = vld1q_f32(pIn1H);
665
666 pIn1 += 4;
667 pIn1B += 4;
668 pIn1C += 4;
669 pIn1D += 4;
670 pIn1E += 4;
671 pIn1F += 4;
672 pIn1G += 4;
673 pIn1H += 4;
674
675 temp = vsetq_lane_f32(*pIn2,temp,0);
676 pIn2 += numColsB;
677 temp = vsetq_lane_f32(*pIn2,temp,1);
678 pIn2 += numColsB;
679 temp = vsetq_lane_f32(*pIn2,temp,2);
680 pIn2 += numColsB;
681 temp = vsetq_lane_f32(*pIn2,temp,3);
682 pIn2 += numColsB;
683
684 acc0 = vmlaq_f32(acc0,a0V,temp);
685 acc1 = vmlaq_f32(acc1,a1V,temp);
686 acc2 = vmlaq_f32(acc2,a2V,temp);
687 acc3 = vmlaq_f32(acc3,a3V,temp);
688 acc4 = vmlaq_f32(acc4,a4V,temp);
689 acc5 = vmlaq_f32(acc5,a5V,temp);
690 acc6 = vmlaq_f32(acc6,a6V,temp);
691 acc7 = vmlaq_f32(acc7,a7V,temp);
692
693 /* Decrement the loop count */
694 colCnt--;
695 }
696
697 accum = vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
698 sum0 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
699
700 accum = vpadd_f32(vget_low_f32(acc1), vget_high_f32(acc1));
701 sum1 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
702
703 accum = vpadd_f32(vget_low_f32(acc2), vget_high_f32(acc2));
704 sum2 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
705
706 accum = vpadd_f32(vget_low_f32(acc3), vget_high_f32(acc3));
707 sum3 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
708
709 accum = vpadd_f32(vget_low_f32(acc4), vget_high_f32(acc4));
710 sum4 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
711
712 accum = vpadd_f32(vget_low_f32(acc5), vget_high_f32(acc5));
713 sum5 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
714
715 accum = vpadd_f32(vget_low_f32(acc6), vget_high_f32(acc6));
716 sum6 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
717
718 accum = vpadd_f32(vget_low_f32(acc7), vget_high_f32(acc7));
719 sum7 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
720
721 /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
722 ** No loop unrolling is used. */
723 colCnt = numColsA & 3;
724
725 while (colCnt > 0U)
726 {
727 /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
728 sum0 += *pIn1++ * (*pIn2);
729 sum1 += *pIn1B++ * (*pIn2);
730 sum2 += *pIn1C++ * (*pIn2);
731 sum3 += *pIn1D++ * (*pIn2);
732 sum4 += *pIn1E++ * (*pIn2);
733 sum5 += *pIn1F++ * (*pIn2);
734 sum6 += *pIn1G++ * (*pIn2);
735 sum7 += *pIn1H++ * (*pIn2);
736 pIn2 += numColsB;
737
738 /* Decrement the loop counter */
739 colCnt--;
740 }
741
742 /* Store the result in the destination buffer */
743 *px++ = sum0;
744 *pxB++ = sum1;
745 *pxC++ = sum2;
746 *pxD++ = sum3;
747 *pxE++ = sum4;
748 *pxF++ = sum5;
749 *pxG++ = sum6;
750 *pxH++ = sum7;
751
752 /* Update the pointer pIn2 to point to the starting address of the next column */
753 j++;
754 pIn2 = pSrcB->pData + j;
755
756 /* Decrement the column loop counter */
757 col--;
758
759 } while (col > 0U);
760
761 /* Update the pointer pInA to point to the starting address of the next row */
762 i = i + numColsB;
763 pInA = pInA + GROUPOFROWS*numColsA;
764
765 /* Decrement the row loop counter */
766 rowCnt--;
767 }
768
769 /*
770
771 i was the index of a group of rows computed by previous loop.
772 Now i is the index of a row since below code is computing row per row
773 and no more group of row per group of rows.
774
775 */
776
777 i = GROUPOFROWS*i;
778 rowCnt = row & 7;
779
780 while(rowCnt > 0)
781 {
782 /* Output pointer is set to starting address of the row being processed */
783 px = pOut + i;
784
785 /* For every row wise process, the column loop counter is to be initiated */
786 col = numColsB;
787
788 /* For every row wise process, the pIn2 pointer is set
789 ** to the starting address of the pSrcB data */
790 pIn2 = pSrcB->pData;
791
792 j = 0U;
793
794 /* Column loop */
795 do
796 {
797 /* Set the variable sum, that acts as accumulator, to zero */
798 sum = 0.0f;
799
800 /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
801 pIn1 = pInA;
802
803 acc0 = vdupq_n_f32(0.0);
804
805 /* Compute 4 MACs simultaneously. */
806 colCnt = numColsA >> 2U;
807
808 /* Matrix multiplication */
809 while (colCnt > 0U)
810 {
811 /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
812 a0V = vld1q_f32(pIn1); // load & separate real/imag pSrcA (de-interleave 2)
813 pIn1 += 4;
814
815 temp = vsetq_lane_f32(*pIn2,temp,0);
816 pIn2 += numColsB;
817 temp = vsetq_lane_f32(*pIn2,temp,1);
818 pIn2 += numColsB;
819 temp = vsetq_lane_f32(*pIn2,temp,2);
820 pIn2 += numColsB;
821 temp = vsetq_lane_f32(*pIn2,temp,3);
822 pIn2 += numColsB;
823
824 acc0 = vmlaq_f32(acc0,a0V,temp);
825
826 /* Decrement the loop count */
827 colCnt--;
828 }
829
830 accum = vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
831 sum += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
832
833 /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
834 ** No loop unrolling is used. */
835 colCnt = numColsA % 0x4U;
836
837 while (colCnt > 0U)
838 {
839 /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
840 sum += *pIn1++ * (*pIn2);
841 pIn2 += numColsB;
842
843 /* Decrement the loop counter */
844 colCnt--;
845 }
846
847 /* Store the result in the destination buffer */
848 *px++ = sum;
849
850 /* Update the pointer pIn2 to point to the starting address of the next column */
851 j++;
852 pIn2 = pSrcB->pData + j;
853
854 /* Decrement the column loop counter */
855 col--;
856
857 } while (col > 0U);
858
859
860 /* Update the pointer pInA to point to the starting address of the next row */
861 i = i + numColsB;
862 pInA = pInA + numColsA;
863
864 /* Decrement the row loop counter */
865 rowCnt--;
866
867 }
868 /* Set status as ARM_MATH_SUCCESS */
869 status = ARM_MATH_SUCCESS;
870 }
871
872 /* Return to application */
873 return (status);
874 }
875 #else
876 /**
877 * @brief Floating-point matrix multiplication.
878 * @param[in] *pSrcA points to the first input matrix structure
879 * @param[in] *pSrcB points to the second input matrix structure
880 * @param[out] *pDst points to output matrix structure
881 * @return The function returns either
882 * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
883 */
arm_mat_mult_f32(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)884 arm_status arm_mat_mult_f32(
885 const arm_matrix_instance_f32 * pSrcA,
886 const arm_matrix_instance_f32 * pSrcB,
887 arm_matrix_instance_f32 * pDst)
888 {
889 float32_t *pIn1 = pSrcA->pData; /* Input data matrix pointer A */
890 float32_t *pIn2 = pSrcB->pData; /* Input data matrix pointer B */
891 float32_t *pInA = pSrcA->pData; /* Input data matrix pointer A */
892 float32_t *pInB = pSrcB->pData; /* Input data matrix pointer B */
893 float32_t *pOut = pDst->pData; /* Output data matrix pointer */
894 float32_t *px; /* Temporary output data matrix pointer */
895 float32_t sum; /* Accumulator */
896 uint16_t numRowsA = pSrcA->numRows; /* Number of rows of input matrix A */
897 uint16_t numColsB = pSrcB->numCols; /* Number of columns of input matrix B */
898 uint16_t numColsA = pSrcA->numCols; /* Number of columns of input matrix A */
899 uint32_t col, i = 0U, row = numRowsA, colCnt; /* Loop counters */
900 arm_status status; /* Status of matrix multiplication */
901
902 #ifdef ARM_MATH_MATRIX_CHECK
903
904 /* Check for matrix mismatch condition */
905 if ((pSrcA->numCols != pSrcB->numRows) ||
906 (pSrcA->numRows != pDst->numRows) ||
907 (pSrcB->numCols != pDst->numCols) )
908 {
909 /* Set status as ARM_MATH_SIZE_MISMATCH */
910 status = ARM_MATH_SIZE_MISMATCH;
911 }
912 else
913
914 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
915
916 {
917 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
918 /* row loop */
919 do
920 {
921 /* Output pointer is set to starting address of row being processed */
922 px = pOut + i;
923
924 /* For every row wise process, column loop counter is to be initiated */
925 col = numColsB;
926
927 /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
928 pIn2 = pSrcB->pData;
929
930 /* column loop */
931 do
932 {
933 /* Set the variable sum, that acts as accumulator, to zero */
934 sum = 0.0f;
935
936 /* Initialize pointer pIn1 to point to starting address of column being processed */
937 pIn1 = pInA;
938
939 #if defined (ARM_MATH_LOOPUNROLL)
940
941 /* Loop unrolling: Compute 4 MACs at a time. */
942 colCnt = numColsA >> 2U;
943
944 /* matrix multiplication */
945 while (colCnt > 0U)
946 {
947 /* c(m,p) = a(m,1) * b(1,p) + a(m,2) * b(2,p) + .... + a(m,n) * b(n,p) */
948
949 /* Perform the multiply-accumulates */
950 sum += *pIn1++ * *pIn2;
951 pIn2 += numColsB;
952
953 sum += *pIn1++ * *pIn2;
954 pIn2 += numColsB;
955
956 sum += *pIn1++ * *pIn2;
957 pIn2 += numColsB;
958
959 sum += *pIn1++ * *pIn2;
960 pIn2 += numColsB;
961
962 /* Decrement loop counter */
963 colCnt--;
964 }
965
966 /* Loop unrolling: Compute remaining MACs */
967 colCnt = numColsA % 0x4U;
968
969 #else
970
971 /* Initialize cntCnt with number of columns */
972 colCnt = numColsA;
973
974 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
975
976 while (colCnt > 0U)
977 {
978 /* c(m,p) = a(m,1) * b(1,p) + a(m,2) * b(2,p) + .... + a(m,n) * b(n,p) */
979
980 /* Perform the multiply-accumulates */
981 sum += *pIn1++ * *pIn2;
982 pIn2 += numColsB;
983
984 /* Decrement loop counter */
985 colCnt--;
986 }
987
988 /* Store result in destination buffer */
989 *px++ = sum;
990
991 /* Decrement column loop counter */
992 col--;
993
994 /* Update pointer pIn2 to point to starting address of next column */
995 pIn2 = pInB + (numColsB - col);
996
997 } while (col > 0U);
998
999 /* Update pointer pInA to point to starting address of next row */
1000 i = i + numColsB;
1001 pInA = pInA + numColsA;
1002
1003 /* Decrement row loop counter */
1004 row--;
1005
1006 } while (row > 0U);
1007
1008 /* Set status as ARM_MATH_SUCCESS */
1009 status = ARM_MATH_SUCCESS;
1010 }
1011
1012 /* Return to application */
1013 return (status);
1014 }
1015
1016 #endif /* #if defined(ARM_MATH_NEON) */
1017 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
1018
1019 /**
1020 * @} end of MatrixMult group
1021 */
1022