1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_vec_mult_q15.c
4  * Description:  Q15 matrix and vector multiplication
5  *
6  * $Date:        23 April 2021
7  *
8  * $Revision:    V1.9.0
9  *
10  * Target Processor: Cortex-M and Cortex-A cores
11  * -------------------------------------------------------------------- */
12 /*
13  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
14  *
15  * SPDX-License-Identifier: Apache-2.0
16  *
17  * Licensed under the Apache License, Version 2.0 (the License); you may
18  * not use this file except in compliance with the License.
19  * You may obtain a copy of the License at
20  *
21  * www.apache.org/licenses/LICENSE-2.0
22  *
23  * Unless required by applicable law or agreed to in writing, software
24  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
25  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26  * See the License for the specific language governing permissions and
27  * limitations under the License.
28  */
29 
30 #include "dsp/matrix_functions.h"
31 
32 /**
33  * @ingroup groupMatrix
34  */
35 
36 
37 
38 /**
39  * @addtogroup MatrixVectMult
40  * @{
41  */
42 
43 /**
44  * @brief Q15 matrix and vector multiplication.
45  * @param[in]       *pSrcMat points to the input matrix structure
46  * @param[in]       *pVec points to input vector
47  * @param[out]      *pDst points to output vector
48  */
49 #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE)
50 
51 #include "arm_helium_utils.h"
52 
arm_mat_vec_mult_q15(const arm_matrix_instance_q15 * pSrcMat,const q15_t * pSrcVec,q15_t * pDstVec)53 void arm_mat_vec_mult_q15(
54     const arm_matrix_instance_q15 * pSrcMat,
55     const q15_t     *pSrcVec,
56     q15_t           *pDstVec)
57 {
58     const q15_t *pMatSrc = pSrcMat->pData;
59     const q15_t *pMat0, *pMat1;
60     uint32_t     numRows = pSrcMat->numRows;
61     uint32_t     numCols = pSrcMat->numCols;
62     q15_t       *px;
63     int32_t      row;
64     uint16_t     blkCnt;           /* loop counters */
65 
66     row = numRows;
67     px = pDstVec;
68 
69     /*
70      * compute 3x64-bit accumulators per loop
71      */
72     while (row >= 3)
73     {
74         q15_t const *pMat0Vec, *pMat1Vec, *pMat2Vec, *pVec;
75         const q15_t  *pMat2;
76         q15_t const  *pSrcVecPtr = pSrcVec;
77         q63_t         acc0, acc1, acc2;
78         q15x8_t     vecMatA0, vecMatA1, vecMatA2, vecIn;
79 
80 
81         pVec = pSrcVec;
82         /*
83          * Initialize the pointer pIn1 to point to the starting address of the column being processed
84          */
85         pMat0 = pMatSrc;
86         pMat1 = pMat0 + numCols;
87         pMat2 = pMat1 + numCols;
88 
89         acc0 = 0LL;
90         acc1 = 0LL;
91         acc2 = 0LL;
92 
93         pMat0Vec = pMat0;
94         pMat1Vec = pMat1;
95         pMat2Vec = pMat2;
96         pVec = pSrcVecPtr;
97 
98         blkCnt = numCols >> 3;
99         while (blkCnt > 0U)
100         {
101             vecMatA0 = vld1q(pMat0Vec);
102             pMat0Vec += 8;
103             vecMatA1 = vld1q(pMat1Vec);
104             pMat1Vec += 8;
105             vecMatA2 = vld1q(pMat2Vec);
106             pMat2Vec += 8;
107             vecIn = vld1q(pVec);
108             pVec += 8;
109 
110             acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
111             acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
112             acc2 = vmlaldavaq(acc2, vecIn, vecMatA2);
113 
114             blkCnt--;
115         }
116         /*
117          * tail
118          * (will be merged thru tail predication)
119          */
120         blkCnt = numCols & 7;
121         if (blkCnt > 0U)
122         {
123             mve_pred16_t p0 = vctp16q(blkCnt);
124 
125             vecMatA0 = vld1q(pMat0Vec);
126             vecMatA1 = vld1q(pMat1Vec);
127             vecMatA2 = vld1q(pMat2Vec);
128             vecIn = vldrhq_z_s16(pVec, p0);
129 
130             acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
131             acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
132             acc2 = vmlaldavaq(acc2, vecIn, vecMatA2);
133         }
134 
135         *px++ = MVE_ASRL_SAT16(acc0, 15);
136         *px++ = MVE_ASRL_SAT16(acc1, 15);
137         *px++ = MVE_ASRL_SAT16(acc2, 15);
138 
139         pMatSrc += numCols * 3;
140         /*
141          * Decrement the row loop counter
142          */
143         row -= 3;
144     }
145 
146     /*
147      * process any remaining rows pair
148      */
149     if (row >= 2)
150     {
151         q15_t const *pMat0Vec, *pMat1Vec, *pVec;
152         q15_t const  *pSrcVecPtr = pSrcVec;
153         q63_t         acc0, acc1;
154         q15x8_t     vecMatA0, vecMatA1, vecIn;
155 
156         /*
157          * For every row wise process, the pInVec pointer is set
158          * to the starting address of the vector
159          */
160         pVec = pSrcVec;
161 
162         /*
163          * Initialize the pointer pIn1 to point to the starting address of the column being processed
164          */
165         pMat0 = pMatSrc;
166         pMat1 = pMat0 + numCols;
167 
168         acc0 = 0LL;
169         acc1 = 0LL;
170 
171         pMat0Vec = pMat0;
172         pMat1Vec = pMat1;
173         pVec = pSrcVecPtr;
174 
175         blkCnt = numCols >> 3;
176         while (blkCnt > 0U)
177         {
178             vecMatA0 = vld1q(pMat0Vec);
179             pMat0Vec += 8;
180             vecMatA1 = vld1q(pMat1Vec);
181             pMat1Vec += 8;
182             vecIn = vld1q(pVec);
183             pVec += 8;
184 
185             acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
186             acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
187 
188             blkCnt--;
189         }
190 
191         /*
192          * tail
193          * (will be merged thru tail predication)
194          */
195         blkCnt = numCols & 7;
196         if (blkCnt > 0U)
197         {
198             mve_pred16_t p0 = vctp16q(blkCnt);
199 
200             vecMatA0 = vld1q(pMat0Vec);
201             vecMatA1 = vld1q(pMat1Vec);
202             vecIn = vldrhq_z_s16(pVec, p0);
203 
204             acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
205             acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
206         }
207 
208         *px++ = MVE_ASRL_SAT16(acc0, 15);
209         *px++ = MVE_ASRL_SAT16(acc1, 15);
210 
211         pMatSrc += numCols * 2;
212         /*
213          * Decrement the row loop counter
214          */
215         row -= 2;
216     }
217 
218     if (row >= 1)
219     {
220         q15_t const *pMat0Vec, *pVec;
221         q15_t const  *pSrcVecPtr = pSrcVec;
222         q63_t         acc0;
223         q15x8_t     vecMatA0, vecIn;
224 
225         /*
226          * For every row wise process, the pInVec pointer is set
227          * to the starting address of the vector
228          */
229         pVec = pSrcVec;
230 
231         /*
232          * Initialize the pointer pIn1 to point to the starting address of the column being processed
233          */
234         pMat0 = pMatSrc;
235 
236         acc0 = 0LL;
237 
238         pMat0Vec = pMat0;
239         pVec = pSrcVecPtr;
240 
241         blkCnt = numCols >> 3;
242         while (blkCnt > 0U)
243         {
244             vecMatA0 = vld1q(pMat0Vec);
245             pMat0Vec += 8;
246             vecIn = vld1q(pVec);
247             pVec += 8;
248             acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
249             blkCnt--;
250         }
251         /*
252          * tail
253          * (will be merged thru tail predication)
254          */
255         blkCnt = numCols & 7;
256         if (blkCnt > 0U)
257         {
258             mve_pred16_t p0 = vctp16q(blkCnt);
259 
260             vecMatA0 = vld1q(pMat0Vec);
261             vecIn = vldrhq_z_s16(pVec, p0);
262             acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
263         }
264         *px++ = MVE_ASRL_SAT16(acc0, 15);
265     }
266 }
267 
268 #else
arm_mat_vec_mult_q15(const arm_matrix_instance_q15 * pSrcMat,const q15_t * pVec,q15_t * pDst)269 void arm_mat_vec_mult_q15(const arm_matrix_instance_q15 *pSrcMat, const q15_t *pVec, q15_t *pDst)
270 {
271     uint32_t numRows = pSrcMat->numRows;
272     uint32_t numCols = pSrcMat->numCols;
273     const q15_t *pSrcA = pSrcMat->pData;
274     const q15_t *pInA1;      /* input data matrix pointer A of Q15 type */
275     const q15_t *pInA2;      /* input data matrix pointer A of Q15 type */
276     const q15_t *pInA3;      /* input data matrix pointer A of Q15 type */
277     const q15_t *pInA4;      /* input data matrix pointer A of Q15 type */
278     const q15_t *pInVec;     /* input data matrix pointer B of Q15 type */
279     q15_t *px;               /* Temporary output data matrix pointer */
280     uint16_t i, row, colCnt; /* loop counters */
281     q31_t matData, matData2, vecData, vecData2;
282 
283 
284     /* Process 4 rows at a time */
285     row = numRows >> 2;
286     i = 0u;
287     px = pDst;
288 
289     /* The following loop performs the dot-product of each row in pSrcA with the vector */
290     /* row loop */
291     while (row > 0) {
292         /* Initialize accumulators */
293         q63_t sum1 = 0;
294         q63_t sum2 = 0;
295         q63_t sum3 = 0;
296         q63_t sum4 = 0;
297 
298         /* For every row wise process, the pInVec pointer is set
299          ** to the starting address of the vector */
300         pInVec = pVec;
301 
302         /* Loop unrolling: process 2 columns per iteration */
303         colCnt = numCols >> 1;
304 
305         /* Initialize pointers to the starting address of the column being processed */
306         pInA1 = pSrcA + i;
307         pInA2 = pInA1 + numCols;
308         pInA3 = pInA2 + numCols;
309         pInA4 = pInA3 + numCols;
310 
311         // Main loop: matrix-vector multiplication
312         while (colCnt > 0u) {
313             // Read 2 values from vector
314             vecData = read_q15x2_ia (&pInVec);
315 
316             // Read 8 values from the matrix - 2 values from each of 4 rows, and do multiply accumulate
317             matData =  read_q15x2_ia (&pInA1);
318             sum1 = __SMLALD(matData, vecData, sum1);
319             matData = read_q15x2_ia (&pInA2);
320             sum2 = __SMLALD(matData, vecData, sum2);
321             matData = read_q15x2_ia (&pInA3);
322             sum3 = __SMLALD(matData, vecData, sum3);
323             matData = read_q15x2_ia (&pInA4);
324             sum4 = __SMLALD(matData, vecData, sum4);
325 
326             // Decrement the loop counter
327             colCnt--;
328         }
329 
330         /* process any remaining columns */
331         colCnt = numCols & 1u;
332         if (numCols & 1u) {
333             vecData = *pInVec++;
334             sum1 += (q63_t)*pInA1++ * vecData;
335             sum2 += (q63_t)*pInA2++ * vecData;
336             sum3 += (q63_t)*pInA3++ * vecData;
337             sum4 += (q63_t)*pInA4++ * vecData;
338         }
339 
340         /* Saturate and store the result in the destination buffer */
341         *px++ = (q15_t)(__SSAT((sum1 >> 15), 16));
342         *px++ = (q15_t)(__SSAT((sum2 >> 15), 16));
343         *px++ = (q15_t)(__SSAT((sum3 >> 15), 16));
344         *px++ = (q15_t)(__SSAT((sum4 >> 15), 16));
345 
346         i = i + numCols * 4;
347 
348         /* Decrement the row loop counter */
349         row--;
350     }
351 
352     /* process any remaining rows */
353     row = numRows & 3u;
354     while (row > 0) {
355 
356         q63_t sum = 0;
357         pInVec = pVec;
358         pInA1 = pSrcA + i;
359 
360         // loop unrolling - process 4 elements at a time
361         colCnt = numCols >> 2;
362 
363         while (colCnt > 0) {
364             vecData = read_q15x2_ia (&pInVec);
365             vecData2 = read_q15x2_ia (&pInVec);
366             matData = read_q15x2_ia (&pInA1);
367             matData2 = read_q15x2_ia (&pInA1);
368             sum = __SMLALD(matData, vecData, sum);
369             sum = __SMLALD(matData2, vecData2, sum);
370             colCnt--;
371         }
372 
373         // process remainder of row
374         colCnt = numCols & 3u;
375         while (colCnt > 0) {
376             sum += (q63_t)*pInA1++ * *pInVec++;
377             colCnt--;
378         }
379         *px++ = (q15_t)(__SSAT((sum >> 15), 16));
380         i = i + numCols;
381         row--;
382     }
383 }
384 #endif /* defined(ARM_MATH_MVEI) */
385 
386 /**
387  * @} end of MatrixMult group
388  */
389