1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_mat_mult_f32.c
4 * Description: Floating-point matrix multiplication
5 *
6 * $Date: 23 April 2021
7 * $Revision: V1.9.0
8 *
9 * Target Processor: Cortex-M and Cortex-A cores
10 * -------------------------------------------------------------------- */
11 /*
12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13 *
14 * SPDX-License-Identifier: Apache-2.0
15 *
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
19 *
20 * www.apache.org/licenses/LICENSE-2.0
21 *
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
27 */
28
29 #include "dsp/matrix_functions.h"
30
31 /**
32 * @ingroup groupMatrix
33 */
34
35 /**
36 * @defgroup MatrixMult Matrix Multiplication
37 *
38 * Multiplies two matrices.
39 *
40 * \image html MatrixMultiplication.gif "Multiplication of two 3 x 3 matrices"
41
42 * Matrix multiplication is only defined if the number of columns of the
43 * first matrix equals the number of rows of the second matrix.
44 * Multiplying an <code>M x N</code> matrix with an <code>N x P</code> matrix results
45 * in an <code>M x P</code> matrix.
46 * When matrix size checking is enabled, the functions check: (1) that the inner dimensions of
47 * <code>pSrcA</code> and <code>pSrcB</code> are equal; and (2) that the size of the output
48 * matrix equals the outer dimensions of <code>pSrcA</code> and <code>pSrcB</code>.
49 */
50
51
52 /**
53 * @addtogroup MatrixMult
54 * @{
55 */
56
57 /**
58 * @brief Floating-point matrix multiplication.
59 * @param[in] *pSrcA points to the first input matrix structure
60 * @param[in] *pSrcB points to the second input matrix structure
61 * @param[out] *pDst points to output matrix structure
62 * @return The function returns either
63 * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
64 */
65
66 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
67
68 #define MATRIX_DIM3 3
69 #define MATRIX_DIM4 4
70
arm_mat_mult_f32_2x2_mve(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)71 __STATIC_INLINE arm_status arm_mat_mult_f32_2x2_mve(
72 const arm_matrix_instance_f32 *pSrcA,
73 const arm_matrix_instance_f32 *pSrcB,
74 arm_matrix_instance_f32 *pDst)
75 {
76 /* {a00, a00, a10, a10} */
77 static const uint32_t offsetA0[4] = { 0, 0, 2, 2 };
78 /* {b00, b01, b00, b01} */
79 static const uint32_t offsetB0[4] = { 0, 1, 0, 1 };
80 /* {a01, a01, a11, a11} */
81 static const uint32_t offsetA1[4] = { 1, 1, 3, 3 };
82 /* {b10, b11, b10, b11} */
83 static const uint32_t offsetB1[4] = { 2, 3, 2, 3 };
84
85 uint32x4_t vecOffsA, vecOffsB;
86 f32x4_t vecInA, vecInB, vecDst;
87
88 vecOffsA = vldrwq_u32((uint32_t const *) offsetA0);
89 vecOffsB = vldrwq_u32((uint32_t const *) offsetB0);
90
91 vecInA = vldrwq_gather_shifted_offset((float32_t const *) pSrcA->pData, vecOffsA);
92 vecInB = vldrwq_gather_shifted_offset((float32_t const *) pSrcB->pData, vecOffsB);
93
94 vecDst = vmulq(vecInA, vecInB);
95
96 vecOffsA = vldrwq_u32((uint32_t const *) offsetA1);
97 vecOffsB = vldrwq_u32((uint32_t const *) offsetB1);
98
99 vecInA = vldrwq_gather_shifted_offset((float32_t const *) pSrcA->pData, vecOffsA);
100 vecInB = vldrwq_gather_shifted_offset((float32_t const *) pSrcB->pData, vecOffsB);
101
102 vecDst = vfmaq(vecDst, vecInA, vecInB);
103
104 vstrwq_f32(pDst->pData, vecDst);
105
106 return (ARM_MATH_SUCCESS);
107
108 }
109
110
111 /*
112 * A = {{a00, a01, a02},
113 * {a10, a11, a12},
114 * {a20, a21, a22}}
115 * B = {{b00, b01, b02},
116 * {b10, b11, b12},
117 * {b20, b21, b22}}
118 *
119 * Dst = {{a00 b00 + a01 b10 + a02 b20, a00 b01 + a01 b11 + a02 b21, a00 b02 + a01 b12 + a02 b22},
120 * {a10 b00 + a11 b10 + a12 b20, a10 b01 + a11 b11 + a12 b21, a10 b02 + a11 b12 + a12 b22},
121 * {a20 b00 + a21 b10 + a22 b20, a20 b01 + a21 b11 + a22 b21, a20 b02 + a21 b12 + a22 b22}}
122 */
arm_mat_mult_f32_3x3_mve(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)123 __STATIC_INLINE arm_status arm_mat_mult_f32_3x3_mve(
124 const arm_matrix_instance_f32 *pSrcA,
125 const arm_matrix_instance_f32 *pSrcB,
126 arm_matrix_instance_f32 *pDst)
127 {
128 float32_t *pInB = pSrcB->pData; /* input data matrix pointer B */
129 float32_t *pInA = pSrcA->pData; /* input data matrix pointer A */
130 float32_t *pOut = pDst->pData; /* output data matrix pointer */
131 float32_t *pInA0, *pInA1, *pInA2;
132 f32x4_t vecMac0, vecMac1, vecMac2;
133 f32x4_t vecInB;
134 float32_t const *pSrBVec;
135
136 pSrBVec = (float32_t const *) pInB;
137
138 pInA0 = pInA;
139 pInA1 = pInA0 + MATRIX_DIM3;
140 pInA2 = pInA1 + MATRIX_DIM3;
141 /* enable predication to disable last (4th) vector element */
142 mve_pred16_t p0 = vctp32q(MATRIX_DIM3);
143
144 /*
145 * load {b0,0, b0,1, b0,2, 0}
146 */
147 vecInB = vldrwq_z_f32(pSrBVec, p0);
148 pSrBVec += MATRIX_DIM3;
149
150 vecMac0 = vmulq(vecInB, *pInA0++);
151 vecMac1 = vmulq(vecInB, *pInA1++);
152 vecMac2 = vmulq(vecInB, *pInA2++);
153 /*
154 * load {b1,0, b1,1, b1,2, 0}
155 */
156 vecInB = vldrwq_z_f32(pSrBVec, p0);
157 pSrBVec += MATRIX_DIM3;
158
159 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
160 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
161 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
162 /*
163 * load {b2,0, b2,1 , b2,2, 0}
164 */
165 vecInB = vldrwq_z_f32(pSrBVec, p0);
166 pSrBVec += MATRIX_DIM3;
167
168 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
169 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
170 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
171
172 /* partial vector stores */
173 vstrwq_p_f32(pOut, vecMac0, p0);
174 pOut += MATRIX_DIM3;
175 vstrwq_p_f32(pOut, vecMac1, p0);
176 pOut += MATRIX_DIM3;
177 vstrwq_p_f32(pOut, vecMac2, p0);
178 /*
179 * Return to application
180 */
181 return (ARM_MATH_SUCCESS);
182 }
183
184
185
186
arm_mat_mult_f32_4x4_mve(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)187 __STATIC_INLINE arm_status arm_mat_mult_f32_4x4_mve(
188 const arm_matrix_instance_f32 *pSrcA,
189 const arm_matrix_instance_f32 *pSrcB,
190 arm_matrix_instance_f32 *pDst)
191 {
192 float32_t const *pSrBVec;
193 float32_t *pInB = pSrcB->pData; /* input data matrix pointer B */
194 float32_t *pInA = pSrcA->pData; /* input data matrix pointer A */
195 float32_t *pOut = pDst->pData; /* output data matrix pointer */
196 float32_t *pInA0, *pInA1, *pInA2, *pInA3;
197 f32x4_t vecMac0, vecMac1, vecMac2, vecMac3;
198 f32x4_t vecInB;
199
200 pSrBVec = (float32_t const *) pInB;
201
202 pInA0 = pInA;
203 pInA1 = pInA0 + MATRIX_DIM4;
204 pInA2 = pInA1 + MATRIX_DIM4;
205 pInA3 = pInA2 + MATRIX_DIM4;
206 /*
207 * load {b0,0, b0,1, b0,2, b0,3}
208 */
209 vecInB = vld1q(pSrBVec);
210 pSrBVec += MATRIX_DIM4;
211
212 vecMac0 = vmulq(vecInB, *pInA0++);
213 vecMac1 = vmulq(vecInB, *pInA1++);
214 vecMac2 = vmulq(vecInB, *pInA2++);
215 vecMac3 = vmulq(vecInB, *pInA3++);
216 /*
217 * load {b1,0, b1,1, b1,2, b1,3}
218 */
219 vecInB = vld1q(pSrBVec);
220 pSrBVec += MATRIX_DIM4;
221
222 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
223 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
224 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
225 vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
226 /*
227 * load {b2,0, b2,1, b2,2, b2,3}
228 */
229 vecInB = vld1q(pSrBVec);
230 pSrBVec += MATRIX_DIM4;
231
232 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
233 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
234 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
235 vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
236 /*
237 * load {b3,0, b3,1, b3,2, b3,3}
238 */
239 vecInB = vld1q(pSrBVec);
240 pSrBVec += MATRIX_DIM4;
241
242 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
243 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
244 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
245 vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
246
247 vst1q(pOut, vecMac0);
248 pOut += MATRIX_DIM4;
249 vst1q(pOut, vecMac1);
250 pOut += MATRIX_DIM4;
251 vst1q(pOut, vecMac2);
252 pOut += MATRIX_DIM4;
253 vst1q(pOut, vecMac3);
254 /*
255 * Return to application
256 */
257 return (ARM_MATH_SUCCESS);
258 }
259
260
arm_mat_mult_f32(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)261 arm_status arm_mat_mult_f32(
262 const arm_matrix_instance_f32 * pSrcA,
263 const arm_matrix_instance_f32 * pSrcB,
264 arm_matrix_instance_f32 * pDst)
265 {
266 float32_t *pInB = pSrcB->pData; /* input data matrix pointer B */
267 float32_t *pInA = pSrcA->pData; /* input data matrix pointer A */
268 float32_t *pOut = pDst->pData; /* output data matrix pointer */
269 int numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
270 int numColsB = pSrcB->numCols; /* number of columns of input matrix B */
271 int numColsA = pSrcA->numCols; /* number of columns of input matrix A */
272 uint32_t blkCnt; /* loop counters */
273 uint32_t i;
274 arm_status status;
275
276 #ifdef ARM_MATH_MATRIX_CHECK
277
278 /* Check for matrix mismatch condition */
279 if ((pSrcA->numCols != pSrcB->numRows) ||
280 (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
281 {
282 /* Set status as ARM_MATH_SIZE_MISMATCH */
283 status = ARM_MATH_SIZE_MISMATCH;
284 }
285 else
286 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
287 {
288 /* small squared matrix specialized routines */
289 if(numRowsA == numColsB && numColsB == numColsA) {
290 if (numRowsA == 1)
291 {
292 pOut[0] = pInA[0] * pInB[0];
293 return(ARM_MATH_SUCCESS);
294 }
295 else if(numRowsA == 2)
296 return arm_mat_mult_f32_2x2_mve(pSrcA, pSrcB, pDst);
297 else if(numRowsA == 3)
298 return arm_mat_mult_f32_3x3_mve(pSrcA, pSrcB, pDst);
299 else if(numRowsA == 4)
300 return arm_mat_mult_f32_4x4_mve(pSrcA, pSrcB, pDst);
301 }
302
303 /* main loop process 4 rows */
304 i = numRowsA >> 2;
305 while (i > 0U)
306 {
307 float32_t *pInA0, *pInA1, *pInA2, *pInA3;
308 float32_t *pInB0;
309 float32_t *pOut0, *pOut1, *pOut2, *pOut3;
310 f32x4_t vecMac0, vecMac1, vecMac2, vecMac3;
311 f32x4_t vecInB;
312
313 /* pointers to 4 consecutive output rows */
314 pOut0 = pOut;
315 pOut1 = pOut0 + numColsB;
316 pOut2 = pOut1 + numColsB;
317 pOut3 = pOut2 + numColsB;
318 pInB0 = pInB;
319
320 uint32_t k = numColsB >> 2;
321 while (k > 0U)
322 {
323 /* pointers to 4 consecutive Matrix A rows */
324 pInA0 = pInA;
325 pInA1 = pInA0 + numColsA;
326 pInA2 = pInA1 + numColsA;
327 pInA3 = pInA2 + numColsA;
328
329 vecMac0 = vdupq_n_f32(0.0f);
330 vecMac1 = vdupq_n_f32(0.0f);
331 vecMac2 = vdupq_n_f32(0.0f);
332 vecMac3 = vdupq_n_f32(0.0f);
333
334 blkCnt = numColsA;
335
336 while (blkCnt > 0U)
337 {
338 /*
339 * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
340 */
341 vecInB = *(f32x4_t *)pInB0; /* vldrwq_f32(pInB0, 0); */
342
343 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
344 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
345 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
346 vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
347
348 pInB0 = pInB0 + numColsB;
349 /*
350 * Decrement the blockSize loop counter
351 */
352 blkCnt--;
353 }
354
355 /* Store the results (4 x 4 block) in the destination buffer */
356 vst1q(pOut0, vecMac0);
357 pOut0 += 4;
358 vst1q(pOut1, vecMac1);
359 pOut1 += 4;
360 vst1q(pOut2, vecMac2);
361 pOut2 += 4;
362 vst1q(pOut3, vecMac3);
363 pOut3 += 4;
364
365 /*
366 * rewind
367 */
368 pInB0 -= (numColsB * numColsA) - 4;
369 k--;
370 }
371
372 int colBLeft = numColsB & 3;
373 if (colBLeft)
374 {
375 pInA0 = pInA;
376 pInA1 = pInA0 + numColsA;
377 pInA2 = pInA1 + numColsA;
378 pInA3 = pInA2 + numColsA;
379 mve_pred16_t p0 = vctp32q(colBLeft);
380
381 vecMac0 = vdupq_n_f32(0.0f);
382 vecMac1 = vdupq_n_f32(0.0f);
383 vecMac2 = vdupq_n_f32(0.0f);
384 vecMac3 = vdupq_n_f32(0.0f);
385
386 blkCnt = numColsA;
387
388 while (blkCnt > 0U)
389 {
390 /*
391 * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
392 */
393 vecInB = vldrwq_z_f32(pInB0, p0);
394
395 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
396 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
397 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
398 vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
399
400 pInB0 = pInB0 + numColsB;
401 /*
402 * Decrement the blockSize loop counter
403 */
404 blkCnt--;
405 }
406
407 /* Store the results (4 x colBLeft block) in the destination buffer */
408 vstrwq_p_f32(pOut0, vecMac0, p0);
409 vstrwq_p_f32(pOut1, vecMac1, p0);
410 vstrwq_p_f32(pOut2, vecMac2, p0);
411 vstrwq_p_f32(pOut3, vecMac3, p0);
412 }
413
414 /* move to next rows */
415 pInA += 4 * numColsA;
416 pOut += 4 * numColsB;
417 i--;
418 }
419
420 /*
421 * non multiple of 4 rows for Matrix A
422 * process single row
423 */
424 if (numRowsA & 3)
425 {
426 i = numRowsA & 3;
427 while (i > 0U)
428 {
429 float32_t *pInA0;
430 float32_t *pInB0;
431 float32_t *pOut0;
432 f32x4_t vecInB;
433 f32x4_t vecMac0;
434
435 pOut0 = pOut;
436 pInB0 = pInB;
437
438 uint32_t k = numColsB >> 2;
439 while (k > 0U)
440 {
441 pInA0 = pInA;
442
443 vecMac0 = vdupq_n_f32(0.0f);
444 blkCnt = numColsA;
445 while (blkCnt > 0U)
446 {
447 /*
448 * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
449 */
450 vecInB = *(f32x4_t *)pInB0; /* vldrwq_f32(pInB0, 0); */
451
452 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
453
454 pInB0 = pInB0 + numColsB;
455 /*
456 * Decrement the blockSize loop counter
457 */
458 blkCnt--;
459 }
460
461 /* Store the results (1 x 4 block) in the destination buffer */
462 vst1q(pOut0, vecMac0);
463 pOut0 += 4;
464
465 /*
466 * rewind
467 */
468 pInB0 -= (numColsB * numColsA) - 4;
469 k--;
470 }
471
472 int colBLeft = numColsB & 3;
473 if (colBLeft)
474 {
475 pInA0 = pInA;
476 mve_pred16_t p0 = vctp32q(colBLeft);
477
478 vecMac0 = vdupq_n_f32(0.0f);
479 blkCnt = numColsA;
480 while (blkCnt > 0U)
481 {
482 /*
483 * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3}
484 */
485 vecInB = vldrwq_z_f32(pInB0, p0);
486
487 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
488
489 pInB0 = pInB0 + numColsB;
490 /*
491 * Decrement the blockSize loop counter
492 */
493 blkCnt--;
494 }
495 /* Store the results (1 x colBLeft block) in the destination buffer */
496 vstrwq_p_f32(pOut0, vecMac0, p0);
497 }
498
499 /* move to next row */
500 pInA += 1 * numColsA;
501 pOut += 1 * numColsB;
502 i--;
503 }
504
505 }
506 status = ARM_MATH_SUCCESS;
507 }
508
509 /* Return to application */
510 return (status);
511 }
512 #else
513
514 #if defined(ARM_MATH_NEON)
515
516 #define GROUPOFROWS 8
517
arm_mat_mult_f32(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)518 arm_status arm_mat_mult_f32(
519 const arm_matrix_instance_f32 * pSrcA,
520 const arm_matrix_instance_f32 * pSrcB,
521 arm_matrix_instance_f32 * pDst)
522 {
523 float32_t *pIn1 = pSrcA->pData; /* input data matrix pointer A */
524 float32_t *pIn2 = pSrcB->pData; /* input data matrix pointer B */
525 float32_t *pInA = pSrcA->pData; /* input data matrix pointer A */
526 float32_t *pOut = pDst->pData; /* output data matrix pointer */
527 float32_t *px; /* Temporary output data matrix pointer */
528 float32_t sum; /* Accumulator */
529 uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
530 uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
531 uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
532
533
534 uint16_t col, i = 0U, j, row = numRowsA, rowCnt, colCnt; /* loop counters */
535 arm_status status; /* status of matrix multiplication */
536
537 float32x4_t a0V, a1V, a2V, a3V, a4V, a5V, a6V, a7V;
538 float32x4_t acc0,acc1,acc2,acc3,acc4,acc5,acc6,acc7,temp;
539 float32x2_t accum = vdup_n_f32(0);
540 float32_t *pIn1B = pSrcA->pData;
541 float32_t *pIn1C = pSrcA->pData;
542 float32_t *pIn1D = pSrcA->pData;
543 float32_t *pIn1E = pSrcA->pData;
544 float32_t *pIn1F = pSrcA->pData;
545 float32_t *pIn1G = pSrcA->pData;
546 float32_t *pIn1H = pSrcA->pData;
547
548 float32_t *pxB,*pxC, *pxD, *pxE, *pxF, *pxG, *pxH; /* Temporary output data matrix pointer */
549 float32_t sum0,sum1, sum2,sum3, sum4, sum5 , sum6, sum7;
550
551 #ifdef ARM_MATH_MATRIX_CHECK
552
553 /* Check for matrix mismatch condition */
554 if ((pSrcA->numCols != pSrcB->numRows) ||
555 (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
556 {
557 /* Set status as ARM_MATH_SIZE_MISMATCH */
558 status = ARM_MATH_SIZE_MISMATCH;
559 }
560 else
561 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
562 {
563 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
564 /* Row loop */
565 rowCnt = row >> 3;
566
567 while(rowCnt > 0)
568 {
569 /* Output pointer is set to starting address of the row being processed */
570 px = pOut + GROUPOFROWS*i;
571 pxB = px + numColsB;
572 pxC = px + 2*numColsB;
573 pxD = px + 3*numColsB;
574 pxE = px + 4*numColsB;
575 pxF = px + 5*numColsB;
576 pxG = px + 6*numColsB;
577 pxH = px + 7*numColsB;
578
579 /* For every row wise process, the column loop counter is to be initiated */
580 col = numColsB;
581
582 /* For every row wise process, the pIn2 pointer is set
583 ** to the starting address of the pSrcB data */
584 pIn2 = pSrcB->pData;
585
586 j = 0U;
587
588 /* Column loop */
589 do
590 {
591 /* Set the variable sum, that acts as accumulator, to zero */
592 sum0 = 0.0f;
593 sum1 = 0.0f;
594 sum2 = 0.0f;
595 sum3 = 0.0f;
596 sum4 = 0.0f;
597 sum5 = 0.0f;
598 sum6 = 0.0f;
599 sum7 = 0.0f;
600
601 /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
602 pIn1 = pInA;
603 pIn1B = pIn1 + numColsA;
604 pIn1C = pIn1 + 2*numColsA;
605 pIn1D = pIn1 + 3*numColsA;
606 pIn1E = pIn1 + 4*numColsA;
607 pIn1F = pIn1 + 5*numColsA;
608 pIn1G = pIn1 + 6*numColsA;
609 pIn1H = pIn1 + 7*numColsA;
610
611 acc0 = vdupq_n_f32(0.0);
612 acc1 = vdupq_n_f32(0.0);
613 acc2 = vdupq_n_f32(0.0);
614 acc3 = vdupq_n_f32(0.0);
615 acc4 = vdupq_n_f32(0.0);
616 acc5 = vdupq_n_f32(0.0);
617 acc6 = vdupq_n_f32(0.0);
618 acc7 = vdupq_n_f32(0.0);
619
620 /* Compute 4 MACs simultaneously. */
621 colCnt = numColsA >> 2U;
622
623 /* Matrix multiplication */
624 while (colCnt > 0U)
625 {
626 /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
627 a0V = vld1q_f32(pIn1);
628 a1V = vld1q_f32(pIn1B);
629 a2V = vld1q_f32(pIn1C);
630 a3V = vld1q_f32(pIn1D);
631 a4V = vld1q_f32(pIn1E);
632 a5V = vld1q_f32(pIn1F);
633 a6V = vld1q_f32(pIn1G);
634 a7V = vld1q_f32(pIn1H);
635
636 pIn1 += 4;
637 pIn1B += 4;
638 pIn1C += 4;
639 pIn1D += 4;
640 pIn1E += 4;
641 pIn1F += 4;
642 pIn1G += 4;
643 pIn1H += 4;
644
645 temp = vsetq_lane_f32(*pIn2,temp,0);
646 pIn2 += numColsB;
647 temp = vsetq_lane_f32(*pIn2,temp,1);
648 pIn2 += numColsB;
649 temp = vsetq_lane_f32(*pIn2,temp,2);
650 pIn2 += numColsB;
651 temp = vsetq_lane_f32(*pIn2,temp,3);
652 pIn2 += numColsB;
653
654 acc0 = vmlaq_f32(acc0,a0V,temp);
655 acc1 = vmlaq_f32(acc1,a1V,temp);
656 acc2 = vmlaq_f32(acc2,a2V,temp);
657 acc3 = vmlaq_f32(acc3,a3V,temp);
658 acc4 = vmlaq_f32(acc4,a4V,temp);
659 acc5 = vmlaq_f32(acc5,a5V,temp);
660 acc6 = vmlaq_f32(acc6,a6V,temp);
661 acc7 = vmlaq_f32(acc7,a7V,temp);
662
663 /* Decrement the loop count */
664 colCnt--;
665 }
666
667 accum = vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
668 sum0 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
669
670 accum = vpadd_f32(vget_low_f32(acc1), vget_high_f32(acc1));
671 sum1 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
672
673 accum = vpadd_f32(vget_low_f32(acc2), vget_high_f32(acc2));
674 sum2 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
675
676 accum = vpadd_f32(vget_low_f32(acc3), vget_high_f32(acc3));
677 sum3 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
678
679 accum = vpadd_f32(vget_low_f32(acc4), vget_high_f32(acc4));
680 sum4 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
681
682 accum = vpadd_f32(vget_low_f32(acc5), vget_high_f32(acc5));
683 sum5 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
684
685 accum = vpadd_f32(vget_low_f32(acc6), vget_high_f32(acc6));
686 sum6 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
687
688 accum = vpadd_f32(vget_low_f32(acc7), vget_high_f32(acc7));
689 sum7 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
690
691 /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
692 ** No loop unrolling is used. */
693 colCnt = numColsA & 3;
694
695 while (colCnt > 0U)
696 {
697 /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
698 sum0 += *pIn1++ * (*pIn2);
699 sum1 += *pIn1B++ * (*pIn2);
700 sum2 += *pIn1C++ * (*pIn2);
701 sum3 += *pIn1D++ * (*pIn2);
702 sum4 += *pIn1E++ * (*pIn2);
703 sum5 += *pIn1F++ * (*pIn2);
704 sum6 += *pIn1G++ * (*pIn2);
705 sum7 += *pIn1H++ * (*pIn2);
706 pIn2 += numColsB;
707
708 /* Decrement the loop counter */
709 colCnt--;
710 }
711
712 /* Store the result in the destination buffer */
713 *px++ = sum0;
714 *pxB++ = sum1;
715 *pxC++ = sum2;
716 *pxD++ = sum3;
717 *pxE++ = sum4;
718 *pxF++ = sum5;
719 *pxG++ = sum6;
720 *pxH++ = sum7;
721
722 /* Update the pointer pIn2 to point to the starting address of the next column */
723 j++;
724 pIn2 = pSrcB->pData + j;
725
726 /* Decrement the column loop counter */
727 col--;
728
729 } while (col > 0U);
730
731 /* Update the pointer pInA to point to the starting address of the next row */
732 i = i + numColsB;
733 pInA = pInA + GROUPOFROWS*numColsA;
734
735 /* Decrement the row loop counter */
736 rowCnt--;
737 }
738
739 /*
740
741 i was the index of a group of rows computed by previous loop.
742 Now i is the index of a row since below code is computing row per row
743 and no more group of row per group of rows.
744
745 */
746
747 i = GROUPOFROWS*i;
748 rowCnt = row & 7;
749
750 while(rowCnt > 0)
751 {
752 /* Output pointer is set to starting address of the row being processed */
753 px = pOut + i;
754
755 /* For every row wise process, the column loop counter is to be initiated */
756 col = numColsB;
757
758 /* For every row wise process, the pIn2 pointer is set
759 ** to the starting address of the pSrcB data */
760 pIn2 = pSrcB->pData;
761
762 j = 0U;
763
764 /* Column loop */
765 do
766 {
767 /* Set the variable sum, that acts as accumulator, to zero */
768 sum = 0.0f;
769
770 /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
771 pIn1 = pInA;
772
773 acc0 = vdupq_n_f32(0.0);
774
775 /* Compute 4 MACs simultaneously. */
776 colCnt = numColsA >> 2U;
777
778 /* Matrix multiplication */
779 while (colCnt > 0U)
780 {
781 /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
782 a0V = vld1q_f32(pIn1); // load & separate real/imag pSrcA (de-interleave 2)
783 pIn1 += 4;
784
785 temp = vsetq_lane_f32(*pIn2,temp,0);
786 pIn2 += numColsB;
787 temp = vsetq_lane_f32(*pIn2,temp,1);
788 pIn2 += numColsB;
789 temp = vsetq_lane_f32(*pIn2,temp,2);
790 pIn2 += numColsB;
791 temp = vsetq_lane_f32(*pIn2,temp,3);
792 pIn2 += numColsB;
793
794 acc0 = vmlaq_f32(acc0,a0V,temp);
795
796 /* Decrement the loop count */
797 colCnt--;
798 }
799
800 accum = vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
801 sum += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
802
803 /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
804 ** No loop unrolling is used. */
805 colCnt = numColsA % 0x4U;
806
807 while (colCnt > 0U)
808 {
809 /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
810 sum += *pIn1++ * (*pIn2);
811 pIn2 += numColsB;
812
813 /* Decrement the loop counter */
814 colCnt--;
815 }
816
817 /* Store the result in the destination buffer */
818 *px++ = sum;
819
820 /* Update the pointer pIn2 to point to the starting address of the next column */
821 j++;
822 pIn2 = pSrcB->pData + j;
823
824 /* Decrement the column loop counter */
825 col--;
826
827 } while (col > 0U);
828
829
830 /* Update the pointer pInA to point to the starting address of the next row */
831 i = i + numColsB;
832 pInA = pInA + numColsA;
833
834 /* Decrement the row loop counter */
835 rowCnt--;
836
837 }
838 /* Set status as ARM_MATH_SUCCESS */
839 status = ARM_MATH_SUCCESS;
840 }
841
842 /* Return to application */
843 return (status);
844 }
845 #else
arm_mat_mult_f32(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)846 arm_status arm_mat_mult_f32(
847 const arm_matrix_instance_f32 * pSrcA,
848 const arm_matrix_instance_f32 * pSrcB,
849 arm_matrix_instance_f32 * pDst)
850 {
851 float32_t *pIn1 = pSrcA->pData; /* Input data matrix pointer A */
852 float32_t *pIn2 = pSrcB->pData; /* Input data matrix pointer B */
853 float32_t *pInA = pSrcA->pData; /* Input data matrix pointer A */
854 float32_t *pInB = pSrcB->pData; /* Input data matrix pointer B */
855 float32_t *pOut = pDst->pData; /* Output data matrix pointer */
856 float32_t *px; /* Temporary output data matrix pointer */
857 float32_t sum; /* Accumulator */
858 uint16_t numRowsA = pSrcA->numRows; /* Number of rows of input matrix A */
859 uint16_t numColsB = pSrcB->numCols; /* Number of columns of input matrix B */
860 uint16_t numColsA = pSrcA->numCols; /* Number of columns of input matrix A */
861 uint32_t col, i = 0U, row = numRowsA, colCnt; /* Loop counters */
862 arm_status status; /* Status of matrix multiplication */
863
864 #ifdef ARM_MATH_MATRIX_CHECK
865
866 /* Check for matrix mismatch condition */
867 if ((pSrcA->numCols != pSrcB->numRows) ||
868 (pSrcA->numRows != pDst->numRows) ||
869 (pSrcB->numCols != pDst->numCols) )
870 {
871 /* Set status as ARM_MATH_SIZE_MISMATCH */
872 status = ARM_MATH_SIZE_MISMATCH;
873 }
874 else
875
876 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
877
878 {
879 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
880 /* row loop */
881 do
882 {
883 /* Output pointer is set to starting address of row being processed */
884 px = pOut + i;
885
886 /* For every row wise process, column loop counter is to be initiated */
887 col = numColsB;
888
889 /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
890 pIn2 = pSrcB->pData;
891
892 /* column loop */
893 do
894 {
895 /* Set the variable sum, that acts as accumulator, to zero */
896 sum = 0.0f;
897
898 /* Initialize pointer pIn1 to point to starting address of column being processed */
899 pIn1 = pInA;
900
901 #if defined (ARM_MATH_LOOPUNROLL)
902
903 /* Loop unrolling: Compute 4 MACs at a time. */
904 colCnt = numColsA >> 2U;
905
906 /* matrix multiplication */
907 while (colCnt > 0U)
908 {
909 /* c(m,p) = a(m,1) * b(1,p) + a(m,2) * b(2,p) + .... + a(m,n) * b(n,p) */
910
911 /* Perform the multiply-accumulates */
912 sum += *pIn1++ * *pIn2;
913 pIn2 += numColsB;
914
915 sum += *pIn1++ * *pIn2;
916 pIn2 += numColsB;
917
918 sum += *pIn1++ * *pIn2;
919 pIn2 += numColsB;
920
921 sum += *pIn1++ * *pIn2;
922 pIn2 += numColsB;
923
924 /* Decrement loop counter */
925 colCnt--;
926 }
927
928 /* Loop unrolling: Compute remaining MACs */
929 colCnt = numColsA % 0x4U;
930
931 #else
932
933 /* Initialize cntCnt with number of columns */
934 colCnt = numColsA;
935
936 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
937
938 while (colCnt > 0U)
939 {
940 /* c(m,p) = a(m,1) * b(1,p) + a(m,2) * b(2,p) + .... + a(m,n) * b(n,p) */
941
942 /* Perform the multiply-accumulates */
943 sum += *pIn1++ * *pIn2;
944 pIn2 += numColsB;
945
946 /* Decrement loop counter */
947 colCnt--;
948 }
949
950 /* Store result in destination buffer */
951 *px++ = sum;
952
953 /* Decrement column loop counter */
954 col--;
955
956 /* Update pointer pIn2 to point to starting address of next column */
957 pIn2 = pInB + (numColsB - col);
958
959 } while (col > 0U);
960
961 /* Update pointer pInA to point to starting address of next row */
962 i = i + numColsB;
963 pInA = pInA + numColsA;
964
965 /* Decrement row loop counter */
966 row--;
967
968 } while (row > 0U);
969
970 /* Set status as ARM_MATH_SUCCESS */
971 status = ARM_MATH_SUCCESS;
972 }
973
974 /* Return to application */
975 return (status);
976 }
977
978 #endif /* #if defined(ARM_MATH_NEON) */
979 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
980
981 /**
982 * @} end of MatrixMult group
983 */
984