1 /*
2  * SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates <open-source-office.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_lstm_step_s8_s16.c
22  * Description:  Update LSTM function for a single iteration step.
23  *
24  * $Date:        9 Februari 2023
25  * $Revision:    V.1.1.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 #include "arm_nnsupportfunctions.h"
31 /**
32  * @ingroup groupSupport
33  */
34 
35 /**
36  * @addtogroup supportLSTM
37  * @{
38  */
39 
40 /*
41  * Calculate the output state tensor of an LSTM step, s8 input/output and s16 weight version.
42  * Refer to header file for details.
43  */
arm_nn_lstm_step_s8_s16(const int8_t * input,const int8_t * input_to_input_weight,const int8_t * input_to_forget_weight,const int8_t * input_to_cell_weight,const int8_t * input_to_output_weight,const int8_t * recurrent_to_input_weight,const int8_t * recurrent_to_forget_weight,const int8_t * recurrent_to_cell_weight,const int8_t * recurrent_to_output_weight,const cmsis_nn_lstm_params * lstm,const int n_batch,const int n_cell,const int n_input,const int n_output,int8_t * output_state,int16_t * cell_state,int8_t * output,cmsis_nn_lstm_context * scratch_buffers)44 arm_cmsis_nn_status arm_nn_lstm_step_s8_s16(const int8_t *input,
45                                             const int8_t *input_to_input_weight,
46                                             const int8_t *input_to_forget_weight,
47                                             const int8_t *input_to_cell_weight,
48                                             const int8_t *input_to_output_weight,
49                                             const int8_t *recurrent_to_input_weight,
50                                             const int8_t *recurrent_to_forget_weight,
51                                             const int8_t *recurrent_to_cell_weight,
52                                             const int8_t *recurrent_to_output_weight,
53                                             const cmsis_nn_lstm_params *lstm,
54                                             const int n_batch,
55                                             const int n_cell,
56                                             const int n_input,
57                                             const int n_output,
58                                             int8_t *output_state,
59                                             int16_t *cell_state,
60                                             int8_t *output,
61                                             cmsis_nn_lstm_context *scratch_buffers)
62 {
63     const int32_t n_block = n_batch * n_cell;
64 
65     // Calculate the input gate
66     arm_nn_lstm_calculate_gate_s8_s16(input,
67                                       input_to_input_weight,
68                                       lstm->i2i_effective_bias,
69                                       lstm->input_to_input_scaling,
70                                       output_state,
71                                       recurrent_to_input_weight,
72                                       lstm->r2i_effective_bias,
73                                       lstm->recurrent_to_input_scaling,
74                                       n_batch,
75                                       n_input,
76                                       n_output,
77                                       n_cell,
78                                       ARM_SIGMOID,
79                                       scratch_buffers->input_gate);
80 
81     // Calculate the forget gate
82     arm_nn_lstm_calculate_gate_s8_s16(input,
83                                       input_to_forget_weight,
84                                       lstm->i2f_effective_bias,
85                                       lstm->input_to_forget_scaling,
86                                       output_state,
87                                       recurrent_to_forget_weight,
88                                       lstm->r2f_effective_bias,
89                                       lstm->recurrent_to_forget_scaling,
90                                       n_batch,
91                                       n_input,
92                                       n_output,
93                                       n_cell,
94                                       ARM_SIGMOID,
95                                       scratch_buffers->forget_gate);
96 
97     // Calculate the cell update gate
98     arm_nn_lstm_calculate_gate_s8_s16(input,
99                                       input_to_cell_weight,
100                                       lstm->i2c_effective_bias,
101                                       lstm->input_to_cell_scaling,
102                                       output_state,
103                                       recurrent_to_cell_weight,
104                                       lstm->r2c_effective_bias,
105                                       lstm->recurrent_to_cell_scaling,
106                                       n_batch,
107                                       n_input,
108                                       n_output,
109                                       n_cell,
110                                       ARM_TANH,
111                                       scratch_buffers->cell_gate);
112 
113     // Update the cell state
114     arm_nn_lstm_update_cell_state_s16(n_block,
115                                       lstm->cell_state_shift,
116                                       cell_state,
117                                       scratch_buffers->input_gate,
118                                       scratch_buffers->forget_gate,
119                                       scratch_buffers->cell_gate);
120 
121     // Calculate the output gate
122     arm_nn_lstm_calculate_gate_s8_s16(input,
123                                       input_to_output_weight,
124                                       lstm->i2o_effective_bias,
125                                       lstm->input_to_output_scaling,
126                                       output_state,
127                                       recurrent_to_output_weight,
128                                       lstm->r2o_effective_bias,
129                                       lstm->recurrent_to_output_scaling,
130                                       n_batch,
131                                       n_input,
132                                       n_output,
133                                       n_cell,
134                                       ARM_SIGMOID,
135                                       scratch_buffers->output_gate);
136 
137     // Update the output state
138     arm_nn_lstm_update_output_s8_s16(n_batch,
139                                      n_cell,
140                                      cell_state,
141                                      lstm->cell_state_shift,
142                                      scratch_buffers->output_gate,
143                                      lstm->hidden_scaling,
144                                      lstm->hidden_offset,
145                                      output_state,
146                                      scratch_buffers->input_gate);
147 
148     arm_memcpy_s8(output, output_state, n_batch * n_output * sizeof(int8_t));
149 
150     return ARM_CMSIS_NN_SUCCESS;
151 }
152 /**
153  * @} end of supportLSTM group
154  */
155