1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_svm_rbf_predict_f32.c
4  * Description:  SVM Radial Basis Function Classifier
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/svm_functions.h"
30 #include <limits.h>
31 #include <math.h>
32 
33 
34 /**
35  * @addtogroup rbfsvm
36  * @{
37  */
38 
39 
40 /**
41  * @brief SVM rbf prediction
42  * @param[in]    S         Pointer to an instance of the rbf SVM structure.
43  * @param[in]    in        Pointer to input vector
44  * @param[out]   pResult   decision value
45  *
46  */
47 
48 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
49 
50 #include "arm_helium_utils.h"
51 #include "arm_vec_math.h"
52 
arm_svm_rbf_predict_f32(const arm_svm_rbf_instance_f32 * S,const float32_t * in,int32_t * pResult)53 ARM_DSP_ATTRIBUTE void arm_svm_rbf_predict_f32(
54     const arm_svm_rbf_instance_f32 *S,
55     const float32_t * in,
56     int32_t * pResult)
57 {
58         /* inlined Matrix x Vector function interleaved with dot prod */
59     uint32_t        numRows = S->nbOfSupportVectors;
60     uint32_t        numCols = S->vectorDimension;
61     const float32_t *pSupport = S->supportVectors;
62     const float32_t *pSrcA = pSupport;
63     const float32_t *pInA0;
64     const float32_t *pInA1;
65     uint32_t         row;
66     uint32_t         blkCnt;     /* loop counters */
67     const float32_t *pDualCoef = S->dualCoefficients;
68     float32_t       sum = S->intercept;
69     f32x4_t         vSum = vdupq_n_f32(0);
70 
71     row = numRows;
72 
73     /*
74      * compute 4 rows in parrallel
75      */
76     while (row >= 4) {
77         const float32_t *pInA2, *pInA3;
78         float32_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec, *pInVec;
79         f32x4_t         vecIn, acc0, acc1, acc2, acc3;
80         float32_t const *pSrcVecPtr = in;
81 
82         /*
83          * Initialize the pointers to 4 consecutive MatrixA rows
84          */
85         pInA0 = pSrcA;
86         pInA1 = pInA0 + numCols;
87         pInA2 = pInA1 + numCols;
88         pInA3 = pInA2 + numCols;
89         /*
90          * Initialize the vector pointer
91          */
92         pInVec = pSrcVecPtr;
93         /*
94          * reset accumulators
95          */
96         acc0 = vdupq_n_f32(0.0f);
97         acc1 = vdupq_n_f32(0.0f);
98         acc2 = vdupq_n_f32(0.0f);
99         acc3 = vdupq_n_f32(0.0f);
100 
101         pSrcA0Vec = pInA0;
102         pSrcA1Vec = pInA1;
103         pSrcA2Vec = pInA2;
104         pSrcA3Vec = pInA3;
105 
106         blkCnt = numCols >> 2;
107         while (blkCnt > 0U) {
108             f32x4_t         vecA;
109             f32x4_t         vecDif;
110 
111             vecIn = vld1q(pInVec);
112             pInVec += 4;
113             vecA = vld1q(pSrcA0Vec);
114             pSrcA0Vec += 4;
115             vecDif = vsubq(vecIn, vecA);
116             acc0 = vfmaq(acc0, vecDif, vecDif);
117             vecA = vld1q(pSrcA1Vec);
118             pSrcA1Vec += 4;
119             vecDif = vsubq(vecIn, vecA);
120             acc1 = vfmaq(acc1, vecDif, vecDif);
121             vecA = vld1q(pSrcA2Vec);
122             pSrcA2Vec += 4;
123             vecDif = vsubq(vecIn, vecA);
124             acc2 = vfmaq(acc2, vecDif, vecDif);
125             vecA = vld1q(pSrcA3Vec);
126             pSrcA3Vec += 4;
127             vecDif = vsubq(vecIn, vecA);
128             acc3 = vfmaq(acc3, vecDif, vecDif);
129 
130             blkCnt--;
131         }
132         /*
133          * tail
134          * (will be merged thru tail predication)
135          */
136         blkCnt = numCols & 3;
137         if (blkCnt > 0U) {
138             mve_pred16_t    p0 = vctp32q(blkCnt);
139             f32x4_t         vecA;
140             f32x4_t         vecDif;
141 
142             vecIn = vldrwq_z_f32(pInVec, p0);
143             vecA = vldrwq_z_f32(pSrcA0Vec, p0);
144             vecDif = vsubq(vecIn, vecA);
145             acc0 = vfmaq(acc0, vecDif, vecDif);
146             vecA = vldrwq_z_f32(pSrcA1Vec, p0);
147             vecDif = vsubq(vecIn, vecA);
148             acc1 = vfmaq(acc1, vecDif, vecDif);
149             vecA = vldrwq_z_f32(pSrcA2Vec, p0);;
150             vecDif = vsubq(vecIn, vecA);
151             acc2 = vfmaq(acc2, vecDif, vecDif);
152             vecA = vldrwq_z_f32(pSrcA3Vec, p0);
153             vecDif = vsubq(vecIn, vecA);
154             acc3 = vfmaq(acc3, vecDif, vecDif);
155         }
156         /*
157          * Sum the partial parts
158          */
159 
160         //sum += *pDualCoef++ * expf(-S->gamma * vecReduceF32Mve(acc0));
161         f32x4_t         vtmp = vuninitializedq_f32();
162         vtmp = vsetq_lane(vecAddAcrossF32Mve(acc0), vtmp, 0);
163         vtmp = vsetq_lane(vecAddAcrossF32Mve(acc1), vtmp, 1);
164         vtmp = vsetq_lane(vecAddAcrossF32Mve(acc2), vtmp, 2);
165         vtmp = vsetq_lane(vecAddAcrossF32Mve(acc3), vtmp, 3);
166 
167         vSum =
168             vfmaq_f32(vSum, vld1q(pDualCoef),
169                       vexpq_f32(vmulq_n_f32(vtmp, -S->gamma)));
170         pDualCoef += 4;
171         pSrcA += numCols * 4;
172         /*
173          * Decrement the row loop counter
174          */
175         row -= 4;
176     }
177 
178     /*
179      * compute 2 rows in parrallel
180      */
181     if (row >= 2) {
182         float32_t const *pSrcA0Vec, *pSrcA1Vec, *pInVec;
183         f32x4_t         vecIn, acc0, acc1;
184         float32_t const *pSrcVecPtr = in;
185 
186         /*
187          * Initialize the pointers to 2 consecutive MatrixA rows
188          */
189         pInA0 = pSrcA;
190         pInA1 = pInA0 + numCols;
191         /*
192          * Initialize the vector pointer
193          */
194         pInVec = pSrcVecPtr;
195         /*
196          * reset accumulators
197          */
198         acc0 = vdupq_n_f32(0.0f);
199         acc1 = vdupq_n_f32(0.0f);
200         pSrcA0Vec = pInA0;
201         pSrcA1Vec = pInA1;
202 
203         blkCnt = numCols >> 2;
204         while (blkCnt > 0U) {
205             f32x4_t         vecA;
206             f32x4_t         vecDif;
207 
208             vecIn = vld1q(pInVec);
209             pInVec += 4;
210             vecA = vld1q(pSrcA0Vec);
211             pSrcA0Vec += 4;
212             vecDif = vsubq(vecIn, vecA);
213             acc0 = vfmaq(acc0, vecDif, vecDif);;
214             vecA = vld1q(pSrcA1Vec);
215             pSrcA1Vec += 4;
216             vecDif = vsubq(vecIn, vecA);
217             acc1 = vfmaq(acc1, vecDif, vecDif);
218 
219             blkCnt--;
220         }
221         /*
222          * tail
223          * (will be merged thru tail predication)
224          */
225         blkCnt = numCols & 3;
226         if (blkCnt > 0U) {
227             mve_pred16_t    p0 = vctp32q(blkCnt);
228             f32x4_t         vecA, vecDif;
229 
230             vecIn = vldrwq_z_f32(pInVec, p0);
231             vecA = vldrwq_z_f32(pSrcA0Vec, p0);
232             vecDif = vsubq(vecIn, vecA);
233             acc0 = vfmaq(acc0, vecDif, vecDif);
234             vecA = vldrwq_z_f32(pSrcA1Vec, p0);
235             vecDif = vsubq(vecIn, vecA);
236             acc1 = vfmaq(acc1, vecDif, vecDif);
237         }
238         /*
239          * Sum the partial parts
240          */
241         f32x4_t         vtmp = vuninitializedq_f32();
242         vtmp = vsetq_lane(vecAddAcrossF32Mve(acc0), vtmp, 0);
243         vtmp = vsetq_lane(vecAddAcrossF32Mve(acc1), vtmp, 1);
244 
245         vSum =
246             vfmaq_m_f32(vSum, vld1q(pDualCoef),
247                         vexpq_f32(vmulq_n_f32(vtmp, -S->gamma)), vctp32q(2));
248         pDualCoef += 2;
249 
250         pSrcA += numCols * 2;
251         row -= 2;
252     }
253 
254     if (row >= 1) {
255         f32x4_t         vecIn, acc0;
256         float32_t const *pSrcA0Vec, *pInVec;
257         float32_t const *pSrcVecPtr = in;
258         /*
259          * Initialize the pointers to last MatrixA row
260          */
261         pInA0 = pSrcA;
262         /*
263          * Initialize the vector pointer
264          */
265         pInVec = pSrcVecPtr;
266         /*
267          * reset accumulators
268          */
269         acc0 = vdupq_n_f32(0.0f);
270 
271         pSrcA0Vec = pInA0;
272 
273         blkCnt = numCols >> 2;
274         while (blkCnt > 0U) {
275             f32x4_t         vecA, vecDif;
276 
277             vecIn = vld1q(pInVec);
278             pInVec += 4;
279             vecA = vld1q(pSrcA0Vec);
280             pSrcA0Vec += 4;
281             vecDif = vsubq(vecIn, vecA);
282             acc0 = vfmaq(acc0, vecDif, vecDif);
283 
284             blkCnt--;
285         }
286         /*
287          * tail
288          * (will be merged thru tail predication)
289          */
290         blkCnt = numCols & 3;
291         if (blkCnt > 0U) {
292             mve_pred16_t    p0 = vctp32q(blkCnt);
293             f32x4_t         vecA, vecDif;
294 
295             vecIn = vldrwq_z_f32(pInVec, p0);
296             vecA = vldrwq_z_f32(pSrcA0Vec, p0);
297             vecDif = vsubq(vecIn, vecA);
298             acc0 = vfmaq(acc0, vecDif, vecDif);
299         }
300         /*
301          * Sum the partial parts
302          */
303         f32x4_t         vtmp = vuninitializedq_f32();
304         vtmp = vsetq_lane(vecAddAcrossF32Mve(acc0), vtmp, 0);
305 
306         vSum =
307             vfmaq_m_f32(vSum, vld1q(pDualCoef),
308                         vexpq_f32(vmulq_n_f32(vtmp, -S->gamma)), vctp32q(1));
309 
310     }
311 
312 
313     sum += vecAddAcrossF32Mve(vSum);
314     *pResult = S->classes[STEP(sum)];
315 }
316 
317 
318 #else
319 #if defined(ARM_MATH_NEON)
320 
321 #include "NEMath.h"
322 
arm_svm_rbf_predict_f32(const arm_svm_rbf_instance_f32 * S,const float32_t * in,int32_t * pResult)323 ARM_DSP_ATTRIBUTE void arm_svm_rbf_predict_f32(
324     const arm_svm_rbf_instance_f32 *S,
325     const float32_t * in,
326     int32_t * pResult)
327 {
328     float32_t sum = S->intercept;
329 
330     float32_t dot;
331     float32x4_t dotV;
332 
333     float32x4_t accuma,accumb,accumc,accumd,accum;
334     float32x2_t accum2;
335     float32x4_t temp;
336     float32x4_t vec1;
337 
338     float32x4_t vec2,vec2a,vec2b,vec2c,vec2d;
339 
340     uint32_t blkCnt;
341     uint32_t vectorBlkCnt;
342 
343     const float32_t *pIn = in;
344 
345     const float32_t *pSupport = S->supportVectors;
346 
347     const float32_t *pSupporta = S->supportVectors;
348     const float32_t *pSupportb;
349     const float32_t *pSupportc;
350     const float32_t *pSupportd;
351 
352     pSupportb = pSupporta + S->vectorDimension;
353     pSupportc = pSupportb + S->vectorDimension;
354     pSupportd = pSupportc + S->vectorDimension;
355 
356     const float32_t *pDualCoefs = S->dualCoefficients;
357 
358 
359     vectorBlkCnt = S->nbOfSupportVectors >> 2;
360     while (vectorBlkCnt > 0U)
361     {
362         accuma = vdupq_n_f32(0);
363         accumb = vdupq_n_f32(0);
364         accumc = vdupq_n_f32(0);
365         accumd = vdupq_n_f32(0);
366 
367         pIn = in;
368 
369         blkCnt = S->vectorDimension >> 2;
370         while (blkCnt > 0U)
371         {
372 
373             vec1 = vld1q_f32(pIn);
374             vec2a = vld1q_f32(pSupporta);
375             vec2b = vld1q_f32(pSupportb);
376             vec2c = vld1q_f32(pSupportc);
377             vec2d = vld1q_f32(pSupportd);
378 
379             pIn += 4;
380             pSupporta += 4;
381             pSupportb += 4;
382             pSupportc += 4;
383             pSupportd += 4;
384 
385             temp = vsubq_f32(vec1, vec2a);
386             accuma = vmlaq_f32(accuma, temp, temp);
387 
388             temp = vsubq_f32(vec1, vec2b);
389             accumb = vmlaq_f32(accumb, temp, temp);
390 
391             temp = vsubq_f32(vec1, vec2c);
392             accumc = vmlaq_f32(accumc, temp, temp);
393 
394             temp = vsubq_f32(vec1, vec2d);
395             accumd = vmlaq_f32(accumd, temp, temp);
396 
397             blkCnt -- ;
398         }
399         accum2 = vpadd_f32(vget_low_f32(accuma),vget_high_f32(accuma));
400         dotV = vsetq_lane_f32(vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1),dotV,0);
401 
402         accum2 = vpadd_f32(vget_low_f32(accumb),vget_high_f32(accumb));
403         dotV = vsetq_lane_f32(vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1),dotV,1);
404 
405         accum2 = vpadd_f32(vget_low_f32(accumc),vget_high_f32(accumc));
406         dotV = vsetq_lane_f32(vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1),dotV,2);
407 
408         accum2 = vpadd_f32(vget_low_f32(accumd),vget_high_f32(accumd));
409         dotV = vsetq_lane_f32(vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1),dotV,3);
410 
411 
412         blkCnt = S->vectorDimension & 3;
413         while (blkCnt > 0U)
414         {
415             dotV = vsetq_lane_f32(vgetq_lane_f32(dotV,0) + ARM_SQ(*pIn - *pSupporta), dotV,0);
416             dotV = vsetq_lane_f32(vgetq_lane_f32(dotV,1) + ARM_SQ(*pIn - *pSupportb), dotV,1);
417             dotV = vsetq_lane_f32(vgetq_lane_f32(dotV,2) + ARM_SQ(*pIn - *pSupportc), dotV,2);
418             dotV = vsetq_lane_f32(vgetq_lane_f32(dotV,3) + ARM_SQ(*pIn - *pSupportd), dotV,3);
419 
420             pSupporta++;
421             pSupportb++;
422             pSupportc++;
423             pSupportd++;
424 
425             pIn++;
426 
427             blkCnt -- ;
428         }
429 
430         vec1 = vld1q_f32(pDualCoefs);
431         pDualCoefs += 4;
432 
433         // To vectorize later
434         dotV = vmulq_n_f32(dotV, -S->gamma);
435         dotV = vexpq_f32(dotV);
436 
437         accum = vmulq_f32(vec1,dotV);
438         accum2 = vpadd_f32(vget_low_f32(accum),vget_high_f32(accum));
439         sum += vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1);
440 
441         pSupporta += 3*S->vectorDimension;
442         pSupportb += 3*S->vectorDimension;
443         pSupportc += 3*S->vectorDimension;
444         pSupportd += 3*S->vectorDimension;
445 
446         vectorBlkCnt -- ;
447     }
448 
449     pSupport = pSupporta;
450     vectorBlkCnt = S->nbOfSupportVectors & 3;
451 
452     while (vectorBlkCnt > 0U)
453     {
454         accum = vdupq_n_f32(0);
455         dot = 0.0f;
456         pIn = in;
457 
458         blkCnt = S->vectorDimension >> 2;
459         while (blkCnt > 0U)
460         {
461 
462             vec1 = vld1q_f32(pIn);
463             vec2 = vld1q_f32(pSupport);
464             pIn += 4;
465             pSupport += 4;
466 
467             temp = vsubq_f32(vec1,vec2);
468             accum = vmlaq_f32(accum, temp,temp);
469 
470             blkCnt -- ;
471         }
472         accum2 = vpadd_f32(vget_low_f32(accum),vget_high_f32(accum));
473         dot = vget_lane_f32(accum2, 0) + vget_lane_f32(accum2, 1);
474 
475 
476         blkCnt = S->vectorDimension & 3;
477         while (blkCnt > 0U)
478         {
479 
480             dot = dot + ARM_SQ(*pIn - *pSupport);
481             pIn++;
482             pSupport++;
483 
484             blkCnt -- ;
485         }
486 
487         sum += *pDualCoefs++ * expf(-S->gamma * dot);
488         vectorBlkCnt -- ;
489     }
490 
491     *pResult=S->classes[STEP(sum)];
492 }
493 #else
arm_svm_rbf_predict_f32(const arm_svm_rbf_instance_f32 * S,const float32_t * in,int32_t * pResult)494 ARM_DSP_ATTRIBUTE void arm_svm_rbf_predict_f32(
495     const arm_svm_rbf_instance_f32 *S,
496     const float32_t * in,
497     int32_t * pResult)
498 {
499     float32_t sum=S->intercept;
500     float32_t dot=0;
501     uint32_t i,j;
502     const float32_t *pSupport = S->supportVectors;
503 
504     for(i=0; i < S->nbOfSupportVectors; i++)
505     {
506         dot=0;
507         for(j=0; j < S->vectorDimension; j++)
508         {
509             dot = dot + ARM_SQ(in[j] - *pSupport);
510             pSupport++;
511         }
512         sum += S->dualCoefficients[i] * expf(-S->gamma * dot);
513     }
514     *pResult=S->classes[STEP(sum)];
515 }
516 #endif
517 
518 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
519 
520 /**
521  * @} end of rbfsvm group
522  */
523