1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_mult_f16.c
4  * Description:  Floating-point matrix multiplication
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/matrix_functions_f16.h"
30 
31 #if defined(ARM_FLOAT16_SUPPORTED)
32 
33 
34 /**
35  * @ingroup groupMatrix
36  */
37 
38 
39 /**
40  * @addtogroup MatrixMult
41  * @{
42  */
43 
44 /**
45  * @brief Floating-point matrix multiplication.
46  * @param[in]       *pSrcA points to the first input matrix structure
47  * @param[in]       *pSrcB points to the second input matrix structure
48  * @param[out]      *pDst points to output matrix structure
49  * @return     		The function returns either
50  * <code>ARM_MATH_SIZE_MISMATCH</code> or <code>ARM_MATH_SUCCESS</code> based on the outcome of size checking.
51  */
52 
53 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
54 
arm_mat_mult_f16_2x2_mve(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)55 __STATIC_FORCEINLINE arm_status arm_mat_mult_f16_2x2_mve(
56     const arm_matrix_instance_f16 *pSrcA,
57     const arm_matrix_instance_f16 *pSrcB,
58     arm_matrix_instance_f16 *pDst)
59 {
60     static const uint16_t offsetA[8] = { 0, 0, 2, 2, 0, 0, 2, 2 };
61     /* offsetB allows to read and duplicate 1 row of B */
62     static const uint16_t offsetB[8] = { 0, 1, 0, 1, 0, 1, 0, 1 };
63     uint16x8_t    vecOffsA, vecOffsB;
64     f16x8_t       vecInA, vecInB, vecDst;
65     float16_t      *pOut = pDst->pData;  /* output data matrix pointer */
66 
67     /*
68      * load initial offsets
69      */
70     vecOffsA = vldrhq_u16((uint16_t const *) offsetA);
71     vecOffsB = vldrhq_u16((uint16_t const *) offsetB);
72     /*
73      * load {a00 a00 a10 a10 x x x x }
74      */
75     vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
76     /*
77      * load {b00 b01 b00 b01 x x x x }
78      */
79     vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
80     /*
81      *  { a00 b00       a00 b01
82      *    a10 b00       a10 b01
83      *       x             x
84      *       x             x   }
85      */
86     vecDst = vmulq(vecInA, vecInB);
87     /*
88      * move to 2nd column of matrix A
89      */
90     vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 1);
91     /*
92      * load {a01 a01 a11 a11 x x x x}
93      */
94     vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
95     /*
96      * move to next B row
97      */
98     vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 2);
99     /*
100      * load {b10, b11, b10, b11, x x x x }
101      */
102     vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
103     /*
104      *  { a00 b00 + a01 b10   a00 b01 + a01 b11
105      *    a10 b00 + a11 b10     a10 b01 + a11 b11
106      *             x                    x
107      *             x                    x       }
108      */
109     vecDst = vfmaq(vecDst, vecInA, vecInB);
110 
111     mve_pred16_t p0 = vctp16q(2*2);
112     /*
113      * Store the result in the destination buffer
114      * (lower half of the vector)
115      */
116     vstrhq_p(pOut, vecDst, p0);
117 
118     return (ARM_MATH_SUCCESS);
119 }
120 
121 
122 
123 
arm_mat_mult_f16_3x3_mve(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)124 __STATIC_FORCEINLINE arm_status arm_mat_mult_f16_3x3_mve(
125     const arm_matrix_instance_f16 *pSrcA,
126     const arm_matrix_instance_f16 *pSrcB,
127     arm_matrix_instance_f16 *pDst)
128 {
129     static const uint16_t offsetA[8] = { 0, 0, 0, 3, 3, 3, 6, 6 };
130     /* offsetB allows to read and duplicate 1 row of B */
131     static const uint16_t offsetB[8] = { 0, 1, 2, 0, 1, 2, 0, 1 };
132     uint16x8_t    vecOffsA, vecOffsB;
133     f16x8_t       vecInA, vecInB, vecDst;
134     float16_t      *pOut = pDst->pData;  /* output data matrix pointer */
135 
136     /*
137      * load initial offsets
138      */
139     vecOffsA = vldrhq_u16((uint16_t const *) offsetA);
140     vecOffsB = vldrhq_u16((uint16_t const *) offsetB);
141 
142     /*
143      * load {a00 a00 a00 a10 a10 a10 a20 a20}
144      */
145     vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
146     /*
147      * load {b00 b01 b02 b00 b01 b02 b00 b01}
148      */
149     vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
150     /*
151      *  { a00 b00       a00 b01     a00 b02
152      *    a10 b00       a10 b01     a10 b02
153      *    a20 b00       a20 b01}
154      */
155     vecDst = vmulq(vecInA, vecInB);
156 
157     /*
158      * move to 2nd column of matrix A
159      */
160     vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 1);
161     /*
162      * load {a01 a01 a01 a11 a11 a11 a21 a21}
163      */
164     vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
165     /*
166      * move to next B row
167      */
168     vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 3);
169     /*
170      * load {b10, b11, b12, b10, b11, b12, b10, b11}
171      */
172     vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
173     /*
174      *  { a00 b00 + a01 b10   a00 b01 + a01 b11     a00 b02 + a01 b12
175      *    a10 b00 + a11 b10     a10 b01 + a11 b11     a10 b02 + a11 b12
176      *    a20 b00 + a21 b10     a20 b01 + a21 b11   }
177      */
178     vecDst = vfmaq(vecDst, vecInA, vecInB);
179     /*
180      * move to 3rd column of matrix A
181      */
182     vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 1);
183     /*
184      * load {a02 a02 a02 a12 a12 a12 a22 a22}
185      */
186     vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
187     /*
188      * move to next B row
189      */
190     vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 3);
191     /*
192      * load {b20, b21, b22, b20, b21, b22, b20, b21}
193      */
194     vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
195     /*
196      *  {a00 b00 + a01 b10 + a02 b20  a00 b01 + a01 b11 + a02 b21     a00 b02 + a01 b12 + a02 b22},
197      *   a10 b00 + a11 b10 + a12 b20    a10 b01 + a11 b11 + a12 b21     a10 b02 + a11 b12 + a12 b22},
198      *   a20 b00 + a21 b10 + a22 b20    a20 b01 + a21 b11 + a22 b21   }
199      */
200     vecDst = vfmaq(vecDst, vecInA, vecInB);
201 
202     /*
203      * Store the result in the destination buffer
204      */
205     vst1q(pOut, vecDst); pOut += 8;
206 
207     /* last element computed in scalar mode
208      * a20 b02 + a21 b12 + a22 b22
209      */
210     _Float16 * pA = (_Float16 *)pSrcA->pData;
211     _Float16 * pB = (_Float16 *)pSrcB->pData;
212     *pOut = pA[2*3] * pB[2] + pA[2*3+1] * pB[3+2] + pA[2*3+2] * pB[2*3+2];
213 
214     return (ARM_MATH_SUCCESS);
215 }
216 
217 
218 
219 
220 
arm_mat_mult_f16_4x4_mve(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)221 __STATIC_FORCEINLINE arm_status arm_mat_mult_f16_4x4_mve(
222     const arm_matrix_instance_f16 *pSrcA,
223     const arm_matrix_instance_f16 *pSrcB,
224     arm_matrix_instance_f16 *pDst)
225 {
226     /* offsetA allows to read and duplicate 2 successive column elements of A */
227     static const uint16_t offsetA[8] = { 0, 0, 0, 0, 4, 4, 4, 4 };
228     /* offsetB allows to read and duplicate 1 row of B */
229     static const uint16_t offsetB[8] = { 0, 1, 2, 3, 0, 1, 2, 3 };
230     uint16x8_t    vecOffsA, vecOffsB;
231     f16x8_t       vecInA, vecInB, vecDst0, vecDst1;
232     float16_t      *pOut = pDst->pData;  /* output data matrix pointer */
233 
234     /*
235      * load initial offsets
236      */
237     vecOffsA = vldrhq_u16((uint16_t const *) offsetA);
238     vecOffsB = vldrhq_u16((uint16_t const *) offsetB);
239 
240     /*
241      * load {a00 a00 a00 a00 a10 a10 a10 a10}
242      */
243     vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
244     /*
245      * load {b00 b01 b02 b03 b00 b01 b02 b03}
246      */
247     vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
248     /*
249      *  { a00 b00       a00 b01     a00 b02     a00 b03
250      *    a10 b00       a10 b01     a10 b02     a10 b03 }
251      */
252     vecDst0 = vmulq(vecInA, vecInB);
253     /*
254      * jump 2 x A rows (2nd half of matrix)
255      */
256     vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 8);
257     /*
258      * load {a20 a20 a20 a20 a30 a30 a30 a30}
259      */
260     vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
261     /*
262      *  { a20 b00       a20 b01     a20 b02     a20 b03
263      *    a30 b00       a30 b01     a30 b02 +   a31 b12 }
264      */
265     vecDst1 = vmulq(vecInA, vecInB);
266     /*
267      * rewind back to top half of the A matrix (2nd column)
268      */
269     vecOffsA = vsubq(vecOffsA, (uint16_t) 7);
270     /*
271      * load {a01 a01 a01 a01 a11 a11 a11 a11}
272      */
273     vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
274     /*
275      * move to next B row
276      */
277     vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 4);
278     /*
279      * load {b10, b11, b12, b13, b10, b11, b12, b13}
280      */
281     vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
282     /*
283      *  { a00 b00 + a01 b10         a00 b01 + a01 b11       a00 b02 + a01 b12       a00 b03 + a01 b13
284      *    a10 b00 + a11 b10         a10 b01 + a11 b11       a10 b02 + a11 b12       a10 b03 + a11 b13 }
285      */
286     vecDst0 = vfmaq(vecDst0, vecInA, vecInB);
287     /*
288      * jump 2 x A rows (2nd half of matrix)
289      */
290     vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 8);
291     /*
292      * load {a21 a21 a21 a21 a31 a31 a31 a31}
293      */
294     vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
295     /*
296      *  {a20 b00 + a21 b10      a20 b01 + a21 b11       a20 b02 + a21 b12       a20 b03 + a21 b13
297      *   a30 b00 + a31 b10      a30 b01 + a31 b11       a30 b02 + a31 b12       a30 b03 + a31 b13 }
298      */
299     vecDst1 = vfmaq(vecDst1, vecInA, vecInB);
300 
301     /*
302      * rewind back to top half of the A matrix (3rd column)
303      */
304     vecOffsA = vsubq(vecOffsA, (uint16_t) 7);
305     /*
306      * load {a02 a02 a02 a02 a12 a12 a12 a12}
307      */
308     vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
309     /*
310      * move to next B row
311      */
312     vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 4);
313     /*
314      * load {b20, b21, b22, b23, b20, b21, b22, b23}
315      */
316     vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
317     /*
318      *  { a00 b00 + a01 b10 + a02 b20    a00 b01 + a01 b11 + a02 b21    a00 b02 + a01 b12 + a02 b22   a00 b03 + a01 b13 + a02 b23
319      *    a10 b00 + a11 b10 + a12 b20    a10 b01 + a11 b11 + a12 b21    a10 b02 + a11 b12 + a12 b22   a10 b03 + a11 b13 + a12 b23 }
320      */
321     vecDst0 = vfmaq(vecDst0, vecInA, vecInB);
322     /*
323      * jump 2 x A rows
324      */
325     vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 8);
326 
327     /*
328      * load {a22 a22 a22 a22 a32 a32 a32 a32}
329      */
330     vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
331     /*
332      *  {a20 b00 + a21 b10 + a22 b20   a20 b01 + a21 b11 + a22 b21  a20 b02 + a21 b12 + a22 b22    a20 b03 + a21 b13 + a22 b23
333      *   a30 b00 + a31 b10 + a32 b20   a30 b01 + a31 b11 + a32 b21  a30 b02 + a31 b12 + a32 b22    a30 b03 + a31 b13 + a32 b23 }
334      */
335     vecDst1 = vfmaq(vecDst1, vecInA, vecInB);
336 
337     /*
338      * rewind back to top half of the A matrix (4th column)
339      */
340     vecOffsA = vsubq(vecOffsA, (uint16_t) 7);
341     /*
342      * load {a03 a03 a03 a03 a13 a13 a13 a13}
343      */
344     vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
345     /*
346      * move to next B row
347      */
348     vecOffsB = vaddq_n_u16(vecOffsB, (uint16_t) 4);
349     /*
350      * load {b30, b31, b32, b33, b30, b31, b32, b33}
351      */
352     vecInB = vldrhq_gather_shifted_offset((float16_t const *) pSrcB->pData, vecOffsB);
353     /*
354      * { a00 b00 +...+ a03 b30,    a00 b01 +...+ a03 b31,   a00 b02 +...+ a03 b32,   a00 b03 +...+ a03 b33
355      *   a10 b00 +...+ a13 b30,    a10 b01 +...+ a13 b31,   a10 b02 +...+ a13 b32,   a10 b03 +...+ a13 b33 }
356      */
357     vecDst0 = vfmaq(vecDst0, vecInA, vecInB);
358     /*
359      * jump 2 x A rows
360      */
361     vecOffsA = vaddq_n_u16(vecOffsA, (uint16_t) 8);
362     /*
363      * load {a23 a23 a23 a23 a33 a33 a33 a33}
364      */
365     vecInA = vldrhq_gather_shifted_offset((float16_t const *) pSrcA->pData, vecOffsA);
366     /*
367      *  {a20 b00 +...+ a23 b30,   a20 b01 +...+ a23 b31,   a20 b02 +...+ a23 b32,   a20 b03 +...+ a23 b33
368      *   a30 b00 +...+ a33 b30,   a30 b01 +...+ a33 b31,   a30 b02 +...+ a33 b32,   a30 b03 +...+ a33 b33 }
369      */
370     vecDst1 = vfmaq(vecDst1, vecInA, vecInB);
371 
372     /*
373      * Store the result in the destination buffer
374      */
375     vst1q(pOut, vecDst0); pOut += 8;
376     vst1q(pOut, vecDst1);
377 
378     return (ARM_MATH_SUCCESS);
379 }
380 
381 
arm_mat_mult_f16(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)382 arm_status arm_mat_mult_f16(
383   const arm_matrix_instance_f16 * pSrcA,
384   const arm_matrix_instance_f16 * pSrcB,
385   arm_matrix_instance_f16 * pDst)
386 {
387        float16_t  *pInB = pSrcB->pData;        /* input data matrix pointer B */
388     float16_t  *pInA = pSrcA->pData;        /* input data matrix pointer A  */
389     float16_t  *pOut = pDst->pData;         /* output data matrix pointer */
390     int         numRowsA = pSrcA->numRows;  /* number of rows of input matrix A */
391     int         numColsB = pSrcB->numCols;  /* number of columns of input matrix B */
392     int         numColsA = pSrcA->numCols;  /* number of columns of input matrix A */
393     uint32_t    blkCnt;                     /* loop counters */
394     int         i;
395 
396 
397 #ifdef ARM_MATH_MATRIX_CHECK
398 
399   /* Check for matrix mismatch condition */
400   if ((pSrcA->numCols != pSrcB->numRows) ||
401       (pSrcA->numRows != pDst->numRows)  ||
402       (pSrcB->numCols != pDst->numCols)    )
403   {
404     /* Set status as ARM_MATH_SIZE_MISMATCH */
405     return(ARM_MATH_SIZE_MISMATCH);
406   }
407   else
408 
409 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
410 {
411     /* small squared matrix specialized routines */
412     if(numRowsA == numColsB && numColsB == numColsA) {
413         if(numRowsA == 2)
414             return arm_mat_mult_f16_2x2_mve(pSrcA, pSrcB, pDst);
415         else if(numRowsA == 3)
416             return arm_mat_mult_f16_3x3_mve(pSrcA, pSrcB, pDst);
417         else if(numRowsA == 4)
418             return arm_mat_mult_f16_4x4_mve(pSrcA, pSrcB, pDst);
419     }
420 
421     /* main loop process 4 rows */
422     i = numRowsA / 4;
423     while(i > 0)
424     {
425         float16_t   *pInA0, *pInA1, *pInA2, *pInA3;
426         float16_t   *pInB0;
427         float16_t   *pOut0, *pOut1, *pOut2, *pOut3;
428         f16x8_t    vecMac0, vecMac1, vecMac2, vecMac3;
429         f16x8_t    vecInB;
430 
431         /* pointers to 4 consecutive output rows */
432         pOut0 = pOut;
433         pOut1 = pOut0 + numColsB;
434         pOut2 = pOut1 + numColsB;
435         pOut3 = pOut2 + numColsB;
436         pInB0 = pInB;
437 
438         int       k = numColsB >> 3;
439         while(k > 0)
440         {
441             /* pointers to 4 consecutive Matrix A rows */
442             pInA0 = pInA;
443             pInA1 = pInA0 + numColsA;
444             pInA2 = pInA1 + numColsA;
445             pInA3 = pInA2 + numColsA;
446 
447             vecMac0 = vdupq_n_f16(0.0f16);
448             vecMac1 = vdupq_n_f16(0.0f16);
449             vecMac2 = vdupq_n_f16(0.0f16);
450             vecMac3 = vdupq_n_f16(0.0f16);
451 
452             blkCnt = numColsA;
453 
454             while (blkCnt > 0U)
455             {
456                 /*
457                  * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3..., bi,4n+7}
458                  */
459                 vecInB = *(f16x8_t *)pInB0; /* vldrhq_f16(pInB0, 0); */
460 
461                 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
462                 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
463                 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
464                 vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
465 
466                 pInB0 = pInB0 + numColsB;
467                 /*
468                  * Decrement the blockSize loop counter
469                  */
470                 blkCnt--;
471             }
472 
473             /* Store the results (4 x 8 block) in the destination buffer */
474             vst1q(pOut0, vecMac0);  pOut0 += 8;
475             vst1q(pOut1, vecMac1);  pOut1 += 8;
476             vst1q(pOut2, vecMac2);  pOut2 += 8;
477             vst1q(pOut3, vecMac3);  pOut3 += 8;
478             /*
479              * rewind
480              */
481             pInB0 -= (numColsB * numColsA) - 8;
482             k--;
483         }
484 
485         int       colBLeft = numColsB & 7;
486         if (colBLeft)
487         {
488             pInA0 = pInA;
489             pInA1 = pInA0 + numColsA;
490             pInA2 = pInA1 + numColsA;
491             pInA3 = pInA2 + numColsA;
492             mve_pred16_t p0 = vctp16q(colBLeft);
493 
494             vecMac0 = vdupq_n_f16(0.0f16);
495             vecMac1 = vdupq_n_f16(0.0f16);
496             vecMac2 = vdupq_n_f16(0.0f16);
497             vecMac3 = vdupq_n_f16(0.0f16);
498 
499             blkCnt = numColsA;
500 
501             while (blkCnt > 0U)
502             {
503                 /*
504                  * load {bi,4n+0, bi,4n+1, bi,4n+2, ..bi,4n+colBLeft-1, 0, ..}
505                  */
506                 vecInB = vldrhq_z_f16(pInB0, p0);
507 
508                 vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
509                 vecMac1 = vfmaq(vecMac1, vecInB, *pInA1++);
510                 vecMac2 = vfmaq(vecMac2, vecInB, *pInA2++);
511                 vecMac3 = vfmaq(vecMac3, vecInB, *pInA3++);
512 
513                 pInB0 = pInB0 + numColsB;
514                 /*
515                  * Decrement the blockSize loop counter
516                  */
517                 blkCnt--;
518             }
519 
520             /* Store the results (4 x colBLeft block) in the destination buffer */
521             vstrhq_p_f16(pOut0, vecMac0, p0);
522             vstrhq_p_f16(pOut1, vecMac1, p0);
523             vstrhq_p_f16(pOut2, vecMac2, p0);
524             vstrhq_p_f16(pOut3, vecMac3, p0);
525         }
526 
527         pInA += 4 * numColsA;
528         pOut += 4 * numColsB;
529         i--;
530     }
531 
532     /*
533      * non multiple of 4 rows for Matrix A
534      * process single row
535      */
536     if (numRowsA & 3)
537     {
538         i = numRowsA & 3;
539         do
540         {
541             float16_t   *pInA0;
542             float16_t   *pInB0;
543             float16_t   *pOut0;
544             f16x8_t    vecInB;
545             f16x8_t    vecMac0;
546 
547             pOut0 = pOut;
548             pInB0 = pInB;
549 
550             int       k = numColsB >> 3;
551             while(k > 0)
552             {
553                 pInA0 = pInA;
554 
555                 vecMac0 = vdupq_n_f16(0.0f16);
556                 blkCnt = numColsA;
557 
558                 while (blkCnt > 0U)
559                 {
560                     /*
561                      * load {bi,4n+0, bi,4n+1, bi,4n+2, bi,4n+3, ...bi,4n+7}
562                      */
563                     vecInB = *(f16x8_t *)pInB0; /* vldrhq_f16(pInB0, 0); */
564 
565                     vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
566 
567                     pInB0 = pInB0 + numColsB;
568                     /*
569                      * Decrement the blockSize loop counter
570                      */
571                     blkCnt--;
572                 }
573                 /* Store the results (1 x 8 block) in the destination buffer */
574                 vst1q(pOut0, vecMac0);   pOut0 += 8;
575                 /*
576                  * rewind
577                  */
578                 pInB0 -= (numColsB * numColsA) - 8;
579                 k--;
580             }
581 
582             int  colBLeft = numColsB & 7;
583             if (colBLeft)
584             {
585                 pInA0 = pInA;
586                 mve_pred16_t p0 = vctp16q(colBLeft);
587 
588                 vecMac0 = vdupq_n_f16(0.0f16);
589                 blkCnt = numColsA;
590 
591                 while (blkCnt > 0U)
592                 {
593                     /*
594                      * load {bi,4n+0, bi,4n+1, bi,4n+2, ..., bi,4n+colBLeft, 0, ...}
595                      */
596                     vecInB = vldrhq_z_f16(pInB0, p0);
597 
598                     vecMac0 = vfmaq(vecMac0, vecInB, *pInA0++);
599 
600                     pInB0 = pInB0 + numColsB;
601                     /*
602                      * Decrement the blockSize loop counter
603                      */
604                     blkCnt--;
605                 }
606                 /* Store the results (1 x colBLeft block) in the destination buffer */
607                 vstrhq_p_f16(pOut0, vecMac0, p0);
608             }
609 
610             pInA += 1 * numColsA;
611             pOut += 1 * numColsB;
612         }
613         while (--i);
614     }
615     /*
616      * Return to application
617      */
618     return (ARM_MATH_SUCCESS);
619   }
620 }
621 #else
622 
623 
arm_mat_mult_f16(const arm_matrix_instance_f16 * pSrcA,const arm_matrix_instance_f16 * pSrcB,arm_matrix_instance_f16 * pDst)624 arm_status arm_mat_mult_f16(
625   const arm_matrix_instance_f16 * pSrcA,
626   const arm_matrix_instance_f16 * pSrcB,
627         arm_matrix_instance_f16 * pDst)
628 {
629   float16_t *pIn1 = pSrcA->pData;                /* Input data matrix pointer A */
630   float16_t *pIn2 = pSrcB->pData;                /* Input data matrix pointer B */
631   float16_t *pInA = pSrcA->pData;                /* Input data matrix pointer A */
632   float16_t *pInB = pSrcB->pData;                /* Input data matrix pointer B */
633   float16_t *pOut = pDst->pData;                 /* Output data matrix pointer */
634   float16_t *px;                                 /* Temporary output data matrix pointer */
635   _Float16 sum;                                 /* Accumulator */
636   uint16_t numRowsA = pSrcA->numRows;            /* Number of rows of input matrix A */
637   uint16_t numColsB = pSrcB->numCols;            /* Number of columns of input matrix B */
638   uint16_t numColsA = pSrcA->numCols;            /* Number of columns of input matrix A */
639   uint32_t col, i = 0U, row = numRowsA, colCnt;  /* Loop counters */
640   arm_status status;                             /* Status of matrix multiplication */
641 
642 #ifdef ARM_MATH_MATRIX_CHECK
643 
644   /* Check for matrix mismatch condition */
645   if ((pSrcA->numCols != pSrcB->numRows) ||
646       (pSrcA->numRows != pDst->numRows)  ||
647       (pSrcB->numCols != pDst->numCols)    )
648   {
649     /* Set status as ARM_MATH_SIZE_MISMATCH */
650     status = ARM_MATH_SIZE_MISMATCH;
651   }
652   else
653 
654 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
655 
656   {
657     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
658     /* row loop */
659     do
660     {
661       /* Output pointer is set to starting address of row being processed */
662       px = pOut + i;
663 
664       /* For every row wise process, column loop counter is to be initiated */
665       col = numColsB;
666 
667       /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
668       pIn2 = pSrcB->pData;
669 
670       /* column loop */
671       do
672       {
673         /* Set the variable sum, that acts as accumulator, to zero */
674         sum = 0.0f16;
675 
676         /* Initialize pointer pIn1 to point to starting address of column being processed */
677         pIn1 = pInA;
678 
679 #if defined (ARM_MATH_LOOPUNROLL)
680 
681         /* Loop unrolling: Compute 4 MACs at a time. */
682         colCnt = numColsA >> 2U;
683 
684         /* matrix multiplication */
685         while (colCnt > 0U)
686         {
687           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
688 
689           /* Perform the multiply-accumulates */
690           sum += (_Float16)*pIn1++ * (_Float16)*pIn2;
691           pIn2 += numColsB;
692 
693           sum += (_Float16)*pIn1++ * (_Float16)*pIn2;
694           pIn2 += numColsB;
695 
696           sum += (_Float16)*pIn1++ * (_Float16)*pIn2;
697           pIn2 += numColsB;
698 
699           sum += (_Float16)*pIn1++ * (_Float16)*pIn2;
700           pIn2 += numColsB;
701 
702           /* Decrement loop counter */
703           colCnt--;
704         }
705 
706         /* Loop unrolling: Compute remaining MACs */
707         colCnt = numColsA % 0x4U;
708 
709 #else
710 
711         /* Initialize cntCnt with number of columns */
712         colCnt = numColsA;
713 
714 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
715 
716         while (colCnt > 0U)
717         {
718           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
719 
720           /* Perform the multiply-accumulates */
721           sum += (_Float16)*pIn1++ * (_Float16)*pIn2;
722           pIn2 += numColsB;
723 
724           /* Decrement loop counter */
725           colCnt--;
726         }
727 
728         /* Store result in destination buffer */
729         *px++ = sum;
730 
731         /* Decrement column loop counter */
732         col--;
733 
734         /* Update pointer pIn2 to point to starting address of next column */
735         pIn2 = pInB + (numColsB - col);
736 
737       } while (col > 0U);
738 
739       /* Update pointer pInA to point to starting address of next row */
740       i = i + numColsB;
741       pInA = pInA + numColsA;
742 
743       /* Decrement row loop counter */
744       row--;
745 
746     } while (row > 0U);
747 
748     /* Set status as ARM_MATH_SUCCESS */
749     status = ARM_MATH_SUCCESS;
750   }
751 
752   /* Return to application */
753   return (status);
754 }
755 
756 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
757 
758 /**
759  * @} end of MatrixMult group
760  */
761 
762 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */
763 
764