1 /*
2  * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_lstm_step_s8.c
22  * Description:  Update LSTM function for a single iteration step.
23  *
24  * $Date:        19 January 2024
25  * $Revision:    V.1.0.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  *
29  * -------------------------------------------------------------------- */
30 #include "arm_nnfunctions.h"
31 #include "arm_nnsupportfunctions.h"
32 /**
33  * @ingroup groupSupport
34  */
35 
36 /**
37  * @addtogroup supportLSTM
38  * @{
39  */
40 
41 /*
42  * Calculate the output state tensor of an LSTM step, s8 input/output/weights and s16 internal buffers version.
43  * Refer to header file for details.
44  */
arm_nn_lstm_step_s8(const int8_t * data_in,const int8_t * hidden_in,int8_t * hidden_out,const cmsis_nn_lstm_params * params,cmsis_nn_lstm_context * buffers,const int32_t batch_offset)45 arm_cmsis_nn_status arm_nn_lstm_step_s8(const int8_t *data_in,
46                                         const int8_t *hidden_in,
47                                         int8_t *hidden_out,
48                                         const cmsis_nn_lstm_params *params,
49                                         cmsis_nn_lstm_context *buffers,
50                                         const int32_t batch_offset)
51 {
52     int16_t *forget_gate = buffers->temp1;
53     int16_t *input_gate = buffers->temp1;
54     int16_t *cell_gate = buffers->temp2;
55     int16_t *output_gate = buffers->temp1;
56     int16_t *hidden_temp = buffers->temp2;
57 
58     int16_t *cell_state = buffers->cell_state;
59 
60     arm_nn_lstm_calculate_gate_s8_s16(data_in, hidden_in, &params->forget_gate, params, forget_gate, batch_offset);
61 
62     // Calculate first term of cell state in place early to maximise reuse of scratch-buffers
63     arm_elementwise_mul_s16(forget_gate,
64                             cell_state,
65                             0,
66                             0,
67                             cell_state,
68                             0,
69                             params->forget_to_cell_multiplier,
70                             params->forget_to_cell_shift,
71                             NN_Q15_MIN,
72                             NN_Q15_MAX,
73                             params->hidden_size * params->batch_size);
74 
75     arm_nn_lstm_calculate_gate_s8_s16(data_in, hidden_in, &params->input_gate, params, input_gate, batch_offset);
76     arm_nn_lstm_calculate_gate_s8_s16(data_in, hidden_in, &params->cell_gate, params, cell_gate, batch_offset);
77 
78     // Reminder of cell state calculation, multiply and add to previous result.
79     arm_elementwise_mul_acc_s16(forget_gate,
80                                 cell_gate,
81                                 0,
82                                 0,
83                                 cell_state,
84                                 0,
85                                 params->input_to_cell_multiplier,
86                                 params->input_to_cell_shift,
87                                 -params->cell_clip,
88                                 params->cell_clip,
89                                 params->hidden_size * params->batch_size);
90 
91     arm_nn_lstm_calculate_gate_s8_s16(data_in, hidden_in, &params->output_gate, params, output_gate, batch_offset);
92 
93     // Calculate hidden state directly to output.
94     arm_nn_activation_s16(
95         cell_state, hidden_temp, params->hidden_size * params->batch_size, params->cell_scale_power + 12, ARM_TANH);
96     arm_elementwise_mul_s16_s8(output_gate,
97                                hidden_temp,
98                                hidden_out,
99                                params->output_offset,
100                                params->output_multiplier,
101                                params->output_shift,
102                                params->hidden_size,
103                                params->batch_size,
104                                batch_offset);
105 
106     return ARM_CMSIS_NN_SUCCESS;
107 }
108 /**
109  * @} end of supportLSTM group
110  */
111