1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_mat_cmplx_mult_f16.c
4 * Description: Floating-point matrix multiplication
5 *
6 * $Date: 23 April 2021
7 * $Revision: V1.9.0
8 *
9 * Target Processor: Cortex-M and Cortex-A cores
10 * -------------------------------------------------------------------- */
11 /*
12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13 *
14 * SPDX-License-Identifier: Apache-2.0
15 *
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
19 *
20 * www.apache.org/licenses/LICENSE-2.0
21 *
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
27 */
28
29 #include "dsp/matrix_functions_f16.h"
30
31 #if defined(ARM_FLOAT16_SUPPORTED)
32
33
34 /**
35 @ingroup groupMatrix
36 */
37
38
39 /**
40 @addtogroup CmplxMatrixMult
41 @{
42 */
43
44 /**
45 @brief Floating-point Complex matrix multiplication.
46 @param[in] pSrcA points to first input complex matrix structure
47 @param[in] pSrcB points to second input complex matrix structure
48 @param[out] pDst points to output complex matrix structure
49 @return execution status
50 - \ref ARM_MATH_SUCCESS : Operation successful
51 - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
52 */
53
54 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE) && defined(__CMSIS_GCC_H)
55 #pragma message "Scalar version of arm_mat_cmplx_mult_f16 built. Helium version has build issues with gcc."
56 #endif
57
58 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE) && !defined(__CMSIS_GCC_H)
59
60 #include "arm_helium_utils.h"
61
62 #define DONTCARE 0 /* inactive lane content */
63
64
arm_mat_cmplx_mult_f16_2x2_mve(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)65 __STATIC_FORCEINLINE arm_status arm_mat_cmplx_mult_f16_2x2_mve(
66 const arm_matrix_instance_f16 * pSrcA,
67 const arm_matrix_instance_f16 * pSrcB,
68 arm_matrix_instance_f16 * pDst)
69 {
70 #define MATRIX_DIM 2
71 float16_t const *pInB = pSrcB->pData; /* input data matrix pointer B */
72 float16_t *pInA = pSrcA->pData; /* input data matrix pointer A */
73 float16_t *pOut = pDst->pData; /* output data matrix pointer */
74 uint16x8_t vecColBOffs0,vecColAOffs0,vecColAOffs1;
75 float16_t *pInA0 = pInA;
76 f16x8_t acc0, acc1;
77 f16x8_t vecB, vecA0, vecA1;
78 f16x8_t vecTmp;
79 uint16_t tmp;
80 static const uint16_t offsetB0[8] = { 0, 1,
81 MATRIX_DIM * CMPLX_DIM, MATRIX_DIM * CMPLX_DIM + 1,
82 2, 3,
83 MATRIX_DIM * CMPLX_DIM + 2 , MATRIX_DIM * CMPLX_DIM + 3,
84 };
85
86
87 vecColBOffs0 = vldrhq_u16((uint16_t const *) offsetB0);
88
89 tmp = 0;
90 vecColAOffs0 = viwdupq_u16(tmp, 4, 1);
91
92 tmp = (CMPLX_DIM * MATRIX_DIM);
93 vecColAOffs1 = vecColAOffs0 + (uint16_t)(CMPLX_DIM * MATRIX_DIM);
94
95
96 pInB = (float16_t const *)pSrcB->pData;
97
98 vecA0 = vldrhq_gather_shifted_offset_f16(pInA0, vecColAOffs0);
99 vecA1 = vldrhq_gather_shifted_offset_f16(pInA0, vecColAOffs1);
100
101
102 vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
103
104 acc0 = vcmulq(vecA0, vecB);
105 acc0 = vcmlaq_rot90(acc0, vecA0, vecB);
106
107 acc1 = vcmulq(vecA1, vecB);
108 acc1 = vcmlaq_rot90(acc1, vecA1, vecB);
109
110
111 /*
112 * Compute
113 * re0+re1 | im0+im1 | re0+re1 | im0+im1
114 * re2+re3 | im2+im3 | re2+re3 | im2+im3
115 */
116
117 vecTmp = (f16x8_t) vrev64q_s32((int32x4_t) acc0);
118 vecTmp = vaddq(vecTmp, acc0);
119
120
121 *(float32_t *)(&pOut[0 * CMPLX_DIM * MATRIX_DIM]) = ((f32x4_t)vecTmp)[0];
122 *(float32_t *)(&pOut[0 * CMPLX_DIM * MATRIX_DIM + CMPLX_DIM]) = ((f32x4_t)vecTmp)[2];
123
124 vecTmp = (f16x8_t) vrev64q_s32((int32x4_t) acc1);
125 vecTmp = vaddq(vecTmp, acc1);
126
127 *(float32_t *)(&pOut[1 * CMPLX_DIM * MATRIX_DIM]) = ((f32x4_t)vecTmp)[0];
128 *(float32_t *)(&pOut[1 * CMPLX_DIM * MATRIX_DIM + CMPLX_DIM]) = ((f32x4_t)vecTmp)[2];
129
130 /*
131 * Return to application
132 */
133 return (ARM_MATH_SUCCESS);
134 #undef MATRIX_DIM
135 }
136
137
138
arm_mat_cmplx_mult_f16_3x3_mve(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)139 __STATIC_FORCEINLINE arm_status arm_mat_cmplx_mult_f16_3x3_mve(
140 const arm_matrix_instance_f16 * pSrcA,
141 const arm_matrix_instance_f16 * pSrcB,
142 arm_matrix_instance_f16 * pDst)
143 {
144 #define MATRIX_DIM 3
145 float16_t const *pInB = pSrcB->pData; /* input data matrix pointer B */
146 float16_t *pInA = pSrcA->pData; /* input data matrix pointer A */
147 float16_t *pOut = pDst->pData; /* output data matrix pointer */
148 uint16x8_t vecColBOffs0;
149 float16_t *pInA0 = pInA;
150 float16_t *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM;
151 float16_t *pInA2 = pInA1 + CMPLX_DIM * MATRIX_DIM;
152 f16x8_t acc0, acc1, acc2;
153 f16x8_t vecB, vecA0, vecA1, vecA2;
154 static const uint16_t offsetB0[8] = { 0, 1,
155 MATRIX_DIM * CMPLX_DIM, MATRIX_DIM * CMPLX_DIM + 1,
156 2 * MATRIX_DIM * CMPLX_DIM, 2 * MATRIX_DIM * CMPLX_DIM + 1,
157 DONTCARE, DONTCARE
158 };
159
160
161 /* enable predication to disable upper half complex vector element */
162 mve_pred16_t p0 = vctp16q(MATRIX_DIM * CMPLX_DIM);
163
164 vecColBOffs0 = vldrhq_u16((uint16_t const *) offsetB0);
165
166 pInB = (float16_t const *)pSrcB->pData;
167
168 vecA0 = vldrhq_f16(pInA0);
169 vecA1 = vldrhq_f16(pInA1);
170 vecA2 = vldrhq_f16(pInA2);
171
172 vecB = vldrhq_gather_shifted_offset_z(pInB, vecColBOffs0, p0);
173
174 acc0 = vcmulq(vecA0, vecB);
175 acc0 = vcmlaq_rot90(acc0, vecA0, vecB);
176
177 acc1 = vcmulq(vecA1, vecB);
178 acc1 = vcmlaq_rot90(acc1, vecA1, vecB);
179
180 acc2 = vcmulq(vecA2, vecB);
181 acc2 = vcmlaq_rot90(acc2, vecA2, vecB);
182
183 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
184 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
185 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
186 pOut += CMPLX_DIM;
187 /*
188 * move to next B column
189 */
190 pInB = pInB + CMPLX_DIM;
191
192 vecB = vldrhq_gather_shifted_offset_z(pInB, vecColBOffs0, p0);
193
194 acc0 = vcmulq(vecA0, vecB);
195 acc0 = vcmlaq_rot90(acc0, vecA0, vecB);
196
197 acc1 = vcmulq(vecA1, vecB);
198 acc1 = vcmlaq_rot90(acc1, vecA1, vecB);
199
200 acc2 = vcmulq(vecA2, vecB);
201 acc2 = vcmlaq_rot90(acc2, vecA2, vecB);
202
203 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
204 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
205 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
206 pOut += CMPLX_DIM;
207 /*
208 * move to next B column
209 */
210 pInB = pInB + CMPLX_DIM;
211
212 vecB = vldrhq_gather_shifted_offset_z(pInB, vecColBOffs0, p0);
213
214 acc0 = vcmulq(vecA0, vecB);
215 acc0 = vcmlaq_rot90(acc0, vecA0, vecB);
216
217 acc1 = vcmulq(vecA1, vecB);
218 acc1 = vcmlaq_rot90(acc1, vecA1, vecB);
219
220 acc2 = vcmulq(vecA2, vecB);
221 acc2 = vcmlaq_rot90(acc2, vecA2, vecB);
222
223 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
224 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
225 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
226 /*
227 * Return to application
228 */
229 return (ARM_MATH_SUCCESS);
230 #undef MATRIX_DIM
231 }
232
233
234
235
arm_mat_cmplx_mult_f16_4x4_mve(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)236 __STATIC_FORCEINLINE arm_status arm_mat_cmplx_mult_f16_4x4_mve(
237 const arm_matrix_instance_f16 * pSrcA,
238 const arm_matrix_instance_f16 * pSrcB,
239 arm_matrix_instance_f16 * pDst)
240 {
241 #define MATRIX_DIM 4
242 float16_t const *pInB = pSrcB->pData; /* input data matrix pointer B */
243 float16_t *pInA = pSrcA->pData; /* input data matrix pointer A */
244 float16_t *pOut = pDst->pData; /* output data matrix pointer */
245 uint16x8_t vecColBOffs0;
246 float16_t *pInA0 = pInA;
247 float16_t *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM;
248 float16_t *pInA2 = pInA1 + CMPLX_DIM * MATRIX_DIM;
249 float16_t *pInA3 = pInA2 + CMPLX_DIM * MATRIX_DIM;
250 f16x8_t acc0, acc1, acc2, acc3;
251 f16x8_t vecB, vecA;
252 static const uint16_t offsetB0[8] = { 0, 1,
253 MATRIX_DIM * CMPLX_DIM, MATRIX_DIM * CMPLX_DIM + 1,
254 2 * MATRIX_DIM * CMPLX_DIM, 2 * MATRIX_DIM * CMPLX_DIM + 1,
255 3 * MATRIX_DIM * CMPLX_DIM, 3 * MATRIX_DIM * CMPLX_DIM + 1
256 };
257
258 vecColBOffs0 = vldrhq_u16((uint16_t const *) offsetB0);
259
260 pInB = (float16_t const *)pSrcB->pData;
261
262 vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
263
264 vecA = vldrhq_f16(pInA0);
265 acc0 = vcmulq(vecA, vecB);
266 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
267
268 vecA = vldrhq_f16(pInA1);
269 acc1 = vcmulq(vecA, vecB);
270 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
271
272 vecA = vldrhq_f16(pInA2);
273 acc2 = vcmulq(vecA, vecB);
274 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
275
276 vecA = vldrhq_f16(pInA3);
277 acc3 = vcmulq(vecA, vecB);
278 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
279
280
281 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
282 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
283 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
284 mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]);
285 pOut += CMPLX_DIM;
286 /*
287 * move to next B column
288 */
289 pInB = pInB + CMPLX_DIM;
290
291 vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
292
293 vecA = vldrhq_f16(pInA0);
294 acc0 = vcmulq(vecA, vecB);
295 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
296
297 vecA = vldrhq_f16(pInA1);
298 acc1 = vcmulq(vecA, vecB);
299 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
300
301 vecA = vldrhq_f16(pInA2);
302 acc2 = vcmulq(vecA, vecB);
303 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
304
305 vecA = vldrhq_f16(pInA3);
306 acc3 = vcmulq(vecA, vecB);
307 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
308
309
310 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
311 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
312 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
313 mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]);
314 pOut += CMPLX_DIM;
315 /*
316 * move to next B column
317 */
318 pInB = pInB + CMPLX_DIM;
319
320 vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
321
322 vecA = vldrhq_f16(pInA0);
323 acc0 = vcmulq(vecA, vecB);
324 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
325
326 vecA = vldrhq_f16(pInA1);
327 acc1 = vcmulq(vecA, vecB);
328 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
329
330 vecA = vldrhq_f16(pInA2);
331 acc2 = vcmulq(vecA, vecB);
332 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
333
334 vecA = vldrhq_f16(pInA3);
335 acc3 = vcmulq(vecA, vecB);
336 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
337
338
339 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
340 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
341 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
342 mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]);
343 pOut += CMPLX_DIM;
344 /*
345 * move to next B column
346 */
347 pInB = pInB + CMPLX_DIM;
348
349 vecB = vldrhq_gather_shifted_offset(pInB, vecColBOffs0);
350
351 vecA = vldrhq_f16(pInA0);
352 acc0 = vcmulq(vecA, vecB);
353 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
354
355 vecA = vldrhq_f16(pInA1);
356 acc1 = vcmulq(vecA, vecB);
357 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
358
359 vecA = vldrhq_f16(pInA2);
360 acc2 = vcmulq(vecA, vecB);
361 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
362
363 vecA = vldrhq_f16(pInA3);
364 acc3 = vcmulq(vecA, vecB);
365 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
366
367
368 mve_cmplx_sum_intra_vec_f16(acc0, &pOut[0 * CMPLX_DIM * MATRIX_DIM]);
369 mve_cmplx_sum_intra_vec_f16(acc1, &pOut[1 * CMPLX_DIM * MATRIX_DIM]);
370 mve_cmplx_sum_intra_vec_f16(acc2, &pOut[2 * CMPLX_DIM * MATRIX_DIM]);
371 mve_cmplx_sum_intra_vec_f16(acc3, &pOut[3 * CMPLX_DIM * MATRIX_DIM]);
372 /*
373 * Return to application
374 */
375 return (ARM_MATH_SUCCESS);
376 #undef MATRIX_DIM
377 }
378
379
380
arm_mat_cmplx_mult_f16(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)381 arm_status arm_mat_cmplx_mult_f16(
382 const arm_matrix_instance_f16 * pSrcA,
383 const arm_matrix_instance_f16 * pSrcB,
384 arm_matrix_instance_f16 * pDst)
385 {
386 float16_t const *pInB = (float16_t const *) pSrcB->pData; /* input data matrix pointer B */
387 float16_t const *pInA = (float16_t const *) pSrcA->pData; /* input data matrix pointer A */
388 float16_t *pOut = pDst->pData; /* output data matrix pointer */
389 float16_t *px; /* Temporary output data matrix pointer */
390 uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
391 uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
392 uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
393 uint16_t col, i = 0U, row = numRowsA; /* loop counters */
394 arm_status status; /* status of matrix multiplication */
395 uint16x8_t vecOffs, vecColBOffs;
396 uint32_t blkCnt,rowCnt; /* loop counters */
397
398 #ifdef ARM_MATH_MATRIX_CHECK
399
400 /* Check for matrix mismatch condition */
401 if ((pSrcA->numCols != pSrcB->numRows) ||
402 (pSrcA->numRows != pDst->numRows) ||
403 (pSrcB->numCols != pDst->numCols) )
404 {
405 /* Set status as ARM_MATH_SIZE_MISMATCH */
406 status = ARM_MATH_SIZE_MISMATCH;
407 }
408 else
409
410 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
411
412 {
413
414 /*
415 * small squared matrix specialized routines
416 */
417 if (numRowsA == numColsB && numColsB == numColsA)
418 {
419 if (numRowsA == 1)
420 {
421 pOut[0] = (_Float16)pInA[0] * (_Float16)pInB[0] - (_Float16)pInA[1] * (_Float16)pInB[1];
422 pOut[1] = (_Float16)pInA[0] * (_Float16)pInB[1] + (_Float16)pInA[1] * (_Float16)pInB[0];
423 return (ARM_MATH_SUCCESS);
424 }
425 else if (numRowsA == 2)
426 return arm_mat_cmplx_mult_f16_2x2_mve(pSrcA, pSrcB, pDst);
427 else if (numRowsA == 3)
428 return arm_mat_cmplx_mult_f16_3x3_mve(pSrcA, pSrcB, pDst);
429 else if (numRowsA == 4)
430 return arm_mat_cmplx_mult_f16_4x4_mve(pSrcA, pSrcB, pDst);
431 }
432
433 vecColBOffs[0] = 0;
434 vecColBOffs[1] = 1;
435 vecColBOffs[2] = numColsB * CMPLX_DIM;
436 vecColBOffs[3] = (numColsB * CMPLX_DIM) + 1;
437 vecColBOffs[4] = 2*numColsB * CMPLX_DIM;
438 vecColBOffs[5] = 2*(numColsB * CMPLX_DIM) + 1;
439 vecColBOffs[6] = 3*numColsB * CMPLX_DIM;
440 vecColBOffs[7] = 3*(numColsB * CMPLX_DIM) + 1;
441
442 /*
443 * The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
444 */
445
446 /*
447 * row loop
448 */
449 rowCnt = row >> 2;
450 while (rowCnt > 0u)
451 {
452 /*
453 * Output pointer is set to starting address of the row being processed
454 */
455 px = pOut + i * CMPLX_DIM;
456 i = i + 4 * numColsB;
457 /*
458 * For every row wise process, the column loop counter is to be initiated
459 */
460 col = numColsB;
461 /*
462 * For every row wise process, the pInB pointer is set
463 * to the starting address of the pSrcB data
464 */
465 pInB = (float16_t const *) pSrcB->pData;
466 /*
467 * column loop
468 */
469 while (col > 0u)
470 {
471 /*
472 * generate 4 columns elements
473 */
474 /*
475 * Matrix A columns number of MAC operations are to be performed
476 */
477
478 float16_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec;
479 float16_t const *pInA0 = pInA;
480 float16_t const *pInA1 = pInA0 + numColsA * CMPLX_DIM;
481 float16_t const *pInA2 = pInA1 + numColsA * CMPLX_DIM;
482 float16_t const *pInA3 = pInA2 + numColsA * CMPLX_DIM;
483 f16x8_t acc0, acc1, acc2, acc3;
484
485 acc0 = vdupq_n_f16(0.0f16);
486 acc1 = vdupq_n_f16(0.0f16);
487 acc2 = vdupq_n_f16(0.0f16);
488 acc3 = vdupq_n_f16(0.0f16);
489
490 pSrcA0Vec = (float16_t const *) pInA0;
491 pSrcA1Vec = (float16_t const *) pInA1;
492 pSrcA2Vec = (float16_t const *) pInA2;
493 pSrcA3Vec = (float16_t const *) pInA3;
494
495 vecOffs = vecColBOffs;
496
497 /*
498 * process 1 x 4 block output
499 */
500 blkCnt = (numColsA * CMPLX_DIM) >> 3;
501 while (blkCnt > 0U)
502 {
503 f16x8_t vecB, vecA;
504
505 vecB = vldrhq_gather_shifted_offset_f16(pInB, vecOffs);
506 /*
507 * move Matrix B read offsets, 4 rows down
508 */
509 vecOffs = vaddq_n_u16(vecOffs , (uint16_t) (numColsB * 4 * CMPLX_DIM));
510
511 vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 8;
512 acc0 = vcmlaq(acc0, vecA, vecB);
513 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
514
515 vecA = vld1q(pSrcA1Vec); pSrcA1Vec += 8;
516 acc1 = vcmlaq(acc1, vecA, vecB);
517 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
518
519 vecA = vld1q(pSrcA2Vec); pSrcA2Vec += 8;
520 acc2 = vcmlaq(acc2, vecA, vecB);
521 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
522
523 vecA = vld1q(pSrcA3Vec); pSrcA3Vec += 8;
524 acc3 = vcmlaq(acc3, vecA, vecB);
525 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
526
527 blkCnt--;
528 }
529 /*
530 * Unsupported addressing mode compiler crash
531 */
532 /*
533 * tail
534 * (will be merged thru tail predication)
535 */
536 blkCnt = (numColsA * CMPLX_DIM) & 7;
537 if (blkCnt > 0U)
538 {
539 mve_pred16_t p0 = vctp16q(blkCnt);
540 f16x8_t vecB, vecA;
541
542 vecB = vldrhq_gather_shifted_offset_z_f16(pInB, vecOffs, p0);
543 /*
544 * move Matrix B read offsets, 4 rows down
545 */
546 vecOffs = vaddq_n_u16(vecOffs, (uint16_t) (numColsB * 4 * CMPLX_DIM));
547
548 vecA = vld1q(pSrcA0Vec);
549 acc0 = vcmlaq(acc0, vecA, vecB);
550 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
551
552 vecA = vld1q(pSrcA1Vec);
553 acc1 = vcmlaq(acc1, vecA, vecB);
554 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
555
556 vecA = vld1q(pSrcA2Vec);
557 acc2 = vcmlaq(acc2, vecA, vecB);
558 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
559
560 vecA = vld1q(pSrcA3Vec);
561 acc3 = vcmlaq(acc3, vecA, vecB);
562 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
563
564 }
565
566
567 mve_cmplx_sum_intra_vec_f16(acc0, &px[0 * CMPLX_DIM * numColsB + 0]);
568 mve_cmplx_sum_intra_vec_f16(acc1, &px[1 * CMPLX_DIM * numColsB + 0]);
569 mve_cmplx_sum_intra_vec_f16(acc2, &px[2 * CMPLX_DIM * numColsB + 0]);
570 mve_cmplx_sum_intra_vec_f16(acc3, &px[3 * CMPLX_DIM * numColsB + 0]);
571
572 px += CMPLX_DIM;
573 /*
574 * Decrement the column loop counter
575 */
576 col--;
577 /*
578 * Update the pointer pInB to point to the starting address of the next column
579 */
580 pInB = (float16_t const *) pSrcB->pData + (numColsB - col) * CMPLX_DIM;
581 }
582
583 /*
584 * Update the pointer pInA to point to the starting address of the next row
585 */
586 pInA += (numColsA * 4) * CMPLX_DIM;
587 /*
588 * Decrement the row loop counter
589 */
590 rowCnt --;
591
592 }
593
594 rowCnt = row & 3;
595 while (rowCnt > 0u)
596 {
597 /*
598 * Output pointer is set to starting address of the row being processed
599 */
600 px = pOut + i * CMPLX_DIM;
601 i = i + numColsB;
602 /*
603 * For every row wise process, the column loop counter is to be initiated
604 */
605 col = numColsB;
606 /*
607 * For every row wise process, the pInB pointer is set
608 * to the starting address of the pSrcB data
609 */
610 pInB = (float16_t const *) pSrcB->pData;
611 /*
612 * column loop
613 */
614 while (col > 0u)
615 {
616 /*
617 * generate 4 columns elements
618 */
619 /*
620 * Matrix A columns number of MAC operations are to be performed
621 */
622
623 float16_t const *pSrcA0Vec;
624 float16_t const *pInA0 = pInA;
625 f16x8_t acc0;
626
627 acc0 = vdupq_n_f16(0.0f16);
628
629 pSrcA0Vec = (float16_t const *) pInA0;
630
631 vecOffs = vecColBOffs;
632
633 /*
634 * process 1 x 4 block output
635 */
636 blkCnt = (numColsA * CMPLX_DIM) >> 3;
637 while (blkCnt > 0U)
638 {
639 f16x8_t vecB, vecA;
640
641 vecB = vldrhq_gather_shifted_offset(pInB, vecOffs);
642 /*
643 * move Matrix B read offsets, 4 rows down
644 */
645 vecOffs = vaddq_n_u16(vecOffs, (uint16_t) (4*numColsB * CMPLX_DIM));
646
647 vecA = vld1q(pSrcA0Vec);
648 pSrcA0Vec += 8;
649 acc0 = vcmlaq(acc0, vecA, vecB);
650 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
651
652
653 blkCnt--;
654 }
655
656
657 /*
658 * tail
659 */
660 blkCnt = (numColsA * CMPLX_DIM) & 7;
661 if (blkCnt > 0U)
662 {
663 mve_pred16_t p0 = vctp16q(blkCnt);
664 f16x8_t vecB, vecA;
665
666 vecB = vldrhq_gather_shifted_offset_z(pInB, vecOffs, p0);
667
668 vecA = vld1q(pSrcA0Vec);
669 acc0 = vcmlaq(acc0, vecA, vecB);
670 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
671
672 }
673
674 mve_cmplx_sum_intra_vec_f16(acc0, &px[0]);
675
676
677 px += CMPLX_DIM;
678 /*
679 * Decrement the column loop counter
680 */
681 col--;
682 /*
683 * Update the pointer pInB to point to the starting address of the next column
684 */
685 pInB = (float16_t const *) pSrcB->pData + (numColsB - col) * CMPLX_DIM;
686 }
687
688 /*
689 * Update the pointer pInA to point to the starting address of the next row
690 */
691 pInA += numColsA * CMPLX_DIM;
692 rowCnt--;
693 }
694
695 /*
696 * set status as ARM_MATH_SUCCESS
697 */
698 status = ARM_MATH_SUCCESS;
699 }
700 /*
701 * Return to application
702 */
703 return (status);
704 }
705 #else
706
arm_mat_cmplx_mult_f16(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)707 arm_status arm_mat_cmplx_mult_f16(
708 const arm_matrix_instance_f16 * pSrcA,
709 const arm_matrix_instance_f16 * pSrcB,
710 arm_matrix_instance_f16 * pDst)
711 {
712 float16_t *pIn1 = pSrcA->pData; /* Input data matrix pointer A */
713 float16_t *pIn2 = pSrcB->pData; /* Input data matrix pointer B */
714 float16_t *pInA = pSrcA->pData; /* Input data matrix pointer A */
715 float16_t *pOut = pDst->pData; /* Output data matrix pointer */
716 float16_t *px; /* Temporary output data matrix pointer */
717 uint16_t numRowsA = pSrcA->numRows; /* Number of rows of input matrix A */
718 uint16_t numColsB = pSrcB->numCols; /* Number of columns of input matrix B */
719 uint16_t numColsA = pSrcA->numCols; /* Number of columns of input matrix A */
720 _Float16 sumReal, sumImag; /* Accumulator */
721 _Float16 a1, b1, c1, d1;
722 uint32_t col, i = 0U, j, row = numRowsA, colCnt; /* loop counters */
723 arm_status status; /* status of matrix multiplication */
724
725 #if defined (ARM_MATH_LOOPUNROLL)
726 _Float16 a0, b0, c0, d0;
727 #endif
728
729 #ifdef ARM_MATH_MATRIX_CHECK
730
731 /* Check for matrix mismatch condition */
732 if ((pSrcA->numCols != pSrcB->numRows) ||
733 (pSrcA->numRows != pDst->numRows) ||
734 (pSrcB->numCols != pDst->numCols) )
735 {
736 /* Set status as ARM_MATH_SIZE_MISMATCH */
737 status = ARM_MATH_SIZE_MISMATCH;
738 }
739 else
740
741 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
742
743 {
744 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
745 /* row loop */
746 do
747 {
748 /* Output pointer is set to starting address of the row being processed */
749 px = pOut + 2 * i;
750
751 /* For every row wise process, the column loop counter is to be initiated */
752 col = numColsB;
753
754 /* For every row wise process, the pIn2 pointer is set
755 ** to the starting address of the pSrcB data */
756 pIn2 = pSrcB->pData;
757
758 j = 0U;
759
760 /* column loop */
761 do
762 {
763 /* Set the variable sum, that acts as accumulator, to zero */
764 sumReal = 0.0f16;
765 sumImag = 0.0f16;
766
767 /* Initiate pointer pIn1 to point to starting address of column being processed */
768 pIn1 = pInA;
769
770 #if defined (ARM_MATH_LOOPUNROLL)
771
772 /* Apply loop unrolling and compute 4 MACs simultaneously. */
773 colCnt = numColsA >> 2U;
774
775 /* matrix multiplication */
776 while (colCnt > 0U)
777 {
778
779 /* Reading real part of complex matrix A */
780 a0 = *pIn1;
781
782 /* Reading real part of complex matrix B */
783 c0 = *pIn2;
784
785 /* Reading imaginary part of complex matrix A */
786 b0 = *(pIn1 + 1U);
787
788 /* Reading imaginary part of complex matrix B */
789 d0 = *(pIn2 + 1U);
790
791 /* Multiply and Accumlates */
792 sumReal += a0 * c0;
793 sumImag += b0 * c0;
794
795 /* update pointers */
796 pIn1 += 2U;
797 pIn2 += 2 * numColsB;
798
799 /* Multiply and Accumlates */
800 sumReal -= b0 * d0;
801 sumImag += a0 * d0;
802
803 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
804
805 /* read real and imag values from pSrcA and pSrcB buffer */
806 a1 = *(pIn1 );
807 c1 = *(pIn2 );
808 b1 = *(pIn1 + 1U);
809 d1 = *(pIn2 + 1U);
810
811 /* Multiply and Accumlates */
812 sumReal += a1 * c1;
813 sumImag += b1 * c1;
814
815 /* update pointers */
816 pIn1 += 2U;
817 pIn2 += 2 * numColsB;
818
819 /* Multiply and Accumlates */
820 sumReal -= b1 * d1;
821 sumImag += a1 * d1;
822
823 a0 = *(pIn1 );
824 c0 = *(pIn2 );
825 b0 = *(pIn1 + 1U);
826 d0 = *(pIn2 + 1U);
827
828 /* Multiply and Accumlates */
829 sumReal += a0 * c0;
830 sumImag += b0 * c0;
831
832 /* update pointers */
833 pIn1 += 2U;
834 pIn2 += 2 * numColsB;
835
836 /* Multiply and Accumlates */
837 sumReal -= b0 * d0;
838 sumImag += a0 * d0;
839
840 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
841
842 a1 = *(pIn1 );
843 c1 = *(pIn2 );
844 b1 = *(pIn1 + 1U);
845 d1 = *(pIn2 + 1U);
846
847 /* Multiply and Accumlates */
848 sumReal += a1 * c1;
849 sumImag += b1 * c1;
850
851 /* update pointers */
852 pIn1 += 2U;
853 pIn2 += 2 * numColsB;
854
855 /* Multiply and Accumlates */
856 sumReal -= b1 * d1;
857 sumImag += a1 * d1;
858
859 /* Decrement loop count */
860 colCnt--;
861 }
862
863 /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
864 ** No loop unrolling is used. */
865 colCnt = numColsA % 0x4U;
866
867 #else
868
869 /* Initialize blkCnt with number of samples */
870 colCnt = numColsA;
871
872 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
873
874 while (colCnt > 0U)
875 {
876 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
877 a1 = *(pIn1 );
878 c1 = *(pIn2 );
879 b1 = *(pIn1 + 1U);
880 d1 = *(pIn2 + 1U);
881
882 /* Multiply and Accumlates */
883 sumReal += a1 * c1;
884 sumImag += b1 * c1;
885
886 /* update pointers */
887 pIn1 += 2U;
888 pIn2 += 2 * numColsB;
889
890 /* Multiply and Accumlates */
891 sumReal -= b1 * d1;
892 sumImag += a1 * d1;
893
894 /* Decrement loop counter */
895 colCnt--;
896 }
897
898 /* Store result in destination buffer */
899 *px++ = sumReal;
900 *px++ = sumImag;
901
902 /* Update pointer pIn2 to point to starting address of next column */
903 j++;
904 pIn2 = pSrcB->pData + 2U * j;
905
906 /* Decrement column loop counter */
907 col--;
908
909 } while (col > 0U);
910
911 /* Update pointer pInA to point to starting address of next row */
912 i = i + numColsB;
913 pInA = pInA + 2 * numColsA;
914
915 /* Decrement row loop counter */
916 row--;
917
918 } while (row > 0U);
919
920 /* Set status as ARM_MATH_SUCCESS */
921 status = ARM_MATH_SUCCESS;
922 }
923
924 /* Return to application */
925 return (status);
926 }
927
928 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
929
930 /**
931 @} end of MatrixMult group
932 */
933
934 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */
935
936