1 /*
2  * SPDX-FileCopyrightText: Copyright 2010-2020, 2022-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_softmax_u8.c
22  * Description:  U8 softmax function
23  *
24  * $Date:        5 January 2023
25  * $Revision:    V.1.1.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nnfunctions.h"
32 #include "arm_nnsupportfunctions.h"
33 
34 #define ACCUM_BITS 12
35 
36 /**
37  *  @ingroup Public
38  */
39 
40 /**
41  * @addtogroup Softmax
42  * @{
43  */
arm_softmax_u8(const uint8_t * input,const int32_t num_rows,const int32_t row_size,const int32_t mult,const int32_t shift,const int32_t diff_min,uint8_t * output)44 void arm_softmax_u8(const uint8_t *input,
45                     const int32_t num_rows,
46                     const int32_t row_size,
47                     const int32_t mult,
48                     const int32_t shift,
49                     const int32_t diff_min,
50                     uint8_t *output)
51 {
52     const int32_t mask = (1 << shift);
53 
54     int32_t col = 0;
55     int32_t row_idx;
56 
57     for (row_idx = 0; row_idx < num_rows; ++row_idx)
58     {
59         // Find the maximum value in order to ensure numerical stability
60         uint8_t max = *input;
61 
62         for (col = 1; col < row_size; ++col)
63         {
64             max = MAX(max, input[col]);
65         }
66 
67         int32_t diff = 0;
68         int32_t sum = 0;
69 
70         for (col = 0; col < row_size; ++col)
71         {
72             diff = input[col] - max;
73             if (diff >= diff_min)
74             {
75                 sum += DIV_POW2(EXP_ON_NEG(MUL_SAT(diff * mask, mult)), ACCUM_BITS);
76             }
77         }
78 
79         const int32_t headroom = CLZ((uint32_t)sum);
80         const int32_t bits_over_unit = ACCUM_BITS - headroom + 23;
81         const int32_t shifted_scale = ONE_OVER1((sum << headroom) - (1 << 31));
82 
83         for (col = 0; col < row_size; ++col)
84         {
85             diff = input[col] - max;
86             if (diff >= diff_min)
87             {
88                 const int32_t res =
89                     DIV_POW2(MUL_SAT(shifted_scale, EXP_ON_NEG(MUL_SAT(diff * mask, mult))), bits_over_unit);
90                 output[col] = (uint8_t)CLAMP(res, (int32_t)255, (int32_t)0);
91             }
92             else
93             {
94                 output[col] = 0;
95             }
96         }
97         input += row_size;
98         output += row_size;
99     }
100 }
101 /**
102  * @} end of Softmax group
103  */