1 /*
2  * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_lstm_step_s16.c
22  * Description:  Update LSTM function for a single iteration step.
23  *
24  * $Date:        26 March 2024
25  * $Revision:    V.1.0.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 #include "arm_nnfunctions.h"
31 #include "arm_nnsupportfunctions.h"
32 
33 /**
34  * @ingroup groupSupport
35  */
36 
37 /**
38  * @addtogroup supportLSTM
39  * @{
40  */
41 
42 /*
43  * Calculate the output state tensor of an LSTM step, s16 input/output/weights and s16 internal buffers version.
44  * Refer to header file for details.
45  */
arm_nn_lstm_step_s16(const int16_t * data_in,const int16_t * hidden_in,int16_t * hidden_out,const cmsis_nn_lstm_params * params,cmsis_nn_lstm_context * buffers,const int32_t batch_offset)46 arm_cmsis_nn_status arm_nn_lstm_step_s16(const int16_t *data_in,
47                                          const int16_t *hidden_in,
48                                          int16_t *hidden_out,
49                                          const cmsis_nn_lstm_params *params,
50                                          cmsis_nn_lstm_context *buffers,
51                                          const int32_t batch_offset)
52 {
53     int16_t *forget_gate = buffers->temp1;
54     int16_t *input_gate = buffers->temp1;
55     int16_t *cell_gate = buffers->temp2;
56     int16_t *output_gate = buffers->temp1;
57     int16_t *hidden_temp = buffers->temp2;
58 
59     int16_t *cell_state = buffers->cell_state;
60 
61     arm_nn_lstm_calculate_gate_s16(data_in, hidden_in, &params->forget_gate, params, forget_gate, batch_offset);
62 
63     // Calculate first term of cell state in place early to maximise reuse of scratch-buffers
64     arm_elementwise_mul_s16(forget_gate,
65                             cell_state,
66                             0,
67                             0,
68                             cell_state,
69                             0,
70                             params->forget_to_cell_multiplier,
71                             params->forget_to_cell_shift,
72                             NN_Q15_MIN,
73                             NN_Q15_MAX,
74                             params->hidden_size * params->batch_size);
75 
76     arm_nn_lstm_calculate_gate_s16(data_in, hidden_in, &params->input_gate, params, input_gate, batch_offset);
77 
78     arm_nn_lstm_calculate_gate_s16(data_in, hidden_in, &params->cell_gate, params, cell_gate, batch_offset);
79 
80     // Reminder of cell state calculation, multiply and add to previous result.
81     arm_elementwise_mul_acc_s16(forget_gate,
82                                 cell_gate,
83                                 0,
84                                 0,
85                                 cell_state,
86                                 0,
87                                 params->input_to_cell_multiplier,
88                                 params->input_to_cell_shift,
89                                 -params->cell_clip,
90                                 params->cell_clip,
91                                 params->hidden_size * params->batch_size);
92 
93     arm_nn_lstm_calculate_gate_s16(data_in, hidden_in, &params->output_gate, params, output_gate, batch_offset);
94 
95     // Calculate hidden state directly to output.
96     arm_nn_activation_s16(
97         cell_state, hidden_temp, params->hidden_size * params->batch_size, params->cell_scale_power + 12, ARM_TANH);
98     arm_elementwise_mul_s16_batch_offset(output_gate,
99                                          hidden_temp,
100                                          hidden_out,
101                                          params->output_offset,
102                                          params->output_multiplier,
103                                          params->output_shift,
104                                          params->hidden_size,
105                                          params->batch_size,
106                                          batch_offset);
107 
108     return ARM_CMSIS_NN_SUCCESS;
109 }
110 /**
111  * @} end of supportLSTM group
112  */
113