1 /*
2 * SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates <open-source-office.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 *
6 * Licensed under the Apache License, Version 2.0 (the License); you may
7 * not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 /* ----------------------------------------------------------------------
20 * Project: CMSIS NN Library
21 * Title: arm_nn_lstm_calculate_gate_s8_s16.c
22 * Description: Update single gate for an incremental step of LSTM function.
23 *
24 * $Date: 8 September 2022
25 * $Revision: V.1.0.0
26 *
27 * Target Processor: Cortex-M cores
28 *
29 * -------------------------------------------------------------------- */
30
31 #include "arm_nn_tables.h"
32 #include "arm_nnfunctions.h"
33 #include "arm_nnsupportfunctions.h"
34
35 /**
36 * @ingroup groupSupport
37 */
38
39 /**
40 * @defgroup supportLSTM LSTM
41 *
42 * Support functions for LSTM
43 *
44 */
45
46 /**
47 * @addtogroup supportLSTM
48 * @{
49 */
50
51 /*
52 * Calculates a single LSTM gate, int8x8_16 version.
53 * Refer to header file for details
54 */
arm_nn_lstm_calculate_gate_s8_s16(const int8_t * input,const int8_t * input_to_gate_weights,const int32_t * input_to_gate_bias,const cmsis_nn_scaling input_to_gate_scaling,const int8_t * output_state,const int8_t * recurrent_to_gate_weights,const int32_t * recurrent_to_gate_bias,const cmsis_nn_scaling recurrent_to_gate,const int32_t n_batch,const int32_t n_input,const int32_t n_output,const int32_t n_cell,const arm_nn_activation_type activation_type,int16_t * gate)55 void arm_nn_lstm_calculate_gate_s8_s16(const int8_t *input,
56 const int8_t *input_to_gate_weights,
57 const int32_t *input_to_gate_bias,
58 const cmsis_nn_scaling input_to_gate_scaling,
59 const int8_t *output_state,
60 const int8_t *recurrent_to_gate_weights,
61 const int32_t *recurrent_to_gate_bias,
62 const cmsis_nn_scaling recurrent_to_gate,
63 const int32_t n_batch,
64 const int32_t n_input,
65 const int32_t n_output,
66 const int32_t n_cell,
67 const arm_nn_activation_type activation_type,
68 int16_t *gate)
69 {
70 const int32_t n_block = n_batch * n_cell;
71
72 memset(gate, 0, n_block * sizeof(int16_t));
73 arm_nn_vec_mat_mul_result_acc_s8(input,
74 input_to_gate_weights,
75 input_to_gate_bias,
76 gate,
77 0,
78 input_to_gate_scaling.multiplier,
79 input_to_gate_scaling.shift,
80 n_input,
81 n_cell,
82 n_batch);
83
84 arm_nn_vec_mat_mul_result_acc_s8(output_state,
85 recurrent_to_gate_weights,
86 recurrent_to_gate_bias,
87 gate,
88 0,
89 recurrent_to_gate.multiplier,
90 recurrent_to_gate.shift,
91 n_output,
92 n_cell,
93 n_batch);
94
95 arm_nn_activation_s16(gate, gate, n_block, 0, activation_type);
96 }
97 /**
98 * @} end of supportLSTM group
99 */
100