1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_naive_gaussian_bayes_predict_f32
4 * Description: Naive Gaussian Bayesian Estimator
5 *
6 * $Date: 23 April 2021
7 * $Revision: V1.9.0
8 *
9 * Target Processor: Cortex-M and Cortex-A cores
10 * -------------------------------------------------------------------- */
11 /*
12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13 *
14 * SPDX-License-Identifier: Apache-2.0
15 *
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
19 *
20 * www.apache.org/licenses/LICENSE-2.0
21 *
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
27 */
28
29 #include "dsp/bayes_functions.h"
30 #include <limits.h>
31 #include <math.h>
32
33 #define PI_F 3.1415926535897932384626433832795f
34 #define DPI_F (2.0f*3.1415926535897932384626433832795f)
35
36 /**
37 * @addtogroup groupBayes
38 * @{
39 */
40
41 /**
42 * @brief Naive Gaussian Bayesian Estimator
43 *
44 * @param[in] *S points to a naive bayes instance structure
45 * @param[in] *in points to the elements of the input vector.
46 * @param[out] *pOutputProbabilities points to a buffer of length numberOfClasses containing estimated probabilities
47 * @param[out] *pBufferB points to a temporary buffer of length numberOfClasses
48 * @return The predicted class
49 *
50 *
51 */
52
53 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
54
55 #include "arm_helium_utils.h"
56 #include "arm_vec_math.h"
57
arm_gaussian_naive_bayes_predict_f32(const arm_gaussian_naive_bayes_instance_f32 * S,const float32_t * in,float32_t * pOutputProbabilities,float32_t * pBufferB)58 ARM_DSP_ATTRIBUTE uint32_t arm_gaussian_naive_bayes_predict_f32(const arm_gaussian_naive_bayes_instance_f32 *S,
59 const float32_t * in,
60 float32_t *pOutputProbabilities,
61 float32_t *pBufferB
62 )
63 {
64 uint32_t nbClass;
65 const float32_t *pTheta = S->theta;
66 const float32_t *pSigma = S->sigma;
67 float32_t *buffer = pOutputProbabilities;
68 const float32_t *pIn = in;
69 float32_t result;
70 f32x4_t vsigma;
71 float32_t tmp;
72 f32x4_t vacc1, vacc2;
73 uint32_t index;
74 float32_t *logclassPriors=pBufferB;
75 float32_t *pLogPrior = logclassPriors;
76
77 arm_vlog_f32((float32_t *) S->classPriors, logclassPriors, S->numberOfClasses);
78
79 pTheta = S->theta;
80 pSigma = S->sigma;
81
82 for (nbClass = 0; nbClass < S->numberOfClasses; nbClass++) {
83 pIn = in;
84
85 vacc1 = vdupq_n_f32(0);
86 vacc2 = vdupq_n_f32(0);
87
88 uint32_t blkCnt =S->vectorDimension >> 2;
89 while (blkCnt > 0U) {
90 f32x4_t vinvSigma, vtmp;
91
92 vsigma = vaddq_n_f32(vld1q(pSigma), S->epsilon);
93 vacc1 = vaddq(vacc1, vlogq_f32(vmulq_n_f32(vsigma, 2.0f * PI)));
94
95 vinvSigma = vrecip_medprec_f32(vsigma);
96
97 vtmp = vsubq(vld1q(pIn), vld1q(pTheta));
98 /* squaring */
99 vtmp = vmulq(vtmp, vtmp);
100
101 vacc2 = vfmaq(vacc2, vtmp, vinvSigma);
102
103 pIn += 4;
104 pTheta += 4;
105 pSigma += 4;
106 blkCnt--;
107 }
108
109 blkCnt = S->vectorDimension & 3;
110 if (blkCnt > 0U) {
111 mve_pred16_t p0 = vctp32q(blkCnt);
112 f32x4_t vinvSigma, vtmp;
113
114 vsigma = vaddq_n_f32(vld1q(pSigma), S->epsilon);
115 vacc1 =
116 vaddq_m_f32(vacc1, vacc1, vlogq_f32(vmulq_n_f32(vsigma, 2.0f * PI)), p0);
117
118 vinvSigma = vrecip_medprec_f32(vsigma);
119
120 vtmp = vsubq(vld1q(pIn), vld1q(pTheta));
121 /* squaring */
122 vtmp = vmulq(vtmp, vtmp);
123
124 vacc2 = vfmaq_m_f32(vacc2, vtmp, vinvSigma, p0);
125
126 pTheta += blkCnt;
127 pSigma += blkCnt;
128 }
129
130 tmp = -0.5f * vecAddAcrossF32Mve(vacc1);
131 tmp -= 0.5f * vecAddAcrossF32Mve(vacc2);
132
133 *buffer = tmp + *pLogPrior++;
134 buffer++;
135 }
136
137 arm_max_f32(pOutputProbabilities, S->numberOfClasses, &result, &index);
138
139 return (index);
140 }
141
142 #else
143
144 #if defined(ARM_MATH_NEON)
145
146 #include "NEMath.h"
147
148
149
arm_gaussian_naive_bayes_predict_f32(const arm_gaussian_naive_bayes_instance_f32 * S,const float32_t * in,float32_t * pOutputProbabilities,float32_t * pBufferB)150 ARM_DSP_ATTRIBUTE uint32_t arm_gaussian_naive_bayes_predict_f32(const arm_gaussian_naive_bayes_instance_f32 *S,
151 const float32_t * in,
152 float32_t *pOutputProbabilities,
153 float32_t *pBufferB)
154 {
155
156 const float32_t *pPrior = S->classPriors;
157
158 const float32_t *pTheta = S->theta;
159 const float32_t *pSigma = S->sigma;
160
161 const float32_t *pTheta1 = S->theta + S->vectorDimension;
162 const float32_t *pSigma1 = S->sigma + S->vectorDimension;
163
164 float32_t *buffer = pOutputProbabilities;
165 const float32_t *pIn=in;
166
167 float32_t result;
168 float32_t sigma,sigma1;
169 float32_t tmp,tmp1;
170 uint32_t index;
171 uint32_t vecBlkCnt;
172 uint32_t classBlkCnt;
173 float32x4_t epsilonV;
174 float32x4_t sigmaV,sigmaV1;
175 float32x4_t tmpV,tmpVb,tmpV1;
176 float32x2_t tmpV2;
177 float32x4_t thetaV,thetaV1;
178 float32x4_t inV;
179 (void)pBufferB;
180
181 epsilonV = vdupq_n_f32(S->epsilon);
182
183 classBlkCnt = S->numberOfClasses >> 1;
184 while(classBlkCnt > 0)
185 {
186
187
188 pIn = in;
189
190 tmp = logf(*pPrior++);
191 tmp1 = logf(*pPrior++);
192 tmpV = vdupq_n_f32(0.0f);
193 tmpV1 = vdupq_n_f32(0.0f);
194
195 vecBlkCnt = S->vectorDimension >> 2;
196 while(vecBlkCnt > 0)
197 {
198 sigmaV = vld1q_f32(pSigma);
199 thetaV = vld1q_f32(pTheta);
200
201 sigmaV1 = vld1q_f32(pSigma1);
202 thetaV1 = vld1q_f32(pTheta1);
203
204 inV = vld1q_f32(pIn);
205
206 sigmaV = vaddq_f32(sigmaV, epsilonV);
207 sigmaV1 = vaddq_f32(sigmaV1, epsilonV);
208
209 tmpVb = vmulq_n_f32(sigmaV,DPI_F);
210 tmpVb = vlogq_f32(tmpVb);
211 tmpV = vmlsq_n_f32(tmpV,tmpVb,0.5f);
212
213 tmpVb = vmulq_n_f32(sigmaV1,DPI_F);
214 tmpVb = vlogq_f32(tmpVb);
215 tmpV1 = vmlsq_n_f32(tmpV1,tmpVb,0.5f);
216
217 tmpVb = vsubq_f32(inV,thetaV);
218 tmpVb = vmulq_f32(tmpVb,tmpVb);
219 tmpVb = vmulq_f32(tmpVb, vinvq_f32(sigmaV));
220 tmpV = vmlsq_n_f32(tmpV,tmpVb,0.5f);
221
222 tmpVb = vsubq_f32(inV,thetaV1);
223 tmpVb = vmulq_f32(tmpVb,tmpVb);
224 tmpVb = vmulq_f32(tmpVb, vinvq_f32(sigmaV1));
225 tmpV1 = vmlsq_n_f32(tmpV1,tmpVb,0.5f);
226
227 pIn += 4;
228 pTheta += 4;
229 pSigma += 4;
230 pTheta1 += 4;
231 pSigma1 += 4;
232
233 vecBlkCnt--;
234 }
235 tmpV2 = vpadd_f32(vget_low_f32(tmpV),vget_high_f32(tmpV));
236 tmp += vget_lane_f32(tmpV2, 0) + vget_lane_f32(tmpV2, 1);
237
238 tmpV2 = vpadd_f32(vget_low_f32(tmpV1),vget_high_f32(tmpV1));
239 tmp1 += vget_lane_f32(tmpV2, 0) + vget_lane_f32(tmpV2, 1);
240
241 vecBlkCnt = S->vectorDimension & 3;
242 while(vecBlkCnt > 0)
243 {
244 sigma = *pSigma + S->epsilon;
245 sigma1 = *pSigma1 + S->epsilon;
246
247 tmp -= 0.5f*logf(2.0f * PI_F * sigma);
248 tmp -= 0.5f*(*pIn - *pTheta) * (*pIn - *pTheta) / sigma;
249
250 tmp1 -= 0.5f*logf(2.0f * PI_F * sigma1);
251 tmp1 -= 0.5f*(*pIn - *pTheta1) * (*pIn - *pTheta1) / sigma1;
252
253 pIn++;
254 pTheta++;
255 pSigma++;
256 pTheta1++;
257 pSigma1++;
258 vecBlkCnt--;
259 }
260
261 *buffer++ = tmp;
262 *buffer++ = tmp1;
263
264 pSigma += S->vectorDimension;
265 pTheta += S->vectorDimension;
266 pSigma1 += S->vectorDimension;
267 pTheta1 += S->vectorDimension;
268
269 classBlkCnt--;
270 }
271
272 classBlkCnt = S->numberOfClasses & 1;
273
274 while(classBlkCnt > 0)
275 {
276
277
278 pIn = in;
279
280 tmp = logf(*pPrior++);
281 tmpV = vdupq_n_f32(0.0f);
282
283 vecBlkCnt = S->vectorDimension >> 2;
284 while(vecBlkCnt > 0)
285 {
286 sigmaV = vld1q_f32(pSigma);
287 thetaV = vld1q_f32(pTheta);
288 inV = vld1q_f32(pIn);
289
290 sigmaV = vaddq_f32(sigmaV, epsilonV);
291
292 tmpVb = vmulq_n_f32(sigmaV,DPI_F);
293 tmpVb = vlogq_f32(tmpVb);
294 tmpV = vmlsq_n_f32(tmpV,tmpVb,0.5f);
295
296 tmpVb = vsubq_f32(inV,thetaV);
297 tmpVb = vmulq_f32(tmpVb,tmpVb);
298 tmpVb = vmulq_f32(tmpVb, vinvq_f32(sigmaV));
299 tmpV = vmlsq_n_f32(tmpV,tmpVb,0.5f);
300
301 pIn += 4;
302 pTheta += 4;
303 pSigma += 4;
304
305 vecBlkCnt--;
306 }
307 tmpV2 = vpadd_f32(vget_low_f32(tmpV),vget_high_f32(tmpV));
308 tmp += vget_lane_f32(tmpV2, 0) + vget_lane_f32(tmpV2, 1);
309
310 vecBlkCnt = S->vectorDimension & 3;
311 while(vecBlkCnt > 0)
312 {
313 sigma = *pSigma + S->epsilon;
314 tmp -= 0.5f*logf(2.0f * PI_F * sigma);
315 tmp -= 0.5f*(*pIn - *pTheta) * (*pIn - *pTheta) / sigma;
316
317 pIn++;
318 pTheta++;
319 pSigma++;
320 vecBlkCnt--;
321 }
322
323 *buffer++ = tmp;
324
325 classBlkCnt--;
326 }
327
328 arm_max_f32(pOutputProbabilities,S->numberOfClasses,&result,&index);
329
330 return(index);
331 }
332
333 #else
334
arm_gaussian_naive_bayes_predict_f32(const arm_gaussian_naive_bayes_instance_f32 * S,const float32_t * in,float32_t * pOutputProbabilities,float32_t * pBufferB)335 ARM_DSP_ATTRIBUTE uint32_t arm_gaussian_naive_bayes_predict_f32(const arm_gaussian_naive_bayes_instance_f32 *S,
336 const float32_t * in,
337 float32_t *pOutputProbabilities,
338 float32_t *pBufferB)
339 {
340 uint32_t nbClass;
341 uint32_t nbDim;
342 const float32_t *pPrior = S->classPriors;
343 const float32_t *pTheta = S->theta;
344 const float32_t *pSigma = S->sigma;
345 float32_t *buffer = pOutputProbabilities;
346 const float32_t *pIn=in;
347 float32_t result;
348 float32_t sigma;
349 float32_t tmp;
350 float32_t acc1,acc2;
351 uint32_t index;
352
353 (void)pBufferB;
354
355 pTheta=S->theta;
356 pSigma=S->sigma;
357
358 for(nbClass = 0; nbClass < S->numberOfClasses; nbClass++)
359 {
360
361
362 pIn = in;
363
364 tmp = 0.0;
365 acc1 = 0.0f;
366 acc2 = 0.0f;
367 for(nbDim = 0; nbDim < S->vectorDimension; nbDim++)
368 {
369 sigma = *pSigma + S->epsilon;
370 acc1 += logf(2.0f * PI_F * sigma);
371 acc2 += (*pIn - *pTheta) * (*pIn - *pTheta) / sigma;
372
373 pIn++;
374 pTheta++;
375 pSigma++;
376 }
377
378 tmp = -0.5f * acc1;
379 tmp -= 0.5f * acc2;
380
381
382 *buffer = tmp + logf(*pPrior++);
383 buffer++;
384 }
385
386 arm_max_f32(pOutputProbabilities,S->numberOfClasses,&result,&index);
387
388 return(index);
389 }
390
391 #endif
392 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
393
394 /**
395 * @} end of groupBayes group
396 */
397