1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_logsumexp_f16.c
4  * Description:  LogSumExp
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/statistics_functions_f16.h"
30 
31 #if defined(ARM_FLOAT16_SUPPORTED)
32 
33 #include <limits.h>
34 #include <math.h>
35 
36 /**
37   @ingroup groupStats
38  */
39 
40 /**
41   @defgroup Kullback-Leibler Kullback-Leibler divergence
42 
43   Computes the Kullback-Leibler divergence between two distributions
44 
45  */
46 
47 
48 /**
49  * @addtogroup Kullback-Leibler
50  * @{
51  */
52 
53 
54 /**
55  * @brief Kullback-Leibler
56  *
57  * Distribution A may contain 0 with Neon version.
58  * Result will be right but some exception flags will be set.
59  *
60  * Distribution B must not contain 0 probability.
61  *
62  * @param[in]  *pSrcA         points to an array of input values for probaility distribution A.
63  * @param[in]  *pSrcB         points to an array of input values for probaility distribution B.
64  * @param[in]  blockSize      number of samples in the input array.
65  * @return Kullback-Leibler divergence D(A || B)
66  *
67  */
68 
69 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
70 
71 #include "arm_helium_utils.h"
72 #include "arm_vec_math_f16.h"
73 
arm_kullback_leibler_f16(const float16_t * pSrcA,const float16_t * pSrcB,uint32_t blockSize)74 float16_t arm_kullback_leibler_f16(const float16_t * pSrcA,const float16_t * pSrcB,uint32_t blockSize)
75 {
76     uint32_t blkCnt;
77     _Float16 accum, pA,pB;
78 
79 
80     blkCnt = blockSize;
81 
82     accum = 0.0f16;
83 
84     f16x8_t         vSum = vdupq_n_f16(0.0f);
85     blkCnt = blockSize >> 3;
86     while(blkCnt > 0)
87     {
88         f16x8_t         vecA = vld1q(pSrcA);
89         f16x8_t         vecB = vld1q(pSrcB);
90         f16x8_t         vRatio;
91 
92         vRatio = vdiv_f16(vecB, vecA);
93         vSum = vaddq_f16(vSum, vmulq(vecA, vlogq_f16(vRatio)));
94 
95         /*
96          * Decrement the blockSize loop counter
97          * Advance vector source and destination pointers
98          */
99         pSrcA += 8;
100         pSrcB += 8;
101         blkCnt --;
102     }
103 
104     accum = vecAddAcrossF16Mve(vSum);
105 
106     blkCnt = blockSize & 7;
107     while(blkCnt > 0)
108     {
109        pA = *pSrcA++;
110        pB = *pSrcB++;
111        accum += pA * logf(pB / pA);
112 
113        blkCnt--;
114 
115     }
116 
117     return(-accum);
118 }
119 
120 #else
arm_kullback_leibler_f16(const float16_t * pSrcA,const float16_t * pSrcB,uint32_t blockSize)121 float16_t arm_kullback_leibler_f16(const float16_t * pSrcA,const float16_t * pSrcB,uint32_t blockSize)
122 {
123     const float16_t *pInA, *pInB;
124     uint32_t blkCnt;
125     _Float16 accum, pA,pB;
126 
127     pInA = pSrcA;
128     pInB = pSrcB;
129     blkCnt = blockSize;
130 
131     accum = 0.0f;
132 
133     while(blkCnt > 0)
134     {
135        pA = *pInA++;
136        pB = *pInB++;
137        accum += pA * logf(pB / pA);
138 
139        blkCnt--;
140 
141     }
142 
143     return(-accum);
144 }
145 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
146 
147 /**
148  * @} end of Kullback-Leibler group
149  */
150 
151 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */
152 
153