1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_mat_cmplx_mult_f32.c
4 * Description: Floating-point matrix multiplication
5 *
6 * $Date: 23 April 2021
7 * $Revision: V1.9.0
8 *
9 * Target Processor: Cortex-M and Cortex-A cores
10 * -------------------------------------------------------------------- */
11 /*
12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13 *
14 * SPDX-License-Identifier: Apache-2.0
15 *
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
19 *
20 * www.apache.org/licenses/LICENSE-2.0
21 *
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
27 */
28
29 #include "dsp/matrix_functions.h"
30
31 /**
32 @ingroup groupMatrix
33 */
34
35 /**
36 @defgroup CmplxMatrixMult Complex Matrix Multiplication
37
38 Complex Matrix multiplication is only defined if the number of columns of the
39 first matrix equals the number of rows of the second matrix.
40 Multiplying an <code>M x N</code> matrix with an <code>N x P</code> matrix results
41 in an <code>M x P</code> matrix.
42 @par
43 When matrix size checking is enabled, the functions check:
44 - that the inner dimensions of <code>pSrcA</code> and <code>pSrcB</code> are equal;
45 - that the size of the output matrix equals the outer dimensions of <code>pSrcA</code> and <code>pSrcB</code>.
46 */
47
48
49 /**
50 @addtogroup CmplxMatrixMult
51 @{
52 */
53
54 /**
55 @brief Floating-point Complex matrix multiplication.
56 @param[in] pSrcA points to first input complex matrix structure
57 @param[in] pSrcB points to second input complex matrix structure
58 @param[out] pDst points to output complex matrix structure
59 @return execution status
60 - \ref ARM_MATH_SUCCESS : Operation successful
61 - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
62 */
63 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
64
65 #include "arm_helium_utils.h"
66
67 #define MATRIX_DIM2 2
68 #define MATRIX_DIM3 3
69 #define MATRIX_DIM4 4
70
arm_mat_cmplx_mult_f32_2x2_mve(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)71 __STATIC_INLINE arm_status arm_mat_cmplx_mult_f32_2x2_mve(
72 const arm_matrix_instance_f32 * pSrcA,
73 const arm_matrix_instance_f32 * pSrcB,
74 arm_matrix_instance_f32 * pDst)
75 {
76 float32_t const *pInB = pSrcB->pData; /* input data matrix pointer B */
77 float32_t *pInA = pSrcA->pData; /* input data matrix pointer A */
78 float32_t *pOut = pDst->pData; /* output data matrix pointer */
79 uint32x4_t vecColBOffs0;
80 float32_t *pInA0 = pInA;
81 float32_t *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM2;
82 f32x4_t acc0, acc1;
83 f32x4_t vecB, vecA;
84
85 static const uint32_t offsetB0[4] = { 0, 1,
86 MATRIX_DIM2 * CMPLX_DIM, MATRIX_DIM2 * CMPLX_DIM + 1
87 };
88
89 vecColBOffs0 = vldrwq_u32((uint32_t const *) offsetB0);
90
91 pInB = (float32_t const *)pSrcB->pData;
92
93 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
94
95 vecA = vldrwq_f32(pInA0);
96 acc0 = vcmulq(vecA, vecB);
97 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
98
99 vecA = vldrwq_f32(pInA1);
100 acc1 = vcmulq(vecA, vecB);
101 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
102
103 pOut[0 * CMPLX_DIM * MATRIX_DIM2 + 0] = acc0[0] + acc0[2];
104 pOut[0 * CMPLX_DIM * MATRIX_DIM2 + 1] = acc0[1] + acc0[3];
105 pOut[1 * CMPLX_DIM * MATRIX_DIM2 + 0] = acc1[0] + acc1[2];
106 pOut[1 * CMPLX_DIM * MATRIX_DIM2 + 1] = acc1[1] + acc1[3];
107 pOut += CMPLX_DIM;
108
109 /*
110 * move to next B column
111 */
112 pInB = pInB + CMPLX_DIM;
113
114 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
115
116 vecA = vldrwq_f32(pInA0);
117 acc0 = vcmulq(vecA, vecB);
118 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
119
120 vecA = vldrwq_f32(pInA1);
121 acc1 = vcmulq(vecA, vecB);
122 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
123
124 pOut[0 * CMPLX_DIM * MATRIX_DIM2 + 0] = acc0[0] + acc0[2];
125 pOut[0 * CMPLX_DIM * MATRIX_DIM2 + 1] = acc0[1] + acc0[3];
126 pOut[1 * CMPLX_DIM * MATRIX_DIM2 + 0] = acc1[0] + acc1[2];
127 pOut[1 * CMPLX_DIM * MATRIX_DIM2 + 1] = acc1[1] + acc1[3];
128 /*
129 * Return to application
130 */
131 return (ARM_MATH_SUCCESS);
132 }
133
134
arm_mat_cmplx_mult_f32_3x3_mve(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)135 __STATIC_INLINE arm_status arm_mat_cmplx_mult_f32_3x3_mve(
136 const arm_matrix_instance_f32 * pSrcA,
137 const arm_matrix_instance_f32 * pSrcB,
138 arm_matrix_instance_f32 * pDst)
139 {
140 float32_t const *pInB = pSrcB->pData; /* input data matrix pointer B */
141 float32_t *pInA = pSrcA->pData; /* input data matrix pointer A */
142 float32_t *pOut = pDst->pData; /* output data matrix pointer */
143 uint32x4_t vecColBOffs0, vecColBOffs1;
144 float32_t *pInA0 = pInA;
145 float32_t *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM3;
146 float32_t *pInA2 = pInA1 + CMPLX_DIM * MATRIX_DIM3;
147 f32x4_t acc0, acc1, acc2;
148 f32x4_t vecB, vecA;
149 /* enable predication to disable upper half complex vector element */
150 mve_pred16_t p0 = vctp32q(CMPLX_DIM);
151
152 static const uint32_t offsetB0[4] = { 0, 1,
153 MATRIX_DIM3 * CMPLX_DIM, MATRIX_DIM3 * CMPLX_DIM + 1
154 };
155 static const uint32_t offsetB1[4] = { 2 * MATRIX_DIM3 * CMPLX_DIM, 2 * MATRIX_DIM3 * CMPLX_DIM + 1,
156 INACTIVELANE, INACTIVELANE
157 };
158
159 vecColBOffs0 = vldrwq_u32((uint32_t const *) offsetB0);
160 vecColBOffs1 = vldrwq_u32((uint32_t const *) offsetB1);
161
162 pInB = (float32_t const *)pSrcB->pData;
163
164 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
165
166 vecA = vldrwq_f32(pInA0);
167 acc0 = vcmulq(vecA, vecB);
168 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
169
170 vecA = vldrwq_f32(pInA1);
171 acc1 = vcmulq(vecA, vecB);
172 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
173
174 vecA = vldrwq_f32(pInA2);
175 acc2 = vcmulq(vecA, vecB);
176 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
177
178
179 vecB = vldrwq_gather_shifted_offset_z(pInB, vecColBOffs1, p0);
180
181 vecA = vldrwq_f32(&pInA0[4]);
182 acc0 = vcmlaq(acc0, vecA, vecB);
183 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
184
185 vecA = vldrwq_f32(&pInA1[4]);
186 acc1 = vcmlaq(acc1, vecA, vecB);
187 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
188
189 vecA = vldrwq_f32(&pInA2[4]);
190 acc2 = vcmlaq(acc2, vecA, vecB);
191 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
192
193
194 pOut[0 * CMPLX_DIM * MATRIX_DIM3 + 0] = acc0[0] + acc0[2];
195 pOut[0 * CMPLX_DIM * MATRIX_DIM3 + 1] = acc0[1] + acc0[3];
196 pOut[1 * CMPLX_DIM * MATRIX_DIM3 + 0] = acc1[0] + acc1[2];
197 pOut[1 * CMPLX_DIM * MATRIX_DIM3 + 1] = acc1[1] + acc1[3];
198 pOut[2 * CMPLX_DIM * MATRIX_DIM3 + 0] = acc2[0] + acc2[2];
199 pOut[2 * CMPLX_DIM * MATRIX_DIM3 + 1] = acc2[1] + acc2[3];
200 pOut += CMPLX_DIM;
201
202 /*
203 * move to next B column
204 */
205 pInB = pInB + CMPLX_DIM;
206
207 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
208
209 vecA = vldrwq_f32(pInA0);
210 acc0 = vcmulq(vecA, vecB);
211 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
212
213 vecA = vldrwq_f32(pInA1);
214 acc1 = vcmulq(vecA, vecB);
215 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
216
217 vecA = vldrwq_f32(pInA2);
218 acc2 = vcmulq(vecA, vecB);
219 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
220
221 vecB = vldrwq_gather_shifted_offset_z(pInB, vecColBOffs1, p0);
222
223 vecA = vldrwq_f32(&pInA0[4]);
224 acc0 = vcmlaq(acc0, vecA, vecB);
225 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
226
227 vecA = vldrwq_f32(&pInA1[4]);
228 acc1 = vcmlaq(acc1, vecA, vecB);
229 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
230
231 vecA = vldrwq_f32(&pInA2[4]);
232 acc2 = vcmlaq(acc2, vecA, vecB);
233 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
234
235
236 pOut[0 * CMPLX_DIM * MATRIX_DIM3 + 0] = acc0[0] + acc0[2];
237 pOut[0 * CMPLX_DIM * MATRIX_DIM3 + 1] = acc0[1] + acc0[3];
238 pOut[1 * CMPLX_DIM * MATRIX_DIM3 + 0] = acc1[0] + acc1[2];
239 pOut[1 * CMPLX_DIM * MATRIX_DIM3 + 1] = acc1[1] + acc1[3];
240 pOut[2 * CMPLX_DIM * MATRIX_DIM3 + 0] = acc2[0] + acc2[2];
241 pOut[2 * CMPLX_DIM * MATRIX_DIM3 + 1] = acc2[1] + acc2[3];
242 pOut += CMPLX_DIM;
243
244 /*
245 * move to next B column
246 */
247 pInB = pInB + CMPLX_DIM;
248
249 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
250
251 vecA = vldrwq_f32(pInA0);
252 acc0 = vcmulq(vecA, vecB);
253 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
254
255 vecA = vldrwq_f32(pInA1);
256 acc1 = vcmulq(vecA, vecB);
257 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
258
259 vecA = vldrwq_f32(pInA2);
260 acc2 = vcmulq(vecA, vecB);
261 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
262
263 vecB = vldrwq_gather_shifted_offset_z(pInB, vecColBOffs1, p0);
264
265 vecA = vldrwq_f32(&pInA0[4]);
266 acc0 = vcmlaq(acc0, vecA, vecB);
267 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
268
269 vecA = vldrwq_f32(&pInA1[4]);
270 acc1 = vcmlaq(acc1, vecA, vecB);
271 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
272
273 vecA = vldrwq_f32(&pInA2[4]);
274 acc2 = vcmlaq(acc2, vecA, vecB);
275 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
276
277
278 pOut[0 * CMPLX_DIM * MATRIX_DIM3 + 0] = acc0[0] + acc0[2];
279 pOut[0 * CMPLX_DIM * MATRIX_DIM3 + 1] = acc0[1] + acc0[3];
280 pOut[1 * CMPLX_DIM * MATRIX_DIM3 + 0] = acc1[0] + acc1[2];
281 pOut[1 * CMPLX_DIM * MATRIX_DIM3 + 1] = acc1[1] + acc1[3];
282 pOut[2 * CMPLX_DIM * MATRIX_DIM3 + 0] = acc2[0] + acc2[2];
283 pOut[2 * CMPLX_DIM * MATRIX_DIM3 + 1] = acc2[1] + acc2[3];
284 /*
285 * Return to application
286 */
287 return (ARM_MATH_SUCCESS);
288 }
289
290
291
arm_mat_cmplx_mult_f32_4x4_mve(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)292 __STATIC_INLINE arm_status arm_mat_cmplx_mult_f32_4x4_mve(
293 const arm_matrix_instance_f32 * pSrcA,
294 const arm_matrix_instance_f32 * pSrcB,
295 arm_matrix_instance_f32 * pDst)
296 {
297 float32_t const *pInB = pSrcB->pData; /* input data matrix pointer B */
298 float32_t *pInA = pSrcA->pData; /* input data matrix pointer A */
299 float32_t *pOut = pDst->pData; /* output data matrix pointer */
300 uint32x4_t vecColBOffs0, vecColBOffs1;
301 float32_t *pInA0 = pInA;
302 float32_t *pInA1 = pInA0 + CMPLX_DIM * MATRIX_DIM4;
303 float32_t *pInA2 = pInA1 + CMPLX_DIM * MATRIX_DIM4;
304 float32_t *pInA3 = pInA2 + CMPLX_DIM * MATRIX_DIM4;
305 f32x4_t acc0, acc1, acc2, acc3;
306 f32x4_t vecB, vecA;
307
308 static const uint32_t offsetB0[4] = { 0, 1,
309 MATRIX_DIM4 * CMPLX_DIM, MATRIX_DIM4 * CMPLX_DIM + 1
310 };
311 static const uint32_t offsetB1[4] = { 2 * MATRIX_DIM4 * CMPLX_DIM, 2 * MATRIX_DIM4 * CMPLX_DIM + 1,
312 3 * MATRIX_DIM4 * CMPLX_DIM, 3 * MATRIX_DIM4 * CMPLX_DIM + 1
313 };
314
315 vecColBOffs0 = vldrwq_u32((uint32_t const *) offsetB0);
316 vecColBOffs1 = vldrwq_u32((uint32_t const *) offsetB1);
317
318 pInB = (float32_t const *)pSrcB->pData;
319
320 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
321
322 vecA = vldrwq_f32(pInA0);
323 acc0 = vcmulq(vecA, vecB);
324 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
325
326 vecA = vldrwq_f32(pInA1);
327 acc1 = vcmulq(vecA, vecB);
328 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
329
330 vecA = vldrwq_f32(pInA2);
331 acc2 = vcmulq(vecA, vecB);
332 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
333
334 vecA = vldrwq_f32(pInA3);
335 acc3 = vcmulq(vecA, vecB);
336 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
337
338 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs1);
339
340 vecA = vldrwq_f32(&pInA0[4]);
341 acc0 = vcmlaq(acc0, vecA, vecB);
342 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
343
344 vecA = vldrwq_f32(&pInA1[4]);
345 acc1 = vcmlaq(acc1, vecA, vecB);
346 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
347
348 vecA = vldrwq_f32(&pInA2[4]);
349 acc2 = vcmlaq(acc2, vecA, vecB);
350 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
351
352 vecA = vldrwq_f32(&pInA3[4]);
353 acc3 = vcmlaq(acc3, vecA, vecB);
354 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
355
356 pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc0[0] + acc0[2];
357 pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc0[1] + acc0[3];
358 pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc1[0] + acc1[2];
359 pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc1[1] + acc1[3];
360 pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc2[0] + acc2[2];
361 pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc2[1] + acc2[3];
362 pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc3[0] + acc3[2];
363 pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc3[1] + acc3[3];
364 pOut += CMPLX_DIM;
365
366 /*
367 * move to next B column
368 */
369 pInB = pInB + CMPLX_DIM;
370
371 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
372
373 vecA = vldrwq_f32(pInA0);
374 acc0 = vcmulq(vecA, vecB);
375 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
376
377 vecA = vldrwq_f32(pInA1);
378 acc1 = vcmulq(vecA, vecB);
379 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
380
381 vecA = vldrwq_f32(pInA2);
382 acc2 = vcmulq(vecA, vecB);
383 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
384
385 vecA = vldrwq_f32(pInA3);
386 acc3 = vcmulq(vecA, vecB);
387 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
388
389 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs1);
390
391 vecA = vldrwq_f32(&pInA0[4]);
392 acc0 = vcmlaq(acc0, vecA, vecB);
393 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
394
395 vecA = vldrwq_f32(&pInA1[4]);
396 acc1 = vcmlaq(acc1, vecA, vecB);
397 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
398
399 vecA = vldrwq_f32(&pInA2[4]);
400 acc2 = vcmlaq(acc2, vecA, vecB);
401 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
402
403 vecA = vldrwq_f32(&pInA3[4]);
404 acc3 = vcmlaq(acc3, vecA, vecB);
405 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
406
407 pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc0[0] + acc0[2];
408 pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc0[1] + acc0[3];
409 pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc1[0] + acc1[2];
410 pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc1[1] + acc1[3];
411 pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc2[0] + acc2[2];
412 pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc2[1] + acc2[3];
413 pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc3[0] + acc3[2];
414 pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc3[1] + acc3[3];
415 pOut += CMPLX_DIM;
416
417 /*
418 * move to next B column
419 */
420 pInB = pInB + CMPLX_DIM;
421
422 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
423
424 vecA = vldrwq_f32(pInA0);
425 acc0 = vcmulq(vecA, vecB);
426 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
427
428 vecA = vldrwq_f32(pInA1);
429 acc1 = vcmulq(vecA, vecB);
430 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
431
432 vecA = vldrwq_f32(pInA2);
433 acc2 = vcmulq(vecA, vecB);
434 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
435
436 vecA = vldrwq_f32(pInA3);
437 acc3 = vcmulq(vecA, vecB);
438 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
439
440 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs1);
441
442 vecA = vldrwq_f32(&pInA0[4]);
443 acc0 = vcmlaq(acc0, vecA, vecB);
444 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
445
446 vecA = vldrwq_f32(&pInA1[4]);
447 acc1 = vcmlaq(acc1, vecA, vecB);
448 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
449
450 vecA = vldrwq_f32(&pInA2[4]);
451 acc2 = vcmlaq(acc2, vecA, vecB);
452 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
453
454 vecA = vldrwq_f32(&pInA3[4]);
455 acc3 = vcmlaq(acc3, vecA, vecB);
456 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
457
458 pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc0[0] + acc0[2];
459 pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc0[1] + acc0[3];
460 pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc1[0] + acc1[2];
461 pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc1[1] + acc1[3];
462 pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc2[0] + acc2[2];
463 pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc2[1] + acc2[3];
464 pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc3[0] + acc3[2];
465 pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc3[1] + acc3[3];
466 pOut += CMPLX_DIM;
467
468 /*
469 * move to next B column
470 */
471 pInB = pInB + CMPLX_DIM;
472
473 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs0);
474
475 vecA = vldrwq_f32(pInA0);
476 acc0 = vcmulq(vecA, vecB);
477 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
478
479 vecA = vldrwq_f32(pInA1);
480 acc1 = vcmulq(vecA, vecB);
481 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
482
483 vecA = vldrwq_f32(pInA2);
484 acc2 = vcmulq(vecA, vecB);
485 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
486
487 vecA = vldrwq_f32(pInA3);
488 acc3 = vcmulq(vecA, vecB);
489 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
490
491 vecB = vldrwq_gather_shifted_offset(pInB, vecColBOffs1);
492
493 vecA = vldrwq_f32(&pInA0[4]);
494 acc0 = vcmlaq(acc0, vecA, vecB);
495 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
496
497 vecA = vldrwq_f32(&pInA1[4]);
498 acc1 = vcmlaq(acc1, vecA, vecB);
499 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
500
501 vecA = vldrwq_f32(&pInA2[4]);
502 acc2 = vcmlaq(acc2, vecA, vecB);
503 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
504
505 vecA = vldrwq_f32(&pInA3[4]);
506 acc3 = vcmlaq(acc3, vecA, vecB);
507 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
508
509 pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc0[0] + acc0[2];
510 pOut[0 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc0[1] + acc0[3];
511 pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc1[0] + acc1[2];
512 pOut[1 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc1[1] + acc1[3];
513 pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc2[0] + acc2[2];
514 pOut[2 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc2[1] + acc2[3];
515 pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 0] = acc3[0] + acc3[2];
516 pOut[3 * CMPLX_DIM * MATRIX_DIM4 + 1] = acc3[1] + acc3[3];
517 /*
518 * Return to application
519 */
520 return (ARM_MATH_SUCCESS);
521 }
522
arm_mat_cmplx_mult_f32(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)523 ARM_DSP_ATTRIBUTE arm_status arm_mat_cmplx_mult_f32(
524 const arm_matrix_instance_f32 * pSrcA,
525 const arm_matrix_instance_f32 * pSrcB,
526 arm_matrix_instance_f32 * pDst)
527 {
528 float32_t const *pInB = (float32_t const *) pSrcB->pData; /* input data matrix pointer B */
529 float32_t const *pInA = (float32_t const *) pSrcA->pData; /* input data matrix pointer A */
530 float32_t *pOut = pDst->pData; /* output data matrix pointer */
531 float32_t *px; /* Temporary output data matrix pointer */
532 uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
533 uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
534 uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
535 uint16_t col, i = 0U, row = numRowsA; /* loop counters */
536 arm_status status; /* status of matrix multiplication */
537 uint32x4_t vecOffs, vecColBOffs;
538 uint32_t blkCnt, rowCnt; /* loop counters */
539
540 #ifdef ARM_MATH_MATRIX_CHECK
541
542
543 /* Check for matrix mismatch condition */
544 if ((pSrcA->numCols != pSrcB->numRows) ||
545 (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
546 {
547
548 /* Set status as ARM_MATH_SIZE_MISMATCH */
549 status = ARM_MATH_SIZE_MISMATCH;
550 }
551 else
552 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
553
554 {
555 /*
556 * small squared matrix specialized routines
557 */
558 if (numRowsA == numColsB && numColsB == numColsA)
559 {
560 if (numRowsA == 1)
561 {
562 pOut[0] = pInA[0] * pInB[0] - pInA[1] * pInB[1];
563 pOut[1] = pInA[0] * pInB[1] + pInA[1] * pInB[0];
564 return (ARM_MATH_SUCCESS);
565 }
566 else if (numRowsA == 2)
567 return arm_mat_cmplx_mult_f32_2x2_mve(pSrcA, pSrcB, pDst);
568 else if (numRowsA == 3)
569 return arm_mat_cmplx_mult_f32_3x3_mve(pSrcA, pSrcB, pDst);
570 else if (numRowsA == 4)
571 return arm_mat_cmplx_mult_f32_4x4_mve(pSrcA, pSrcB, pDst);
572 }
573
574 vecColBOffs[0] = 0;
575 vecColBOffs[1] = 1;
576 vecColBOffs[2] = numColsB * CMPLX_DIM;
577 vecColBOffs[3] = (numColsB * CMPLX_DIM) + 1;
578
579 /*
580 * The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
581 */
582
583 /*
584 * row loop
585 */
586 rowCnt = row >> 2;
587 while (rowCnt > 0u)
588 {
589 /*
590 * Output pointer is set to starting address of the row being processed
591 */
592 px = pOut + i * CMPLX_DIM;
593 i = i + 4 * numColsB;
594 /*
595 * For every row wise process, the column loop counter is to be initiated
596 */
597 col = numColsB;
598 /*
599 * For every row wise process, the pInB pointer is set
600 * to the starting address of the pSrcB data
601 */
602 pInB = (float32_t const *) pSrcB->pData;
603 /*
604 * column loop
605 */
606 while (col > 0u)
607 {
608 /*
609 * generate 4 columns elements
610 */
611 /*
612 * Matrix A columns number of MAC operations are to be performed
613 */
614
615 float32_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec;
616 float32_t const *pInA0 = pInA;
617 float32_t const *pInA1 = pInA0 + numColsA * CMPLX_DIM;
618 float32_t const *pInA2 = pInA1 + numColsA * CMPLX_DIM;
619 float32_t const *pInA3 = pInA2 + numColsA * CMPLX_DIM;
620 f32x4_t acc0, acc1, acc2, acc3;
621
622 acc0 = vdupq_n_f32(0.0f);
623 acc1 = vdupq_n_f32(0.0f);
624 acc2 = vdupq_n_f32(0.0f);
625 acc3 = vdupq_n_f32(0.0f);
626
627 pSrcA0Vec = (float32_t const *) pInA0;
628 pSrcA1Vec = (float32_t const *) pInA1;
629 pSrcA2Vec = (float32_t const *) pInA2;
630 pSrcA3Vec = (float32_t const *) pInA3;
631
632 vecOffs = vecColBOffs;
633
634 /*
635 * process 1 x 4 block output
636 */
637 blkCnt = (numColsA * CMPLX_DIM) >> 2;
638 while (blkCnt > 0U)
639 {
640 f32x4_t vecB, vecA;
641
642 vecB = vldrwq_gather_shifted_offset(pInB, vecOffs);
643 /*
644 * move Matrix B read offsets, 4 rows down
645 */
646 vecOffs = vecOffs + (uint32_t) (numColsB * 2 * CMPLX_DIM);
647
648 vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 4;
649 acc0 = vcmlaq(acc0, vecA, vecB);
650 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
651 vecA = vld1q(pSrcA1Vec); pSrcA1Vec += 4;
652 acc1 = vcmlaq(acc1, vecA, vecB);
653 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
654 vecA = vld1q(pSrcA2Vec); pSrcA2Vec += 4;
655 acc2 = vcmlaq(acc2, vecA, vecB);
656 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
657 vecA = vld1q(pSrcA3Vec); pSrcA3Vec += 4;
658 acc3 = vcmlaq(acc3, vecA, vecB);
659 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
660
661 blkCnt--;
662 }
663
664
665 /*
666 * tail
667 * (will be merged thru tail predication)
668 */
669 blkCnt = (numColsA * CMPLX_DIM) & 3;
670 if (blkCnt > 0U)
671 {
672 mve_pred16_t p0 = vctp32q(blkCnt);
673 f32x4_t vecB, vecA;
674
675 vecB = vldrwq_gather_shifted_offset_z(pInB, vecOffs, p0);
676 /*
677 * move Matrix B read offsets, 4 rows down
678 */
679 vecOffs = vecOffs + (uint32_t) (numColsB * 2 * CMPLX_DIM);
680
681 vecA = vld1q(pSrcA0Vec);
682 acc0 = vcmlaq(acc0, vecA, vecB);
683 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
684 vecA = vld1q(pSrcA1Vec);
685 acc1 = vcmlaq(acc1, vecA, vecB);
686 acc1 = vcmlaq_rot90(acc1, vecA, vecB);
687 vecA = vld1q(pSrcA2Vec);
688 acc2 = vcmlaq(acc2, vecA, vecB);
689 acc2 = vcmlaq_rot90(acc2, vecA, vecB);
690 vecA = vld1q(pSrcA3Vec);
691 acc3 = vcmlaq(acc3, vecA, vecB);
692 acc3 = vcmlaq_rot90(acc3, vecA, vecB);
693
694 }
695
696 px[0 * CMPLX_DIM * numColsB + 0] = acc0[0] + acc0[2];
697 px[0 * CMPLX_DIM * numColsB + 1] = acc0[1] + acc0[3];
698 px[1 * CMPLX_DIM * numColsB + 0] = acc1[0] + acc1[2];
699 px[1 * CMPLX_DIM * numColsB + 1] = acc1[1] + acc1[3];
700 px[2 * CMPLX_DIM * numColsB + 0] = acc2[0] + acc2[2];
701 px[2 * CMPLX_DIM * numColsB + 1] = acc2[1] + acc2[3];
702 px[3 * CMPLX_DIM * numColsB + 0] = acc3[0] + acc3[2];
703 px[3 * CMPLX_DIM * numColsB + 1] = acc3[1] + acc3[3];
704 px += CMPLX_DIM;
705 /*
706 * Decrement the column loop counter
707 */
708 col--;
709 /*
710 * Update the pointer pInB to point to the starting address of the next column
711 */
712 pInB = (float32_t const *) pSrcB->pData + (numColsB - col) * CMPLX_DIM;
713 }
714
715 /*
716 * Update the pointer pInA to point to the starting address of the next row
717 */
718 pInA += (numColsA * 4) * CMPLX_DIM;
719 /*
720 * Decrement the row loop counter
721 */
722 rowCnt --;
723
724 }
725
726 rowCnt = row & 3;
727 while (rowCnt > 0u)
728 {
729 /*
730 * Output pointer is set to starting address of the row being processed
731 */
732 px = pOut + i * CMPLX_DIM;
733 i = i + numColsB;
734 /*
735 * For every row wise process, the column loop counter is to be initiated
736 */
737 col = numColsB;
738 /*
739 * For every row wise process, the pInB pointer is set
740 * to the starting address of the pSrcB data
741 */
742 pInB = (float32_t const *) pSrcB->pData;
743 /*
744 * column loop
745 */
746 while (col > 0u)
747 {
748 /*
749 * generate 4 columns elements
750 */
751 /*
752 * Matrix A columns number of MAC operations are to be performed
753 */
754
755 float32_t const *pSrcA0Vec;
756 float32_t const *pInA0 = pInA;
757 f32x4_t acc0;
758
759 acc0 = vdupq_n_f32(0.0f);
760
761 pSrcA0Vec = (float32_t const *) pInA0;
762
763 vecOffs = vecColBOffs;
764
765 /*
766 * process 1 x 4 block output
767 */
768 blkCnt = (numColsA * CMPLX_DIM) >> 2;
769 while (blkCnt > 0U)
770 {
771 f32x4_t vecB, vecA;
772
773 vecB = vldrwq_gather_shifted_offset(pInB, vecOffs);
774 /*
775 * move Matrix B read offsets, 4 rows down
776 */
777 vecOffs = vecOffs + (uint32_t) (numColsB * 2 * CMPLX_DIM);
778
779 vecA = vld1q(pSrcA0Vec);
780 pSrcA0Vec += 4;
781 acc0 = vcmlaq(acc0, vecA, vecB);
782 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
783
784
785 blkCnt--;
786 }
787
788
789 /*
790 * tail
791 */
792 blkCnt = (numColsA * CMPLX_DIM) & 3;
793 if (blkCnt > 0U)
794 {
795 mve_pred16_t p0 = vctp32q(blkCnt);
796 f32x4_t vecB, vecA;
797
798 vecB = vldrwq_gather_shifted_offset_z(pInB, vecOffs, p0);
799
800 vecA = vld1q(pSrcA0Vec);
801 acc0 = vcmlaq(acc0, vecA, vecB);
802 acc0 = vcmlaq_rot90(acc0, vecA, vecB);
803
804 }
805
806 px[0] = acc0[0] + acc0[2];
807 px[1] = acc0[1] + acc0[3];
808
809 px += CMPLX_DIM;
810 /*
811 * Decrement the column loop counter
812 */
813 col--;
814 /*
815 * Update the pointer pInB to point to the starting address of the next column
816 */
817 pInB = (float32_t const *) pSrcB->pData + (numColsB - col) * CMPLX_DIM;
818 }
819
820 /*
821 * Update the pointer pInA to point to the starting address of the next row
822 */
823 pInA += numColsA * CMPLX_DIM;
824 rowCnt--;
825 }
826
827
828 /* Set status as ARM_MATH_SUCCESS */
829 status = ARM_MATH_SUCCESS;
830 }
831
832 /* Return to application */
833 return (status);
834
835 }
836
837 #else
838 #if defined(ARM_MATH_NEON)
arm_mat_cmplx_mult_f32(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)839 ARM_DSP_ATTRIBUTE arm_status arm_mat_cmplx_mult_f32(
840 const arm_matrix_instance_f32 * pSrcA,
841 const arm_matrix_instance_f32 * pSrcB,
842 arm_matrix_instance_f32 * pDst)
843 {
844 float32_t *pIn1 = pSrcA->pData; /* input data matrix pointer A */
845 float32_t *pIn2 = pSrcB->pData; /* input data matrix pointer B */
846 float32_t *pInA = pSrcA->pData; /* input data matrix pointer A */
847 float32_t *pOut = pDst->pData; /* output data matrix pointer */
848 float32_t *px; /* Temporary output data matrix pointer */
849 uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
850 uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
851 uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
852 float32_t sumReal1, sumImag1; /* accumulator */
853 float32_t a1, a1B,b1, b1B, c1, d1;
854 float32_t sumReal2, sumImag2; /* accumulator */
855
856
857 float32x4x2_t a0V, a1V;
858 float32x4_t accR0,accI0, accR1,accI1,tempR, tempI;
859 float32x2_t accum = vdup_n_f32(0);
860 float32_t *pIn1B = pSrcA->pData;
861
862 uint16_t col, i = 0U, j, rowCnt, row = numRowsA, colCnt; /* loop counters */
863 arm_status status; /* status of matrix multiplication */
864 float32_t sumReal1B, sumImag1B;
865 float32_t sumReal2B, sumImag2B;
866 float32_t *pxB;
867
868 #ifdef ARM_MATH_MATRIX_CHECK
869
870
871 /* Check for matrix mismatch condition */
872 if ((pSrcA->numCols != pSrcB->numRows) ||
873 (pSrcA->numRows != pDst->numRows) || (pSrcB->numCols != pDst->numCols))
874 {
875
876 /* Set status as ARM_MATH_SIZE_MISMATCH */
877 status = ARM_MATH_SIZE_MISMATCH;
878 }
879 else
880 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
881
882 {
883 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
884
885 rowCnt = row >> 1;
886
887 /* Row loop */
888 while (rowCnt > 0U)
889 {
890 /* Output pointer is set to starting address of the row being processed */
891 px = pOut + 2 * i;
892 pxB = px + 2 * numColsB;
893
894 /* For every row wise process, the column loop counter is to be initiated */
895 col = numColsB;
896
897 /* For every row wise process, the pIn2 pointer is set
898 ** to the starting address of the pSrcB data */
899 pIn2 = pSrcB->pData;
900
901 j = 0U;
902
903 /* Column loop */
904 while (col > 0U)
905 {
906 /* Set the variable sum, that acts as accumulator, to zero */
907 sumReal1 = 0.0f;
908 sumImag1 = 0.0f;
909 sumReal1B = 0.0f;
910 sumImag1B = 0.0f;
911
912 sumReal2 = 0.0f;
913 sumImag2 = 0.0f;
914 sumReal2B = 0.0f;
915 sumImag2B = 0.0f;
916
917 /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
918 pIn1 = pInA;
919 pIn1B = pIn1 + 2*numColsA;
920
921 accR0 = vdupq_n_f32(0.0);
922 accI0 = vdupq_n_f32(0.0);
923 accR1 = vdupq_n_f32(0.0);
924 accI1 = vdupq_n_f32(0.0);
925
926 /* Compute 4 MACs simultaneously. */
927 colCnt = numColsA >> 2;
928
929 /* Matrix multiplication */
930 while (colCnt > 0U)
931 {
932 /* Reading real part of complex matrix A */
933 a0V = vld2q_f32(pIn1); // load & separate real/imag pSrcA (de-interleave 2)
934 a1V = vld2q_f32(pIn1B); // load & separate real/imag pSrcA (de-interleave 2)
935
936 pIn1 += 8;
937 pIn1B += 8;
938
939 tempR = vsetq_lane_f32(*pIn2,tempR,0);
940 tempI = vsetq_lane_f32(*(pIn2 + 1U),tempI,0);
941 pIn2 += 2 * numColsB;
942
943
944 tempR = vsetq_lane_f32(*pIn2,tempR,1);
945 tempI = vsetq_lane_f32(*(pIn2 + 1U),tempI,1);
946 pIn2 += 2 * numColsB;
947
948 tempR = vsetq_lane_f32(*pIn2,tempR,2);
949 tempI = vsetq_lane_f32(*(pIn2 + 1U),tempI,2);
950 pIn2 += 2 * numColsB;
951
952 tempR = vsetq_lane_f32(*pIn2,tempR,3);
953 tempI = vsetq_lane_f32(*(pIn2 + 1U),tempI,3);
954 pIn2 += 2 * numColsB;
955
956 accR0 = vmlaq_f32(accR0,a0V.val[0],tempR);
957 accR0 = vmlsq_f32(accR0,a0V.val[1],tempI);
958
959 accI0 = vmlaq_f32(accI0,a0V.val[1],tempR);
960 accI0 = vmlaq_f32(accI0,a0V.val[0],tempI);
961
962 accR1 = vmlaq_f32(accR1,a1V.val[0],tempR);
963 accR1 = vmlsq_f32(accR1,a1V.val[1],tempI);
964
965 accI1 = vmlaq_f32(accI1,a1V.val[1],tempR);
966 accI1 = vmlaq_f32(accI1,a1V.val[0],tempI);
967
968 /* Decrement the loop count */
969 colCnt--;
970 }
971
972 accum = vpadd_f32(vget_low_f32(accR0), vget_high_f32(accR0));
973 sumReal1 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
974
975 accum = vpadd_f32(vget_low_f32(accI0), vget_high_f32(accI0));
976 sumImag1 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
977
978 accum = vpadd_f32(vget_low_f32(accR1), vget_high_f32(accR1));
979 sumReal1B += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
980
981 accum = vpadd_f32(vget_low_f32(accI1), vget_high_f32(accI1));
982 sumImag1B += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
983
984 /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
985 ** No loop unrolling is used. */
986 colCnt = numColsA & 3;
987
988 while (colCnt > 0U)
989 {
990 /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
991 a1 = *pIn1;
992 a1B = *pIn1B;
993
994 c1 = *pIn2;
995
996 b1 = *(pIn1 + 1U);
997 b1B = *(pIn1B + 1U);
998
999 d1 = *(pIn2 + 1U);
1000
1001 sumReal1 += a1 * c1;
1002 sumImag1 += b1 * c1;
1003
1004 sumReal1B += a1B * c1;
1005 sumImag1B += b1B * c1;
1006
1007 pIn1 += 2U;
1008 pIn1B += 2U;
1009 pIn2 += 2 * numColsB;
1010
1011 sumReal2 -= b1 * d1;
1012 sumImag2 += a1 * d1;
1013
1014 sumReal2B -= b1B * d1;
1015 sumImag2B += a1B * d1;
1016
1017 /* Decrement the loop counter */
1018 colCnt--;
1019 }
1020
1021 sumReal1 += sumReal2;
1022 sumImag1 += sumImag2;
1023
1024 sumReal1B += sumReal2B;
1025 sumImag1B += sumImag2B;
1026
1027 /* Store the result in the destination buffer */
1028 *px++ = sumReal1;
1029 *px++ = sumImag1;
1030 *pxB++ = sumReal1B;
1031 *pxB++ = sumImag1B;
1032
1033 /* Update the pointer pIn2 to point to the starting address of the next column */
1034 j++;
1035 pIn2 = pSrcB->pData + 2U * j;
1036
1037 /* Decrement the column loop counter */
1038 col--;
1039 }
1040
1041 /* Update the pointer pInA to point to the starting address of the next 2 row */
1042 i = i + 2*numColsB;
1043 pInA = pInA + 4 * numColsA;
1044
1045 /* Decrement the row loop counter */
1046 rowCnt--;
1047 }
1048
1049 rowCnt = row & 1;
1050 while (rowCnt > 0U)
1051 {
1052 /* Output pointer is set to starting address of the row being processed */
1053 px = pOut + 2 * i;
1054
1055 /* For every row wise process, the column loop counter is to be initiated */
1056 col = numColsB;
1057
1058 /* For every row wise process, the pIn2 pointer is set
1059 ** to the starting address of the pSrcB data */
1060 pIn2 = pSrcB->pData;
1061
1062 j = 0U;
1063
1064 /* Column loop */
1065 while (col > 0U)
1066 {
1067 /* Set the variable sum, that acts as accumulator, to zero */
1068 sumReal1 = 0.0f;
1069 sumImag1 = 0.0f;
1070
1071 sumReal2 = 0.0f;
1072 sumImag2 = 0.0f;
1073
1074 /* Initiate the pointer pIn1 to point to the starting address of the column being processed */
1075 pIn1 = pInA;
1076
1077 accR0 = vdupq_n_f32(0.0);
1078 accI0 = vdupq_n_f32(0.0);
1079
1080 /* Compute 4 MACs simultaneously. */
1081 colCnt = numColsA >> 2;
1082
1083 /* Matrix multiplication */
1084 while (colCnt > 0U)
1085 {
1086 /* Reading real part of complex matrix A */
1087 a0V = vld2q_f32(pIn1); // load & separate real/imag pSrcA (de-interleave 2)
1088 pIn1 += 8;
1089
1090 tempR = vsetq_lane_f32(*pIn2,tempR,0);
1091 tempI = vsetq_lane_f32(*(pIn2 + 1U),tempI,0);
1092 pIn2 += 2 * numColsB;
1093
1094 tempR = vsetq_lane_f32(*pIn2,tempR,1);
1095 tempI = vsetq_lane_f32(*(pIn2 + 1U),tempI,1);
1096 pIn2 += 2 * numColsB;
1097
1098 tempR = vsetq_lane_f32(*pIn2,tempR,2);
1099 tempI = vsetq_lane_f32(*(pIn2 + 1U),tempI,2);
1100 pIn2 += 2 * numColsB;
1101
1102 tempR = vsetq_lane_f32(*pIn2,tempR,3);
1103 tempI = vsetq_lane_f32(*(pIn2 + 1U),tempI,3);
1104 pIn2 += 2 * numColsB;
1105
1106 accR0 = vmlaq_f32(accR0,a0V.val[0],tempR);
1107 accR0 = vmlsq_f32(accR0,a0V.val[1],tempI);
1108
1109 accI0 = vmlaq_f32(accI0,a0V.val[1],tempR);
1110 accI0 = vmlaq_f32(accI0,a0V.val[0],tempI);
1111
1112 /* Decrement the loop count */
1113 colCnt--;
1114 }
1115
1116 accum = vpadd_f32(vget_low_f32(accR0), vget_high_f32(accR0));
1117 sumReal1 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
1118
1119 accum = vpadd_f32(vget_low_f32(accI0), vget_high_f32(accI0));
1120 sumImag1 += vget_lane_f32(accum, 0) + vget_lane_f32(accum, 1);
1121
1122 /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
1123 ** No loop unrolling is used. */
1124 colCnt = numColsA & 3;
1125
1126 while (colCnt > 0U)
1127 {
1128 /* c(m,n) = a(1,1)*b(1,1) + a(1,2)*b(2,1) + ... + a(m,p)*b(p,n) */
1129 a1 = *pIn1;
1130 c1 = *pIn2;
1131
1132 b1 = *(pIn1 + 1U);
1133 d1 = *(pIn2 + 1U);
1134
1135 sumReal1 += a1 * c1;
1136 sumImag1 += b1 * c1;
1137
1138 pIn1 += 2U;
1139 pIn2 += 2 * numColsB;
1140
1141 sumReal2 -= b1 * d1;
1142 sumImag2 += a1 * d1;
1143
1144 /* Decrement the loop counter */
1145 colCnt--;
1146 }
1147
1148 sumReal1 += sumReal2;
1149 sumImag1 += sumImag2;
1150
1151 /* Store the result in the destination buffer */
1152 *px++ = sumReal1;
1153 *px++ = sumImag1;
1154
1155 /* Update the pointer pIn2 to point to the starting address of the next column */
1156 j++;
1157 pIn2 = pSrcB->pData + 2U * j;
1158
1159 /* Decrement the column loop counter */
1160 col--;
1161
1162 }
1163
1164 /* Update the pointer pInA to point to the starting address of the next row */
1165 i = i + numColsB;
1166 pInA = pInA + 2 * numColsA;
1167
1168 /* Decrement the row loop counter */
1169 rowCnt--;
1170
1171 }
1172
1173 /* Set status as ARM_MATH_SUCCESS */
1174 status = ARM_MATH_SUCCESS;
1175 }
1176
1177 /* Return to application */
1178 return (status);
1179 }
1180 #else
arm_mat_cmplx_mult_f32(const arm_matrix_instance_f32 * pSrcA,const arm_matrix_instance_f32 * pSrcB,arm_matrix_instance_f32 * pDst)1181 ARM_DSP_ATTRIBUTE arm_status arm_mat_cmplx_mult_f32(
1182 const arm_matrix_instance_f32 * pSrcA,
1183 const arm_matrix_instance_f32 * pSrcB,
1184 arm_matrix_instance_f32 * pDst)
1185 {
1186 float32_t *pIn1 = pSrcA->pData; /* Input data matrix pointer A */
1187 float32_t *pIn2 = pSrcB->pData; /* Input data matrix pointer B */
1188 float32_t *pInA = pSrcA->pData; /* Input data matrix pointer A */
1189 float32_t *pOut = pDst->pData; /* Output data matrix pointer */
1190 float32_t *px; /* Temporary output data matrix pointer */
1191 uint16_t numRowsA = pSrcA->numRows; /* Number of rows of input matrix A */
1192 uint16_t numColsB = pSrcB->numCols; /* Number of columns of input matrix B */
1193 uint16_t numColsA = pSrcA->numCols; /* Number of columns of input matrix A */
1194 float32_t sumReal, sumImag; /* Accumulator */
1195 float32_t a1, b1, c1, d1;
1196 uint32_t col, i = 0U, j, row = numRowsA, colCnt; /* loop counters */
1197 arm_status status; /* status of matrix multiplication */
1198
1199 #if defined (ARM_MATH_LOOPUNROLL)
1200 float32_t a0, b0, c0, d0;
1201 #endif
1202
1203 #ifdef ARM_MATH_MATRIX_CHECK
1204
1205 /* Check for matrix mismatch condition */
1206 if ((pSrcA->numCols != pSrcB->numRows) ||
1207 (pSrcA->numRows != pDst->numRows) ||
1208 (pSrcB->numCols != pDst->numCols) )
1209 {
1210 /* Set status as ARM_MATH_SIZE_MISMATCH */
1211 status = ARM_MATH_SIZE_MISMATCH;
1212 }
1213 else
1214
1215 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
1216
1217 {
1218 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
1219 /* row loop */
1220 do
1221 {
1222 /* Output pointer is set to starting address of the row being processed */
1223 px = pOut + 2 * i;
1224
1225 /* For every row wise process, the column loop counter is to be initiated */
1226 col = numColsB;
1227
1228 /* For every row wise process, the pIn2 pointer is set
1229 ** to the starting address of the pSrcB data */
1230 pIn2 = pSrcB->pData;
1231
1232 j = 0U;
1233
1234 /* column loop */
1235 do
1236 {
1237 /* Set the variable sum, that acts as accumulator, to zero */
1238 sumReal = 0.0f;
1239 sumImag = 0.0f;
1240
1241 /* Initiate pointer pIn1 to point to starting address of column being processed */
1242 pIn1 = pInA;
1243
1244 #if defined (ARM_MATH_LOOPUNROLL)
1245
1246 /* Apply loop unrolling and compute 4 MACs simultaneously. */
1247 colCnt = numColsA >> 2U;
1248
1249 /* matrix multiplication */
1250 while (colCnt > 0U)
1251 {
1252
1253 /* Reading real part of complex matrix A */
1254 a0 = *pIn1;
1255
1256 /* Reading real part of complex matrix B */
1257 c0 = *pIn2;
1258
1259 /* Reading imaginary part of complex matrix A */
1260 b0 = *(pIn1 + 1U);
1261
1262 /* Reading imaginary part of complex matrix B */
1263 d0 = *(pIn2 + 1U);
1264
1265 /* Multiply and Accumlates */
1266 sumReal += a0 * c0;
1267 sumImag += b0 * c0;
1268
1269 /* update pointers */
1270 pIn1 += 2U;
1271 pIn2 += 2 * numColsB;
1272
1273 /* Multiply and Accumlates */
1274 sumReal -= b0 * d0;
1275 sumImag += a0 * d0;
1276
1277 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
1278
1279 /* read real and imag values from pSrcA and pSrcB buffer */
1280 a1 = *(pIn1 );
1281 c1 = *(pIn2 );
1282 b1 = *(pIn1 + 1U);
1283 d1 = *(pIn2 + 1U);
1284
1285 /* Multiply and Accumlates */
1286 sumReal += a1 * c1;
1287 sumImag += b1 * c1;
1288
1289 /* update pointers */
1290 pIn1 += 2U;
1291 pIn2 += 2 * numColsB;
1292
1293 /* Multiply and Accumlates */
1294 sumReal -= b1 * d1;
1295 sumImag += a1 * d1;
1296
1297 a0 = *(pIn1 );
1298 c0 = *(pIn2 );
1299 b0 = *(pIn1 + 1U);
1300 d0 = *(pIn2 + 1U);
1301
1302 /* Multiply and Accumlates */
1303 sumReal += a0 * c0;
1304 sumImag += b0 * c0;
1305
1306 /* update pointers */
1307 pIn1 += 2U;
1308 pIn2 += 2 * numColsB;
1309
1310 /* Multiply and Accumlates */
1311 sumReal -= b0 * d0;
1312 sumImag += a0 * d0;
1313
1314 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
1315
1316 a1 = *(pIn1 );
1317 c1 = *(pIn2 );
1318 b1 = *(pIn1 + 1U);
1319 d1 = *(pIn2 + 1U);
1320
1321 /* Multiply and Accumlates */
1322 sumReal += a1 * c1;
1323 sumImag += b1 * c1;
1324
1325 /* update pointers */
1326 pIn1 += 2U;
1327 pIn2 += 2 * numColsB;
1328
1329 /* Multiply and Accumlates */
1330 sumReal -= b1 * d1;
1331 sumImag += a1 * d1;
1332
1333 /* Decrement loop count */
1334 colCnt--;
1335 }
1336
1337 /* If the columns of pSrcA is not a multiple of 4, compute any remaining MACs here.
1338 ** No loop unrolling is used. */
1339 colCnt = numColsA % 0x4U;
1340
1341 #else
1342
1343 /* Initialize blkCnt with number of samples */
1344 colCnt = numColsA;
1345
1346 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
1347
1348 while (colCnt > 0U)
1349 {
1350 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
1351 a1 = *(pIn1 );
1352 c1 = *(pIn2 );
1353 b1 = *(pIn1 + 1U);
1354 d1 = *(pIn2 + 1U);
1355
1356 /* Multiply and Accumlates */
1357 sumReal += a1 * c1;
1358 sumImag += b1 * c1;
1359
1360 /* update pointers */
1361 pIn1 += 2U;
1362 pIn2 += 2 * numColsB;
1363
1364 /* Multiply and Accumlates */
1365 sumReal -= b1 * d1;
1366 sumImag += a1 * d1;
1367
1368 /* Decrement loop counter */
1369 colCnt--;
1370 }
1371
1372 /* Store result in destination buffer */
1373 *px++ = sumReal;
1374 *px++ = sumImag;
1375
1376 /* Update pointer pIn2 to point to starting address of next column */
1377 j++;
1378 pIn2 = pSrcB->pData + 2U * j;
1379
1380 /* Decrement column loop counter */
1381 col--;
1382
1383 } while (col > 0U);
1384
1385 /* Update pointer pInA to point to starting address of next row */
1386 i = i + numColsB;
1387 pInA = pInA + 2 * numColsA;
1388
1389 /* Decrement row loop counter */
1390 row--;
1391
1392 } while (row > 0U);
1393
1394 /* Set status as ARM_MATH_SUCCESS */
1395 status = ARM_MATH_SUCCESS;
1396 }
1397
1398 /* Return to application */
1399 return (status);
1400 }
1401
1402 #endif /* #if defined(ARM_MATH_NEON) */
1403 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
1404
1405 /**
1406 @} end of MatrixMult group
1407 */
1408