1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_logsumexp_f16.c
4  * Description:  LogSumExp
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/statistics_functions_f16.h"
30 
31 #if defined(ARM_FLOAT16_SUPPORTED)
32 
33 #include <limits.h>
34 #include <math.h>
35 
36 
37 /**
38  * @addtogroup LogSumExp
39  * @{
40  */
41 
42 
43 /**
44  * @brief Computation of the LogSumExp
45  *
46  * In probabilistic computations, the dynamic of the probability values can be very
47  * wide because they come from gaussian functions.
48  * To avoid underflow and overflow issues, the values are represented by their log.
49  * In this representation, multiplying the original exp values is easy : their logs are added.
50  * But adding the original exp values is requiring some special handling and it is the
51  * goal of the LogSumExp function.
52  *
53  * If the values are x1...xn, the function is computing:
54  *
55  * ln(exp(x1) + ... + exp(xn)) and the computation is done in such a way that
56  * rounding issues are minimised.
57  *
58  * The max xm of the values is extracted and the function is computing:
59  * xm + ln(exp(x1 - xm) + ... + exp(xn - xm))
60  *
61  * @param[in]  *in         Pointer to an array of input values.
62  * @param[in]  blockSize   Number of samples in the input array.
63  * @return LogSumExp
64  *
65  */
66 
67 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
68 
69 #include "arm_helium_utils.h"
70 #include "arm_vec_math_f16.h"
71 
arm_logsumexp_f16(const float16_t * in,uint32_t blockSize)72 float16_t arm_logsumexp_f16(const float16_t *in, uint32_t blockSize)
73 {
74     float16_t       maxVal;
75     const float16_t *pIn;
76     int32_t         blkCnt;
77     _Float16       accum=0.0f16;
78     _Float16       tmp;
79 
80 
81     arm_max_no_idx_f16((float16_t *) in, blockSize, &maxVal);
82 
83 
84     blkCnt = blockSize;
85     pIn = in;
86 
87 
88     f16x8_t         vSum = vdupq_n_f16(0.0f16);
89     blkCnt = blockSize >> 3;
90     while(blkCnt > 0)
91     {
92         f16x8_t         vecIn = vld1q(pIn);
93         f16x8_t         vecExp;
94 
95         vecExp = vexpq_f16(vsubq_n_f16(vecIn, maxVal));
96 
97         vSum = vaddq_f16(vSum, vecExp);
98 
99         /*
100          * Decrement the blockSize loop counter
101          * Advance vector source and destination pointers
102          */
103         pIn += 8;
104         blkCnt --;
105     }
106 
107     /* sum + log */
108     accum = vecAddAcrossF16Mve(vSum);
109 
110     blkCnt = blockSize & 0x7;
111     while(blkCnt > 0)
112     {
113        tmp = *pIn++;
114        accum += expf(tmp - maxVal);
115        blkCnt--;
116 
117     }
118 
119     accum = maxVal + logf(accum);
120 
121     return (accum);
122 }
123 
124 #else
arm_logsumexp_f16(const float16_t * in,uint32_t blockSize)125 float16_t arm_logsumexp_f16(const float16_t *in, uint32_t blockSize)
126 {
127     _Float16 maxVal;
128     _Float16 tmp;
129     const float16_t *pIn;
130     uint32_t blkCnt;
131     _Float16 accum;
132 
133     pIn = in;
134     blkCnt = blockSize;
135 
136     maxVal = *pIn++;
137     blkCnt--;
138 
139     while(blkCnt > 0)
140     {
141        tmp = *pIn++;
142 
143        if (tmp > maxVal)
144        {
145           maxVal = tmp;
146        }
147        blkCnt--;
148 
149     }
150 
151     blkCnt = blockSize;
152     pIn = in;
153     accum = 0;
154     while(blkCnt > 0)
155     {
156        tmp = *pIn++;
157        accum += expf(tmp - maxVal);
158        blkCnt--;
159 
160     }
161     accum = maxVal + logf(accum);
162 
163     return(accum);
164 }
165 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
166 
167 /**
168  * @} end of LogSumExp group
169  */
170 
171 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */
172 
173