1 /*
2  * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_softmax_common_s8.c
22  * Description:  Softmax with s8 input and output of s8 or s16.
23  *
24  * $Date:        5 January 2023
25  * $Revision:    V.1.1.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  * -------------------------------------------------------------------- */
29 
30 #include "arm_nnsupportfunctions.h"
31 
32 #define ACCUM_BITS 12
33 
34 /**
35  * @ingroup groupSupport
36  */
37 
38 /**
39  * @defgroup supportSoftmax Softmax
40  *
41  * Support functions for Softmax
42  *
43  */
44 
45 /**
46  * @addtogroup supportSoftmax
47  * @{
48  */
49 
50 /*
51  * Softmax function with s8 input and output of s8 or s16.
52  *
53  * Refer header file for details.
54  *
55  */
arm_nn_softmax_common_s8(const int8_t * input,const int32_t num_rows,const int32_t row_size,const int32_t mult,const int32_t shift,const int32_t diff_min,const bool int16_output,void * output)56 void arm_nn_softmax_common_s8(const int8_t *input,
57                               const int32_t num_rows,
58                               const int32_t row_size,
59                               const int32_t mult,
60                               const int32_t shift,
61                               const int32_t diff_min,
62                               const bool int16_output,
63                               void *output)
64 {
65     const int32_t mask = (1 << shift);
66 
67     int32_t col = 0;
68     int32_t row_idx;
69 
70     for (row_idx = 0; row_idx < num_rows; ++row_idx)
71     {
72         // Find the maximum value in order to ensure numerical stability
73         int8_t max = *input;
74 
75         for (col = 1; col < row_size; ++col)
76         {
77             max = MAX(max, input[col]);
78         }
79 
80         int32_t diff = 0;
81         int32_t sum = 0;
82 
83         for (col = 0; col < row_size; ++col)
84         {
85             diff = input[col] - max;
86             if (diff >= diff_min)
87             {
88                 sum += DIV_POW2(EXP_ON_NEG(MUL_SAT(diff * mask, mult)), ACCUM_BITS);
89             }
90         }
91 
92         const int32_t headroom = CLZ(sum);
93         const int32_t shifted_scale = ONE_OVER1((sum > 0 ? sum << headroom : 0) - (1 << 31));
94         int32_t bits_over_unit;
95 
96         if (int16_output)
97         {
98             int16_t *output_s16 = (int16_t *)output + row_idx * row_size;
99 
100             bits_over_unit = ACCUM_BITS - headroom + 15;
101 
102             for (col = 0; col < row_size; ++col)
103             {
104                 diff = input[col] - max;
105 
106                 if (diff >= diff_min)
107                 {
108                     const int32_t res =
109                         DIV_POW2(MUL_SAT(shifted_scale, EXP_ON_NEG(MUL_SAT(diff * mask, mult))), bits_over_unit) +
110                         NN_Q15_MIN;
111                     output_s16[col] = (int16_t)CLAMP(res, (int32_t)NN_Q15_MAX, (int32_t)NN_Q15_MIN);
112                 }
113                 else
114                 {
115                     output_s16[col] = NN_Q15_MIN;
116                 }
117             }
118         }
119         else
120         {
121             int8_t *output_s8 = (int8_t *)output + row_idx * row_size;
122 
123             bits_over_unit = ACCUM_BITS - headroom + 23;
124 
125             for (col = 0; col < row_size; ++col)
126             {
127                 diff = input[col] - max;
128                 if (diff >= diff_min)
129                 {
130                     const int32_t res =
131                         DIV_POW2(MUL_SAT(shifted_scale, EXP_ON_NEG(MUL_SAT(diff * mask, mult))), bits_over_unit) +
132                         NN_Q7_MIN;
133                     output_s8[col] = (int8_t)CLAMP(res, (int32_t)NN_Q7_MAX, (int32_t)NN_Q7_MIN);
134                 }
135                 else
136                 {
137                     output_s8[col] = NN_Q7_MIN;
138                 }
139             }
140         }
141 
142         input += row_size;
143     }
144 }
145 
146 /**
147  * @} end of Doxygen group
148  */
149