1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_cmplx_mat_mult_q15.c
4 * Description: Q15 complex matrix multiplication
5 *
6 * $Date: 23 April 2021
7 * $Revision: V1.9.0
8 *
9 * Target Processor: Cortex-M and Cortex-A cores
10 * -------------------------------------------------------------------- */
11 /*
12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13 *
14 * SPDX-License-Identifier: Apache-2.0
15 *
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
19 *
20 * www.apache.org/licenses/LICENSE-2.0
21 *
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
27 */
28
29 #include "dsp/matrix_functions.h"
30
31 /**
32 @ingroup groupMatrix
33 */
34
35 /**
36 @addtogroup CmplxMatrixMult
37 @{
38 */
39
40 /**
41 @brief Q15 Complex matrix multiplication.
42 @param[in] pSrcA points to first input complex matrix structure
43 @param[in] pSrcB points to second input complex matrix structure
44 @param[out] pDst points to output complex matrix structure
45 @param[in] pScratch points to an array for storing intermediate results
46 @return execution status
47 - \ref ARM_MATH_SUCCESS : Operation successful
48 - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
49
50 @par Conditions for optimum performance
51 Input, output and state buffers should be aligned by 32-bit
52
53 @par Scaling and Overflow Behavior
54 The function is implemented using an internal 64-bit accumulator. The inputs to the
55 multiplications are in 1.15 format and multiplications yield a 2.30 result.
56 The 2.30 intermediate results are accumulated in a 64-bit accumulator in 34.30 format.
57 This approach provides 33 guard bits and there is no risk of overflow. The 34.30 result is then
58 truncated to 34.15 format by discarding the low 15 bits and then saturated to 1.15 format.
59 */
60 #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE)
61
62 #define MVE_ASRL_SAT16(acc, shift) ((sqrshrl_sat48(acc, -(32-shift)) >> 32) & 0xffffffff)
63
arm_mat_cmplx_mult_q15(const arm_matrix_instance_q15 * pSrcA,const arm_matrix_instance_q15 * pSrcB,arm_matrix_instance_q15 * pDst,q15_t * pScratch)64 arm_status arm_mat_cmplx_mult_q15(
65 const arm_matrix_instance_q15 * pSrcA,
66 const arm_matrix_instance_q15 * pSrcB,
67 arm_matrix_instance_q15 * pDst,
68 q15_t * pScratch)
69 {
70 q15_t const *pInA = (q15_t const *) pSrcA->pData; /* input data matrix pointer A of Q15 type */
71 q15_t const *pInB = (q15_t const *) pSrcB->pData; /* input data matrix pointer B of Q15 type */
72 q15_t const *pInB2;
73 q15_t *px; /* Temporary output data matrix pointer */
74 uint32_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
75 uint32_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
76 uint32_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
77 uint32_t numRowsB = pSrcB->numRows; /* number of rows of input matrix A */
78 uint32_t col, i = 0u, j, row = numRowsB; /* loop counters */
79 uint32_t blkCnt; /* loop counters */
80 uint16x8_t vecOffs, vecColBOffs;
81 arm_status status; /* Status of matrix multiplication */
82 (void)pScratch;
83
84 #ifdef ARM_MATH_MATRIX_CHECK
85
86 /* Check for matrix mismatch condition */
87 if ((pSrcA->numCols != pSrcB->numRows) ||
88 (pSrcA->numRows != pDst->numRows) ||
89 (pSrcB->numCols != pDst->numCols) )
90 {
91 /* Set status as ARM_MATH_SIZE_MISMATCH */
92 status = ARM_MATH_SIZE_MISMATCH;
93 }
94 else
95
96 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
97
98 {
99 vecColBOffs[0] = 0;
100 vecColBOffs[1] = 1;
101 vecColBOffs[2] = numColsB * CMPLX_DIM;
102 vecColBOffs[3] = (numColsB * CMPLX_DIM) + 1;
103 vecColBOffs[4] = 2 * numColsB * CMPLX_DIM;
104 vecColBOffs[5] = 2 * (numColsB * CMPLX_DIM) + 1;
105 vecColBOffs[6] = 3 * numColsB * CMPLX_DIM;
106 vecColBOffs[7] = 3 * (numColsB * CMPLX_DIM) + 1;
107
108 /*
109 * Reset the variables for the usage in the following multiplication process
110 */
111 i = 0;
112 row = numRowsA;
113 px = pDst->pData;
114
115 /*
116 * The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
117 */
118
119 /*
120 * row loop
121 */
122 while (row > 0u)
123 {
124 /*
125 * For every row wise process, the column loop counter is to be initiated
126 */
127 col = numColsB >> 1;
128 j = 0;
129 /*
130 * column loop
131 */
132 while (col > 0u)
133 {
134 q15_t const *pSrcAVec;
135 //, *pSrcBVec, *pSrcB2Vec;
136 q15x8_t vecA, vecB, vecB2;
137 q63_t acc0, acc1, acc2, acc3;
138
139 /*
140 * Initiate the pointer pIn1 to point to the starting address of the column being processed
141 */
142 pInA = pSrcA->pData + i;
143 pInB = pSrcB->pData + j;
144 pInB2 = pInB + CMPLX_DIM;
145
146 j += 2 * CMPLX_DIM;
147 /*
148 * Decrement the column loop counter
149 */
150 col--;
151
152 /*
153 * Initiate the pointers
154 * - current Matrix A rows
155 * - 2 x consecutive Matrix B' rows (j increment is 2 x numRowsB)
156 */
157 pSrcAVec = (q15_t const *) pInA;
158
159 acc0 = 0LL;
160 acc1 = 0LL;
161 acc2 = 0LL;
162 acc3 = 0LL;
163
164 vecOffs = vecColBOffs;
165
166
167 blkCnt = (numColsA * CMPLX_DIM) >> 3;
168 while (blkCnt > 0U)
169 {
170 vecA = vld1q(pSrcAVec);
171 pSrcAVec += 8;
172 vecB = vldrhq_gather_shifted_offset(pInB, vecOffs);
173
174 acc0 = vmlsldavaq_s16(acc0, vecA, vecB);
175 acc1 = vmlaldavaxq_s16(acc1, vecA, vecB);
176 vecB2 = vldrhq_gather_shifted_offset(pInB2, vecOffs);
177 /*
178 * move Matrix B read offsets, 4 rows down
179 */
180 vecOffs = vaddq_n_u16(vecOffs, (uint16_t) (numColsB * 4 * CMPLX_DIM));
181
182 acc2 = vmlsldavaq_s16(acc2, vecA, vecB2);
183 acc3 = vmlaldavaxq_s16(acc3, vecA, vecB2);
184
185 blkCnt--;
186 }
187
188 /*
189 * tail
190 */
191 blkCnt = (numColsA * CMPLX_DIM) & 7;
192 if (blkCnt > 0U)
193 {
194 mve_pred16_t p0 = vctp16q(blkCnt);
195 vecB = vldrhq_gather_shifted_offset(pInB, vecOffs);
196
197 vecA = vldrhq_z_s16(pSrcAVec, p0);
198
199 acc0 = vmlsldavaq_s16(acc0, vecA, vecB);
200 acc1 = vmlaldavaxq_s16(acc1, vecA, vecB);
201 vecB2 = vldrhq_gather_shifted_offset(pInB2, vecOffs);
202
203 /*
204 * move Matrix B read offsets, 4 rows down
205 */
206 vecOffs = vaddq_n_u16(vecOffs, (uint16_t) (numColsB * 4 * CMPLX_DIM));
207
208 acc2 = vmlsldavaq_s16(acc2, vecA, vecB2);
209 acc3 = vmlaldavaxq_s16(acc3, vecA, vecB2);
210
211 }
212 /*
213 * Convert to 1.15, Store the results (1 x 2 block) in the destination buffer
214 */
215 *px++ = (q15_t)MVE_ASRL_SAT16(acc0, 15);
216 *px++ = (q15_t)MVE_ASRL_SAT16(acc1, 15);
217 *px++ = (q15_t)MVE_ASRL_SAT16(acc2, 15);
218 *px++ = (q15_t)MVE_ASRL_SAT16(acc3, 15);
219 }
220
221 col = numColsB & 1;
222 /*
223 * column loop
224 */
225 while (col > 0u)
226 {
227
228 q15_t const *pSrcAVec;
229 //, *pSrcBVec, *pSrcB2Vec;
230 q15x8_t vecA, vecB;
231 q63_t acc0, acc1;
232
233 /*
234 * Initiate the pointer pIn1 to point to the starting address of the column being processed
235 */
236 pInA = pSrcA->pData + i;
237 pInB = pSrcB->pData + j;
238
239 j += CMPLX_DIM;
240 /*
241 * Decrement the column loop counter
242 */
243 col--;
244
245 /*
246 * Initiate the pointers
247 * - current Matrix A rows
248 * - 2 x consecutive Matrix B' rows (j increment is 2 x numRowsB)
249 */
250 pSrcAVec = (q15_t const *) pInA;
251
252 acc0 = 0LL;
253 acc1 = 0LL;
254
255
256 vecOffs = vecColBOffs;
257
258
259
260 blkCnt = (numColsA * CMPLX_DIM) >> 3;
261 while (blkCnt > 0U)
262 {
263 vecA = vld1q(pSrcAVec);
264 pSrcAVec += 8;
265 vecB = vldrhq_gather_shifted_offset(pInB, vecOffs);
266
267 acc0 = vmlsldavaq_s16(acc0, vecA, vecB);
268 acc1 = vmlaldavaxq_s16(acc1, vecA, vecB);
269 /*
270 * move Matrix B read offsets, 4 rows down
271 */
272 vecOffs = vaddq_n_u16(vecOffs, (uint16_t) (numColsB * 4 * CMPLX_DIM));
273
274 blkCnt--;
275 }
276
277 /*
278 * tail
279 */
280 blkCnt = (numColsA * CMPLX_DIM) & 7;
281 if (blkCnt > 0U)
282 {
283 mve_pred16_t p0 = vctp16q(blkCnt);
284 vecB = vldrhq_gather_shifted_offset(pInB, vecOffs);
285 vecA = vldrhq_z_s16(pSrcAVec, p0);
286
287 acc0 = vmlsldavaq_s16(acc0, vecA, vecB);
288 acc1 = vmlaldavaxq_s16(acc1, vecA, vecB);
289
290 }
291 /*
292 * Convert to 1.15, Store the results (1 x 2 block) in the destination buffer
293 */
294 *px++ = (q15_t)MVE_ASRL_SAT16(acc0, 15);
295 *px++ = (q15_t)MVE_ASRL_SAT16(acc1, 15);
296
297 }
298
299 i = i + numColsA * CMPLX_DIM;
300
301 /*
302 * Decrement the row loop counter
303 */
304 row--;
305 }
306
307
308 status = ARM_MATH_SUCCESS;
309 }
310
311 /* Return to application */
312 return (status);
313 }
314 #else
arm_mat_cmplx_mult_q15(const arm_matrix_instance_q15 * pSrcA,const arm_matrix_instance_q15 * pSrcB,arm_matrix_instance_q15 * pDst,q15_t * pScratch)315 arm_status arm_mat_cmplx_mult_q15(
316 const arm_matrix_instance_q15 * pSrcA,
317 const arm_matrix_instance_q15 * pSrcB,
318 arm_matrix_instance_q15 * pDst,
319 q15_t * pScratch)
320 {
321 q15_t *pSrcBT = pScratch; /* input data matrix pointer for transpose */
322 q15_t *pInA = pSrcA->pData; /* input data matrix pointer A of Q15 type */
323 q15_t *pInB = pSrcB->pData; /* input data matrix pointer B of Q15 type */
324 q15_t *px; /* Temporary output data matrix pointer */
325 uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
326 uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
327 uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
328 uint16_t numRowsB = pSrcB->numRows; /* number of rows of input matrix A */
329 q63_t sumReal, sumImag; /* accumulator */
330 uint32_t col, i = 0U, row = numRowsB, colCnt; /* Loop counters */
331 arm_status status; /* Status of matrix multiplication */
332
333 #if defined (ARM_MATH_DSP)
334 q31_t prod1, prod2;
335 q31_t pSourceA, pSourceB;
336 #else
337 q15_t a, b, c, d;
338 #endif /* #if defined (ARM_MATH_DSP) */
339
340 #ifdef ARM_MATH_MATRIX_CHECK
341
342 /* Check for matrix mismatch condition */
343 if ((pSrcA->numCols != pSrcB->numRows) ||
344 (pSrcA->numRows != pDst->numRows) ||
345 (pSrcB->numCols != pDst->numCols) )
346 {
347 /* Set status as ARM_MATH_SIZE_MISMATCH */
348 status = ARM_MATH_SIZE_MISMATCH;
349 }
350 else
351
352 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
353
354 {
355 /* Matrix transpose */
356 do
357 {
358 /* The pointer px is set to starting address of column being processed */
359 px = pSrcBT + i;
360
361 #if defined (ARM_MATH_LOOPUNROLL)
362
363 /* Apply loop unrolling and exchange the columns with row elements */
364 col = numColsB >> 2;
365
366 /* First part of the processing with loop unrolling. Compute 4 outputs at a time.
367 a second loop below computes the remaining 1 to 3 samples. */
368 while (col > 0U)
369 {
370 /* Read two elements from row */
371 write_q15x2 (px, read_q15x2_ia (&pInB));
372
373 /* Update pointer px to point to next row of transposed matrix */
374 px += numRowsB * 2;
375
376 /* Read two elements from row */
377 write_q15x2 (px, read_q15x2_ia (&pInB));
378
379 /* Update pointer px to point to next row of transposed matrix */
380 px += numRowsB * 2;
381
382 /* Read two elements from row */
383 write_q15x2 (px, read_q15x2_ia (&pInB));
384
385 /* Update pointer px to point to next row of transposed matrix */
386 px += numRowsB * 2;
387
388 /* Read two elements from row */
389 write_q15x2 (px, read_q15x2_ia (&pInB));
390
391 /* Update pointer px to point to next row of transposed matrix */
392 px += numRowsB * 2;
393
394 /* Decrement column loop counter */
395 col--;
396 }
397
398 /* If the columns of pSrcB is not a multiple of 4, compute any remaining output samples here.
399 ** No loop unrolling is used. */
400 col = numColsB % 0x4U;
401
402 #else
403
404 /* Initialize blkCnt with number of samples */
405 col = numColsB;
406
407 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
408
409 while (col > 0U)
410 {
411 /* Read two elements from row */
412 write_q15x2 (px, read_q15x2_ia (&pInB));
413
414 /* Update pointer px to point to next row of transposed matrix */
415 px += numRowsB * 2;
416
417 /* Decrement column loop counter */
418 col--;
419 }
420
421 i = i + 2U;
422
423 /* Decrement row loop counter */
424 row--;
425
426 } while (row > 0U);
427
428 /* Reset variables for usage in following multiplication process */
429 row = numRowsA;
430 i = 0U;
431 px = pDst->pData;
432
433 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
434 /* row loop */
435 do
436 {
437 /* For every row wise process, column loop counter is to be initiated */
438 col = numColsB;
439
440 /* For every row wise process, pIn2 pointer is set to starting address of transposed pSrcB data */
441 pInB = pSrcBT;
442
443 /* column loop */
444 do
445 {
446 /* Set variable sum, that acts as accumulator, to zero */
447 sumReal = 0;
448 sumImag = 0;
449
450 /* Initiate pointer pInA to point to starting address of column being processed */
451 pInA = pSrcA->pData + i * 2;
452
453 /* Apply loop unrolling and compute 2 MACs simultaneously. */
454 colCnt = numColsA >> 1U;
455
456 /* matrix multiplication */
457 while (colCnt > 0U)
458 {
459 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
460
461 #if defined (ARM_MATH_DSP)
462
463 /* read real and imag values from pSrcA and pSrcB buffer */
464 pSourceA = read_q15x2_ia ((q15_t **) &pInA);
465 pSourceB = read_q15x2_ia ((q15_t **) &pInB);
466
467 /* Multiply and Accumlates */
468 #ifdef ARM_MATH_BIG_ENDIAN
469 prod1 = -__SMUSD(pSourceA, pSourceB);
470 #else
471 prod1 = __SMUSD(pSourceA, pSourceB);
472 #endif
473 prod2 = __SMUADX(pSourceA, pSourceB);
474 sumReal += (q63_t) prod1;
475 sumImag += (q63_t) prod2;
476
477 /* read real and imag values from pSrcA and pSrcB buffer */
478 pSourceA = read_q15x2_ia ((q15_t **) &pInA);
479 pSourceB = read_q15x2_ia ((q15_t **) &pInB);
480
481 /* Multiply and Accumlates */
482 #ifdef ARM_MATH_BIG_ENDIAN
483 prod1 = -__SMUSD(pSourceA, pSourceB);
484 #else
485 prod1 = __SMUSD(pSourceA, pSourceB);
486 #endif
487 prod2 = __SMUADX(pSourceA, pSourceB);
488 sumReal += (q63_t) prod1;
489 sumImag += (q63_t) prod2;
490
491 #else /* #if defined (ARM_MATH_DSP) */
492
493 /* read real and imag values from pSrcA buffer */
494 a = *pInA;
495 b = *(pInA + 1U);
496 /* read real and imag values from pSrcB buffer */
497 c = *pInB;
498 d = *(pInB + 1U);
499
500 /* Multiply and Accumlates */
501 sumReal += (q31_t) a *c;
502 sumImag += (q31_t) a *d;
503 sumReal -= (q31_t) b *d;
504 sumImag += (q31_t) b *c;
505
506 /* read next real and imag values from pSrcA buffer */
507 a = *(pInA + 2U);
508 b = *(pInA + 3U);
509 /* read next real and imag values from pSrcB buffer */
510 c = *(pInB + 2U);
511 d = *(pInB + 3U);
512
513 /* update pointer */
514 pInA += 4U;
515
516 /* Multiply and Accumlates */
517 sumReal += (q31_t) a * c;
518 sumImag += (q31_t) a * d;
519 sumReal -= (q31_t) b * d;
520 sumImag += (q31_t) b * c;
521 /* update pointer */
522 pInB += 4U;
523
524 #endif /* #if defined (ARM_MATH_DSP) */
525
526 /* Decrement loop counter */
527 colCnt--;
528 }
529
530 /* process odd column samples */
531 if ((numColsA & 0x1U) > 0U)
532 {
533 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
534
535 #if defined (ARM_MATH_DSP)
536 /* read real and imag values from pSrcA and pSrcB buffer */
537 pSourceA = read_q15x2_ia ((q15_t **) &pInA);
538 pSourceB = read_q15x2_ia ((q15_t **) &pInB);
539
540 /* Multiply and Accumlates */
541 #ifdef ARM_MATH_BIG_ENDIAN
542 prod1 = -__SMUSD(pSourceA, pSourceB);
543 #else
544 prod1 = __SMUSD(pSourceA, pSourceB);
545 #endif
546 prod2 = __SMUADX(pSourceA, pSourceB);
547 sumReal += (q63_t) prod1;
548 sumImag += (q63_t) prod2;
549
550 #else /* #if defined (ARM_MATH_DSP) */
551
552 /* read real and imag values from pSrcA and pSrcB buffer */
553 a = *pInA++;
554 b = *pInA++;
555 c = *pInB++;
556 d = *pInB++;
557
558 /* Multiply and Accumlates */
559 sumReal += (q31_t) a * c;
560 sumImag += (q31_t) a * d;
561 sumReal -= (q31_t) b * d;
562 sumImag += (q31_t) b * c;
563
564 #endif /* #if defined (ARM_MATH_DSP) */
565
566 }
567
568 /* Saturate and store result in destination buffer */
569 *px++ = (q15_t) (__SSAT(sumReal >> 15, 16));
570 *px++ = (q15_t) (__SSAT(sumImag >> 15, 16));
571
572 /* Decrement column loop counter */
573 col--;
574
575 } while (col > 0U);
576
577 i = i + numColsA;
578
579 /* Decrement row loop counter */
580 row--;
581
582 } while (row > 0U);
583
584 /* Set status as ARM_MATH_SUCCESS */
585 status = ARM_MATH_SUCCESS;
586 }
587
588 /* Return to application */
589 return (status);
590 }
591 #endif /* defined(ARM_MATH_MVEI) */
592
593 /**
594 @} end of MatrixMult group
595 */
596