1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_mat_vec_mult_f16.c
4 * Description: Floating-point matrix and vector multiplication
5 *
6 * $Date: 23 April 2021
7 *
8 * $Revision: V1.9.0
9 *
10 * Target Processor: Cortex-M and Cortex-A cores
11 * -------------------------------------------------------------------- */
12 /*
13 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
14 *
15 * SPDX-License-Identifier: Apache-2.0
16 *
17 * Licensed under the Apache License, Version 2.0 (the License); you may
18 * not use this file except in compliance with the License.
19 * You may obtain a copy of the License at
20 *
21 * www.apache.org/licenses/LICENSE-2.0
22 *
23 * Unless required by applicable law or agreed to in writing, software
24 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
25 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26 * See the License for the specific language governing permissions and
27 * limitations under the License.
28 */
29
30 #include "dsp/matrix_functions_f16.h"
31
32 #if defined(ARM_FLOAT16_SUPPORTED)
33
34
35 /**
36 * @ingroup groupMatrix
37 */
38
39
40 /**
41 * @addtogroup MatrixVectMult
42 * @{
43 */
44
45 /**
46 * @brief Floating-point matrix and vector multiplication.
47 * @param[in] *pSrcMat points to the input matrix structure
48 * @param[in] *pVec points to input vector
49 * @param[out] *pDst points to output vector
50 */
51 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
52
53 #include "arm_helium_utils.h"
54
arm_mat_vec_mult_f16(const arm_matrix_instance_f16 * pSrcMat,const float16_t * pSrcVec,float16_t * pDstVec)55 ARM_DSP_ATTRIBUTE void arm_mat_vec_mult_f16(
56 const arm_matrix_instance_f16 *pSrcMat,
57 const float16_t *pSrcVec,
58 float16_t *pDstVec)
59 {
60 uint32_t numRows = pSrcMat->numRows;
61 uint32_t numCols = pSrcMat->numCols;
62 const float16_t *pSrcA = pSrcMat->pData;
63 const float16_t *pInA0;
64 const float16_t *pInA1;
65 float16_t *px;
66 int32_t row;
67 uint32_t blkCnt; /* loop counters */
68
69 row = numRows;
70 px = pDstVec;
71
72 /*
73 * compute 4 rows in parallel
74 */
75 while (row >= 4)
76 {
77 const float16_t *pInA2, *pInA3;
78 float16_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec, *pInVec;
79 f16x8_t vecIn, acc0, acc1, acc2, acc3;
80 float16_t const *pSrcVecPtr = pSrcVec;
81
82 /*
83 * Initialize the pointers to 4 consecutive MatrixA rows
84 */
85 pInA0 = pSrcA;
86 pInA1 = pInA0 + numCols;
87 pInA2 = pInA1 + numCols;
88 pInA3 = pInA2 + numCols;
89 /*
90 * Initialize the vector pointer
91 */
92 pInVec = pSrcVecPtr;
93 /*
94 * reset accumulators
95 */
96 acc0 = vdupq_n_f16(0.0f);
97 acc1 = vdupq_n_f16(0.0f);
98 acc2 = vdupq_n_f16(0.0f);
99 acc3 = vdupq_n_f16(0.0f);
100
101 pSrcA0Vec = pInA0;
102 pSrcA1Vec = pInA1;
103 pSrcA2Vec = pInA2;
104 pSrcA3Vec = pInA3;
105
106 blkCnt = numCols >> 3;
107 while (blkCnt > 0U)
108 {
109 f16x8_t vecA;
110
111 vecIn = vld1q(pInVec);
112 pInVec += 8;
113 vecA = vld1q(pSrcA0Vec);
114 pSrcA0Vec += 8;
115 acc0 = vfmaq(acc0, vecIn, vecA);
116 vecA = vld1q(pSrcA1Vec);
117 pSrcA1Vec += 8;
118 acc1 = vfmaq(acc1, vecIn, vecA);
119 vecA = vld1q(pSrcA2Vec);
120 pSrcA2Vec += 8;
121 acc2 = vfmaq(acc2, vecIn, vecA);
122 vecA = vld1q(pSrcA3Vec);
123 pSrcA3Vec += 8;
124 acc3 = vfmaq(acc3, vecIn, vecA);
125
126 blkCnt--;
127 }
128 /*
129 * tail
130 * (will be merged thru tail predication)
131 */
132 blkCnt = numCols & 7;
133 if (blkCnt > 0U)
134 {
135 mve_pred16_t p0 = vctp16q(blkCnt);
136 f16x8_t vecA;
137
138 vecIn = vldrhq_z_f16(pInVec, p0);
139 vecA = vld1q(pSrcA0Vec);
140 acc0 = vfmaq(acc0, vecIn, vecA);
141 vecA = vld1q(pSrcA1Vec);
142 acc1 = vfmaq(acc1, vecIn, vecA);
143 vecA = vld1q(pSrcA2Vec);
144 acc2 = vfmaq(acc2, vecIn, vecA);
145 vecA = vld1q(pSrcA3Vec);
146 acc3 = vfmaq(acc3, vecIn, vecA);
147 }
148 /*
149 * Sum the partial parts
150 */
151 *px++ = vecAddAcrossF16Mve(acc0);
152 *px++ = vecAddAcrossF16Mve(acc1);
153 *px++ = vecAddAcrossF16Mve(acc2);
154 *px++ = vecAddAcrossF16Mve(acc3);
155
156 pSrcA += numCols * 4;
157 /*
158 * Decrement the row loop counter
159 */
160 row -= 4;
161 }
162
163 /*
164 * compute 2 rows in parrallel
165 */
166 if (row >= 2)
167 {
168 float16_t const *pSrcA0Vec, *pSrcA1Vec, *pInVec;
169 f16x8_t vecIn, acc0, acc1;
170 float16_t const *pSrcVecPtr = pSrcVec;
171
172 /*
173 * Initialize the pointers to 2 consecutive MatrixA rows
174 */
175 pInA0 = pSrcA;
176 pInA1 = pInA0 + numCols;
177 /*
178 * Initialize the vector pointer
179 */
180 pInVec = pSrcVecPtr;
181 /*
182 * reset accumulators
183 */
184 acc0 = vdupq_n_f16(0.0f);
185 acc1 = vdupq_n_f16(0.0f);
186 pSrcA0Vec = pInA0;
187 pSrcA1Vec = pInA1;
188
189 blkCnt = numCols >> 3;
190 while (blkCnt > 0U)
191 {
192 f16x8_t vecA;
193
194 vecIn = vld1q(pInVec);
195 pInVec += 8;
196 vecA = vld1q(pSrcA0Vec);
197 pSrcA0Vec += 8;
198 acc0 = vfmaq(acc0, vecIn, vecA);
199 vecA = vld1q(pSrcA1Vec);
200 pSrcA1Vec += 8;
201 acc1 = vfmaq(acc1, vecIn, vecA);
202
203 blkCnt--;
204 }
205 /*
206 * tail
207 * (will be merged thru tail predication)
208 */
209 blkCnt = numCols & 7;
210 if (blkCnt > 0U)
211 {
212 mve_pred16_t p0 = vctp16q(blkCnt);
213 f16x8_t vecA;
214
215 vecIn = vldrhq_z_f16(pInVec, p0);
216 vecA = vld1q(pSrcA0Vec);
217 acc0 = vfmaq(acc0, vecIn, vecA);
218 vecA = vld1q(pSrcA1Vec);
219 acc1 = vfmaq(acc1, vecIn, vecA);
220 }
221 /*
222 * Sum the partial parts
223 */
224 *px++ = vecAddAcrossF16Mve(acc0);
225 *px++ = vecAddAcrossF16Mve(acc1);
226
227 pSrcA += numCols * 2;
228 row -= 2;
229 }
230
231 if (row >= 1)
232 {
233 f16x8_t vecIn, acc0;
234 float16_t const *pSrcA0Vec, *pInVec;
235 float16_t const *pSrcVecPtr = pSrcVec;
236 /*
237 * Initialize the pointers to last MatrixA row
238 */
239 pInA0 = pSrcA;
240 /*
241 * Initialize the vector pointer
242 */
243 pInVec = pSrcVecPtr;
244 /*
245 * reset accumulators
246 */
247 acc0 = vdupq_n_f16(0.0f);
248
249 pSrcA0Vec = pInA0;
250
251 blkCnt = numCols >> 3;
252 while (blkCnt > 0U)
253 {
254 f16x8_t vecA;
255
256 vecIn = vld1q(pInVec);
257 pInVec += 8;
258 vecA = vld1q(pSrcA0Vec);
259 pSrcA0Vec += 8;
260 acc0 = vfmaq(acc0, vecIn, vecA);
261
262 blkCnt--;
263 }
264 /*
265 * tail
266 * (will be merged thru tail predication)
267 */
268 blkCnt = numCols & 7;
269 if (blkCnt > 0U)
270 {
271 mve_pred16_t p0 = vctp16q(blkCnt);
272 f16x8_t vecA;
273
274 vecIn = vldrhq_z_f16(pInVec, p0);
275 vecA = vld1q(pSrcA0Vec);
276 acc0 = vfmaq(acc0, vecIn, vecA);
277 }
278 /*
279 * Sum the partial parts
280 */
281 *px++ = vecAddAcrossF16Mve(acc0);
282 }
283 }
284 #else
arm_mat_vec_mult_f16(const arm_matrix_instance_f16 * pSrcMat,const float16_t * pVec,float16_t * pDst)285 ARM_DSP_ATTRIBUTE void arm_mat_vec_mult_f16(const arm_matrix_instance_f16 *pSrcMat, const float16_t *pVec, float16_t *pDst)
286 {
287 uint32_t numRows = pSrcMat->numRows;
288 uint32_t numCols = pSrcMat->numCols;
289 const float16_t *pSrcA = pSrcMat->pData;
290 const float16_t *pInA1; /* input data matrix pointer A of Q31 type */
291 const float16_t *pInA2; /* input data matrix pointer A of Q31 type */
292 const float16_t *pInA3; /* input data matrix pointer A of Q31 type */
293 const float16_t *pInA4; /* input data matrix pointer A of Q31 type */
294 const float16_t *pInVec; /* input data matrix pointer B of Q31 type */
295 float16_t *px; /* Temporary output data matrix pointer */
296 uint32_t i;
297 uint16_t row, colCnt; /* loop counters */
298 float16_t matData, matData2, vecData, vecData2;
299
300
301 /* Process 4 rows at a time */
302 row = numRows >> 2;
303 i = 0u;
304 px = pDst;
305
306 /* The following loop performs the dot-product of each row in pSrcA with the vector */
307 /* row loop */
308 while (row > 0) {
309 /* For every row wise process, the pInVec pointer is set
310 ** to the starting address of the vector */
311 pInVec = pVec;
312
313 /* Initialize accumulators */
314 float16_t sum1 = 0.0f16;
315 float16_t sum2 = 0.0f16;
316 float16_t sum3 = 0.0f16;
317 float16_t sum4 = 0.0f16;
318
319 /* Loop unrolling: process 2 columns per iteration */
320 colCnt = numCols;
321
322 /* Initialize pointers to the starting address of the column being processed */
323 pInA1 = pSrcA + i;
324 pInA2 = pInA1 + numCols;
325 pInA3 = pInA2 + numCols;
326 pInA4 = pInA3 + numCols;
327
328
329 // Main loop: matrix-vector multiplication
330 while (colCnt > 0u) {
331 // Read 2 values from vector
332 vecData = *(pInVec)++;
333 // Read 8 values from the matrix - 2 values from each of 4 rows, and do multiply accumulate
334 matData = *(pInA1)++;
335 sum1 += (_Float16)matData * (_Float16)vecData;
336 matData = *(pInA2)++;
337 sum2 += (_Float16)matData * (_Float16)vecData;
338 matData = *(pInA3)++;
339 sum3 += (_Float16)matData * (_Float16)vecData;
340 matData = *(pInA4)++;
341 sum4 += (_Float16)matData * (_Float16)vecData;
342
343 // Decrement the loop counter
344 colCnt--;
345 }
346
347 /* Saturate and store the result in the destination buffer */
348 *px++ = sum1;
349 *px++ = sum2;
350 *px++ = sum3;
351 *px++ = sum4;
352
353 i = i + numCols * 4;
354
355 /* Decrement the row loop counter */
356 row--;
357 }
358
359 /* process any remaining rows */
360 row = numRows & 3u;
361 while (row > 0) {
362
363 float16_t sum = 0.0f16;
364 pInVec = pVec;
365 pInA1 = pSrcA + i;
366
367 colCnt = numCols >> 1;
368
369 while (colCnt > 0) {
370 vecData = *(pInVec)++;
371 vecData2 = *(pInVec)++;
372 matData = *(pInA1)++;
373 matData2 = *(pInA1)++;
374 sum += (_Float16)matData * (_Float16)vecData;
375 sum += (_Float16)matData2 * (_Float16)vecData2;
376 colCnt--;
377 }
378 // process remainder of row
379 colCnt = numCols & 1u;
380 while (colCnt > 0) {
381 sum += (_Float16)*pInA1++ * (_Float16)*pInVec++;
382 colCnt--;
383 }
384
385 *px++ = sum;
386 i = i + numCols;
387 row--;
388 }
389 }
390 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
391
392 /**
393 * @} end of MatrixMult group
394 */
395
396 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */
397
398