1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_mat_mult_fast_q15.c
4 * Description: Q15 matrix multiplication (fast variant)
5 *
6 * $Date: 23 April 2021
7 * $Revision: V1.9.0
8 *
9 * Target Processor: Cortex-M and Cortex-A cores
10 * -------------------------------------------------------------------- */
11 /*
12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13 *
14 * SPDX-License-Identifier: Apache-2.0
15 *
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
19 *
20 * www.apache.org/licenses/LICENSE-2.0
21 *
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
27 */
28
29 #include "dsp/matrix_functions.h"
30
31 /**
32 @ingroup groupMatrix
33 */
34
35 /**
36 @addtogroup MatrixMult
37 @{
38 */
39
40 /**
41 @brief Q15 matrix multiplication (fast variant).
42 @param[in] pSrcA points to the first input matrix structure
43 @param[in] pSrcB points to the second input matrix structure
44 @param[out] pDst points to output matrix structure
45 @param[in] pState points to the array for storing intermediate results
46 @return execution status
47 - \ref ARM_MATH_SUCCESS : Operation successful
48 - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
49
50 @par Scaling and Overflow Behavior
51 The difference between the function \ref arm_mat_mult_q15() and this fast variant is that
52 the fast variant use a 32-bit rather than a 64-bit accumulator.
53 The result of each 1.15 x 1.15 multiplication is truncated to
54 2.30 format. These intermediate results are accumulated in a 32-bit register in 2.30
55 format. Finally, the accumulator is saturated and converted to a 1.15 result.
56 @par
57 The fast version has the same overflow behavior as the standard version but provides
58 less precision since it discards the low 16 bits of each multiplication result.
59 In order to avoid overflows completely the input signals must be scaled down.
60 Scale down one of the input matrices by log2(numColsA) bits to avoid overflows,
61 as a total of numColsA additions are computed internally for each output element.
62 @remark
63 Refer to \ref arm_mat_mult_q15() for a slower implementation of this function
64 which uses 64-bit accumulation to provide higher precision.
65 */
66
arm_mat_mult_fast_q15(const arm_matrix_instance_q15 * pSrcA,const arm_matrix_instance_q15 * pSrcB,arm_matrix_instance_q15 * pDst,q15_t * pState)67 arm_status arm_mat_mult_fast_q15(
68 const arm_matrix_instance_q15 * pSrcA,
69 const arm_matrix_instance_q15 * pSrcB,
70 arm_matrix_instance_q15 * pDst,
71 q15_t * pState)
72 {
73 q31_t sum; /* Accumulator */
74 q15_t *pSrcBT = pState; /* Input data matrix pointer for transpose */
75 q15_t *pInA = pSrcA->pData; /* Input data matrix pointer A of Q15 type */
76 q15_t *pInB = pSrcB->pData; /* Input data matrix pointer B of Q15 type */
77 q15_t *px; /* Temporary output data matrix pointer */
78 uint16_t numRowsA = pSrcA->numRows; /* Number of rows of input matrix A */
79 uint16_t numColsB = pSrcB->numCols; /* Number of columns of input matrix B */
80 uint16_t numColsA = pSrcA->numCols; /* Number of columns of input matrix A */
81 uint16_t numRowsB = pSrcB->numRows; /* Number of rows of input matrix B */
82 uint32_t col, i = 0U, row = numRowsB, colCnt; /* Loop counters */
83 arm_status status; /* Status of matrix multiplication */
84
85 #if defined (ARM_MATH_DSP)
86 q31_t in; /* Temporary variable to hold the input value */
87 q31_t inA1, inB1, inA2, inB2;
88 q31_t sum2, sum3, sum4;
89 q15_t *pInA2, *pInB2, *px2;
90 uint32_t j = 0;
91 #else
92 q15_t in; /* Temporary variable to hold the input value */
93 q15_t inA1, inB1, inA2, inB2;
94 #endif /* #if defined (ARM_MATH_DSP) */
95
96 #ifdef ARM_MATH_MATRIX_CHECK
97
98 /* Check for matrix mismatch condition */
99 if ((pSrcA->numCols != pSrcB->numRows) ||
100 (pSrcA->numRows != pDst->numRows) ||
101 (pSrcB->numCols != pDst->numCols) )
102 {
103 /* Set status as ARM_MATH_SIZE_MISMATCH */
104 status = ARM_MATH_SIZE_MISMATCH;
105 }
106 else
107
108 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
109
110 {
111 /* Matrix transpose */
112 do
113 {
114 /* The pointer px is set to starting address of column being processed */
115 px = pSrcBT + i;
116
117 /* Apply loop unrolling and exchange columns with row elements */
118 col = numColsB >> 2U;
119
120 /* First part of the processing with loop unrolling. Compute 4 outputs at a time.
121 ** a second loop below computes the remaining 1 to 3 samples. */
122 while (col > 0U)
123 {
124
125 #if defined (ARM_MATH_DSP)
126
127 /* Read two elements from row */
128 in = read_q15x2_ia ((q15_t **) &pInB);
129
130 /* Unpack and store one element in destination */
131 #ifndef ARM_MATH_BIG_ENDIAN
132 *px = (q15_t) in;
133 #else
134 *px = (q15_t) ((in & (q31_t) 0xffff0000) >> 16);
135 #endif /* #ifndef ARM_MATH_BIG_ENDIAN */
136
137 /* Update pointer px to point to next row of transposed matrix */
138 px += numRowsB;
139
140 /* Unpack and store second element in destination */
141 #ifndef ARM_MATH_BIG_ENDIAN
142 *px = (q15_t) ((in & (q31_t) 0xffff0000) >> 16);
143 #else
144 *px = (q15_t) in;
145 #endif /* #ifndef ARM_MATH_BIG_ENDIAN */
146
147 /* Update pointer px to point to next row of transposed matrix */
148 px += numRowsB;
149
150 in = read_q15x2_ia ((q15_t **) &pInB);
151 #ifndef ARM_MATH_BIG_ENDIAN
152 *px = (q15_t) in;
153 #else
154 *px = (q15_t) ((in & (q31_t) 0xffff0000) >> 16);
155 #endif /* #ifndef ARM_MATH_BIG_ENDIAN */
156 px += numRowsB;
157
158 #ifndef ARM_MATH_BIG_ENDIAN
159 *px = (q15_t) ((in & (q31_t) 0xffff0000) >> 16);
160 #else
161 *px = (q15_t) in;
162 #endif /* #ifndef ARM_MATH_BIG_ENDIAN */
163 px += numRowsB;
164
165 #else /* #if defined (ARM_MATH_DSP) */
166
167 /* Read one element from row */
168 in = *pInB++;
169
170 /* Store one element in destination */
171 *px = in;
172
173 /* Update pointer px to point to next row of transposed matrix */
174 px += numRowsB;
175
176 in = *pInB++;
177 *px = in;
178 px += numRowsB;
179
180 in = *pInB++;
181 *px = in;
182 px += numRowsB;
183
184 in = *pInB++;
185 *px = in;
186 px += numRowsB;
187
188 #endif /* #if defined (ARM_MATH_DSP) */
189
190 /* Decrement column loop counter */
191 col--;
192 }
193
194 /* If the columns of pSrcB is not a multiple of 4, compute any remaining output samples here.
195 ** No loop unrolling is used. */
196 col = numColsB % 0x4U;
197
198 while (col > 0U)
199 {
200 /* Read and store input element in destination */
201 *px = *pInB++;
202
203 /* Update pointer px to point to next row of transposed matrix */
204 px += numRowsB;
205
206 /* Decrement column loop counter */
207 col--;
208 }
209
210 i++;
211
212 /* Decrement row loop counter */
213 row--;
214
215 } while (row > 0U);
216
217 /* Reset variables for usage in following multiplication process */
218 row = numRowsA;
219 i = 0U;
220 px = pDst->pData;
221
222 #if defined (ARM_MATH_DSP)
223 /* Process two rows from matrix A at a time and output two rows at a time */
224 row = row >> 1U;
225 px2 = px + numColsB;
226 #endif
227
228 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
229 /* row loop */
230 while (row > 0U)
231 {
232 /* For every row wise process, column loop counter is to be initiated */
233 col = numColsB;
234
235 /* For every row wise process, pIn2 pointer is set to starting address of transposed pSrcB data */
236 pInB = pSrcBT;
237
238 #if defined (ARM_MATH_DSP)
239 /* Process two (transposed) columns from matrix B at a time */
240 col = col >> 1U;
241 j = 0;
242 #endif
243
244 /* column loop */
245 while (col > 0U)
246 {
247 /* Set variable sum, that acts as accumulator, to zero */
248 sum = 0;
249
250 /* Initiate pointer pInA to point to starting address of column being processed */
251 pInA = pSrcA->pData + i;
252
253 #if defined (ARM_MATH_DSP)
254 sum2 = 0;
255 sum3 = 0;
256 sum4 = 0;
257 pInB = pSrcBT + j;
258 pInA2 = pInA + numColsA;
259 pInB2 = pInB + numRowsB;
260
261 /* Read in two elements at once - allows dual MAC instruction */
262 colCnt = numColsA >> 1U;
263 #else
264 colCnt = numColsA >> 2U;
265 #endif
266
267 /* matrix multiplication */
268 while (colCnt > 0U)
269 {
270 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
271
272 #if defined (ARM_MATH_DSP)
273 /* read real and imag values from pSrcA and pSrcB buffer */
274 inA1 = read_q15x2_ia ((q15_t **) &pInA);
275 inB1 = read_q15x2_ia ((q15_t **) &pInB);
276
277 inA2 = read_q15x2_ia ((q15_t **) &pInA2);
278 inB2 = read_q15x2_ia ((q15_t **) &pInB2);
279
280 /* Multiply and Accumulates */
281 sum = __SMLAD(inA1, inB1, sum);
282 sum2 = __SMLAD(inA1, inB2, sum2);
283 sum3 = __SMLAD(inA2, inB1, sum3);
284 sum4 = __SMLAD(inA2, inB2, sum4);
285 #else
286 /* read real and imag values from pSrcA and pSrcB buffer */
287 inA1 = *pInA++;
288 inB1 = *pInB++;
289 /* Multiply and Accumulates */
290 sum += inA1 * inB1;
291
292 inA2 = *pInA++;
293 inB2 = *pInB++;
294 sum += inA2 * inB2;
295
296 inA1 = *pInA++;
297 inB1 = *pInB++;
298 sum += inA1 * inB1;
299
300 inA2 = *pInA++;
301 inB2 = *pInB++;
302 sum += inA2 * inB2;
303 #endif /* #if defined (ARM_MATH_DSP) */
304
305 /* Decrement loop counter */
306 colCnt--;
307 }
308
309 /* process odd column samples */
310 #if defined (ARM_MATH_DSP)
311 if (numColsA & 1U) {
312 inA1 = *pInA++;
313 inB1 = *pInB++;
314 inA2 = *pInA2++;
315 inB2 = *pInB2++;
316 sum += inA1 * inB1;
317 sum2 += inA1 * inB2;
318 sum3 += inA2 * inB1;
319 sum4 += inA2 * inB2;
320 }
321 #else
322 colCnt = numColsA % 0x4U;
323
324 while (colCnt > 0U)
325 {
326 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
327 sum += (q31_t) *pInA++ * *pInB++;
328
329 /* Decrement loop counter */
330 colCnt--;
331 }
332 #endif /* #if defined (ARM_MATH_DSP) */
333
334 /* Saturate and store result in destination buffer */
335 *px++ = (q15_t) (sum >> 15);
336
337 #if defined (ARM_MATH_DSP)
338 *px++ = (q15_t) (sum2 >> 15);
339 *px2++ = (q15_t) (sum3 >> 15);
340 *px2++ = (q15_t) (sum4 >> 15);
341 j += numRowsB * 2;
342 #endif
343
344 /* Decrement column loop counter */
345 col--;
346
347 }
348
349 i = i + numColsA;
350
351 #if defined (ARM_MATH_DSP)
352 i = i + numColsA;
353 px = px2 + (numColsB & 1U);
354 px2 = px + numColsB;
355 #endif
356
357 /* Decrement row loop counter */
358 row--;
359
360 }
361
362 /* Compute any remaining odd row/column below */
363
364 #if defined (ARM_MATH_DSP)
365
366 /* Compute remaining output column */
367 if (numColsB & 1U) {
368
369 /* Avoid redundant computation of last element */
370 row = numRowsA & (~0x1);
371
372 /* Point to remaining unfilled column in output matrix */
373 px = pDst->pData + numColsB-1;
374 pInA = pSrcA->pData;
375
376 /* row loop */
377 while (row > 0)
378 {
379
380 /* point to last column in matrix B */
381 pInB = pSrcBT + numRowsB * (numColsB-1);
382
383 /* Set variable sum, that acts as accumulator, to zero */
384 sum = 0;
385
386 /* Compute 4 columns at once */
387 colCnt = numColsA >> 2U;
388
389 /* matrix multiplication */
390 while (colCnt > 0U)
391 {
392 inA1 = read_q15x2_ia ((q15_t **) &pInA);
393 inA2 = read_q15x2_ia ((q15_t **) &pInA);
394 inB1 = read_q15x2_ia ((q15_t **) &pInB);
395 inB2 = read_q15x2_ia ((q15_t **) &pInB);
396
397 sum = __SMLAD(inA1, inB1, sum);
398 sum = __SMLAD(inA2, inB2, sum);
399
400 /* Decrement loop counter */
401 colCnt--;
402 }
403
404 colCnt = numColsA & 3U;
405 while (colCnt > 0U) {
406 sum += (q31_t) (*pInA++) * (*pInB++);
407 colCnt--;
408 }
409
410 /* Store result in destination buffer */
411 *px = (q15_t) (sum >> 15);
412 px += numColsB;
413
414 /* Decrement row loop counter */
415 row--;
416 }
417 }
418
419 /* Compute remaining output row */
420 if (numRowsA & 1U) {
421
422 /* point to last row in output matrix */
423 px = pDst->pData + (numColsB) * (numRowsA-1);
424
425 pInB = pSrcBT;
426 col = numColsB;
427 i = 0U;
428
429 /* col loop */
430 while (col > 0)
431 {
432 /* point to last row in matrix A */
433 pInA = pSrcA->pData + (numRowsA-1) * numColsA;
434
435 /* Set variable sum, that acts as accumulator, to zero */
436 sum = 0;
437
438 /* Compute 4 columns at once */
439 colCnt = numColsA >> 2U;
440
441 /* matrix multiplication */
442 while (colCnt > 0U)
443 {
444 inA1 = read_q15x2_ia ((q15_t **) &pInA);
445 inA2 = read_q15x2_ia ((q15_t **) &pInA);
446 inB1 = read_q15x2_ia ((q15_t **) &pInB);
447 inB2 = read_q15x2_ia ((q15_t **) &pInB);
448
449 sum = __SMLAD(inA1, inB1, sum);
450 sum = __SMLAD(inA2, inB2, sum);
451
452 /* Decrement loop counter */
453 colCnt--;
454 }
455
456 colCnt = numColsA % 4U;
457 while (colCnt > 0U) {
458 sum += (q31_t) (*pInA++) * (*pInB++);
459
460 colCnt--;
461 }
462
463 /* Store result in destination buffer */
464 *px++ = (q15_t) (sum >> 15);
465
466 /* Decrement column loop counter */
467 col--;
468 }
469 }
470
471 #endif /* #if defined (ARM_MATH_DSP) */
472
473 /* Set status as ARM_MATH_SUCCESS */
474 status = ARM_MATH_SUCCESS;
475 }
476
477 /* Return to application */
478 return (status);
479 }
480
481 /**
482 @} end of MatrixMult group
483 */
484