1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_vec_mult_f16.c
4  * Description:  Floating-point matrix and vector multiplication
5  *
6  * $Date:        23 April 2021
7  *
8  * $Revision:    V1.9.0
9  *
10  * Target Processor: Cortex-M and Cortex-A cores
11  * -------------------------------------------------------------------- */
12 /*
13  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
14  *
15  * SPDX-License-Identifier: Apache-2.0
16  *
17  * Licensed under the Apache License, Version 2.0 (the License); you may
18  * not use this file except in compliance with the License.
19  * You may obtain a copy of the License at
20  *
21  * www.apache.org/licenses/LICENSE-2.0
22  *
23  * Unless required by applicable law or agreed to in writing, software
24  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
25  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26  * See the License for the specific language governing permissions and
27  * limitations under the License.
28  */
29 
30 #include "dsp/matrix_functions_f16.h"
31 
32 #if defined(ARM_FLOAT16_SUPPORTED)
33 
34 
35 /**
36  * @ingroup groupMatrix
37  */
38 
39 
40 /**
41  * @addtogroup MatrixVectMult
42  * @{
43  */
44 
45 /**
46  * @brief Floating-point matrix and vector multiplication.
47  * @param[in]       *pSrcMat points to the input matrix structure
48  * @param[in]       *pVec points to input vector
49  * @param[out]      *pDst points to output vector
50  */
51 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
52 
53 #include "arm_helium_utils.h"
54 
arm_mat_vec_mult_f16(const arm_matrix_instance_f16 * pSrcMat,const float16_t * pSrcVec,float16_t * pDstVec)55 ARM_DSP_ATTRIBUTE void arm_mat_vec_mult_f16(
56     const arm_matrix_instance_f16   *pSrcMat,
57     const float16_t                 *pSrcVec,
58     float16_t                       *pDstVec)
59 {
60     uint32_t         numRows = pSrcMat->numRows;
61     uint32_t         numCols = pSrcMat->numCols;
62     const float16_t *pSrcA = pSrcMat->pData;
63     const float16_t *pInA0;
64     const float16_t *pInA1;
65     float16_t       *px;
66     int32_t          row;
67     uint32_t         blkCnt;           /* loop counters */
68 
69     row = numRows;
70     px = pDstVec;
71 
72     /*
73      * compute 4 rows in parallel
74      */
75     while (row >= 4)
76     {
77         const float16_t     *pInA2, *pInA3;
78         float16_t const    *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec, *pInVec;
79         f16x8_t            vecIn, acc0, acc1, acc2, acc3;
80         float16_t const     *pSrcVecPtr = pSrcVec;
81 
82         /*
83          * Initialize the pointers to 4 consecutive MatrixA rows
84          */
85         pInA0 = pSrcA;
86         pInA1 = pInA0 + numCols;
87         pInA2 = pInA1 + numCols;
88         pInA3 = pInA2 + numCols;
89         /*
90          * Initialize the vector pointer
91          */
92         pInVec =  pSrcVecPtr;
93         /*
94          * reset accumulators
95          */
96         acc0 = vdupq_n_f16(0.0f);
97         acc1 = vdupq_n_f16(0.0f);
98         acc2 = vdupq_n_f16(0.0f);
99         acc3 = vdupq_n_f16(0.0f);
100 
101         pSrcA0Vec = pInA0;
102         pSrcA1Vec = pInA1;
103         pSrcA2Vec = pInA2;
104         pSrcA3Vec = pInA3;
105 
106         blkCnt = numCols >> 3;
107         while (blkCnt > 0U)
108         {
109             f16x8_t vecA;
110 
111             vecIn = vld1q(pInVec);
112             pInVec += 8;
113             vecA = vld1q(pSrcA0Vec);
114             pSrcA0Vec += 8;
115             acc0 = vfmaq(acc0, vecIn, vecA);
116             vecA = vld1q(pSrcA1Vec);
117             pSrcA1Vec += 8;
118             acc1 = vfmaq(acc1, vecIn, vecA);
119             vecA = vld1q(pSrcA2Vec);
120             pSrcA2Vec += 8;
121             acc2 = vfmaq(acc2, vecIn, vecA);
122             vecA = vld1q(pSrcA3Vec);
123             pSrcA3Vec += 8;
124             acc3 = vfmaq(acc3, vecIn, vecA);
125 
126             blkCnt--;
127         }
128         /*
129          * tail
130          * (will be merged thru tail predication)
131          */
132         blkCnt = numCols & 7;
133         if (blkCnt > 0U)
134         {
135             mve_pred16_t p0 = vctp16q(blkCnt);
136             f16x8_t vecA;
137 
138             vecIn = vldrhq_z_f16(pInVec, p0);
139             vecA = vld1q(pSrcA0Vec);
140             acc0 = vfmaq(acc0, vecIn, vecA);
141             vecA = vld1q(pSrcA1Vec);
142             acc1 = vfmaq(acc1, vecIn, vecA);
143             vecA = vld1q(pSrcA2Vec);
144             acc2 = vfmaq(acc2, vecIn, vecA);
145             vecA = vld1q(pSrcA3Vec);
146             acc3 = vfmaq(acc3, vecIn, vecA);
147         }
148         /*
149          * Sum the partial parts
150          */
151         *px++ = vecAddAcrossF16Mve(acc0);
152         *px++ = vecAddAcrossF16Mve(acc1);
153         *px++ = vecAddAcrossF16Mve(acc2);
154         *px++ = vecAddAcrossF16Mve(acc3);
155 
156         pSrcA += numCols * 4;
157         /*
158          * Decrement the row loop counter
159          */
160         row -= 4;
161     }
162 
163     /*
164      * compute 2 rows in parrallel
165      */
166     if (row >= 2)
167     {
168         float16_t const    *pSrcA0Vec, *pSrcA1Vec, *pInVec;
169         f16x8_t            vecIn, acc0, acc1;
170         float16_t const     *pSrcVecPtr = pSrcVec;
171 
172         /*
173          * Initialize the pointers to 2 consecutive MatrixA rows
174          */
175         pInA0 = pSrcA;
176         pInA1 = pInA0 + numCols;
177         /*
178          * Initialize the vector pointer
179          */
180         pInVec = pSrcVecPtr;
181         /*
182          * reset accumulators
183          */
184         acc0 = vdupq_n_f16(0.0f);
185         acc1 = vdupq_n_f16(0.0f);
186         pSrcA0Vec = pInA0;
187         pSrcA1Vec = pInA1;
188 
189         blkCnt = numCols >> 3;
190         while (blkCnt > 0U)
191         {
192             f16x8_t vecA;
193 
194             vecIn = vld1q(pInVec);
195             pInVec += 8;
196             vecA = vld1q(pSrcA0Vec);
197             pSrcA0Vec += 8;
198             acc0 = vfmaq(acc0, vecIn, vecA);
199             vecA = vld1q(pSrcA1Vec);
200             pSrcA1Vec += 8;
201             acc1 = vfmaq(acc1, vecIn, vecA);
202 
203             blkCnt--;
204         }
205         /*
206          * tail
207          * (will be merged thru tail predication)
208          */
209         blkCnt = numCols & 7;
210         if (blkCnt > 0U)
211         {
212             mve_pred16_t p0 = vctp16q(blkCnt);
213             f16x8_t vecA;
214 
215             vecIn = vldrhq_z_f16(pInVec, p0);
216             vecA = vld1q(pSrcA0Vec);
217             acc0 = vfmaq(acc0, vecIn, vecA);
218             vecA = vld1q(pSrcA1Vec);
219             acc1 = vfmaq(acc1, vecIn, vecA);
220         }
221         /*
222          * Sum the partial parts
223          */
224         *px++ = vecAddAcrossF16Mve(acc0);
225         *px++ = vecAddAcrossF16Mve(acc1);
226 
227         pSrcA += numCols * 2;
228         row -= 2;
229     }
230 
231     if (row >= 1)
232     {
233         f16x8_t             vecIn, acc0;
234         float16_t const     *pSrcA0Vec, *pInVec;
235         float16_t const      *pSrcVecPtr = pSrcVec;
236         /*
237          * Initialize the pointers to last MatrixA row
238          */
239         pInA0 = pSrcA;
240         /*
241          * Initialize the vector pointer
242          */
243         pInVec = pSrcVecPtr;
244         /*
245          * reset accumulators
246          */
247         acc0 = vdupq_n_f16(0.0f);
248 
249         pSrcA0Vec = pInA0;
250 
251         blkCnt = numCols >> 3;
252         while (blkCnt > 0U)
253         {
254             f16x8_t vecA;
255 
256             vecIn = vld1q(pInVec);
257             pInVec += 8;
258             vecA = vld1q(pSrcA0Vec);
259             pSrcA0Vec += 8;
260             acc0 = vfmaq(acc0, vecIn, vecA);
261 
262             blkCnt--;
263         }
264         /*
265          * tail
266          * (will be merged thru tail predication)
267          */
268         blkCnt = numCols & 7;
269         if (blkCnt > 0U)
270         {
271             mve_pred16_t p0 = vctp16q(blkCnt);
272             f16x8_t vecA;
273 
274             vecIn = vldrhq_z_f16(pInVec, p0);
275             vecA = vld1q(pSrcA0Vec);
276             acc0 = vfmaq(acc0, vecIn, vecA);
277         }
278         /*
279          * Sum the partial parts
280          */
281         *px++ = vecAddAcrossF16Mve(acc0);
282     }
283 }
284 #else
arm_mat_vec_mult_f16(const arm_matrix_instance_f16 * pSrcMat,const float16_t * pVec,float16_t * pDst)285 ARM_DSP_ATTRIBUTE void arm_mat_vec_mult_f16(const arm_matrix_instance_f16 *pSrcMat, const float16_t *pVec, float16_t *pDst)
286 {
287     uint32_t numRows = pSrcMat->numRows;
288     uint32_t numCols = pSrcMat->numCols;
289     const float16_t *pSrcA = pSrcMat->pData;
290     const float16_t *pInA1;      /* input data matrix pointer A of Q31 type */
291     const float16_t *pInA2;      /* input data matrix pointer A of Q31 type */
292     const float16_t *pInA3;      /* input data matrix pointer A of Q31 type */
293     const float16_t *pInA4;      /* input data matrix pointer A of Q31 type */
294     const float16_t *pInVec;     /* input data matrix pointer B of Q31 type */
295     float16_t *px;               /* Temporary output data matrix pointer */
296     uint32_t i;
297     uint16_t row, colCnt; /* loop counters */
298     float16_t matData, matData2, vecData, vecData2;
299 
300 
301     /* Process 4 rows at a time */
302     row = numRows >> 2;
303     i = 0u;
304     px = pDst;
305 
306     /* The following loop performs the dot-product of each row in pSrcA with the vector */
307     /* row loop */
308     while (row > 0) {
309         /* For every row wise process, the pInVec pointer is set
310          ** to the starting address of the vector */
311         pInVec = pVec;
312 
313         /* Initialize accumulators */
314         float16_t sum1 = 0.0f16;
315         float16_t sum2 = 0.0f16;
316         float16_t sum3 = 0.0f16;
317         float16_t sum4 = 0.0f16;
318 
319         /* Loop unrolling: process 2 columns per iteration */
320         colCnt = numCols;
321 
322         /* Initialize pointers to the starting address of the column being processed */
323         pInA1 = pSrcA + i;
324         pInA2 = pInA1 + numCols;
325         pInA3 = pInA2 + numCols;
326         pInA4 = pInA3 + numCols;
327 
328 
329         // Main loop: matrix-vector multiplication
330         while (colCnt > 0u) {
331             // Read 2 values from vector
332             vecData = *(pInVec)++;
333             // Read 8 values from the matrix - 2 values from each of 4 rows, and do multiply accumulate
334             matData = *(pInA1)++;
335             sum1 += (_Float16)matData * (_Float16)vecData;
336             matData = *(pInA2)++;
337             sum2 += (_Float16)matData * (_Float16)vecData;
338             matData = *(pInA3)++;
339             sum3 += (_Float16)matData * (_Float16)vecData;
340             matData = *(pInA4)++;
341             sum4 += (_Float16)matData * (_Float16)vecData;
342 
343             // Decrement the loop counter
344             colCnt--;
345         }
346 
347         /* Saturate and store the result in the destination buffer */
348         *px++ = sum1;
349         *px++ = sum2;
350         *px++ = sum3;
351         *px++ = sum4;
352 
353         i = i + numCols * 4;
354 
355         /* Decrement the row loop counter */
356         row--;
357     }
358 
359     /* process any remaining rows */
360     row = numRows & 3u;
361     while (row > 0) {
362 
363         float16_t sum = 0.0f16;
364         pInVec = pVec;
365         pInA1 = pSrcA + i;
366 
367         colCnt = numCols >> 1;
368 
369         while (colCnt > 0) {
370             vecData = *(pInVec)++;
371             vecData2 = *(pInVec)++;
372             matData = *(pInA1)++;
373             matData2 = *(pInA1)++;
374             sum += (_Float16)matData * (_Float16)vecData;
375             sum += (_Float16)matData2 * (_Float16)vecData2;
376             colCnt--;
377         }
378         // process remainder of row
379         colCnt = numCols & 1u;
380         while (colCnt > 0) {
381             sum += (_Float16)*pInA1++ * (_Float16)*pInVec++;
382             colCnt--;
383         }
384 
385         *px++ = sum;
386         i = i + numCols;
387         row--;
388     }
389 }
390 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
391 
392 /**
393  * @} end of MatrixMult group
394  */
395 
396 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */
397 
398