1 /*
2  * SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nnsupportfunctions.h
22  * Description:  Public header file of support functions for CMSIS NN Library
23  *
24  * $Date:        30 April 2024
25  * $Revision:    V.22.0.0
26  *
27  * Target :  Arm(R) M-Profile Architecture
28  * -------------------------------------------------------------------- */
29 
30 #ifndef ARM_NNSUPPORTFUNCTIONS_H
31 #define ARM_NNSUPPORTFUNCTIONS_H
32 
33 #include "Internal/arm_nn_compiler.h"
34 #include "arm_nn_math_types.h"
35 #include "arm_nn_types.h"
36 
37 #include <stdbool.h>
38 
39 #ifdef __cplusplus
40 extern "C" {
41 #endif
42 
43 #define USE_FAST_DW_CONV_S16_FUNCTION(dw_conv_params, filter_dims, input_dims)                                         \
44     (dw_conv_params->ch_mult == 1 && dw_conv_params->dilation.w == 1 && dw_conv_params->dilation.h == 1 &&             \
45      filter_dims->w * filter_dims->h < 512)
46 
47 #define LEFT_SHIFT(_shift) (_shift > 0 ? _shift : 0)
48 #define RIGHT_SHIFT(_shift) (_shift > 0 ? 0 : -_shift)
49 #define MASK_IF_ZERO(x) (x) == 0 ? ~0 : 0
50 #define MASK_IF_NON_ZERO(x) (x) != 0 ? ~0 : 0
51 #define SELECT_USING_MASK(mask, a, b) ((mask) & (a)) ^ (~(mask) & (b))
52 
53 #define MAX(A, B) ((A) > (B) ? (A) : (B))
54 #define MIN(A, B) ((A) < (B) ? (A) : (B))
55 #define CLAMP(x, h, l) MAX(MIN((x), (h)), (l))
56 #define REDUCE_MULTIPLIER(_mult) ((_mult < 0x7FFF0000) ? ((_mult + (1 << 15)) >> 16) : 0x7FFF)
57 
58 // Number of channels processed in a block for DW Conv with Int8 weights(MVE)
59 // Requirement: Greater than 0 & less than 128
60 // This can be fine tuned to match number of input channels for best performance.
61 // A layer with lower number of channels than CH_IN_BLOCK_MVE will result in higher
62 // scratch buffer usage and a layer with higher number of channels than CH_IN_BLOCK_MVE
63 // will result in lower scratch buffer usage.
64 #define CH_IN_BLOCK_MVE (124)
65 
66 // Number of channels processed in a block for DW Conv with Int4 weights(MVE)
67 // Requirement: See CH_IN_BLOCK_MVE.
68 // An additional requirement for this signed 4 variant is that it must be an even number.
69 #define S4_CH_IN_BLOCK_MVE (124)
70 
71 // For input of int16 when number of columns are above this limit int64 accumulation is needed
72 // to not loose precision.
73 #define MAX_COL_COUNT (512)
74 
75 /**
76  * @brief definition to pack four 8 bit values.
77  */
78 #define PACK_S8x4_32x1(v0, v1, v2, v3)                                                                                 \
79     ((((int32_t)(v0) << 0) & (int32_t)0x000000FF) | (((int32_t)(v1) << 8) & (int32_t)0x0000FF00) |                     \
80      (((int32_t)(v2) << 16) & (int32_t)0x00FF0000) | (((int32_t)(v3) << 24) & (int32_t)0xFF000000))
81 
82 /**
83  * @brief definition to pack two 16 bit values.
84  */
85 #define PACK_Q15x2_32x1(v0, v1) (((int32_t)v0 & (int32_t)0xFFFF) | ((int32_t)v1 << 16))
86 
87 /**
88  * @defgroup groupSupport Private
89  *
90  * Internal Support functions. Not intended to be called direclty by a CMSIS-NN user.
91  *
92  */
93 
94 /**
95  * @defgroup genPrivTypes Structure Types
96  * @ingroup groupSupport
97  * @brief Data structure types used by private functions.
98  * @{
99  */
100 
101 /**
102  * @brief Union for SIMD access of q31/s16/s8 types
103  */
104 union arm_nnword
105 {
106     int32_t word;
107     /**< q31 type */
108     int16_t half_words[2];
109     /**< s16 type */
110     int8_t bytes[4];
111     /**< s8 type */
112 };
113 
114 /**
115  * @brief Union for data type long long
116  */
117 struct arm_nn_double
118 {
119     uint32_t low;
120     int32_t high;
121 };
122 
123 union arm_nn_long_long
124 {
125     int64_t long_long;
126     struct arm_nn_double word;
127 };
128 
129 /**
130  * @} // end group groupPrivTypes
131  */
132 
133 /**
134  * @defgroup supportConversion Data Conversion
135  *
136  * Perform data type conversion in-between neural network operations
137  *
138  */
139 
140 /**
141  * @brief Converts the elements from a s8 vector to a s16 vector with an added offset
142  * @param[in]    src        pointer to the s8 input vector
143  * @param[out]   dst        pointer to the s16 output vector
144  * @param[in]    block_size length of the input vector
145  * @param[in]    offset     s16 offset to be added to each input vector element.
146  *
147  * \par Description:
148  *
149  * Output elements are ordered.
150  * The equation used for the conversion process is:
151  *
152  * <pre>
153  *  dst[n] = (int16_t) src[n] + offset;   0 <= n < block_size.
154  * </pre>
155  *
156  */
157 void arm_q7_to_q15_with_offset(const int8_t *src, int16_t *dst, int32_t block_size, int16_t offset);
158 
159 #if defined(ARM_MATH_DSP)
160 /**
161  * @brief Converts the elements from a s8 vector to a s16 vector with an added offset
162  * @param[in]    src        pointer to the s8 input vector
163  * @param[out]   dst        pointer to the s16 output vector
164  * @param[in]    block_size length of the input vector
165  * @param[in]    offset     s16 offset to be added to each input vector element.
166  *
167  * \par Description:
168  *
169  * No additonal ordering is done with the result that output elements are not in order.
170  * Instead of ABCD order will be ACBD.
171  * Note this is for processors with DSP extension only.
172  * The equation used for the conversion process is:
173  *
174  * <pre>
175  *  dst[n - 0] = (int16_t) src[n - 0] + offset;   0 <= n < block_size.
176  *  dst[n - 1] = (int16_t) src[n - 2] + offset;   0 <= n < block_size.
177  *  dst[n - 2] = (int16_t) src[n - 1] + offset;   0 <= n < block_size.
178  *  dst[n - 3] = (int16_t) src[n - 3] + offset;   0 <= n < block_size.
179  * </pre>
180  *
181  */
182 void arm_s8_to_s16_unordered_with_offset(const int8_t *src, int16_t *dst, int32_t block_size, int16_t offset);
183 
184 #endif
185 
186 /**
187  * @brief Get the required buffer size for optimized s8 depthwise convolution
188  *        function with constraint that in_channel equals out_channel.
189  *        This is for processors with MVE extension.
190  *        Refer to arm_depthwise_conv_s8_opt_get_buffer_size() for function argument details.
191  *
192  * @note  Intended for compilation on Host. If compiling for an Arm target, use
193  *        arm_depthwise_conv_s8_opt_get_buffer_size(). Note also this is a support function,
194  *        so not recommended to call directly even on Host.
195  *
196  */
197 int32_t arm_depthwise_conv_s8_opt_get_buffer_size_mve(const cmsis_nn_dims *input_dims,
198                                                       const cmsis_nn_dims *filter_dims);
199 
200 /**
201  * @brief Get the required buffer size for optimized s8 depthwise convolution
202  *        function with constraint that in_channel equals out_channel.
203  *        This is for processors with DSP extension.
204  *        Refer to arm_depthwise_conv_s8_opt_get_buffer_size() for function argument details.
205  *
206  * @note  Intended for compilation on Host. If compiling for an Arm target, use
207  *        arm_depthwise_conv_s8_opt_get_buffer_size(). Note also this is a support function,
208  *        so not recommended to call directly even on Host.
209  *
210  */
211 int32_t arm_depthwise_conv_s8_opt_get_buffer_size_dsp(const cmsis_nn_dims *input_dims,
212                                                       const cmsis_nn_dims *filter_dims);
213 
214 /**
215  * @brief Depthwise conv on an im2col buffer where the input channel equals output channel.
216  * @param[in]    row     pointer to row
217  * @param[in]    col     pointer to im2col buffer, always consists of 2 columns.
218  * @param[in]    num_ch   number of channels
219  * @param[in]    out_shift  pointer to per output channel requantization shift parameter.
220  * @param[in]    out_mult   pointer to per output channel requantization multiplier parameter.
221  * @param[in]    out_offset      output tensor offset.
222  * @param[in]    activation_min   minimum value to clamp the output to. Range : int8
223  * @param[in]    activation_max   maximum value to clamp the output to. Range : int8
224  * @param[in]    kernel_size   number of elements in one column.
225  * @param[in]    output_bias per output channel bias. Range : int32
226  * @param[out]   out         pointer to output
227  * @return     The function returns one of the two
228  *              1. The incremented output pointer for a successful operation or
229  *              2. NULL if implementation is not available.
230  *
231  * @details     Supported framework: TensorFlow Lite micro.
232  */
233 int8_t *arm_nn_depthwise_conv_s8_core(const int8_t *row,
234                                       const int16_t *col,
235                                       const uint16_t num_ch,
236                                       const int32_t *out_shift,
237                                       const int32_t *out_mult,
238                                       const int32_t out_offset,
239                                       const int32_t activation_min,
240                                       const int32_t activation_max,
241                                       const uint16_t kernel_size,
242                                       const int32_t *const output_bias,
243                                       int8_t *out);
244 
245 /**
246  * @brief General Matrix-multiplication function with per-channel requantization.
247  * @param[in]       input_row    pointer to row operand
248  * @param[in]       input_col    pointer to col operand
249  * @param[in]       output_ch    number of rows of input_row
250  * @param[in]       col_batches  number of column batches. Range: 1 to 4
251  * @param[in]       output_shift  pointer to per output channel requantization shift parameter.
252  * @param[in]       output_mult   pointer to per output channel requantization multiplier parameter.
253  * @param[in]       out_offset    output tensor offset.
254  * @param[in]       col_offset    input tensor(col) offset.
255  * @param[in]       row_offset    kernel offset(row). Not used.
256  * @param[in]       out_activation_min   minimum value to clamp the output to. Range : int8
257  * @param[in]       out_activation_max   maximum value to clamp the output to. Range : int8
258  * @param[in]       row_len       number of elements in each row
259  * @param[in]       bias          per output channel bias. Range : int32
260  * @param[in,out]   out           pointer to output
261  * @return     The function returns one of the two
262  *              1. The incremented output pointer for a successful operation or
263  *              2. NULL if implementation is not available.
264  *
265  * @details   Supported framework: TensorFlow Lite
266  */
267 int8_t *arm_nn_mat_mult_s8(const int8_t *input_row,
268                            const int8_t *input_col,
269                            const uint16_t output_ch,
270                            const uint16_t col_batches,
271                            const int32_t *output_shift,
272                            const int32_t *output_mult,
273                            const int32_t out_offset,
274                            const int32_t col_offset,
275                            const int32_t row_offset,
276                            const int16_t out_activation_min,
277                            const int16_t out_activation_max,
278                            const uint16_t row_len,
279                            const int32_t *const bias,
280                            int8_t *out);
281 /**
282  * @brief Matrix-multiplication function for convolution with per-channel requantization for 16 bits convolution.
283  * @param[in]       input_a     pointer to operand A
284  * @param[in]       input_b     pointer to operand B, always consists of 2 vectors.
285  * @param[in]       output_ch   number of rows of A
286  * @param[in]       out_shift   pointer to per output channel requantization shift parameter.
287  * @param[in]       out_mult    pointer to per output channel requantization multiplier parameter.
288  * @param[in]       activation_min   minimum value to clamp the output to. Range : int16
289  * @param[in]       activation_max   maximum value to clamp the output to. Range : int16
290  * @param[in]       num_col_a   number of columns of A
291  * @param[in]       bias_data   pointer to struct with bias vector. The length of this vector is equal to the number
292  *                              of output columns (or RHS input rows). The vector can be int32 or int64 indicated by a
293  *                              flag in the struct.
294  * @param[in,out]   out_0       pointer to output
295  * @return     The function returns one of the two
296  *              1. The incremented output pointer for a successful operation or
297  *              2. NULL if implementation is not available.
298  *
299  * @details   This function does the matrix multiplication of weight matrix for all output channels
300  *            with 2 columns from im2col and produces two elements/output_channel. The outputs are
301  *            clamped in the range provided by activation min and max.
302  *            Supported framework: TensorFlow Lite micro.
303  */
304 int16_t *arm_nn_mat_mult_kernel_s16(const int8_t *input_a,
305                                     const int16_t *input_b,
306                                     const int32_t output_ch,
307                                     const int32_t *out_shift,
308                                     const int32_t *out_mult,
309                                     const int32_t activation_min,
310                                     const int32_t activation_max,
311                                     const int32_t num_col_a,
312                                     const cmsis_nn_bias_data *const bias_data,
313                                     int16_t *out_0);
314 
315 /**
316  * @brief General Vector by Matrix multiplication with requantization and storage of result.
317  * @param[in]       row_elements          number of row elements
318  * @param[in]       skipped_row_elements  number of row elements skipped due to padding.
319  *                                        row_elements + skipped_row_elements = (kernel_x * kernel_y) * input_ch
320  * @param[in]       row_base_ref          pointer to row operand
321  * @param[in]       col_base_ref          pointer to col operand
322  * @param[out]      out_ch                Number of output channels
323  * @param[in]       conv_params           Pointer to convolution parameters like offsets and activation values
324  * @param[in]       quant_params          Pointer to per-channel quantization parameters
325  * @param[in]       bias                  Pointer to optional per-channel bias
326  * @param[out]      output                Pointer to output where int8 results are stored.
327  * @return     The function performs matrix(row_base_ref) multiplication with vector(col_base_ref) and
328  *             scaled result is stored in memory.
329  *
330  * @details Pseudo-code
331  *      *output = 0
332  *      sum_col = 0
333  *      for (j = 0; j < out_ch; j++)
334  *      for (i = 0; i < row_elements; i++)
335  *          *output += row_base_ref[i] * col_base_ref[i]
336  *          sum_col += col_base_ref[i]
337  *      scale sum_col using quant_params and bias
338  *      store result in 'output'
339  *
340  *
341  */
342 arm_cmsis_nn_status arm_nn_mat_mul_core_1x_s8(int32_t row_elements,
343                                               const int32_t skipped_row_elements,
344                                               const int8_t *row_base_ref,
345                                               const int8_t *col_base_ref,
346                                               const int32_t out_ch,
347                                               const cmsis_nn_conv_params *conv_params,
348                                               const cmsis_nn_per_channel_quant_params *quant_params,
349                                               const int32_t *bias,
350                                               int8_t *output);
351 
352 /**
353  * @brief General Vector by Matrix multiplication with requantization, storage of result and int4 weights packed into an
354  * int8 buffer.
355  * @param[in]       row_elements          number of row elements
356  * @param[in]       skipped_row_elements  number of row elements skipped due to padding.
357  *                                        row_elements + skipped_row_elements = (kernel_x * kernel_y) * input_ch
358  * @param[in]       row_base_ref          pointer to row operand
359  * @param[in]       col_base_ref          pointer to col operand as packed int4
360  * @param[out]      out_ch                Number of output channels
361  * @param[in]       conv_params           Pointer to convolution parameters like offsets and activation values
362  * @param[in]       quant_params          Pointer to per-channel quantization parameters
363  * @param[in]       bias                  Pointer to optional per-channel bias
364  * @param[out]      output                Pointer to output where int8 results are stored.
365  * @return     The function performs matrix(row_base_ref) multiplication with vector(col_base_ref) and
366  *             scaled result is stored in memory.
367  *
368  * @details Pseudo-code as int8 example. Int4 filter data will be unpacked.
369  *      *output = 0
370  *      sum_col = 0
371  *      for (j = 0; j < out_ch; j++)
372  *      for (i = 0; i < row_elements; i++)
373  *          *output += row_base_ref[i] * col_base_ref[i]
374  *          sum_col += col_base_ref[i]
375  *      scale sum_col using quant_params and bias
376  *      store result in 'output'
377  *
378  *
379  */
380 arm_cmsis_nn_status arm_nn_mat_mul_core_1x_s4(int32_t row_elements,
381                                               const int32_t skipped_row_elements,
382                                               const int8_t *row_base_ref,
383                                               const int8_t *col_base_ref,
384                                               const int32_t out_ch,
385                                               const cmsis_nn_conv_params *conv_params,
386                                               const cmsis_nn_per_channel_quant_params *quant_params,
387                                               const int32_t *bias,
388                                               int8_t *output);
389 
390 /**
391  * @brief Matrix-multiplication with requantization & activation function for four rows and one column
392  * @param[in]       row_elements  number of row elements
393  * @param[in]       offset        offset between rows. Can be the same as row_elements.
394  *                                For e.g, in a 1x1 conv scenario with stride as 1.
395  * @param[in]       row_base      pointer to row operand
396  * @param[in]       col_base      pointer to col operand
397  * @param[in]       out_ch        Number of output channels
398  * @param[in]       conv_params   Pointer to convolution parameters like offsets and activation values
399  * @param[in]       quant_params  Pointer to per-channel quantization parameters
400  * @param[in]       bias          Pointer to per-channel bias
401  * @param[out]      output        Pointer to output where int8 results are stored.
402  *
403  * @return     The function returns the updated output pointer or NULL if implementation is not available.
404  *
405  * @details Compliant to TFLM int8 specification. MVE implementation only
406  */
407 int8_t *arm_nn_mat_mul_core_4x_s8(const int32_t row_elements,
408                                   const int32_t offset,
409                                   const int8_t *row_base,
410                                   const int8_t *col_base,
411                                   const int32_t out_ch,
412                                   const cmsis_nn_conv_params *conv_params,
413                                   const cmsis_nn_per_channel_quant_params *quant_params,
414                                   const int32_t *bias,
415                                   int8_t *output);
416 
417 /**
418  * @brief General Matrix-multiplication function with per-channel requantization.
419  *        This function assumes:
420  *        - LHS input matrix NOT transposed (nt)
421  *        - RHS input matrix transposed (t)
422  *        - RHS is int8 packed with 2x int4
423  *        - LHS is int8
424  *
425  *  @note This operation also performs the broadcast bias addition before the requantization
426  *
427  * @param[in]  lhs                Pointer to the LHS input matrix
428  * @param[in]  rhs                Pointer to the RHS input matrix
429  * @param[in]  bias               Pointer to the bias vector. The length of this vector is equal to the number of
430  *                                output columns (or RHS input rows)
431  * @param[out] dst                Pointer to the output matrix with "m" rows and "n" columns
432  * @param[in]  dst_multipliers    Pointer to the multipliers vector needed for the per-channel requantization.
433  *                                The length of this vector is equal to the number of output columns (or RHS input
434  *                                rows)
435  * @param[in]  dst_shifts         Pointer to the shifts vector needed for the per-channel requantization. The length
436  *                                of this vector is equal to the number of output columns (or RHS input rows)
437  * @param[in]  lhs_rows           Number of LHS input rows
438  * @param[in]  rhs_rows           Number of RHS input rows
439  * @param[in]  rhs_cols           Number of LHS/RHS input columns
440  * @param[in]  lhs_offset         Offset to be applied to the LHS input value
441  * @param[in]  dst_offset         Offset to be applied the output result
442  * @param[in]  activation_min     Minimum value to clamp down the output. Range : int8
443  * @param[in]  activation_max     Maximum value to clamp up the output. Range : int8
444  * @param[in]  lhs_cols_offset    Column offset between subsequent lhs_rows
445  *
446  * @return     The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
447  *
448  */
449 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s4(const int8_t *lhs,
450                                             const int8_t *rhs,
451                                             const int32_t *bias,
452                                             int8_t *dst,
453                                             const int32_t *dst_multipliers,
454                                             const int32_t *dst_shifts,
455                                             const int32_t lhs_rows,
456                                             const int32_t rhs_rows,
457                                             const int32_t rhs_cols,
458                                             const int32_t lhs_offset,
459                                             const int32_t dst_offset,
460                                             const int32_t activation_min,
461                                             const int32_t activation_max,
462                                             const int32_t lhs_cols_offset);
463 
464 /**
465  * @brief General Matrix-multiplication function with per-channel requantization.
466  *        This function assumes:
467  *        - LHS input matrix NOT transposed (nt)
468  *        - RHS input matrix transposed (t)
469  *
470  *  @note This operation also performs the broadcast bias addition before the requantization
471  *
472  * @param[in]  lhs                Pointer to the LHS input matrix
473  * @param[in]  rhs                Pointer to the RHS input matrix
474  * @param[in]  bias               Pointer to the bias vector. The length of this vector is equal to the number of
475  *                                output columns (or RHS input rows)
476  * @param[out] dst                Pointer to the output matrix with "m" rows and "n" columns
477  * @param[in]  dst_multipliers    Pointer to the multipliers vector needed for the per-channel requantization.
478  *                                The length of this vector is equal to the number of output columns (or RHS input
479  *                                rows)
480  * @param[in]  dst_shifts         Pointer to the shifts vector needed for the per-channel requantization. The length
481  *                                of this vector is equal to the number of output columns (or RHS input rows)
482  * @param[in]  lhs_rows           Number of LHS input rows
483  * @param[in]  rhs_rows           Number of RHS input rows
484  * @param[in]  rhs_cols           Number of LHS/RHS input columns
485  * @param[in]  lhs_offset         Offset to be applied to the LHS input value
486  * @param[in]  dst_offset         Offset to be applied the output result
487  * @param[in]  activation_min     Minimum value to clamp down the output. Range : int8
488  * @param[in]  activation_max     Maximum value to clamp up the output. Range : int8
489  * @param[in]  row_address_offset Address offset between rows in output. NOTE: Only used for MVEI extension.
490  * @param[in]  lhs_cols_offset    Column offset between subsequent lhs_rows
491  *
492  * @return     The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
493  *
494  */
495 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8(const int8_t *lhs,
496                                             const int8_t *rhs,
497                                             const int32_t *bias,
498                                             int8_t *dst,
499                                             const int32_t *dst_multipliers,
500                                             const int32_t *dst_shifts,
501                                             const int32_t lhs_rows,
502                                             const int32_t rhs_rows,
503                                             const int32_t rhs_cols,
504                                             const int32_t lhs_offset,
505                                             const int32_t dst_offset,
506                                             const int32_t activation_min,
507                                             const int32_t activation_max,
508                                             const int32_t row_address_offset,
509                                             const int32_t lhs_cols_offset);
510 
511 /**
512  * @brief General Matrix-multiplication function with per-channel requantization and int16 input (LHS) and output.
513  *        This function assumes:
514  *        - LHS input matrix NOT transposed (nt)
515  *        - RHS input matrix transposed (t)
516  *
517  *  @note This operation also performs the broadcast bias addition before the requantization
518  *
519  * @param[in]  lhs                Pointer to the LHS input matrix
520  * @param[in]  rhs                Pointer to the RHS input matrix
521  * @param[in]  bias_data          Pointer to struct with bias vector. The length of this vector is equal to the number
522  *                                of output columns (or RHS input rows). The vector can be int32 or int64 indicated by a
523  *                                flag in the struct.
524  * @param[out] dst                Pointer to the output matrix with "m" rows and "n" columns
525  * @param[in]  dst_multipliers    Pointer to the multipliers vector needed for the per-channel requantization.
526  *                                The length of this vector is equal to the number of output columns (or RHS input
527  *                                rows)
528  * @param[in]  dst_shifts         Pointer to the shifts vector needed for the per-channel requantization. The length
529  *                                of this vector is equal to the number of output columns (or RHS input rows)
530  * @param[in]  lhs_rows           Number of LHS input rows
531  * @param[in]  rhs_rows           Number of RHS input rows
532  * @param[in]  rhs_cols           Number of LHS/RHS input columns
533  * @param[in]  activation_min     Minimum value to clamp down the output. Range : int16
534  * @param[in]  activation_max     Maximum value to clamp up the output. Range : int16
535  *
536  * @details MVE implementation only.
537  *
538  * @return     The function returns <code>ARM_CMSIS_NN_SUCCESS</code> or
539  *                                  <code>ARM_CMSIS_NN_NO_IMPL_ERROR</code> if not for MVE
540  *
541  */
542 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s16(const int16_t *lhs,
543                                              const int8_t *rhs,
544                                              const cmsis_nn_bias_data *bias_data,
545                                              int16_t *dst,
546                                              const int32_t *dst_multipliers,
547                                              const int32_t *dst_shifts,
548                                              const int32_t lhs_rows,
549                                              const int32_t rhs_rows,
550                                              const int32_t rhs_cols,
551                                              const int32_t activation_min,
552                                              const int32_t activation_max);
553 
554 /**
555  * @brief General Matrix-multiplication function with int8 input and int32 output.
556  *        This function assumes:
557  *        - LHS input matrix NOT transposed (nt)
558  *        - RHS input matrix transposed (t)
559  *
560  * @note  Dst/output buffer must be zeroed out before calling this function.
561  *
562  * @param[in]  lhs                Pointer to the LHS input matrix
563  * @param[in]  rhs                Pointer to the RHS input matrix
564  * @param[out] dst                Pointer to the output matrix with "m" rows and "n" columns
565  * @param[in]  lhs_rows           Number of LHS input rows
566  * @param[in]  rhs_rows           Number of LHS input columns/RHS input rows
567  * @param[in]  rhs_cols           Number of RHS input columns
568  * @param[in]  lhs_offset         Offset to be applied to the LHS input value
569  * @param[in]  dst_idx_offset     Offset between subsequent output results
570  *
571  * @return     The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
572  *
573  */
574 arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s8_s32(const int8_t *lhs,
575                                                 const int8_t *rhs,
576                                                 int32_t *dst,
577                                                 const int32_t lhs_rows,
578                                                 const int32_t rhs_rows,
579                                                 const int32_t rhs_cols,
580                                                 const int32_t lhs_offset,
581                                                 const int32_t dst_idx_offset);
582 
583 /**
584  * @brief s4 Vector by Matrix (transposed) multiplication
585  *
586  * @param[in]      lhs             Input left-hand side vector
587  * @param[in]      packed_rhs      Input right-hand side matrix (transposed)
588  * @param[in]      bias            Input bias
589  * @param[out]     dst             Output vector
590  * @param[in]      lhs_offset      Offset to be added to the input values of the left-hand side vector.
591  *                                 Range: -127 to 128
592  * @param[in]      dst_offset      Offset to be added to the output values. Range: -127 to 128
593  * @param[in]      dst_multiplier  Output multiplier
594  * @param[in]      dst_shift       Output shift
595  * @param[in]      rhs_cols        Number of columns in the right-hand side input matrix
596  * @param[in]      rhs_rows        Number of rows in the right-hand side input matrix
597  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int8
598  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
599  *
600  * @return         The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
601  *
602  */
603 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s4(const int8_t *lhs,
604                                              const int8_t *packed_rhs,
605                                              const int32_t *bias,
606                                              int8_t *dst,
607                                              const int32_t lhs_offset,
608                                              const int32_t dst_offset,
609                                              const int32_t dst_multiplier,
610                                              const int32_t dst_shift,
611                                              const int32_t rhs_cols,
612                                              const int32_t rhs_rows,
613                                              const int32_t activation_min,
614                                              const int32_t activation_max);
615 
616 /**
617  * @brief s8 Vector by Matrix (transposed) multiplication
618  *
619  * @param[in]      lhs             Input left-hand side vector
620  * @param[in]      rhs             Input right-hand side matrix (transposed)
621  * @param[in]      kernel_sum      Kernel sums of the kernels (rhs). See arm_vector_sum_s8 for more info.
622  * @param[in]      bias            Input bias
623  * @param[out]     dst             Output vector
624  * @param[in]      lhs_offset      Offset to be added to the input values of the left-hand side vector.
625  *                                 Range: -127 to 128
626  * @param[in]      dst_offset      Offset to be added to the output values. Range: -127 to 128
627  * @param[in]      dst_multiplier  Output multiplier
628  * @param[in]      dst_shift       Output shift
629  * @param[in]      rhs_cols        Number of columns in the right-hand side input matrix
630  * @param[in]      rhs_rows        Number of rows in the right-hand side input matrix
631  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int8
632  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
633  * @param[in]      address_offset  Memory position offset for dst. First output is stored at 'dst', the
634  *                                 second at 'dst + address_offset' and so on. Default value is typically 1.
635  * @param[in]      rhs_offset      Offset to be added to the input values of the right-hand side vector.
636  *                                 Range: -127 to 128
637  *
638  * @return         The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
639  *
640  */
641 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s8(const int8_t *lhs,
642                                              const int8_t *rhs,
643                                              const int32_t *kernel_sum,
644                                              const int32_t *bias,
645                                              int8_t *dst,
646                                              const int32_t lhs_offset,
647                                              const int32_t dst_offset,
648                                              const int32_t dst_multiplier,
649                                              const int32_t dst_shift,
650                                              const int32_t rhs_cols,
651                                              const int32_t rhs_rows,
652                                              const int32_t activation_min,
653                                              const int32_t activation_max,
654                                              const int32_t address_offset,
655                                              const int32_t rhs_offset);
656 
657 /**
658  * @brief s16 Vector by Matrix (transposed) multiplication
659  *
660  * @param[in]      lhs             Input left-hand side vector
661  * @param[in]      rhs             Input right-hand side matrix (transposed)
662  * @param[in]      bias            Input bias
663  * @param[out]     dst             Output vector
664  * @param[in]      dst_multiplier  Output multiplier
665  * @param[in]      dst_shift       Output shift
666  * @param[in]      rhs_cols        Number of columns in the right-hand side input matrix
667  * @param[in]      rhs_rows        Number of rows in the right-hand side input matrix
668  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int16
669  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int16
670  *
671  * @return         The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
672  *
673  */
674 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_s16(const int16_t *lhs,
675                                               const int8_t *rhs,
676                                               const int64_t *bias,
677                                               int16_t *dst,
678                                               const int32_t dst_multiplier,
679                                               const int32_t dst_shift,
680                                               const int32_t rhs_cols,
681                                               const int32_t rhs_rows,
682                                               const int32_t activation_min,
683                                               const int32_t activation_max);
684 
685 /**
686  * @brief s8 Vector by Matrix (transposed) multiplication with s16 output
687  *
688  * @param[in]      lhs             Input left-hand side vector
689  * @param[in]      rhs             Input right-hand side matrix (transposed)
690  * @param[out]     dst             Output vector
691  * @param[in]      lhs_offset      Offset to be added to the input values of the left-hand side
692  *                                 vector. Range: -127 to 128
693  * @param[in]      scatter_offset  Address offset for dst. First output is stored at 'dst', the
694  *                                 second at 'dst + scatter_offset' and so on.
695  * @param[in]      dst_multiplier  Output multiplier
696  * @param[in]      dst_shift       Output shift
697  * @param[in]      rhs_cols        Number of columns in the right-hand side input matrix
698  * @param[in]      rhs_rows        Number of rows in the right-hand side input matrix
699  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int16
700  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int16
701  *
702  * @return         The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
703  *
704  */
705 arm_cmsis_nn_status arm_nn_vec_mat_mult_t_svdf_s8(const int8_t *lhs,
706                                                   const int8_t *rhs,
707                                                   int16_t *dst,
708                                                   const int32_t lhs_offset,
709                                                   const int32_t scatter_offset,
710                                                   const int32_t dst_multiplier,
711                                                   const int32_t dst_shift,
712                                                   const int32_t rhs_cols,
713                                                   const int32_t rhs_rows,
714                                                   const int32_t activation_min,
715                                                   const int32_t activation_max);
716 
717 /**
718  * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in padded cases where
719  *        the padding is -lhs_offset(Range: int8). Dimensions are the same for lhs and rhs.
720  *
721  * @param[in]      lhs             Input left-hand side matrix
722  * @param[in]      rhs             Input right-hand side matrix (transposed)
723  * @param[in]      lhs_offset      LHS matrix offset(input offset). Range: -127 to 128
724  * @param[in]      active_ch       Subset of total_ch processed
725  * @param[in]      total_ch        Number of channels in LHS/RHS
726  * @param[in]      out_shift       Per channel output shift. Length of vector is equal to number of channels
727  * @param[in]      out_mult        Per channel output multiplier. Length of vector is equal to number of channels
728  * @param[in]      out_offset      Offset to be added to the output values. Range: -127 to 128
729  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int8
730  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
731  * @param[in]       row_x_col       (row_dimension * col_dimension) of LHS/RHS matrix
732  * @param[in]      output_bias     Per channel output bias. Length of vector is equal to number of channels
733  * @param[in]      out             Output pointer
734  *
735  * @return         The function returns one of the two
736  *                  - Updated output pointer if an implementation is available
737  *                  - NULL if no implementation is available.
738  *
739  * @note           If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read
740  * out for the following.
741  *                  - Output shift
742  *                  - Output multiplier
743  *                  - Output bias
744  *                  - rhs
745  */
746 arm_cmsis_nn_status arm_nn_depthwise_conv_nt_t_padded_s8(const int8_t *lhs,
747                                                          const int8_t *rhs,
748                                                          const int32_t lhs_offset,
749                                                          const int32_t active_ch,
750                                                          const int32_t total_ch,
751                                                          const int32_t *out_shift,
752                                                          const int32_t *out_mult,
753                                                          const int32_t out_offset,
754                                                          const int32_t activation_min,
755                                                          const int32_t activation_max,
756                                                          const uint16_t row_x_col,
757                                                          const int32_t *const output_bias,
758                                                          int8_t *out);
759 
760 /**
761  * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in non-padded cases.
762  *        Dimensions are the same for lhs and rhs.
763  *
764  * @param[in]      lhs             Input left-hand side matrix
765  * @param[in]      rhs             Input right-hand side matrix (transposed)
766  * @param[in]      lhs_offset      LHS matrix offset(input offset). Range: -127 to 128
767  * @param[in]      active_ch       Subset of total_ch processed
768  * @param[in]      total_ch        Number of channels in LHS/RHS
769  * @param[in]      out_shift       Per channel output shift. Length of vector is equal to number of channels.
770  * @param[in]      out_mult        Per channel output multiplier. Length of vector is equal to number of channels.
771  * @param[in]      out_offset      Offset to be added to the output values. Range: -127 to 128
772  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int8
773  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
774  * @param[in]       row_x_col       (row_dimension * col_dimension) of LHS/RHS matrix
775  * @param[in]      output_bias     Per channel output bias. Length of vector is equal to number of channels.
776  * @param[in]      out             Output pointer
777  *
778  * @return         The function returns one of the two
779  *                  - Updated output pointer if an implementation is available
780  *                  - NULL if no implementation is available.
781  *
782  * @note           If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read
783  * out for the following.
784  *                  - Output shift
785  *                  - Output multiplier
786  *                  - Output bias
787  *                  - rhs
788  */
789 arm_cmsis_nn_status arm_nn_depthwise_conv_nt_t_s8(const int8_t *lhs,
790                                                   const int8_t *rhs,
791                                                   const int32_t lhs_offset,
792                                                   const int32_t active_ch,
793                                                   const int32_t total_ch,
794                                                   const int32_t *out_shift,
795                                                   const int32_t *out_mult,
796                                                   const int32_t out_offset,
797                                                   const int32_t activation_min,
798                                                   const int32_t activation_max,
799                                                   const uint16_t row_x_col,
800                                                   const int32_t *const output_bias,
801                                                   int8_t *out);
802 
803 /**
804  * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in non-padded cases. rhs
805  * consists of packed int4 data. Dimensions are the same for lhs and rhs.
806  *
807  * @param[in]      lhs             Input left-hand side matrix
808  * @param[in]      rhs             Input right-hand side matrix (transposed). Consists of int4 data packed in an int8
809  * buffer.
810  * @param[in]      lhs_offset      LHS matrix offset(input offset). Range: -127 to 128
811  * @param[in]      active_ch       Subset of total_ch processed
812  * @param[in]      total_ch        Number of channels in LHS/RHS
813  * @param[in]      out_shift       Per channel output shift. Length of vector is equal to number of channels.
814  * @param[in]      out_mult        Per channel output multiplier. Length of vector is equal to number of channels.
815  * @param[in]      out_offset      Offset to be added to the output values. Range: -127 to 128
816  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int8
817  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
818  * @param[in]       row_x_col       (row_dimension * col_dimension) of LHS/RHS matrix
819  * @param[in]      output_bias     Per channel output bias. Length of vector is equal to number of channels.
820  * @param[in]      out             Output pointer
821  *
822  * @return         The function returns one of the two
823  *                  - Updated output pointer if an implementation is available
824  *                  - NULL if no implementation is available.
825  *
826  * @note           If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read
827  * out for the following.
828  *                  - Output shift
829  *                  - Output multiplier
830  *                  - Output bias
831  *                  - rhs
832  */
833 arm_cmsis_nn_status arm_nn_depthwise_conv_nt_t_s4(const int8_t *lhs,
834                                                   const int8_t *rhs,
835                                                   const int32_t lhs_offset,
836                                                   const int32_t active_ch,
837                                                   const int32_t total_ch,
838                                                   const int32_t *out_shift,
839                                                   const int32_t *out_mult,
840                                                   const int32_t out_offset,
841                                                   const int32_t activation_min,
842                                                   const int32_t activation_max,
843                                                   const uint16_t row_x_col,
844                                                   const int32_t *const output_bias,
845                                                   int8_t *out);
846 
847 /**
848  * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in non-padded cases.
849  *        Dimensions are the same for lhs and rhs.
850  *
851  * @param[in]      lhs             Input left-hand side matrix
852  * @param[in]      rhs             Input right-hand side matrix (transposed)
853  * @param[in]      num_ch          Number of channels in LHS/RHS
854  * @param[in]      out_shift       Per channel output shift. Length of vector is equal to number of channels.
855  * @param[in]      out_mult        Per channel output multiplier. Length of vector is equal to number of channels.
856  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int8
857  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
858  * @param[in]       row_x_col       (row_dimension * col_dimension) of LHS/RHS matrix
859  * @param[in]      output_bias     Per channel output bias. Length of vector is equal to number of channels.
860  * @param[in]      out             Output pointer
861  *
862  * @return         The function returns one of the two
863  *                  - Updated output pointer if an implementation is available
864  *                  - NULL if no implementation is available.
865  *
866  * @note           If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read
867  * out for the following.
868  *                  - Output shift
869  *                  - Output multiplier
870  *                  - Output bias
871  *                  - rhs
872  */
873 int16_t *arm_nn_depthwise_conv_nt_t_s16(const int16_t *lhs,
874                                         const int8_t *rhs,
875                                         const uint16_t num_ch,
876                                         const int32_t *out_shift,
877                                         const int32_t *out_mult,
878                                         const int32_t activation_min,
879                                         const int32_t activation_max,
880                                         const uint16_t row_x_col,
881                                         const int64_t *const output_bias,
882                                         int16_t *out);
883 
884 /**
885   @brief         Read 2 s16 elements and post increment pointer.
886   @param[in]     in_q15   Pointer to pointer that holds address of input.
887   @return        q31 value
888  */
arm_nn_read_q15x2_ia(const int16_t ** in_q15)889 __STATIC_FORCEINLINE int32_t arm_nn_read_q15x2_ia(const int16_t **in_q15)
890 {
891     int32_t val;
892 
893     memcpy(&val, *in_q15, 4);
894     *in_q15 += 2;
895 
896     return (val);
897 }
898 
899 /**
900   @brief         Read 4 s8 from s8 pointer and post increment pointer.
901   @param[in]     in_s8       Pointer to pointer that holds address of input.
902   @return        q31 value
903  */
arm_nn_read_s8x4_ia(const int8_t ** in_s8)904 __STATIC_FORCEINLINE int32_t arm_nn_read_s8x4_ia(const int8_t **in_s8)
905 {
906     int32_t val;
907     memcpy(&val, *in_s8, 4);
908     *in_s8 += 4;
909 
910     return (val);
911 }
912 
913 /**
914   @brief         Read 2 s8 from s8 pointer and post increment pointer.
915   @param[in]     in_s8    Pointer to pointer that holds address of input.
916   @return        q31      value
917  */
arm_nn_read_s8x2_ia(const int8_t ** in_s8)918 __STATIC_FORCEINLINE int32_t arm_nn_read_s8x2_ia(const int8_t **in_s8)
919 {
920     int32_t val;
921     memcpy(&val, *in_s8, 2);
922     *in_s8 += 2;
923 
924     return (val);
925 }
926 
927 /**
928   @brief         Read 2 int16 values from int16 pointer.
929   @param[in]     in     pointer to address of input.
930   @return        s32    value
931  */
arm_nn_read_s16x2(const int16_t * in)932 __STATIC_FORCEINLINE int32_t arm_nn_read_s16x2(const int16_t *in)
933 {
934     int32_t val;
935     memcpy(&val, in, 4);
936 
937     return (val);
938 }
939 
940 /**
941   @brief         Read 4 s8 values.
942   @param[in]     in_s8       pointer to address of input.
943   @return        s32 value
944  */
arm_nn_read_s8x4(const int8_t * in_s8)945 __STATIC_FORCEINLINE int32_t arm_nn_read_s8x4(const int8_t *in_s8)
946 {
947     int32_t val;
948     memcpy(&val, in_s8, 4);
949 
950     return (val);
951 }
952 /**
953   @brief         Read 2 s8 values.
954   @param[in]     in_s8    pointer to address of input.
955   @return        s32      value
956  */
arm_nn_read_s8x2(const int8_t * in_s8)957 __STATIC_FORCEINLINE int32_t arm_nn_read_s8x2(const int8_t *in_s8)
958 {
959     int32_t val;
960     memcpy(&val, in_s8, 2);
961 
962     return (val);
963 }
964 
965 /**
966   @brief         Write four s8 to s8 pointer and increment pointer afterwards.
967   @param[in]     in       Double pointer to input value
968   @param[in]     value    Four bytes to copy
969  */
arm_nn_write_s8x4_ia(int8_t ** in,int32_t value)970 __STATIC_FORCEINLINE void arm_nn_write_s8x4_ia(int8_t **in, int32_t value)
971 {
972     memcpy(*in, &value, 4);
973     *in += 4;
974 }
975 
976 /**
977  * @brief           memset optimized for MVE
978  * @param[in, out]  dst         Destination pointer
979  * @param[in]       val         Value to set
980  * @param[in]       block_size  Number of bytes to copy.
981  *
982  */
arm_memset_s8(int8_t * dst,const int8_t val,uint32_t block_size)983 __STATIC_FORCEINLINE void arm_memset_s8(int8_t *dst, const int8_t val, uint32_t block_size)
984 {
985 #if defined(ARM_MATH_MVEI)
986     __asm volatile("   vdup.8                  q0, %[set_val]             \n"
987                    "   wlstp.8                 lr, %[cnt], 1f             \n"
988                    "2:                                                    \n"
989                    "   vstrb.8                 q0, [%[in]], #16            \n"
990                    "   letp                    lr, 2b                     \n"
991                    "1:                                                    \n"
992                    : [in] "+r"(dst)
993                    : [cnt] "r"(block_size), [set_val] "r"(val)
994                    : "q0", "memory", "r14");
995 #else
996     memset(dst, val, block_size);
997 #endif
998 }
999 
1000 #if defined(ARM_MATH_DSP)
1001 
1002 /**
1003  * @brief read and expand one s4 word into two s8 words.
1004  */
read_and_pad_s4(const int8_t * source,int32_t * out1,int32_t * out2)1005 __STATIC_FORCEINLINE void read_and_pad_s4(const int8_t *source, int32_t *out1, int32_t *out2)
1006 {
1007     int16_t in = arm_nn_read_s8x2(source);
1008     int32_t inA = (in & 0x00FF) | ((in & 0xFF00) << 8);
1009 
1010     *out1 = SXTB16_RORn(__sxtb16(inA << 4), 4);
1011     *out2 = SXTB16_RORn(__sxtb16(inA), 4);
1012 }
1013 
1014 /**
1015  * @brief read and expand one s4 word into two s8 words.
1016  * @details   The s4 elements are not evenly aligned on the byte boundary, so 3 bytes need to be read instead of 2.
1017  *            In other words first nibble to read start at the middle of a byte.
1018  *            byte index, s4 element
1019  *            0,          s4_x
1020  *            0,          s4_0
1021  *            1,          s4_1
1022  *            1,          s4_2
1023  *            2,          s4_3
1024  *            2,          s4_x
1025  */
read_and_pad_s4_uneven(const int8_t * source,int32_t * out1,int32_t * out2)1026 __STATIC_FORCEINLINE void read_and_pad_s4_uneven(const int8_t *source, int32_t *out1, int32_t *out2)
1027 {
1028     int32_t inA1 = (source[0] & 0xFF) | ((source[1] & 0xFF) << 16);
1029     int32_t inA2 = (source[1] & 0xFF) | ((source[2] & 0xFF) << 16);
1030 
1031     *out1 = SXTB16_RORn(__sxtb16(inA2 << 4), 4);
1032     *out2 = SXTB16_RORn(__sxtb16(inA1), 4);
1033 }
1034 
1035 /**
1036  * @brief read and expand one s4 word into two s16 words with ordering.
1037  */
read_and_pad_s4_ordered(const int8_t * source,int32_t * out1,int32_t * out2)1038 __STATIC_FORCEINLINE void read_and_pad_s4_ordered(const int8_t *source, int32_t *out1, int32_t *out2)
1039 {
1040     int16_t in = arm_nn_read_s8x2(source);
1041     int32_t inA = (in & 0x00FF) | ((in & 0xFF00) << 8);
1042     int32_t inAbuf1 = SXTB16_RORn(__sxtb16(inA), 4);
1043     int32_t inAbuf2 = SXTB16_RORn(__sxtb16(inA << 4), 4);
1044     #ifndef ARM_MATH_BIG_ENDIAN
1045     *out2 = (int32_t)(PKHTB(inAbuf1, inAbuf2, 16));
1046     *out1 = (int32_t)(PKHBT(inAbuf2, inAbuf1, 16));
1047     #else
1048     *out1 = (int32_t)(PKHTB(inAbuf1, inAbuf2, 16));
1049     *out2 = (int32_t)(PKHBT(inAbuf2, inAbuf1, 16));
1050     #endif
1051 }
1052 
1053 /**
1054  * @brief read and expand one s8 word into two s16 words with ordering.
1055  */
read_and_pad(const int8_t * source,int32_t * out1,int32_t * out2)1056 __STATIC_FORCEINLINE const int8_t *read_and_pad(const int8_t *source, int32_t *out1, int32_t *out2)
1057 {
1058     int32_t inA = arm_nn_read_s8x4_ia(&source);
1059     int32_t inAbuf1 = SXTB16_RORn((uint32_t)inA, 8);
1060     int32_t inAbuf2 = SXTB16(inA);
1061 
1062     #ifndef ARM_MATH_BIG_ENDIAN
1063     *out2 = (int32_t)(PKHTB(inAbuf1, inAbuf2, 16));
1064     *out1 = (int32_t)(PKHBT(inAbuf2, inAbuf1, 16));
1065     #else
1066     *out1 = (int32_t)(PKHTB(inAbuf1, inAbuf2, 16));
1067     *out2 = (int32_t)(PKHBT(inAbuf2, inAbuf1, 16));
1068     #endif
1069 
1070     return source;
1071 }
1072 
1073 /**
1074  * @brief read and expand one s8 word into two s16 words with ordering and addition.
1075  */
read_pad_and_add_s8(const int8_t * source,int32_t * out1,int32_t * out2,const uint32_t add)1076 __STATIC_FORCEINLINE void read_pad_and_add_s8(const int8_t *source, int32_t *out1, int32_t *out2, const uint32_t add)
1077 {
1078     int32_t inA = arm_nn_read_s8x4(source);
1079     int32_t inAbuf1 = SXTAB16_RORn(add, (uint32_t)inA, 8);
1080     int32_t inAbuf2 = SXTAB16(add, inA);
1081 
1082     #ifndef ARM_MATH_BIG_ENDIAN
1083     *out2 = (int32_t)(PKHTB(inAbuf1, inAbuf2, 16));
1084     *out1 = (int32_t)(PKHBT(inAbuf2, inAbuf1, 16));
1085     #else
1086     *out1 = (int32_t)(PKHTB(inAbuf1, inAbuf2, 16));
1087     *out2 = (int32_t)(PKHBT(inAbuf2, inAbuf1, 16));
1088     #endif
1089 }
1090 
1091 /**
1092  * @brief read and expand two bytes into one word with ordering.
1093  */
read_and_pad_s8x2(const int8_t * source,int32_t * out)1094 __STATIC_FORCEINLINE void read_and_pad_s8x2(const int8_t *source, int32_t *out)
1095 {
1096     int16_t in = arm_nn_read_s8x2(source);
1097     int32_t inA = (in & 0x00FF) | ((in & 0xFF00) << 8);
1098     *out = SXTB16(inA);
1099 }
1100 
1101 /**
1102  * @brief read and expand two bytes into one word with ordering and addition.
1103  */
read_pad_and_add_s8x2(const int8_t * source,int32_t * out,const uint32_t add)1104 __STATIC_FORCEINLINE void read_pad_and_add_s8x2(const int8_t *source, int32_t *out, const uint32_t add)
1105 {
1106     int16_t in = arm_nn_read_s8x2(source);
1107     int32_t inA = (in & 0x00FF) | ((in & 0xFF00) << 8);
1108     *out = SXTAB16(add, inA);
1109 }
1110 
1111 /**
1112  * @brief read and expand one s8 word into two s16 words with no additional ordering.
1113  */
read_and_pad_reordered(const int8_t * source,int32_t * out1,int32_t * out2)1114 __STATIC_FORCEINLINE const int8_t *read_and_pad_reordered(const int8_t *source, int32_t *out1, int32_t *out2)
1115 {
1116     int32_t inA = arm_nn_read_s8x4_ia(&source);
1117     #ifndef ARM_MATH_BIG_ENDIAN
1118     *out2 = SXTB16(ROR((uint32_t)inA, 8));
1119     *out1 = SXTB16(inA);
1120     #else
1121     *out1 = SXTB16(ROR((uint32_t)inA, 8));
1122     *out2 = SXTB16(inA);
1123     #endif
1124 
1125     return source;
1126 }
1127 
1128 #endif
1129 
1130 /**
1131  * @brief Matrix-multiplication function for convolution with per-channel requantization and 4 bit weights.
1132  * @param[in]       input_a            pointer to operand A, int8 packed with 2x int4.
1133  * @param[in]       input_b            pointer to operand B, always consists of 2 vectors.
1134  * @param[in]       output_ch          number of rows of A
1135  * @param[in]       out_shift          pointer to per output channel requantization shift parameter.
1136  * @param[in]       out_mult           pointer to per output channel requantization multiplier parameter.
1137  * @param[in]       out_offset         output tensor offset.
1138  * @param[in]       activation_min     minimum value to clamp the output to. Range : int8
1139  * @param[in]       activation_max     maximum value to clamp the output to. Range : int8
1140  * @param[in]       num_col_a          number of columns of A
1141  * @param[in]       output_bias        per output channel bias. Range : int32
1142  * @param[in,out]   out_0              pointer to output
1143  * @return     The function returns one of the two
1144  *              1. The incremented output pointer for a successful operation or
1145  *              2. NULL if implementation is not available.
1146  *
1147  * @details   This function does the matrix multiplication of weight matrix for all output channels
1148  *            with 2 columns from im2col and produces two elements/output_channel. The outputs are
1149  *            clamped in the range provided by activation min and max.
1150  *            Supported framework: TensorFlow Lite micro.
1151  */
1152 int8_t *arm_nn_mat_mult_kernel_s4_s16(const int8_t *input_a,
1153                                       const int16_t *input_b,
1154                                       const uint16_t output_ch,
1155                                       const int32_t *out_shift,
1156                                       const int32_t *out_mult,
1157                                       const int32_t out_offset,
1158                                       const int32_t activation_min,
1159                                       const int32_t activation_max,
1160                                       const int32_t num_col_a,
1161                                       const int32_t *const output_bias,
1162                                       int8_t *out_0);
1163 /**
1164  * @brief Matrix-multiplication function for convolution with per-channel requantization.
1165  * @param[in]       input_a            pointer to operand A
1166  * @param[in]       input_b            pointer to operand B, always consists of 2 vectors.
1167  * @param[in]       output_ch          number of rows of A
1168  * @param[in]       out_shift          pointer to per output channel requantization shift parameter.
1169  * @param[in]       out_mult           pointer to per output channel requantization multiplier parameter.
1170  * @param[in]       out_offset         output tensor offset.
1171  * @param[in]       activation_min     minimum value to clamp the output to. Range : int8
1172  * @param[in]       activation_max     maximum value to clamp the output to. Range : int8
1173  * @param[in]       num_col_a          number of columns of A
1174  * @param[in]       aligned_num_col_a  number of columns of A aligned by 4
1175  * @param[in]       output_bias        per output channel bias. Range : int32
1176  * @param[in,out]   out_0              pointer to output
1177  * @return     The function returns one of the two
1178  *              1. The incremented output pointer for a successful operation or
1179  *              2. NULL if implementation is not available.
1180  *
1181  * @details   This function does the matrix multiplication of weight matrix for all output channels
1182  *            with 2 columns from im2col and produces two elements/output_channel. The outputs are
1183  *            clamped in the range provided by activation min and max.
1184  *            Supported framework: TensorFlow Lite micro.
1185  */
1186 int8_t *arm_nn_mat_mult_kernel_s8_s16(const int8_t *input_a,
1187                                       const int16_t *input_b,
1188                                       const uint16_t output_ch,
1189                                       const int32_t *out_shift,
1190                                       const int32_t *out_mult,
1191                                       const int32_t out_offset,
1192                                       const int16_t activation_min,
1193                                       const int16_t activation_max,
1194                                       const int32_t num_col_a,
1195                                       const int32_t aligned_num_col_a,
1196                                       const int32_t *const output_bias,
1197                                       int8_t *out_0);
1198 
1199 /**
1200  * @brief Matrix-multiplication function for convolution with per-channel requantization, supporting an address offset
1201  * between rows.
1202  * @param[in]       input_a            pointer to operand A
1203  * @param[in]       input_b            pointer to operand B, always consists of 2 vectors.
1204  * @param[in]       output_ch          number of rows of A
1205  * @param[in]       out_shift          pointer to per output channel requantization shift parameter.
1206  * @param[in]       out_mult           pointer to per output channel requantization multiplier parameter.
1207  * @param[in]       out_offset         output tensor offset.
1208  * @param[in]       activation_min     minimum value to clamp the output to. Range : int8
1209  * @param[in]       activation_max     maximum value to clamp the output to. Range : int8
1210  * @param[in]       num_col_a          number of columns of A
1211  * @param[in]       aligned_num_col_a  number of columns of A aligned by 4
1212  * @param[in]       output_bias        per output channel bias. Range : int32
1213  * @param[in]       row_address_offset address offset between rows in the output
1214  * @param[in,out]   out_0              pointer to output
1215  * @return     The function returns one of the two
1216  *              1. The incremented output pointer for a successful operation or
1217  *              2. NULL if implementation is not available.
1218  *
1219  * @details   This function does the matrix multiplication of weight matrix for all output channels
1220  *            with 2 columns from im2col and produces two elements/output_channel. The outputs are
1221  *            clamped in the range provided by activation min and max.
1222  *
1223  *            This function is slighly less performant than arm_nn_mat_mult_kernel_s8_s16, but allows support for
1224  * grouped convolution. Supported framework: TensorFlow Lite micro.
1225  */
1226 int8_t *arm_nn_mat_mult_kernel_row_offset_s8_s16(const int8_t *input_a,
1227                                                  const int16_t *input_b,
1228                                                  const uint16_t output_ch,
1229                                                  const int32_t *out_shift,
1230                                                  const int32_t *out_mult,
1231                                                  const int32_t out_offset,
1232                                                  const int16_t activation_min,
1233                                                  const int16_t activation_max,
1234                                                  const int32_t num_col_a,
1235                                                  const int32_t aligned_num_col_a,
1236                                                  const int32_t *const output_bias,
1237                                                  const int32_t row_address_offset,
1238                                                  int8_t *out_0);
1239 
1240 /**
1241  * @brief Common softmax function for s8 input and s8 or s16 output
1242  * @param[in]  input          Pointer to the input tensor
1243  * @param[in]  num_rows       Number of rows in the input tensor
1244  * @param[in]  row_size       Number of elements in each input row
1245  * @param[in]  mult           Input quantization multiplier
1246  * @param[in]  shift          Input quantization shift within the range [0, 31]
1247  * @param[in]  diff_min       Minimum difference with max in row. Used to check if
1248  *                            the quantized exponential operation can be performed
1249  * @param[in]  int16_output   Indicating s8 output if 0 else s16 output
1250  * @param[out] output         Pointer to the output tensor
1251  *
1252  * @note Supported framework: TensorFlow Lite micro (bit-accurate)
1253  *
1254  */
1255 void arm_nn_softmax_common_s8(const int8_t *input,
1256                               const int32_t num_rows,
1257                               const int32_t row_size,
1258                               const int32_t mult,
1259                               const int32_t shift,
1260                               const int32_t diff_min,
1261                               const bool int16_output,
1262                               void *output);
1263 
1264 /**
1265  * @brief macro for adding rounding offset
1266  */
1267 #ifndef ARM_NN_TRUNCATE
1268     #define NN_ROUND(out_shift) ((0x1 << out_shift) >> 1)
1269 #else
1270     #define NN_ROUND(out_shift) 0
1271 #endif
1272 
1273 // Macros for shortening quantization functions' names and avoid long lines
1274 #define MUL_SAT(a, b) arm_nn_doubling_high_mult((a), (b))
1275 #define MUL_SAT_MVE(a, b) arm_doubling_high_mult_mve_32x4((a), (b))
1276 #define MUL_POW2(a, b) arm_nn_mult_by_power_of_two((a), (b))
1277 
1278 #define DIV_POW2(a, b) arm_nn_divide_by_power_of_two((a), (b))
1279 #define DIV_POW2_MVE(a, b) arm_divide_by_power_of_two_mve((a), (b))
1280 
1281 #define EXP_ON_NEG(x) arm_nn_exp_on_negative_values((x))
1282 #define ONE_OVER1(x) arm_nn_one_over_one_plus_x_for_x_in_0_1((x))
1283 
1284 /**
1285  * @brief           Saturating doubling high multiply. Result matches
1286  *                  NEON instruction VQRDMULH.
1287  * @param[in]       m1        Multiplicand. Range: {NN_Q31_MIN, NN_Q31_MAX}
1288  * @param[in]       m2        Multiplier. Range: {NN_Q31_MIN, NN_Q31_MAX}
1289  * @return          Result of multiplication.
1290  *
1291  */
arm_nn_doubling_high_mult(const int32_t m1,const int32_t m2)1292 __STATIC_FORCEINLINE int32_t arm_nn_doubling_high_mult(const int32_t m1, const int32_t m2)
1293 {
1294     int32_t result = 0;
1295     // Rounding offset to add for a right shift of 31
1296     int64_t mult = 1 << 30;
1297 
1298     if ((m1 < 0) ^ (m2 < 0))
1299     {
1300         mult = 1 - mult;
1301     }
1302     // Gets resolved as a SMLAL instruction
1303     mult = mult + (int64_t)m1 * m2;
1304 
1305     // Utilize all of the upper 32 bits. This is the doubling step
1306     // as well.
1307     result = (int32_t)(mult / (1ll << 31));
1308 
1309     if ((m1 == m2) && (m1 == (int32_t)NN_Q31_MIN))
1310     {
1311         result = NN_Q31_MAX;
1312     }
1313     return result;
1314 }
1315 
1316 /**
1317  * @brief           Doubling high multiply without saturation. This is intended
1318  *                  for requantization where the scale is a positive integer
1319  *
1320  * @param[in]       m1        Multiplicand. Range: {NN_Q31_MIN, NN_Q31_MAX}
1321  * @param[in]       m2        Multiplier Range: {NN_Q31_MIN, NN_Q31_MAX}
1322  * @return          Result of multiplication.
1323  * @note            The result of this matches that of neon instruction
1324  *                  VQRDMULH for m1 in range {NN_Q31_MIN, NN_Q31_MAX} and m2 in
1325  *                  range {NN_Q31_MIN + 1, NN_Q31_MAX}. Saturation occurs when
1326  *                  m1 equals m2 equals NN_Q31_MIN and that is not handled by
1327  *                  this function.
1328  *
1329  */
arm_nn_doubling_high_mult_no_sat(const int32_t m1,const int32_t m2)1330 __STATIC_FORCEINLINE int32_t arm_nn_doubling_high_mult_no_sat(const int32_t m1, const int32_t m2)
1331 {
1332     int32_t result = 0;
1333     union arm_nn_long_long mult;
1334 
1335     // Rounding offset to add for a right shift of 31
1336     mult.word.low = 1 << 30;
1337     mult.word.high = 0;
1338 
1339     // Gets resolved as a SMLAL instruction
1340     mult.long_long = mult.long_long + (int64_t)m1 * m2;
1341 
1342     // Utilize all of the upper 32 bits. This is the doubling step
1343     // as well.
1344     result = (int32_t)(mult.long_long >> 31);
1345 
1346     return result;
1347 }
1348 
1349 /**
1350  * @brief           Rounding divide by power of two.
1351  * @param[in]       dividend - Dividend
1352  * @param[in]       exponent - Divisor = power(2, exponent)
1353  *                             Range: [0, 31]
1354  * @return          Rounded result of division. Midpoint is rounded away from zero.
1355  *
1356  */
arm_nn_divide_by_power_of_two(const int32_t dividend,const int32_t exponent)1357 __STATIC_FORCEINLINE int32_t arm_nn_divide_by_power_of_two(const int32_t dividend, const int32_t exponent)
1358 {
1359     int32_t result = 0;
1360     const int32_t remainder_mask = (1 << exponent) - 1;
1361     int32_t remainder = remainder_mask & dividend;
1362 
1363     // Basic division
1364     result = dividend >> exponent;
1365 
1366     // Adjust 'result' for rounding (mid point away from zero)
1367     int32_t threshold = remainder_mask >> 1;
1368     if (result < 0)
1369     {
1370         threshold++;
1371     }
1372     if (remainder > threshold)
1373     {
1374         result++;
1375     }
1376 
1377     return result;
1378 }
1379 
1380 /**
1381  * @brief           Requantize a given value.
1382  * @details         Essentially returns (val * multiplier)/(2 ^ shift) with different rounding depending if
1383  *                  CMSIS_NN_USE_SINGLE_ROUNDING is defined or not.
1384  * @param[in]       val         Value to be requantized
1385  * @param[in]       multiplier  Multiplier. Range {NN_Q31_MIN + 1, Q32_MAX}
1386  * @param[in]       shift       Shift. Range: {-31, 30}
1387  *                              Default branch:
1388  *                                  If shift is positive left shift 'val * multiplier' with shift
1389  *                                  If shift is negative right shift 'val * multiplier' with abs(shift)
1390  *                              Single round branch:
1391  *                                  Input for total_shift in divide by '2 ^ total_shift'
1392  *
1393  * @return          Default branch:
1394  *                      Returns (val * multiplier) with rounding divided by (2 ^ shift) with rounding
1395  *                  Single round branch:
1396  *                      Returns (val * multiplier)/(2 ^ (31 - shift)) with rounding
1397  *
1398  */
arm_nn_requantize(const int32_t val,const int32_t multiplier,const int32_t shift)1399 __STATIC_FORCEINLINE int32_t arm_nn_requantize(const int32_t val, const int32_t multiplier, const int32_t shift)
1400 {
1401 #ifdef CMSIS_NN_USE_SINGLE_ROUNDING
1402     const int64_t total_shift = 31 - shift;
1403     const int64_t new_val = val * (int64_t)multiplier;
1404 
1405     int32_t result = new_val >> (total_shift - 1);
1406     result = (result + 1) >> 1;
1407 
1408     return result;
1409 #else
1410     return arm_nn_divide_by_power_of_two(arm_nn_doubling_high_mult_no_sat(val * (1 << LEFT_SHIFT(shift)), multiplier),
1411                                          RIGHT_SHIFT(shift));
1412 #endif
1413 }
1414 
1415 /**
1416  * @brief           Requantize a given 64 bit value.
1417  * @param[in]       val                 Value to be requantized in the range {-(1<<47)} to {(1<<47) - 1}
1418  * @param[in]       reduced_multiplier  Reduced multiplier in the range {NN_Q31_MIN + 1, Q32_MAX} to {Q16_MIN + 1,
1419  * Q16_MAX}
1420  * @param[in]       shift               Left or right shift for 'val * multiplier' in the range {-31} to {7}
1421  *
1422  * @return          Returns (val * multiplier)/(2 ^ shift)
1423  *
1424  */
arm_nn_requantize_s64(const int64_t val,const int32_t reduced_multiplier,const int32_t shift)1425 __STATIC_FORCEINLINE int32_t arm_nn_requantize_s64(const int64_t val,
1426                                                    const int32_t reduced_multiplier,
1427                                                    const int32_t shift)
1428 {
1429     const int64_t new_val = val * reduced_multiplier;
1430 
1431     int32_t result = new_val >> (14 - shift); // 64->32 bit reduction
1432     result = (result + 1) >> 1;               // Last shift position and insert round
1433 
1434     return result;
1435 }
1436 
1437 /**
1438  * @brief           memcpy optimized for MVE
1439  * @param[in, out]  dst         Destination pointer
1440  * @param[in]       src         Source pointer.
1441  * @param[in]       block_size  Number of bytes to copy.
1442  *
1443  */
arm_memcpy_s8(int8_t * __RESTRICT dst,const int8_t * __RESTRICT src,uint32_t block_size)1444 __STATIC_FORCEINLINE void arm_memcpy_s8(int8_t *__RESTRICT dst, const int8_t *__RESTRICT src, uint32_t block_size)
1445 {
1446 #if defined(ARM_MATH_MVEI)
1447     __asm volatile("   wlstp.8                 lr, %[cnt], 1f             \n"
1448                    "2:                                                    \n"
1449                    "   vldrb.8                 q0, [%[in]], #16            \n"
1450                    "   vstrb.8                 q0, [%[out]], #16           \n"
1451                    "   letp                    lr, 2b                     \n"
1452                    "1:                                                    \n"
1453                    : [in] "+r"(src), [out] "+r"(dst)
1454                    : [cnt] "r"(block_size)
1455                    : "q0", "memory", "r14");
1456 #else
1457     memcpy(dst, src, block_size);
1458 #endif
1459 }
1460 
1461 /**
1462  * @brief           memcpy wrapper for int16
1463  * @param[in, out]  dst         Destination pointer
1464  * @param[in]       src         Source pointer.
1465  * @param[in]       block_size  Number of bytes to copy.
1466  *
1467  */
arm_memcpy_q15(int16_t * __RESTRICT dst,const int16_t * __RESTRICT src,uint32_t block_size)1468 __STATIC_FORCEINLINE void arm_memcpy_q15(int16_t *__RESTRICT dst, const int16_t *__RESTRICT src, uint32_t block_size)
1469 {
1470     memcpy(dst, src, block_size);
1471 }
1472 
1473 #if defined(ARM_MATH_MVEI)
1474 /**
1475  * @brief           Vector saturating doubling high multiply returning high half.
1476  * @param[in]       m1        Multiplicand
1477  * @param[in]       m2        Multiplier
1478  * @return          Result of multiplication.
1479  *
1480  */
arm_doubling_high_mult_mve(const int32x4_t m1,const int32_t m2)1481 __STATIC_FORCEINLINE int32x4_t arm_doubling_high_mult_mve(const int32x4_t m1, const int32_t m2)
1482 {
1483     return vqrdmulhq_n_s32(m1, m2);
1484 }
1485 
1486 /**
1487  * @brief           Vector rounding divide by power of two.
1488  * @param[in]       dividend - Dividend vector
1489  * @param[in]       exponent - Divisor = power(2, exponent)
1490  *                             Range: [0, 31]
1491  * @return          Rounded result of division. Midpoint is rounded away from zero.
1492  *
1493  */
arm_divide_by_power_of_two_mve(const int32x4_t dividend,const int32_t exponent)1494 __STATIC_FORCEINLINE int32x4_t arm_divide_by_power_of_two_mve(const int32x4_t dividend, const int32_t exponent)
1495 {
1496     const int32x4_t shift = vdupq_n_s32(-exponent);
1497     const int32x4_t fixup = vshrq_n_s32(vandq_s32(dividend, shift), 31);
1498     const int32x4_t fixed_up_dividend = vqaddq_s32(dividend, fixup);
1499     return vrshlq_s32(fixed_up_dividend, shift);
1500 }
1501 
1502 /**
1503  * @brief           Requantize a given vector.
1504  * @param[in]       val         Vector to be requantized
1505  * @param[in]       multiplier  multiplier
1506  * @param[in]       shift       shift
1507  *
1508  * @return          Returns (val * multiplier)/(2 ^ shift) with different rounding. See arm_nn_requantize for detatails.
1509  *
1510  */
arm_requantize_mve(const int32x4_t val,const int32_t multiplier,const int32_t shift)1511 __STATIC_FORCEINLINE int32x4_t arm_requantize_mve(const int32x4_t val, const int32_t multiplier, const int32_t shift)
1512 {
1513     #ifdef CMSIS_NN_USE_SINGLE_ROUNDING
1514     const int right_shift = MIN(-1, shift);
1515     const int left_shift = shift - right_shift;
1516 
1517     const int32x4_t left_shift_dup = vdupq_n_s32(left_shift);
1518     const int32x4_t right_shift_dup = vdupq_n_s32(right_shift);
1519 
1520     int32x4_t result = vqdmulhq_n_s32(vshlq_s32(val, left_shift_dup), multiplier);
1521     result = vrshlq_s32(result, right_shift_dup);
1522 
1523     return result;
1524     #else
1525     return arm_divide_by_power_of_two_mve(
1526         arm_doubling_high_mult_mve(vshlq_s32(val, vdupq_n_s32(LEFT_SHIFT(shift))), multiplier), RIGHT_SHIFT(shift));
1527     #endif
1528 }
1529 
1530 /**
1531  * @brief           Vector saturating doubling high multiply with predication returning high half.
1532  * @param[in]       m1        Multiplicand
1533  * @param[in]       m2        Multiplier
1534  * @param[in]       p         Vector predication mask
1535  * @param[in]       v_zero    Vector of zeroes for merging predication intrinsic
1536  * @return          Result of multiplication.
1537  *
1538  */
arm_doubling_high_mult_mve_pred(const int32x4_t m1,const int32_t m2,const mve_pred16_t p,const int32x4_t v_zero)1539 __STATIC_FORCEINLINE int32x4_t arm_doubling_high_mult_mve_pred(const int32x4_t m1,
1540                                                                const int32_t m2,
1541                                                                const mve_pred16_t p,
1542                                                                const int32x4_t v_zero)
1543 {
1544     return vqrdmulhq_m_n_s32(v_zero, m1, m2, p);
1545 }
1546 
1547 /**
1548  * @brief           Vector rounding divide by power of two with predication.
1549  * @param[in]       dividend - Dividend vector
1550  * @param[in]       exponent - Divisor = power(2, exponent)
1551  *                             Range: [0, 31]
1552  * @param[in]       p        - Vector predication mask
1553  * @param[in]       v_zero   - Vector of zeroes for merging predication intrinsic
1554  * @return          Rounded result of division. Midpoint is rounded away from zero.
1555  *
1556  */
arm_divide_by_power_of_two_mve_pred(const int32x4_t dividend,const int32_t exponent,const mve_pred16_t p,const int32x4_t v_zero)1557 __STATIC_FORCEINLINE int32x4_t arm_divide_by_power_of_two_mve_pred(const int32x4_t dividend,
1558                                                                    const int32_t exponent,
1559                                                                    const mve_pred16_t p,
1560                                                                    const int32x4_t v_zero)
1561 {
1562     const int32x4_t shift = vdupq_x_n_s32(-exponent, p);
1563     const int32x4_t fixup = vshrq_x_n_s32(vandq_x_s32(dividend, shift, p), 31, p);
1564     const int32x4_t fixed_up_dividend = vqaddq_m_s32(v_zero, dividend, fixup, p);
1565     return vrshlq_m_s32(v_zero, fixed_up_dividend, shift, p);
1566 }
1567 
1568 /**
1569  * @brief           Requantize a given vector with predication.
1570  * @param[in]       val         Vector to be requantized
1571  * @param[in]       multiplier  multiplier
1572  * @param[in]       shift       shift
1573  * @param[in]       p           Vector predication mask
1574  *
1575  * @return          Returns (val * multiplier)/(2 ^ shift)
1576  *
1577  */
arm_requantize_mve_pred(const int32x4_t val,const int32_t multiplier,const int32_t shift,const mve_pred16_t p)1578 __STATIC_FORCEINLINE int32x4_t arm_requantize_mve_pred(const int32x4_t val,
1579                                                        const int32_t multiplier,
1580                                                        const int32_t shift,
1581                                                        const mve_pred16_t p)
1582 {
1583     #ifdef CMSIS_NN_USE_SINGLE_ROUNDING
1584     const int right_shift = MIN(-1, shift);
1585     const int left_shift = shift - right_shift;
1586     const int32x4_t v_zero = vcreateq_s32(0, 0);
1587 
1588     const int32x4_t left_shift_dup = vdupq_x_n_s32(left_shift, p);
1589     const int32x4_t right_shift_dup = vdupq_x_n_s32(right_shift, p);
1590 
1591     int32x4_t result = vqrdmulhq_m_n_s32(v_zero, vshlq_m_s32(v_zero, val, left_shift_dup, p), multiplier, p);
1592     result = vrshlq_m_s32(v_zero, result, right_shift_dup, p);
1593 
1594     return result;
1595     #else
1596     const int32x4_t v_zero = vcreateq_s32(0, 0);
1597     return arm_divide_by_power_of_two_mve_pred(
1598         arm_doubling_high_mult_mve_pred(
1599             vshlq_m_s32(v_zero, val, vdupq_x_n_s32(LEFT_SHIFT(shift), p), p), multiplier, p, v_zero),
1600         RIGHT_SHIFT(shift),
1601         p,
1602         v_zero);
1603     #endif
1604 }
1605 
arm_doubling_high_mult_mve_32x4(const int32x4_t m1,const int32x4_t m2)1606 __STATIC_FORCEINLINE int32x4_t arm_doubling_high_mult_mve_32x4(const int32x4_t m1, const int32x4_t m2)
1607 {
1608     return vqrdmulhq_s32(m1, m2);
1609 }
1610 
arm_divide_by_power_of_two_mve_32x4(const int32x4_t dividend,const int32x4_t exponent)1611 __STATIC_FORCEINLINE int32x4_t arm_divide_by_power_of_two_mve_32x4(const int32x4_t dividend, const int32x4_t exponent)
1612 {
1613     const int32x4_t shift = -exponent;
1614     const int32x4_t fixup = vshrq_n_s32(vandq_s32(dividend, shift), 31);
1615     const int32x4_t fixed_up_dividend = vqaddq_s32(dividend, fixup);
1616     return vrshlq_s32(fixed_up_dividend, shift);
1617 }
1618 
arm_requantize_mve_32x4(const int32x4_t val,const int32x4_t multiplier,const int32x4_t shift)1619 __STATIC_FORCEINLINE int32x4_t arm_requantize_mve_32x4(const int32x4_t val,
1620                                                        const int32x4_t multiplier,
1621                                                        const int32x4_t shift)
1622 {
1623     #ifdef CMSIS_NN_USE_SINGLE_ROUNDING
1624     const int32x4_t right_shift = vminq_s32(vdupq_n_s32(-1), shift);
1625     const int32x4_t left_shift = vqsubq_s32(shift, right_shift);
1626 
1627     int32x4_t result = vqdmulhq_s32(vshlq_s32(val, left_shift), multiplier);
1628     result = vrshlq_s32(result, right_shift);
1629 
1630     return result;
1631     #else
1632     const int32x4_t zz = vdupq_n_s32(0);
1633     const mve_pred16_t p = vcmpgtq_n_s32(shift, 0);
1634 
1635     const int32x4_t left_shift = vpselq_s32(shift, zz, p);
1636     const int32x4_t right_shift = -vpselq_s32(zz, shift, p);
1637 
1638     return arm_divide_by_power_of_two_mve_32x4(arm_doubling_high_mult_mve_32x4(vshlq_s32(val, left_shift), multiplier),
1639                                                right_shift);
1640     #endif
1641 }
1642 #endif
1643 
1644 // @note The following functions are used only for softmax layer, scaled bits = 5 assumed
1645 
arm_nn_exp_on_negative_values(int32_t val)1646 __STATIC_FORCEINLINE int32_t arm_nn_exp_on_negative_values(int32_t val)
1647 {
1648     int32_t mask = 0;
1649     int32_t shift = 24;
1650 
1651     const int32_t val_mod_minus_quarter = (val & ((1 << shift) - 1)) - (1 << shift);
1652     const int32_t remainder = val_mod_minus_quarter - val;
1653     const int32_t x = (val_mod_minus_quarter << 5) + (1 << 28);
1654     const int32_t x2 = MUL_SAT(x, x);
1655 
1656     int32_t result = 1895147668 +
1657         MUL_SAT(1895147668, x + DIV_POW2(MUL_SAT(DIV_POW2(MUL_SAT(x2, x2), 2) + MUL_SAT(x2, x), 715827883) + x2, 1));
1658 
1659 #define SELECT_IF_NON_ZERO(x)                                                                                          \
1660     {                                                                                                                  \
1661         mask = MASK_IF_NON_ZERO(remainder & (1 << shift++));                                                           \
1662         result = SELECT_USING_MASK(mask, MUL_SAT(result, x), result);                                                  \
1663     }
1664 
1665     SELECT_IF_NON_ZERO(1672461947)
1666     SELECT_IF_NON_ZERO(1302514674)
1667     SELECT_IF_NON_ZERO(790015084)
1668     SELECT_IF_NON_ZERO(290630308)
1669     SELECT_IF_NON_ZERO(39332535)
1670     SELECT_IF_NON_ZERO(720401)
1671     SELECT_IF_NON_ZERO(242)
1672 
1673 #undef SELECT_IF_NON_ZERO
1674 
1675     mask = MASK_IF_ZERO(val);
1676     return SELECT_USING_MASK(mask, NN_Q31_MAX, result);
1677 }
1678 
arm_nn_mult_by_power_of_two(const int32_t val,const int32_t exp)1679 __STATIC_FORCEINLINE int32_t arm_nn_mult_by_power_of_two(const int32_t val, const int32_t exp)
1680 {
1681     const int32_t thresh = ((1 << (31 - exp)) - 1);
1682     int32_t result = val << exp;
1683     result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val > thresh), NN_Q31_MAX, result);
1684     result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val < -thresh), NN_Q31_MIN, result);
1685     return result;
1686 }
1687 
arm_nn_one_over_one_plus_x_for_x_in_0_1(int32_t val)1688 __STATIC_FORCEINLINE int32_t arm_nn_one_over_one_plus_x_for_x_in_0_1(int32_t val)
1689 {
1690     const int64_t sum = (int64_t)val + (int64_t)NN_Q31_MAX;
1691     const int32_t half_denominator = (int32_t)((sum + (sum >= 0 ? 1 : -1)) / 2L);
1692     int32_t x = 1515870810 + MUL_SAT(half_denominator, -1010580540);
1693 
1694     const int32_t shift = (1 << 29);
1695     x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
1696     x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
1697     x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
1698 
1699     return MUL_POW2(x, 1);
1700 }
1701 
1702 /**
1703   @brief         Write 2 s16 elements and post increment pointer.
1704   @param[in]     dest_q15  Pointer to pointer that holds address of destination.
1705   @param[in]     src_q31   Input value to be written.
1706  */
arm_nn_write_q15x2_ia(int16_t ** dest_q15,int32_t src_q31)1707 __STATIC_FORCEINLINE void arm_nn_write_q15x2_ia(int16_t **dest_q15, int32_t src_q31)
1708 {
1709     int32_t val = src_q31;
1710 
1711     memcpy(*dest_q15, &val, 4);
1712     *dest_q15 += 2;
1713 }
1714 
1715 /**
1716   @brief         Write 2 s8 elements and post increment pointer.
1717   @param[in]     dst  Pointer to pointer that holds address of destination.
1718   @param[in]     src  Input value to be written.
1719  */
arm_nn_write_s8x2_ia(int8_t ** dst,int16_t src)1720 __STATIC_FORCEINLINE void arm_nn_write_s8x2_ia(int8_t **dst, int16_t src)
1721 {
1722     memcpy(*dst, &src, 2);
1723     *dst += 2;
1724 }
1725 
1726 // Support functions for LSTM
1727 /**
1728  * @brief Update LSTM function for an iteration step using s8 input and output, and s16 internally.
1729  *
1730  * @param[in]   data_in                         Data input pointer
1731  * @param[in]   hidden_in                       Hidden state/ recurrent input pointer
1732  * @param[out]  hidden_out                      Hidden state/ recurrent output pointer
1733  * @param[in]   params                          Struct containg all information about the lstm operator, see
1734  * arm_nn_types.
1735  * @param[in]   buffers                         Struct containg pointers to all temporary scratch buffers needed for the
1736  * lstm operator, see arm_nn_types.
1737  * @param[in]   batch_offset                    Number of timesteps between consecutive batches.
1738  * E.g for params->timing_major = true, all batches for t=0 are stored sequentially, so batch offset = 1.
1739  * For params->time major = false, all time steps are stored continously before the next batch, so
1740  * batch offset = params->time_steps.
1741  * @return                                      The function returns ARM_CMSIS_NN_SUCCESS
1742 
1743  */
1744 arm_cmsis_nn_status arm_nn_lstm_step_s8(const int8_t *data_in,
1745                                         const int8_t *hidden_in,
1746                                         int8_t *hidden_out,
1747                                         const cmsis_nn_lstm_params *params,
1748                                         cmsis_nn_lstm_context *buffers,
1749                                         const int32_t batch_offset);
1750 
1751 /**
1752  * @brief Update LSTM function for an iteration step using s16 input and output, and s16 internally.
1753  *
1754  * @param[in]   data_in                         Data input pointer
1755  * @param[in]   hidden_in                       Hidden state/ recurrent input pointer
1756  * @param[out]  hidden_out                      Hidden state/ recurrent output pointer
1757  * @param[in]   params                          Struct containg all information about the lstm operator, see
1758  * arm_nn_types.
1759  * @param[in]   buffers                         Struct containg pointers to all temporary scratch buffers needed for the
1760  * lstm operator, see arm_nn_types.
1761  * @param[in]   batch_offset                    Number of timesteps between consecutive batches.
1762  * E.g for params->timing_major = true, all batches for t=0 are stored sequentially, so batch offset = 1.
1763  * For params->time major = false, all time steps are stored continously before the next batch, so
1764  * batch offset = params->time_steps.
1765  * @return                                      The function returns ARM_CMSIS_NN_SUCCESS
1766 
1767  */
1768 arm_cmsis_nn_status arm_nn_lstm_step_s16(const int16_t *data_in,
1769                                          const int16_t *hidden_in,
1770                                          int16_t *hidden_out,
1771                                          const cmsis_nn_lstm_params *params,
1772                                          cmsis_nn_lstm_context *buffers,
1773                                          const int32_t batch_offset);
1774 
1775 /**
1776  * @brief Updates a LSTM gate for an iteration step of LSTM function, int8x8_16 version.
1777  *
1778  * @param[in]   data_in                         Data input pointer
1779  * @param[in]   hidden_in                       Hidden state/ recurrent input pointer
1780  * @param[in]   gate_data                       Struct containing all information about the gate caluclation, see
1781  * arm_nn_types.
1782  * @param[in]   params                          Struct containing all information about the lstm_operation, see
1783  * arm_nn_types
1784  * @param[out]  output                          Hidden state/ recurrent output pointer
1785  * @param[in]   batch_offset                    Number of timesteps between consecutive batches, see
1786  * arm_nn_lstm_step_s8.
1787  * @return                                      The function returns ARM_CMSIS_NN_SUCCESS
1788  */
1789 arm_cmsis_nn_status arm_nn_lstm_calculate_gate_s8_s16(const int8_t *data_in,
1790                                                       const int8_t *hidden_in,
1791                                                       const cmsis_nn_lstm_gate *gate_data,
1792                                                       const cmsis_nn_lstm_params *params,
1793                                                       int16_t *output,
1794                                                       const int32_t batch_offset);
1795 
1796 /**
1797  * @brief Updates a LSTM gate for an iteration step of LSTM function, int16x8_16 version.
1798  *
1799  * @param[in]   data_in                         Data input pointer
1800  * @param[in]   hidden_in                       Hidden state/ recurrent input pointer
1801  * @param[in]   gate_data                       Struct containing all information about the gate caluclation, see
1802  * arm_nn_types.
1803  * @param[in]   params                          Struct containing all information about the lstm_operation, see
1804  * arm_nn_types
1805  * @param[out]  output                          Hidden state/ recurrent output pointer
1806  * @param[in]   batch_offset                    Number of timesteps between consecutive batches, see
1807  * arm_nn_lstm_step_s16.
1808  * @return                                      The function returns ARM_CMSIS_NN_SUCCESS
1809  */
1810 arm_cmsis_nn_status arm_nn_lstm_calculate_gate_s16(const int16_t *data_in,
1811                                                    const int16_t *hidden_in,
1812                                                    const cmsis_nn_lstm_gate *gate_data,
1813                                                    const cmsis_nn_lstm_params *params,
1814                                                    int16_t *output,
1815                                                    const int32_t batch_offset);
1816 
1817 /**
1818  * @brief The result of the multiplication is accumulated to the passed result buffer.
1819  * Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch dimension composed by input vectors independent
1820  * from each other).
1821  *
1822  * @param[in]   lhs              Batched vector
1823  * @param[in]   rhs              Weights - input matrix (H(Rows)xW(Columns))
1824  * @param[in]   effective_bias   Bias + lhs_offset * kernel_sum term precalculated into a constant vector.
1825  * @param[out]  dst              Output
1826  * @param[in]   dst_multiplier   Multiplier for quantization
1827  * @param[in]   dst_shift        Shift for quantization
1828  * @param[in]   rhs_cols         Vector/matarix column length
1829  * @param[in]   rhs_rows         Row count of matrix
1830  * @param[in]   batches          Batch size
1831  * @param[in]   batch_offset     Number of timesteps between consecutive batches in input, see arm_nn_lstm_step_s8. Note
1832  that the output is always stored with sequential batches.
1833  * @return                       The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
1834 
1835  */
1836 arm_cmsis_nn_status arm_nn_vec_mat_mul_result_acc_s8_s16(const int8_t *lhs,
1837                                                          const int8_t *rhs,
1838                                                          const int32_t *effective_bias,
1839                                                          int16_t *dst,
1840                                                          const int32_t dst_multiplier,
1841                                                          const int32_t dst_shift,
1842                                                          const int32_t rhs_cols,
1843                                                          const int32_t rhs_rows,
1844                                                          const int32_t batches,
1845                                                          const int32_t batch_offset);
1846 
1847 /**
1848  * @brief The result of the multiplication is accumulated to the passed result buffer.
1849  * Multiplies a matrix by a "batched" vector (i.e. a matrix with a batch dimension composed by input vectors independent
1850  * from each other).
1851  *
1852  * @param[in]   lhs              Batched vector
1853  * @param[in]   rhs              Weights - input matrix (H(Rows)xW(Columns))
1854  * @param[in]   effective_bias   Bias + lhs_offset * kernel_sum term precalculated into a constant vector.
1855  * @param[out]  dst              Output
1856  * @param[in]   dst_multiplier   Multiplier for quantization
1857  * @param[in]   dst_shift        Shift for quantization
1858  * @param[in]   rhs_cols         Vector/matarix column length
1859  * @param[in]   rhs_rows         Row count of matrix
1860  * @param[in]   batches          Batch size
1861  * @param[in]   batch_offset     Number of timesteps between consecutive batches in input, see arm_nn_lstm_step_s16.
1862  Note that the output is always stored with sequential batches.
1863  * @return                       The function returns <code>ARM_CMSIS_NN_SUCCESS</code>
1864 
1865  */
1866 arm_cmsis_nn_status arm_nn_vec_mat_mul_result_acc_s16(const int16_t *lhs,
1867                                                       const int8_t *rhs,
1868                                                       const int64_t *effective_bias,
1869                                                       int16_t *dst,
1870                                                       const int32_t dst_multiplier,
1871                                                       const int32_t dst_shift,
1872                                                       const int32_t rhs_cols,
1873                                                       const int32_t rhs_rows,
1874                                                       const int32_t batches,
1875                                                       const int32_t batch_offset);
1876 
1877 /**
1878  * @brief s16 elementwise multiplication with s8 output
1879  * @param[in]       input_1_vect        pointer to input vector 1
1880  * @param[in]       input_2_vect        pointer to input vector 2
1881  * @param[in,out]   output              pointer to output vector
1882  * @param[in]       out_offset          output offset
1883  * @param[in]       out_mult            output multiplier
1884  * @param[in]       out_shift           output shift
1885  * @param[in]       block_size          number of samples per batch
1886  * @param[in]       batch_size          number of samples per batch
1887  * @param[in]       batch_offset        Number of timesteps between consecutive batches in output, see
1888  * arm_nn_lstm_step_s8. Note that it is assumed that the input is stored with sequential batches.
1889  * @return          The function returns ARM_CMSIS_NN_SUCCESS
1890  *
1891  * @details   Supported framework: TensorFlow Lite micro
1892  */
1893 arm_cmsis_nn_status arm_elementwise_mul_s16_s8(const int16_t *input_1_vect,
1894                                                const int16_t *input_2_vect,
1895                                                int8_t *output,
1896                                                const int32_t out_offset,
1897                                                const int32_t out_mult,
1898                                                const int32_t out_shift,
1899                                                const int32_t block_size,
1900                                                const int32_t batch_size,
1901                                                const int32_t batch_offset);
1902 
1903 /**
1904  * @brief s16 elementwise multiplication with s16 output
1905  * @param[in]       input_1_vect        pointer to input vector 1
1906  * @param[in]       input_2_vect        pointer to input vector 2
1907  * @param[in,out]   output              pointer to output vector
1908  * @param[in]       out_offset          output offset
1909  * @param[in]       out_mult            output multiplier
1910  * @param[in]       out_shift           output shift
1911  * @param[in]       block_size          number of samples per batch
1912  * @param[in]       batch_size          number of samples per batch
1913  * @param[in]       batch_offset        Number of timesteps between consecutive batches in output, see
1914  * arm_nn_lstm_step_s16. Note that it is assumed that the input is stored with sequential batches.
1915  * @return          The function returns ARM_CMSIS_NN_SUCCESS
1916  *
1917  * @details   Supported framework: TensorFlow Lite micro
1918  */
1919 arm_cmsis_nn_status arm_elementwise_mul_s16_batch_offset(const int16_t *input_1_vect,
1920                                                          const int16_t *input_2_vect,
1921                                                          int16_t *output,
1922                                                          const int32_t out_offset,
1923                                                          const int32_t out_mult,
1924                                                          const int32_t out_shift,
1925                                                          const int32_t block_size,
1926                                                          const int32_t batch_size,
1927                                                          const int32_t batch_offset);
1928 
1929 /**
1930  * @brief s16 elementwise multiplication. The result of the multiplication is accumulated to the passed result buffer.
1931  * @param[in]       input_1_vect        pointer to input vector 1
1932  * @param[in]       input_2_vect        pointer to input vector 2
1933  * @param[in]       input_1_offset      offset for input 1. Not used.
1934  * @param[in]       input_2_offset      offset for input 2. Not used.
1935  * @param[in,out]   output              pointer to output vector
1936  * @param[in]       out_offset          output offset. Not used.
1937  * @param[in]       out_mult            output multiplier
1938  * @param[in]       out_shift           output shift
1939  * @param[in]       out_activation_min  minimum value to clamp output to. Min: -32768
1940  * @param[in]       out_activation_max  maximum value to clamp output to. Max: 32767
1941  * @param[in]       block_size          number of samples
1942  * @return          The function returns ARM_CMSIS_NN_SUCCESS
1943  *
1944  * @details   Supported framework: TensorFlow Lite micro
1945  */
1946 arm_cmsis_nn_status arm_elementwise_mul_acc_s16(const int16_t *input_1_vect,
1947                                                 const int16_t *input_2_vect,
1948                                                 const int32_t input_1_offset,
1949                                                 const int32_t input_2_offset,
1950                                                 int16_t *output,
1951                                                 const int32_t out_offset,
1952                                                 const int32_t out_mult,
1953                                                 const int32_t out_shift,
1954                                                 const int32_t out_activation_min,
1955                                                 const int32_t out_activation_max,
1956                                                 const int32_t block_size);
1957 
1958 #ifdef __cplusplus
1959 }
1960 #endif
1961 
1962 #endif /* ARM_NNSUPPORTFUNCTIONS_H */
1963