1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_mat_mult_f64.c
4 * Description: Floating-point matrix multiplication
5 *
6 * $Date: 10 August 2022
7 * $Revision: V1.9.1
8 *
9 * Target Processor: Cortex-M and Cortex-A cores
10 * -------------------------------------------------------------------- */
11 /*
12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13 *
14 * SPDX-License-Identifier: Apache-2.0
15 *
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
19 *
20 * www.apache.org/licenses/LICENSE-2.0
21 *
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
27 */
28
29 #include "dsp/matrix_functions.h"
30 #if defined(ARM_MATH_NEON) && defined(__aarch64__)
31 #define GROUPOFROWS 8
32 #endif
33
34 /**
35 * @ingroup groupMatrix
36 */
37
38 /**
39 * @defgroup MatrixMult Matrix Multiplication
40 *
41 * Multiplies two matrices.
42 *
43 * \image html MatrixMultiplication.gif "Multiplication of two 3 x 3 matrices"
44
45 * Matrix multiplication is only defined if the number of columns of the
46 * first matrix equals the number of rows of the second matrix.
47 * Multiplying an <code>M x N</code> matrix with an <code>N x P</code> matrix results
48 * in an <code>M x P</code> matrix.
49 * When matrix size checking is enabled, the functions check: (1) that the inner dimensions of
50 * <code>pSrcA</code> and <code>pSrcB</code> are equal; and (2) that the size of the output
51 * matrix equals the outer dimensions of <code>pSrcA</code> and <code>pSrcB</code>.
52 */
53
54
55 /**
56 * @addtogroup MatrixMult
57 * @{
58 */
59
60 /**
61 * @brief Floating-point matrix multiplication.
62 * @param[in] *pSrcA points to the first input matrix structure
63 * @param[in] *pSrcB points to the second input matrix structure
64 * @param[out] *pDst points to output matrix structure
65 * @return The function returns either
66 * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
67 */
68
69 #if defined(ARM_MATH_NEON) && defined(__aarch64__)
arm_mat_mult_f64(const arm_matrix_instance_f64 * pSrcA,const arm_matrix_instance_f64 * pSrcB,arm_matrix_instance_f64 * pDst)70 ARM_DSP_ATTRIBUTE arm_status arm_mat_mult_f64(
71 const arm_matrix_instance_f64 * pSrcA,
72 const arm_matrix_instance_f64 * pSrcB,
73 arm_matrix_instance_f64 * pDst)
74 {
75 float64_t *pIn1 = pSrcA->pData; /* input data matrix pointer A */
76 float64_t *pIn2 = pSrcB->pData; /* input data matrix pointer B */
77 float64_t *pInA = pSrcA->pData; /* input data matrix pointer A */
78 float64_t *pOut = pDst->pData; /* output data matrix pointer */
79 float64_t *px; /* Temporary output data matrix pointer */
80 float64_t sum; /* Accumulator */
81 uint32_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
82 uint32_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
83 uint32_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
84
85
86 uint32_t col, i = 0U, j, row = numRowsA, rowCnt, colCnt; /* loop counters */
87 arm_status status; /* status of matrix multiplication */
88
89 float64x2_t a0V, a1V, a2V, a3V, a4V, a5V, a6V, a7V;
90 float64x2_t acc0,acc1,acc2,acc3,acc4,acc5,acc6,acc7,temp;
91 float64_t *pIn1B = pSrcA->pData;
92 float64_t *pIn1C = pSrcA->pData;
93 float64_t *pIn1D = pSrcA->pData;
94 float64_t *pIn1E = pSrcA->pData;
95 float64_t *pIn1F = pSrcA->pData;
96 float64_t *pIn1G = pSrcA->pData;
97 float64_t *pIn1H = pSrcA->pData;
98
99 float64_t *pxB,*pxC, *pxD, *pxE, *pxF, *pxG, *pxH; /* Temporary output data matrix pointer */
100 float64_t sum0,sum1, sum2,sum3, sum4, sum5 , sum6, sum7;
101
102 #ifdef ARM_MATH_MATRIX_CHECK
103
104 /* Check for matrix mismatch condition */
105 if ((pSrcA->numCols != pSrcB->numRows) ||
106 (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
107 {
108 /* Set status as ARM_MATH_SIZE_MISMATCH */
109 status = ARM_MATH_SIZE_MISMATCH;
110 }
111 else
112 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
113 {
114 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
115 /* Row loop */
116 rowCnt = row >> 3;
117
118 while(rowCnt > 0)
119 {
120 /* Output pointer is set to starting address of the row being processed */
121 px = pOut + GROUPOFROWS*i;
122 pxB = px + numColsB;
123 pxC = px + 2*numColsB;
124 pxD = px + 3*numColsB;
125 pxE = px + 4*numColsB;
126 pxF = px + 5*numColsB;
127 pxG = px + 6*numColsB;
128 pxH = px + 7*numColsB;
129
130 /* For every row wise process, the column loop counter is to be initiated */
131 col = numColsB;
132
133 /* For every row wise process, the pIn2 pointer is set
134 ** to the starting address of the pSrcB data */
135 pIn2 = pSrcB->pData;
136
137 j = 0U;
138
139 /* Column loop */
140 do
141 {
142 /* Set the variable sum, that acts as accumulator, to zero */
143 sum0 = 0.0;
144 sum1 = 0.0;
145 sum2 = 0.0;
146 sum3 = 0.0;
147 sum4 = 0.0;
148 sum5 = 0.0;
149 sum6 = 0.0;
150 sum7 = 0.0;
151
152 /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
153 pIn1 = pInA;
154 pIn1B = pIn1 + numColsA;
155 pIn1C = pIn1 + 2*numColsA;
156 pIn1D = pIn1 + 3*numColsA;
157 pIn1E = pIn1 + 4*numColsA;
158 pIn1F = pIn1 + 5*numColsA;
159 pIn1G = pIn1 + 6*numColsA;
160 pIn1H = pIn1 + 7*numColsA;
161
162 acc0 = vdupq_n_f64(0.0);
163 acc1 = vdupq_n_f64(0.0);
164 acc2 = vdupq_n_f64(0.0);
165 acc3 = vdupq_n_f64(0.0);
166 acc4 = vdupq_n_f64(0.0);
167 acc5 = vdupq_n_f64(0.0);
168 acc6 = vdupq_n_f64(0.0);
169 acc7 = vdupq_n_f64(0.0);
170
171 /* Compute 2 MACs simultaneously. */
172 colCnt = numColsA >> 1U;
173
174 /* Matrix multiplication */
175 while (colCnt > 0U)
176 {
177 /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
178 a0V = vld1q_f64(pIn1);
179 a1V = vld1q_f64(pIn1B);
180 a2V = vld1q_f64(pIn1C);
181 a3V = vld1q_f64(pIn1D);
182 a4V = vld1q_f64(pIn1E);
183 a5V = vld1q_f64(pIn1F);
184 a6V = vld1q_f64(pIn1G);
185 a7V = vld1q_f64(pIn1H);
186
187 pIn1 += 2;
188 pIn1B += 2;
189 pIn1C += 2;
190 pIn1D += 2;
191 pIn1E += 2;
192 pIn1F += 2;
193 pIn1G += 2;
194 pIn1H += 2;
195
196 temp = vsetq_lane_f64(*pIn2,temp,0);
197 pIn2 += numColsB;
198 temp = vsetq_lane_f64(*pIn2,temp,1);
199 pIn2 += numColsB;
200
201
202 acc0 = vmlaq_f64(acc0,a0V,temp);
203 acc1 = vmlaq_f64(acc1,a1V,temp);
204 acc2 = vmlaq_f64(acc2,a2V,temp);
205 acc3 = vmlaq_f64(acc3,a3V,temp);
206 acc4 = vmlaq_f64(acc4,a4V,temp);
207 acc5 = vmlaq_f64(acc5,a5V,temp);
208 acc6 = vmlaq_f64(acc6,a6V,temp);
209 acc7 = vmlaq_f64(acc7,a7V,temp);
210
211 /* Decrement the loop count */
212 colCnt--;
213 }
214
215 sum0 += vaddvq_f64(acc0);
216 sum1 += vaddvq_f64(acc1);
217 sum2 += vaddvq_f64(acc2);
218 sum3 += vaddvq_f64(acc3);
219 sum4 += vaddvq_f64(acc4);
220 sum5 += vaddvq_f64(acc5);
221 sum6 += vaddvq_f64(acc6);
222 sum7 += vaddvq_f64(acc7);
223
224 /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
225 ** No loop unrolling is used. */
226 colCnt = numColsA & 1;
227
228 while (colCnt > 0U)
229 {
230 /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
231 sum0 += *pIn1++ * (*pIn2);
232 sum1 += *pIn1B++ * (*pIn2);
233 sum2 += *pIn1C++ * (*pIn2);
234 sum3 += *pIn1D++ * (*pIn2);
235 sum4 += *pIn1E++ * (*pIn2);
236 sum5 += *pIn1F++ * (*pIn2);
237 sum6 += *pIn1G++ * (*pIn2);
238 sum7 += *pIn1H++ * (*pIn2);
239 pIn2 += numColsB;
240
241 /* Decrement the loop counter */
242 colCnt--;
243 }
244
245 /* Store the result in the destination buffer */
246 *px++ = sum0;
247 *pxB++ = sum1;
248 *pxC++ = sum2;
249 *pxD++ = sum3;
250 *pxE++ = sum4;
251 *pxF++ = sum5;
252 *pxG++ = sum6;
253 *pxH++ = sum7;
254
255 /* Update the pointer pIn2 to point to the starting address of the next column */
256 j++;
257 pIn2 = pSrcB->pData + j;
258
259 /* Decrement the column loop counter */
260 col--;
261
262 } while (col > 0U);
263
264 /* Update the pointer pInA to point to the starting address of the next row */
265 i = i + numColsB;
266 pInA = pInA + GROUPOFROWS*numColsA;
267
268 /* Decrement the row loop counter */
269 rowCnt--;
270 }
271
272 /*
273
274 i was the index of a group of rows computed by previous loop.
275 Now i is the index of a row since below code is computing row per row
276 and no more group of row per group of rows.
277
278 */
279
280 i = GROUPOFROWS*i;
281 rowCnt = row & 7;
282
283 while(rowCnt > 0)
284 {
285 /* Output pointer is set to starting address of the row being processed */
286 px = pOut + i;
287
288 /* For every row wise process, the column loop counter is to be initiated */
289 col = numColsB;
290
291 /* For every row wise process, the pIn2 pointer is set
292 ** to the starting address of the pSrcB data */
293 pIn2 = pSrcB->pData;
294
295 j = 0U;
296
297 /* Column loop */
298 do
299 {
300 /* Set the variable sum, that acts as accumulator, to zero */
301 sum = 0.0;
302
303 /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
304 pIn1 = pInA;
305
306 acc0 = vdupq_n_f64(0.0);
307
308 /* Compute 4 MACs simultaneously. */
309 colCnt = numColsA >> 1U;
310
311 /* Matrix multiplication */
312 while (colCnt > 0U)
313 {
314 /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
315 a0V = vld1q_f64(pIn1); // load & separate real/imag pSrcA (de-interleave 2)
316 pIn1 += 2;
317
318 temp = vsetq_lane_f64(*pIn2,temp,0);
319 pIn2 += numColsB;
320 temp = vsetq_lane_f64(*pIn2,temp,1);
321 pIn2 += numColsB;
322
323
324 acc0 = vmlaq_f64(acc0,a0V,temp);
325
326 /* Decrement the loop count */
327 colCnt--;
328 }
329
330 //accum = vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
331 sum += vaddvq_f64(acc0);
332
333 /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
334 ** No loop unrolling is used. */
335 colCnt = numColsA % 0x2U;
336
337 while (colCnt > 0U)
338 {
339 /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
340 sum += *pIn1++ * (*pIn2);
341 pIn2 += numColsB;
342
343 /* Decrement the loop counter */
344 colCnt--;
345 }
346
347 /* Store the result in the destination buffer */
348 *px++ = sum;
349
350 /* Update the pointer pIn2 to point to the starting address of the next column */
351 j++;
352 pIn2 = pSrcB->pData + j;
353
354 /* Decrement the column loop counter */
355 col--;
356
357 } while (col > 0U);
358
359
360 /* Update the pointer pInA to point to the starting address of the next row */
361 i = i + numColsB;
362 pInA = pInA + numColsA;
363
364 /* Decrement the row loop counter */
365 rowCnt--;
366
367 }
368 /* Set status as ARM_MATH_SUCCESS */
369 status = ARM_MATH_SUCCESS;
370 }
371
372 /* Return to application */
373 return (status);
374 }
375 #else
arm_mat_mult_f64(const arm_matrix_instance_f64 * pSrcA,const arm_matrix_instance_f64 * pSrcB,arm_matrix_instance_f64 * pDst)376 ARM_DSP_ATTRIBUTE arm_status arm_mat_mult_f64(
377 const arm_matrix_instance_f64 * pSrcA,
378 const arm_matrix_instance_f64 * pSrcB,
379 arm_matrix_instance_f64 * pDst)
380 {
381 float64_t *pIn1 = pSrcA->pData; /* Input data matrix pointer A */
382 float64_t *pIn2 = pSrcB->pData; /* Input data matrix pointer B */
383 float64_t *pInA = pSrcA->pData; /* Input data matrix pointer A */
384 float64_t *pInB = pSrcB->pData; /* Input data matrix pointer B */
385 float64_t *pOut = pDst->pData; /* Output data matrix pointer */
386 float64_t *px; /* Temporary output data matrix pointer */
387 float64_t sum; /* Accumulator */
388 uint16_t numRowsA = pSrcA->numRows; /* Number of rows of input matrix A */
389 uint16_t numColsB = pSrcB->numCols; /* Number of columns of input matrix B */
390 uint16_t numColsA = pSrcA->numCols; /* Number of columns of input matrix A */
391 uint64_t col, i = 0U, row = numRowsA, colCnt; /* Loop counters */
392 arm_status status; /* Status of matrix multiplication */
393
394 #ifdef ARM_MATH_MATRIX_CHECK
395
396 /* Check for matrix mismatch condition */
397 if ((pSrcA->numCols != pSrcB->numRows) ||
398 (pSrcA->numRows != pDst->numRows) ||
399 (pSrcB->numCols != pDst->numCols) )
400 {
401 /* Set status as ARM_MATH_SIZE_MISMATCH */
402 status = ARM_MATH_SIZE_MISMATCH;
403 }
404 else
405
406 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
407
408 {
409 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
410 /* row loop */
411 do
412 {
413 /* Output pointer is set to starting address of row being processed */
414 px = pOut + i;
415
416 /* For every row wise process, column loop counter is to be initiated */
417 col = numColsB;
418
419 /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
420 pIn2 = pSrcB->pData;
421
422 /* column loop */
423 do
424 {
425 /* Set the variable sum, that acts as accumulator, to zero */
426 sum = 0.0;
427
428 /* Initialize pointer pIn1 to point to starting address of column being processed */
429 pIn1 = pInA;
430
431 #if defined (ARM_MATH_LOOPUNROLL)
432
433 /* Loop unrolling: Compute 4 MACs at a time. */
434 colCnt = numColsA >> 2U;
435
436 /* matrix multiplication */
437 while (colCnt > 0U)
438 {
439 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
440
441 /* Perform the multiply-accumulates */
442 sum += *pIn1++ * *pIn2;
443 pIn2 += numColsB;
444
445 sum += *pIn1++ * *pIn2;
446 pIn2 += numColsB;
447
448 sum += *pIn1++ * *pIn2;
449 pIn2 += numColsB;
450
451 sum += *pIn1++ * *pIn2;
452 pIn2 += numColsB;
453
454 /* Decrement loop counter */
455 colCnt--;
456 }
457
458 /* Loop unrolling: Compute remaining MACs */
459 colCnt = numColsA % 0x4U;
460
461 #else
462
463 /* Initialize cntCnt with number of columns */
464 colCnt = numColsA;
465
466 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
467
468 while (colCnt > 0U)
469 {
470 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
471
472 /* Perform the multiply-accumulates */
473 sum += *pIn1++ * *pIn2;
474 pIn2 += numColsB;
475
476 /* Decrement loop counter */
477 colCnt--;
478 }
479
480 /* Store result in destination buffer */
481 *px++ = sum;
482
483 /* Decrement column loop counter */
484 col--;
485
486 /* Update pointer pIn2 to point to starting address of next column */
487 pIn2 = pInB + (numColsB - col);
488
489 } while (col > 0U);
490
491 /* Update pointer pInA to point to starting address of next row */
492 i = i + numColsB;
493 pInA = pInA + numColsA;
494
495 /* Decrement row loop counter */
496 row--;
497
498 } while (row > 0U);
499
500 /* Set status as ARM_MATH_SUCCESS */
501 status = ARM_MATH_SUCCESS;
502 }
503
504 /* Return to application */
505 return (status);
506 }
507 #endif
508
509 /**
510 * @} end of MatrixMult group
511 */
512