1 /*
2  * SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates <open-source-office.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nn_lstm_calculate_gate_s8_s16.c
22  * Description:  Update single gate for an incremental step of LSTM function.
23  *
24  * $Date:        8 September 2022
25  * $Revision:    V.1.0.0
26  *
27  * Target Processor:  Cortex-M cores
28  *
29  * -------------------------------------------------------------------- */
30 
31 #include "arm_nn_tables.h"
32 #include "arm_nnfunctions.h"
33 #include "arm_nnsupportfunctions.h"
34 
35 /**
36  * @ingroup groupSupport
37  */
38 
39 /**
40  * @defgroup supportLSTM LSTM
41  *
42  * Support functions for LSTM
43  *
44  */
45 
46 /**
47  * @addtogroup supportLSTM
48  * @{
49  */
50 
51 /*
52  * Calculates a single LSTM gate, int8x8_16 version.
53  * Refer to header file for details
54  */
arm_nn_lstm_calculate_gate_s8_s16(const int8_t * input,const int8_t * input_to_gate_weights,const int32_t * input_to_gate_bias,const cmsis_nn_scaling input_to_gate_scaling,const int8_t * output_state,const int8_t * recurrent_to_gate_weights,const int32_t * recurrent_to_gate_bias,const cmsis_nn_scaling recurrent_to_gate,const int32_t n_batch,const int32_t n_input,const int32_t n_output,const int32_t n_cell,const arm_nn_activation_type activation_type,int16_t * gate)55 void arm_nn_lstm_calculate_gate_s8_s16(const int8_t *input,
56                                        const int8_t *input_to_gate_weights,
57                                        const int32_t *input_to_gate_bias,
58                                        const cmsis_nn_scaling input_to_gate_scaling,
59                                        const int8_t *output_state,
60                                        const int8_t *recurrent_to_gate_weights,
61                                        const int32_t *recurrent_to_gate_bias,
62                                        const cmsis_nn_scaling recurrent_to_gate,
63                                        const int32_t n_batch,
64                                        const int32_t n_input,
65                                        const int32_t n_output,
66                                        const int32_t n_cell,
67                                        const arm_nn_activation_type activation_type,
68                                        int16_t *gate)
69 {
70     const int32_t n_block = n_batch * n_cell;
71 
72     memset(gate, 0, n_block * sizeof(int16_t));
73     arm_nn_vec_mat_mul_result_acc_s8(input,
74                                      input_to_gate_weights,
75                                      input_to_gate_bias,
76                                      gate,
77                                      0,
78                                      input_to_gate_scaling.multiplier,
79                                      input_to_gate_scaling.shift,
80                                      n_input,
81                                      n_cell,
82                                      n_batch);
83 
84     arm_nn_vec_mat_mul_result_acc_s8(output_state,
85                                      recurrent_to_gate_weights,
86                                      recurrent_to_gate_bias,
87                                      gate,
88                                      0,
89                                      recurrent_to_gate.multiplier,
90                                      recurrent_to_gate.shift,
91                                      n_output,
92                                      n_cell,
93                                      n_batch);
94 
95     arm_nn_activation_s16(gate, gate, n_block, 0, activation_type);
96 }
97 /**
98  * @} end of supportLSTM group
99  */
100