1 /*
2  * Copyright (C) 2010-2021 Arm Limited or its affiliates. All rights reserved.
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  *
6  * Licensed under the Apache License, Version 2.0 (the License); you may
7  * not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
14  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 /* ----------------------------------------------------------------------
20  * Project:      CMSIS NN Library
21  * Title:        arm_nnsupportfunctions.h
22  * Description:  Public header file of support functions for CMSIS NN Library
23  *
24  * $Date:        15. April 2021
25  * $Revision:    V.5.5.0
26  *
27  * Target Processor:  Cortex-M CPUs
28  * -------------------------------------------------------------------- */
29 
30 #ifndef _ARM_NNSUPPORTFUNCTIONS_H_
31 #define _ARM_NNSUPPORTFUNCTIONS_H_
32 
33 #include "arm_common_tables.h"
34 #include "arm_math_types.h"
35 
36 #ifdef __cplusplus
37 extern "C" {
38 #endif
39 
40 #define LEFT_SHIFT(_shift) (_shift > 0 ? _shift : 0)
41 #define RIGHT_SHIFT(_shift) (_shift > 0 ? 0 : -_shift)
42 #define MASK_IF_ZERO(x) (x) == 0 ? ~0 : 0
43 #define MASK_IF_NON_ZERO(x) (x) != 0 ? ~0 : 0
44 #define SELECT_USING_MASK(mask, a, b) ((mask) & (a)) ^ (~(mask) & (b))
45 
46 #define MAX(A, B) ((A) > (B) ? (A) : (B))
47 #define MIN(A, B) ((A) < (B) ? (A) : (B))
48 #define CLAMP(x, h, l) MAX(MIN((x), (h)), (l))
49 
50 /**
51  * @brief Union for SIMD access of q31/q15/q7 types
52  */
53 union arm_nnword
54 {
55     q31_t word;
56     /**< q31 type */
57     q15_t half_words[2];
58     /**< q15 type */
59     q7_t bytes[4];
60     /**< q7 type */
61 };
62 
63 /**
64  * @brief Union for data type long long
65  */
66 struct arm_nn_double
67 {
68     uint32_t low;
69     int32_t high;
70 };
71 
72 union arm_nn_long_long
73 {
74     int64_t long_long;
75     struct arm_nn_double word;
76 };
77 
78 /**
79  * @defgroup nndata_convert Neural Network Data Conversion Functions
80  *
81  * Perform data type conversion in-between neural network operations
82  *
83  */
84 
85 /**
86  * @brief Converts the elements of the q7 vector to q15 vector without left-shift
87  * @param[in]       *pSrc points to the q7 input vector
88  * @param[out]      *pDst points to the q15 output vector
89  * @param[in]       blockSize length of the input vector
90  *
91  */
92 void arm_q7_to_q15_no_shift(const q7_t *pSrc, q15_t *pDst, uint32_t blockSize);
93 
94 /**
95  * @brief Non-saturating addition of elements of a q7 vector
96  * @param[in]       *input Pointer to the q7 input vector
97  * @param[out]      *output Pointer to the q31 output variable.
98  * @param[in]       block_size length of the input vector
99  * \par Description:
100  *
101  * 2^24 samples can be added without saturating the result.
102  *
103  * The equation used for the conversion process is:
104  *
105  * <pre>
106  *  sum = input[0] + input[1] + .. + input[block_size -1]
107  * </pre>
108  *
109  * */
110 void arm_nn_add_q7(const q7_t *input, q31_t *output, uint32_t block_size);
111 
112 /**
113  * @brief  Converts the elements of the q7 vector to reordered q15 vector without left-shift
114  * @param[in]       *pSrc points to the q7 input vector
115  * @param[out]      *pDst points to the q15 output vector
116  * @param[in]       blockSize length of the input vector
117  * @return none.
118  *
119  */
120 void arm_q7_to_q15_reordered_no_shift(const q7_t *pSrc, q15_t *pDst, uint32_t blockSize);
121 
122 /**
123  * @brief Converts the elements from a q7 vector to a q15 vector with an added offset
124  * @param[in]    src        pointer to the q7 input vector
125  * @param[out]   dst        pointer to the q15 output vector
126  * @param[in]    block_size length of the input vector
127  * @param[in]    offset     q7 offset to be added to each input vector element.
128  *
129  * \par Description:
130  *
131  * The equation used for the conversion process is:
132  *
133  * <pre>
134  *  dst[n] = (q15_t) src[n] + offset;   0 <= n < block_size.
135  * </pre>
136  *
137  */
138 void arm_q7_to_q15_with_offset(const q7_t *src, q15_t *dst, uint32_t block_size, q15_t offset);
139 
140 /**
141  * @brief Converts the elements of the q7 vector to reordered q15 vector with an added offset
142  * @param[in]       src        pointer to the q7 input vector
143  * @param[out]      dst        pointer to the q15 output vector
144  * @param[in]       block_size length of the input vector
145  * @param[in]       offset     offset to be added to each input vector element.
146  * @return none.
147  *
148  * @details  This function does the q7 to q15 expansion with re-ordering of bytes. Re-ordering is a consequence of
149  *           the sign extension intrinsic(DSP extension). The tail (i.e., last (N % 4) elements) retains its
150  * original order.
151  *
152  */
153 void arm_q7_to_q15_reordered_with_offset(const q7_t *src, q15_t *dst, uint32_t block_size, q15_t offset);
154 
155 /**
156  * @brief Converts the elements from a q7 vector and accumulate to a q15 vector
157  * @param[in]    *src       points to the q7 input vector
158  * @param[out]   *dst       points to the q15 output vector
159  * @param[in]    block_size length of the input vector
160  *
161  * \par Description:
162  *
163  * The equation used for the conversion process is:
164  *
165  * <pre>
166  *  dst[n] += (q15_t) src[n] ;   0 <= n < block_size.
167  * </pre>
168  *
169  */
170 void arm_nn_accumulate_q7_to_q15(q15_t *dst, const q7_t *src, uint32_t block_size);
171 
172 /**
173  * @brief Depthwise conv on an im2col buffer where the input channel equals output channel.
174  * @param[in]    row     pointer to row
175  * @param[in]    col     pointer to im2col buffer, always consists of 2 columns.
176  * @param[in]    num_ch   number of channels
177  * @param[in]    out_shift  pointer to per output channel requantization shift parameter.
178  * @param[in]    out_mult   pointer to per output channel requantization multiplier parameter.
179  * @param[in]    out_offset      output tensor offset.
180  * @param[in]    activation_min   minimum value to clamp the output to. Range : int8
181  * @param[in]    activation_max   maximum value to clamp the output to. Range : int8
182  * @param[in]    kernel_size   number of elements in one column.
183  * @param[in]    output_bias per output channel bias. Range : int32
184  * @param[out]   out         pointer to output
185  * @return     The function returns one of the two
186  *              1. The incremented output pointer for a successful operation or
187  *              2. NULL if implementation is not available.
188  *
189  * @details     Supported framework: TensorFlow Lite micro.
190  */
191 q7_t *arm_nn_depthwise_conv_s8_core(const q7_t *row,
192                                     const q15_t *col,
193                                     const uint16_t num_ch,
194                                     const int32_t *out_shift,
195                                     const int32_t *out_mult,
196                                     const int32_t out_offset,
197                                     const int32_t activation_min,
198                                     const int32_t activation_max,
199                                     const uint16_t kernel_size,
200                                     const int32_t *const output_bias,
201                                     q7_t *out);
202 
203 /**
204  * @brief General Matrix-multiplication function with per-channel requantization.
205  * @param[in]       input_row    pointer to row operand
206  * @param[in]       input_col    pointer to col operand
207  * @param[in]       output_ch    number of rows of input_row
208  * @param[in]       col_batches  number of column batches. Range: 1 to 4
209  * @param[in]       output_shift  pointer to per output channel requantization shift parameter.
210  * @param[in]       output_mult   pointer to per output channel requantization multiplier parameter.
211  * @param[in]       out_offset    output tensor offset.
212  * @param[in]       col_offset    input tensor(col) offset.
213  * @param[in]       row_offset    kernel offset(row). Not used.
214  * @param[in]       out_activation_min   minimum value to clamp the output to. Range : int8
215  * @param[in]       out_activation_max   maximum value to clamp the output to. Range : int8
216  * @param[in]       row_len       number of elements in each row
217  * @param[in]       bias          per output channel bias. Range : int32
218  * @param[in,out]   out           pointer to output
219  * @return     The function returns one of the two
220  *              1. The incremented output pointer for a successful operation or
221  *              2. NULL if implementation is not available.
222  *
223  * @details   Supported framework: TensorFlow Lite
224  */
225 q7_t *arm_nn_mat_mult_s8(const q7_t *input_row,
226                          const q7_t *input_col,
227                          const uint16_t output_ch,
228                          const uint16_t col_batches,
229                          const int32_t *output_shift,
230                          const int32_t *output_mult,
231                          const int32_t out_offset,
232                          const int32_t col_offset,
233                          const int32_t row_offset,
234                          const int16_t out_activation_min,
235                          const int16_t out_activation_max,
236                          const uint16_t row_len,
237                          const int32_t *const bias,
238                          q7_t *out);
239 
240 /**
241  * @brief General Matrix-multiplication without requantization for one row & one column
242  * @param[in]       row_elements  number of row elements
243  * @param[in]       row_base      pointer to row operand
244  * @param[in]       col_base      pointer to col operand
245  * @param[out]      sum_col       pointer to store sum of column elements
246  * @param[out]      output        pointer to store result of multiply-accumulate
247  * @return     The function returns the multiply-accumulated result of the row by column.
248  *
249  * @details Pseudo-code
250  *      *output = 0
251  *      sum_col = 0
252  *      for (i = 0; i < row_elements; i++)
253  *          *output += row_base[i] * col_base[i]
254  *          sum_col += col_base[i]
255  *
256  */
257 arm_status arm_nn_mat_mul_core_1x_s8(int32_t row_elements,
258                                      const int8_t *row_base,
259                                      const int8_t *col_base,
260                                      int32_t *const sum_col,
261                                      int32_t *const output);
262 
263 /**
264  * @brief General Matrix-multiplication without requantization for four rows and one column
265  * @param[in]       row_elements  number of row elements
266  * @param[in]       offset        offset between rows. Can be the same as row_elements.
267  *                                For e.g, in a 1x1 conv scenario with stride as 1.
268  * @param[in]       row_base      pointer to row operand
269  * @param[in]       col_base      pointer to col operand
270  * @param[out]      sum_col       pointer to store sum of column elements
271  * @param[out]      output        pointer to store result(4 int32's) of multiply-accumulate
272  * @return     The function returns the multiply-accumulated result of the row by column
273  *
274  * @details Pseudo-code
275  *      output[0] = 0
276  *         ..
277  *      output[3] = 0
278  *      sum_col = 0
279  *      for (i = 0; i < row_elements; i++)
280  *          output[0] += row_base[i] * col_base[i]
281  *                ..
282  *          output[3] += row_base[i + (row_elements * 3)] * col_base[i]
283  *          sum_col += col_base[i]
284  */
285 arm_status arm_nn_mat_mul_core_4x_s8(const int32_t row_elements,
286                                      const int32_t offset,
287                                      const int8_t *row_base,
288                                      const int8_t *col_base,
289                                      int32_t *const sum_col,
290                                      int32_t *const output);
291 
292 /**
293  * @brief General Matrix-multiplication function with per-channel requantization.
294  *        This function assumes:
295  *        - LHS input matrix NOT transposed (nt)
296  *        - RHS input matrix transposed (t)
297  *
298  *  @note This operation also performs the broadcast bias addition before the requantization
299  *
300  * @param[in]  lhs                Pointer to the LHS input matrix
301  * @param[in]  rhs                Pointer to the RHS input matrix
302  * @param[in]  bias               Pointer to the bias vector. The length of this vector is equal to the number of
303  * output columns (or RHS input rows)
304  * @param[out] dst                Pointer to the output matrix with "m" rows and "n" columns
305  * @param[in]  dst_multipliers    Pointer to the multipliers vector needed for the per-channel requantization.
306  *                                The length of this vector is equal to the number of output columns (or RHS input
307  * rows)
308  * @param[in]  dst_shifts         Pointer to the shifts vector needed for the per-channel requantization. The length
309  * of this vector is equal to the number of output columns (or RHS input rows)
310  * @param[in]  lhs_rows           Number of LHS input rows
311  * @param[in]  rhs_rows           Number of RHS input rows
312  * @param[in]  rhs_cols           Number of LHS/RHS input columns
313  * @param[in]  lhs_offset         Offset to be applied to the LHS input value
314  * @param[in]  dst_offset         Offset to be applied the output result
315  * @param[in]  activation_min     Minimum value to clamp down the output. Range : int8
316  * @param[in]  activation_max     Maximum value to clamp up the output. Range : int8
317  *
318  * @return     The function returns <code>ARM_MATH_SUCCESS</code>
319  *
320  */
321 arm_status arm_nn_mat_mult_nt_t_s8(const q7_t *lhs,
322                                    const q7_t *rhs,
323                                    const q31_t *bias,
324                                    q7_t *dst,
325                                    const int32_t *dst_multipliers,
326                                    const int32_t *dst_shifts,
327                                    const int32_t lhs_rows,
328                                    const int32_t rhs_rows,
329                                    const int32_t rhs_cols,
330                                    const int32_t lhs_offset,
331                                    const int32_t dst_offset,
332                                    const int32_t activation_min,
333                                    const int32_t activation_max);
334 
335 /**
336  * @brief s8 Vector by Matrix (transposed) multiplication
337  *
338  * @param[in]      lhs             Input left-hand side vector
339  * @param[in]      rhs             Input right-hand side matrix (transposed)
340  * @param[in]      bias            Input bias
341  * @param[out]     dst             Output vector
342  * @param[in]      lhs_offset      Offset to be added to the input values of the left-hand side vector.
343  *                                 Range: -127 to 128
344  * @param[in]      rhs_offset      Not used
345  * @param[in]      dst_offset      Offset to be added to the output values. Range: -127 to 128
346  * @param[in]      dst_multiplier  Output multiplier
347  * @param[in]      dst_shift       Output shift
348  * @param[in]      rhs_cols        Number of columns in the right-hand side input matrix
349  * @param[in]      rhs_rows        Number of rows in the right-hand side input matrix
350  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int8
351  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
352  *
353  * @return         The function returns <code>ARM_MATH_SUCCESS</code>
354  *
355  */
356 arm_status arm_nn_vec_mat_mult_t_s8(const q7_t *lhs,
357                                     const q7_t *rhs,
358                                     const q31_t *bias,
359                                     q7_t *dst,
360                                     const int32_t lhs_offset,
361                                     const int32_t rhs_offset,
362                                     const int32_t dst_offset,
363                                     const int32_t dst_multiplier,
364                                     const int32_t dst_shift,
365                                     const int32_t rhs_cols,
366                                     const int32_t rhs_rows,
367                                     const int32_t activation_min,
368                                     const int32_t activation_max);
369 
370 /**
371  * @brief s8 Vector by Matrix (transposed) multiplication with s16 output
372  *
373  * @param[in]      lhs             Input left-hand side vector
374  * @param[in]      rhs             Input right-hand side matrix (transposed)
375  * @param[out]     dst             Output vector
376  * @param[in]      lhs_offset      Offset to be added to the input values of the left-hand side
377  *                                 vector. Range: -127 to 128
378  * @param[in]      rhs_offset      Not used
379  * @param[in]      scatter_offset  Address offset for dst. First output is stored at 'dst', the
380  *                                 second at 'dst + scatter_offset' and so on.
381  * @param[in]      dst_multiplier  Output multiplier
382  * @param[in]      dst_shift       Output shift
383  * @param[in]      rhs_cols        Number of columns in the right-hand side input matrix
384  * @param[in]      rhs_rows        Number of rows in the right-hand side input matrix
385  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int16
386  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int16
387  *
388  * @return         The function returns <code>ARM_MATH_SUCCESS</code>
389  *
390  */
391 arm_status arm_nn_vec_mat_mult_t_svdf_s8(const q7_t *lhs,
392                                          const q7_t *rhs,
393                                          q15_t *dst,
394                                          const int32_t lhs_offset,
395                                          const int32_t rhs_offset,
396                                          const int32_t scatter_offset,
397                                          const int32_t dst_multiplier,
398                                          const int32_t dst_shift,
399                                          const int32_t rhs_cols,
400                                          const int32_t rhs_rows,
401                                          const int32_t activation_min,
402                                          const int32_t activation_max);
403 
404 /**
405  * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in padded cases where
406  *        the padding is -lhs_offset(Range: int8). Dimensions are the same for lhs and rhs.
407  *
408  * @param[in]      lhs             Input left-hand side matrix
409  * @param[in]      rhs             Input right-hand side matrix (transposed)
410  * @param[in]      lhs_offset      LHS matrix offset(input offset). Range: -127 to 128
411  * @param[in]      num_ch          Number of channels in LHS/RHS
412  * @param[in]      out_shift       Per channel output shift. Length of vector is equal to number of channels
413  * @param[in]      out_mult        Per channel output multiplier. Length of vector is equal to number of channels
414  * @param[in]      out_offset      Offset to be added to the output values. Range: -127 to 128
415  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int8
416  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
417  * @param[in]       row_x_col       (row_dimension * col_dimension) of LHS/RHS matrix
418  * @param[in]      output_bias     Per channel output bias. Length of vector is equal to number of channels
419  * @param[in]      out             Output pointer
420  *
421  * @return         The function returns one of the two
422  *                  - Updated output pointer if an implementation is available
423  *                  - NULL if no implementation is available.
424  *
425  * @note           If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read
426  * out for the following.
427  *                  - Output shift
428  *                  - Output multiplier
429  *                  - Output bias
430  *                  - rhs
431  */
432 q7_t *arm_nn_depthwise_conv_nt_t_padded_s8(const q7_t *lhs,
433                                            const q7_t *rhs,
434                                            const int32_t lhs_offset,
435                                            const uint16_t num_ch,
436                                            const int32_t *out_shift,
437                                            const int32_t *out_mult,
438                                            const int32_t out_offset,
439                                            const int32_t activation_min,
440                                            const int32_t activation_max,
441                                            const uint16_t row_x_col,
442                                            const int32_t *const output_bias,
443                                            q7_t *out);
444 
445 /**
446  * @brief Depthwise convolution of transposed rhs matrix with 4 lhs matrices. To be used in non-padded cases.
447  *        Dimensions are the same for lhs and rhs.
448  *
449  * @param[in]      lhs             Input left-hand side matrix
450  * @param[in]      rhs             Input right-hand side matrix (transposed)
451  * @param[in]      lhs_offset      LHS matrix offset(input offset). Range: -127 to 128
452  * @param[in]      num_ch          Number of channels in LHS/RHS
453  * @param[in]      out_shift       Per channel output shift. Length of vector is equal to number of channels.
454  * @param[in]      out_mult        Per channel output multiplier. Length of vector is equal to number of channels.
455  * @param[in]      out_offset      Offset to be added to the output values. Range: -127 to 128
456  * @param[in]      activation_min  Minimum value to clamp the output to. Range: int8
457  * @param[in]      activation_max  Maximum value to clamp the output to. Range: int8
458  * @param[in]       row_x_col       (row_dimension * col_dimension) of LHS/RHS matrix
459  * @param[in]      output_bias     Per channel output bias. Length of vector is equal to number of channels.
460  * @param[in]      out             Output pointer
461  *
462  * @return         The function returns one of the two
463  *                  - Updated output pointer if an implementation is available
464  *                  - NULL if no implementation is available.
465  *
466  * @note           If number of channels is not a multiple of 4, upto 3 elements outside the boundary will be read
467  * out for the following.
468  *                  - Output shift
469  *                  - Output multiplier
470  *                  - Output bias
471  *                  - rhs
472  */
473 q7_t *arm_nn_depthwise_conv_nt_t_s8(const q7_t *lhs,
474                                     const q7_t *rhs,
475                                     const int32_t lhs_offset,
476                                     const uint16_t num_ch,
477                                     const int32_t *out_shift,
478                                     const int32_t *out_mult,
479                                     const int32_t out_offset,
480                                     const int32_t activation_min,
481                                     const int32_t activation_max,
482                                     const uint16_t row_x_col,
483                                     const int32_t *const output_bias,
484                                     q7_t *out);
485 
486 /**
487   @brief         Read 2 q15 elements and post increment pointer.
488   @param[in]     in_q15   Pointer to pointer that holds address of input.
489   @return        q31 value
490  */
arm_nn_read_q15x2_ia(const q15_t ** in_q15)491 __STATIC_FORCEINLINE q31_t arm_nn_read_q15x2_ia(const q15_t **in_q15)
492 {
493     q31_t val;
494 
495     memcpy(&val, *in_q15, 4);
496     *in_q15 += 2;
497 
498     return (val);
499 }
500 
501 /**
502   @brief         Read 4 q7 from q7 pointer and post increment pointer.
503   @param[in]     in_q7       Pointer to pointer that holds address of input.
504   @return        q31 value
505  */
arm_nn_read_q7x4_ia(const q7_t ** in_q7)506 __STATIC_FORCEINLINE q31_t arm_nn_read_q7x4_ia(const q7_t **in_q7)
507 {
508     q31_t val;
509     memcpy(&val, *in_q7, 4);
510     *in_q7 += 4;
511 
512     return (val);
513 }
514 
515 /**
516   @brief         Read 2 q15 from q15 pointer.
517   @param[in]     in_q15   pointer to address of input.
518   @return        q31 value
519  */
arm_nn_read_q15x2(const q15_t * in_q15)520 __STATIC_FORCEINLINE q31_t arm_nn_read_q15x2(const q15_t *in_q15)
521 {
522     q31_t val;
523     memcpy(&val, in_q15, 4);
524 
525     return (val);
526 }
527 
528 /**
529   @brief         Read 4 q7 values.
530   @param[in]     in_q7       pointer to address of input.
531   @return        q31 value
532  */
arm_nn_read_q7x4(const q7_t * in_q7)533 __STATIC_FORCEINLINE q31_t arm_nn_read_q7x4(const q7_t *in_q7)
534 {
535     q31_t val;
536     memcpy(&val, in_q7, 4);
537 
538     return (val);
539 }
540 
541 /**
542  * @brief           memset optimized for MVE
543  * @param[in, out]  dst         Destination pointer
544  * @param[in]       val         Value to set
545  * @param[in]       block_size  Number of bytes to copy.
546  *
547  */
arm_memset_q7(q7_t * dst,const q7_t val,uint32_t block_size)548 __STATIC_FORCEINLINE void arm_memset_q7(q7_t *dst, const q7_t val, uint32_t block_size)
549 {
550 #if defined(ARM_MATH_MVEI)
551     __asm volatile("   vdup.8                  q0, %[set_val]             \n"
552                    "   wlstp.8                 lr, %[cnt], 1f             \n"
553                    "2:                                                    \n"
554                    "   vstrb.8                 q0, [%[in]], 16            \n"
555                    "   letp                    lr, 2b                     \n"
556                    "1:                                                    \n"
557                    : [ in ] "+r"(dst)
558                    : [ cnt ] "r"(block_size), [ set_val ] "r"(val)
559                    : "q0", "memory", "r14");
560 #else
561     memset(dst, val, block_size);
562 #endif
563 }
564 
565 #if defined(ARM_MATH_DSP)
566 
567 /**
568  * @brief read and expand one q7 word into two q15 words
569  */
570 
read_and_pad(const q7_t * source,q31_t * out1,q31_t * out2)571 __STATIC_FORCEINLINE const q7_t *read_and_pad(const q7_t *source, q31_t *out1, q31_t *out2)
572 {
573     q31_t inA = arm_nn_read_q7x4_ia(&source);
574     q31_t inAbuf1 = __SXTB16(__ROR((uint32_t)inA, 8));
575     q31_t inAbuf2 = __SXTB16(inA);
576 
577 #ifndef ARM_MATH_BIG_ENDIAN
578     *out2 = (int32_t)(__PKHTB(inAbuf1, inAbuf2, 16));
579     *out1 = (int32_t)(__PKHBT(inAbuf2, inAbuf1, 16));
580 #else
581     *out1 = (int32_t)(__PKHTB(inAbuf1, inAbuf2, 16));
582     *out2 = (int32_t)(__PKHBT(inAbuf2, inAbuf1, 16));
583 #endif
584 
585     return source;
586 }
587 
588 /**
589  * @brief read and expand one q7 word into two q15 words with reordering
590  */
591 
read_and_pad_reordered(const q7_t * source,q31_t * out1,q31_t * out2)592 __STATIC_FORCEINLINE const q7_t *read_and_pad_reordered(const q7_t *source, q31_t *out1, q31_t *out2)
593 {
594     q31_t inA = arm_nn_read_q7x4_ia(&source);
595 #ifndef ARM_MATH_BIG_ENDIAN
596     *out2 = __SXTB16(__ROR((uint32_t)inA, 8));
597     *out1 = __SXTB16(inA);
598 #else
599     *out1 = __SXTB16(__ROR((uint32_t)inA, 8));
600     *out2 = __SXTB16(inA);
601 #endif
602 
603     return source;
604 }
605 
606 /**
607  * @brief read and expand one q7 word into two q15 words with reordering and add an offset
608  */
609 __STATIC_FORCEINLINE const q7_t *
read_and_pad_reordered_with_offset(const q7_t * source,q31_t * out1,q31_t * out2,q31_t offset)610 read_and_pad_reordered_with_offset(const q7_t *source, q31_t *out1, q31_t *out2, q31_t offset)
611 {
612     q31_t inA = arm_nn_read_q7x4_ia(&source);
613 
614 #ifndef ARM_MATH_BIG_ENDIAN
615     *out2 = __SXTB16(__ROR((uint32_t)inA, 8));
616     *out1 = __SXTB16(inA);
617 #else
618     *out1 = __SXTB16(__ROR((uint32_t)inA, 8));
619     *out2 = __SXTB16(inA);
620 #endif
621     *out1 = __QADD16(*out1, offset);
622     *out2 = __QADD16(*out2, offset);
623 
624     return source;
625 }
626 
627 #endif
628 
629 /**
630  * @defgroup NNBasicMath Basic Math Functions for Neural Network Computation
631  *
632  * Basic Math Functions for Neural Network Computation
633  *
634  */
635 
636 /**
637  * @brief           q7 vector multiplication with variable output shifts
638  * @param[in]       *pSrcA        pointer to the first input vector
639  * @param[in]       *pSrcB        pointer to the second input vector
640  * @param[out]      *pDst         pointer to the output vector
641  * @param[in]       out_shift     amount of right-shift for output
642  * @param[in]       blockSize     number of samples in each vector
643  * @return none.
644  *
645  * <b>Scaling and Overflow Behavior:</b>
646  * \par
647  * The function uses saturating arithmetic.
648  * Results outside of the allowable q15 range [0x8000 0x7FFF] will be saturated.
649  */
650 
651 void arm_nn_mult_q15(q15_t *pSrcA, q15_t *pSrcB, q15_t *pDst, const uint16_t out_shift, uint32_t blockSize);
652 
653 /**
654  * @brief           q7 vector multiplication with variable output shifts
655  * @param[in]       *pSrcA        pointer to the first input vector
656  * @param[in]       *pSrcB        pointer to the second input vector
657  * @param[out]      *pDst         pointer to the output vector
658  * @param[in]       out_shift     amount of right-shift for output
659  * @param[in]       blockSize     number of samples in each vector
660  * @return none.
661  *
662  * <b>Scaling and Overflow Behavior:</b>
663  * \par
664  * The function uses saturating arithmetic.
665  * Results outside of the allowable q7 range [0x80 0x7F] will be saturated.
666  */
667 
668 void arm_nn_mult_q7(q7_t *pSrcA, q7_t *pSrcB, q7_t *pDst, const uint16_t out_shift, uint32_t blockSize);
669 
670 /**
671  * @brief macro for adding rounding offset
672  */
673 #ifndef ARM_NN_TRUNCATE
674 #define NN_ROUND(out_shift) ((0x1u << out_shift) >> 1)
675 #else
676 #define NN_ROUND(out_shift) 0
677 #endif
678 
679 // Macros for shortening quantization functions' names and avoid long lines
680 #define MUL_SAT(a, b) arm_nn_doubling_high_mult((a), (b))
681 #define MUL_SAT_MVE(a, b) arm_doubling_high_mult_mve_32x4((a), (b))
682 #define MUL_POW2(a, b) arm_nn_mult_by_power_of_two((a), (b))
683 
684 #define DIV_POW2(a, b) arm_nn_divide_by_power_of_two((a), (b))
685 #define DIV_POW2_MVE(a, b) arm_divide_by_power_of_two_mve((a), (b))
686 
687 #define EXP_ON_NEG(x) arm_nn_exp_on_negative_values((x))
688 #define ONE_OVER1(x) arm_nn_one_over_one_plus_x_for_x_in_0_1((x))
689 
690 /**
691  * @brief           Saturating doubling high multiply. Result matches
692  *                  NEON instruction VQRDMULH.
693  * @param[in]       m1        Multiplicand. Range: {Q31_MIN, Q31_MAX}
694  * @param[in]       m2        Multiplier. Range: {Q31_MIN, Q31_MAX}
695  * @return          Result of multiplication.
696  *
697  */
arm_nn_doubling_high_mult(const q31_t m1,const q31_t m2)698 __STATIC_FORCEINLINE q31_t arm_nn_doubling_high_mult(const q31_t m1, const q31_t m2)
699 {
700     q31_t result = 0;
701     // Rounding offset to add for a right shift of 31
702     q63_t mult = 1 << 30;
703 
704     if ((m1 < 0) ^ (m2 < 0))
705     {
706         mult = 1 - mult;
707     }
708     // Gets resolved as a SMLAL instruction
709     mult = mult + (q63_t)m1 * m2;
710 
711     // Utilize all of the upper 32 bits. This is the doubling step
712     // as well.
713     result = (int32_t)(mult / (1ll << 31));
714 
715     if ((m1 == m2) && (m1 == (int32_t)Q31_MIN))
716     {
717         result = Q31_MAX;
718     }
719     return result;
720 }
721 
722 /**
723  * @brief           Doubling high multiply without saturation. This is intended
724  *                  for requantization where the scale is a positive integer
725  *
726  * @param[in]       m1        Multiplicand. Range: {Q31_MIN, Q31_MAX}
727  * @param[in]       m2        Multiplier Range: {Q31_MIN, Q31_MAX}
728  * @return          Result of multiplication.
729  * @note            The result of this matches that of neon instruction
730  *                  VQRDMULH for m1 in range {Q31_MIN, Q31_MAX} and m2 in
731  *                  range {Q31_MIN + 1, Q31_MAX}. Saturation occurs when
732  *                  m1 equals m2 equals Q31_MIN and that is not handled by
733  *                  this function.
734  *
735  */
arm_nn_doubling_high_mult_no_sat(const q31_t m1,const q31_t m2)736 __STATIC_FORCEINLINE q31_t arm_nn_doubling_high_mult_no_sat(const q31_t m1, const q31_t m2)
737 {
738     q31_t result = 0;
739     union arm_nn_long_long mult;
740 
741     // Rounding offset to add for a right shift of 31
742     mult.word.low = 1 << 30;
743     mult.word.high = 0;
744 
745     // Gets resolved as a SMLAL instruction
746     mult.long_long = mult.long_long + (q63_t)m1 * m2;
747 
748     // Utilize all of the upper 32 bits. This is the doubling step
749     // as well.
750     result = (int32_t)(mult.long_long >> 31);
751 
752     return result;
753 }
754 
755 /**
756  * @brief           Rounding divide by power of two.
757  * @param[in]       dividend - Dividend
758  * @param[in]       exponent - Divisor = power(2, exponent)
759  *                             Range: [0, 31]
760  * @return          Rounded result of division. Midpoint is rounded away from zero.
761  *
762  */
arm_nn_divide_by_power_of_two(const q31_t dividend,const q31_t exponent)763 __STATIC_FORCEINLINE q31_t arm_nn_divide_by_power_of_two(const q31_t dividend, const q31_t exponent)
764 {
765     q31_t result = 0;
766     const q31_t remainder_mask = (1 << exponent) - 1;
767     int32_t remainder = remainder_mask & dividend;
768 
769     // Basic division
770     result = dividend >> exponent;
771 
772     // Adjust 'result' for rounding (mid point away from zero)
773     q31_t threshold = remainder_mask >> 1;
774     if (result < 0)
775     {
776         threshold++;
777     }
778     if (remainder > threshold)
779     {
780         result++;
781     }
782 
783     return result;
784 }
785 
786 /**
787  * @brief           Requantize a given value.
788  * @param[in]       val         Value to be requantized
789  * @param[in]       multiplier  multiplier. Range {Q31_MIN + 1, Q32_MAX}
790  * @param[in]       shift       left or right shift for 'val * multiplier'
791  *
792  * @return          Returns (val * multiplier)/(2 ^ shift)
793  *
794  */
arm_nn_requantize(const q31_t val,const q31_t multiplier,const q31_t shift)795 __STATIC_FORCEINLINE q31_t arm_nn_requantize(const q31_t val, const q31_t multiplier, const q31_t shift)
796 {
797     return arm_nn_divide_by_power_of_two(arm_nn_doubling_high_mult_no_sat(val * (1 << LEFT_SHIFT(shift)), multiplier),
798                                          RIGHT_SHIFT(shift));
799 }
800 
801 /**
802  * @brief           memcpy optimized for MVE
803  * @param[in, out]  dst         Destination pointer
804  * @param[in]       src         Source pointer.
805  * @param[in]       block_size  Number of bytes to copy.
806  *
807  */
arm_memcpy_q7(q7_t * __RESTRICT dst,const q7_t * __RESTRICT src,uint32_t block_size)808 __STATIC_FORCEINLINE void arm_memcpy_q7(q7_t *__RESTRICT dst, const q7_t *__RESTRICT src, uint32_t block_size)
809 {
810 #if defined(ARM_MATH_MVEI)
811     __asm volatile("   wlstp.8                 lr, %[cnt], 1f             \n"
812                    "2:                                                    \n"
813                    "   vldrb.8                 q0, [%[in]], 16            \n"
814                    "   vstrb.8                 q0, [%[out]], 16           \n"
815                    "   letp                    lr, 2b                     \n"
816                    "1:                                                    \n"
817                    : [ in ] "+r"(src), [ out ] "+r"(dst)
818                    : [ cnt ] "r"(block_size)
819                    : "q0", "memory", "r14");
820 #else
821     memcpy(dst, src, block_size);
822 #endif
823 }
824 
825 #if defined(ARM_MATH_MVEI)
826 /**
827  * @brief           Vector saturating doubling high multiply returning high half.
828  * @param[in]       m1        Multiplicand
829  * @param[in]       m2        Multiplier
830  * @return          Result of multiplication.
831  *
832  */
arm_doubling_high_mult_mve(const int32x4_t m1,const q31_t m2)833 __STATIC_FORCEINLINE int32x4_t arm_doubling_high_mult_mve(const int32x4_t m1, const q31_t m2)
834 {
835     return vqrdmulhq_n_s32(m1, m2);
836 }
837 
838 /**
839  * @brief           Vector rounding divide by power of two.
840  * @param[in]       dividend - Dividend vector
841  * @param[in]       exponent - Divisor = power(2, exponent)
842  *                             Range: [0, 31]
843  * @return          Rounded result of division. Midpoint is rounded away from zero.
844  *
845  */
arm_divide_by_power_of_two_mve(const int32x4_t dividend,const q31_t exponent)846 __STATIC_FORCEINLINE int32x4_t arm_divide_by_power_of_two_mve(const int32x4_t dividend, const q31_t exponent)
847 {
848     const int32x4_t shift = vdupq_n_s32(-exponent);
849     const int32x4_t fixup = vshrq_n_s32(vandq_s32(dividend, shift), 31);
850     const int32x4_t fixed_up_dividend = vqaddq_s32(dividend, fixup);
851     return vrshlq_s32(fixed_up_dividend, shift);
852 }
853 
854 /**
855  * @brief           Requantize a given vector.
856  * @param[in]       val         Vector to be requantized
857  * @param[in]       multiplier  multiplier
858  * @param[in]       shift       shift
859  *
860  * @return          Returns (val * multiplier)/(2 ^ shift)
861  *
862  */
arm_requantize_mve(const int32x4_t val,const q31_t multiplier,const q31_t shift)863 __STATIC_FORCEINLINE int32x4_t arm_requantize_mve(const int32x4_t val, const q31_t multiplier, const q31_t shift)
864 {
865     return arm_divide_by_power_of_two_mve(
866         arm_doubling_high_mult_mve(vshlq_s32(val, vdupq_n_s32(LEFT_SHIFT(shift))), multiplier), RIGHT_SHIFT(shift));
867 }
868 
arm_doubling_high_mult_mve_32x4(const int32x4_t m1,const int32x4_t m2)869 __STATIC_FORCEINLINE int32x4_t arm_doubling_high_mult_mve_32x4(const int32x4_t m1, const int32x4_t m2)
870 {
871     return vqrdmulhq_s32(m1, m2);
872 }
873 
arm_divide_by_power_of_two_mve_32x4(const int32x4_t dividend,const int32x4_t exponent)874 __STATIC_FORCEINLINE int32x4_t arm_divide_by_power_of_two_mve_32x4(const int32x4_t dividend, const int32x4_t exponent)
875 {
876     const int32x4_t shift = -exponent;
877     const int32x4_t fixup = vshrq_n_s32(vandq_s32(dividend, shift), 31);
878     const int32x4_t fixed_up_dividend = vqaddq_s32(dividend, fixup);
879     return vrshlq_s32(fixed_up_dividend, shift);
880 }
881 
arm_requantize_mve_32x4(const int32x4_t val,const int32x4_t multiplier,const int32x4_t shift)882 __STATIC_FORCEINLINE int32x4_t arm_requantize_mve_32x4(const int32x4_t val,
883                                                        const int32x4_t multiplier,
884                                                        const int32x4_t shift)
885 {
886     const int32x4_t zz = vdupq_n_s32(0);
887     const mve_pred16_t p = vcmpgtq_n_s32(shift, 0);
888 
889     const int32x4_t left_shift = vpselq_s32(shift, zz, p);
890     const int32x4_t right_shift = -vpselq_s32(zz, shift, p);
891 
892     return arm_divide_by_power_of_two_mve_32x4(arm_doubling_high_mult_mve_32x4(vshlq_s32(val, left_shift), multiplier),
893                                                right_shift);
894 }
895 #endif
896 
897 // @note The following functions are used only for softmax layer, scaled bits = 5 assumed
898 
arm_nn_exp_on_negative_values(int32_t val)899 __STATIC_FORCEINLINE int32_t arm_nn_exp_on_negative_values(int32_t val)
900 {
901     int32_t mask = 0;
902     int32_t shift = 24;
903 
904     const int32_t val_mod_minus_quarter = (val & ((1 << shift) - 1)) - (1 << shift);
905     const int32_t remainder = val_mod_minus_quarter - val;
906     const int32_t x = (val_mod_minus_quarter << 5) + (1 << 28);
907     const int32_t x2 = MUL_SAT(x, x);
908 
909     int32_t result = 1895147668 +
910         MUL_SAT(1895147668, x + DIV_POW2(MUL_SAT(DIV_POW2(MUL_SAT(x2, x2), 2) + MUL_SAT(x2, x), 715827883) + x2, 1));
911 
912 #define SELECT_IF_NON_ZERO(x)                                                                                          \
913     {                                                                                                                  \
914         mask = MASK_IF_NON_ZERO(remainder & (1 << shift++));                                                           \
915         result = SELECT_USING_MASK(mask, MUL_SAT(result, x), result);                                                  \
916     }
917 
918     SELECT_IF_NON_ZERO(1672461947)
919     SELECT_IF_NON_ZERO(1302514674)
920     SELECT_IF_NON_ZERO(790015084)
921     SELECT_IF_NON_ZERO(290630308)
922     SELECT_IF_NON_ZERO(39332535)
923     SELECT_IF_NON_ZERO(720401)
924     SELECT_IF_NON_ZERO(242)
925 
926 #undef SELECT_IF_NON_ZERO
927 
928     mask = MASK_IF_ZERO(val);
929     return SELECT_USING_MASK(mask, Q31_MAX, result);
930 }
931 
arm_nn_mult_by_power_of_two(const int32_t val,const int32_t exp)932 __STATIC_FORCEINLINE q31_t arm_nn_mult_by_power_of_two(const int32_t val, const int32_t exp)
933 {
934     const int32_t thresh = ((1 << (31 - exp)) - 1);
935     int32_t result = val << exp;
936     result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val > thresh), Q31_MAX, result);
937     result = SELECT_USING_MASK(MASK_IF_NON_ZERO(val < -thresh), Q31_MIN, result);
938     return result;
939 }
940 
arm_nn_one_over_one_plus_x_for_x_in_0_1(int32_t val)941 __STATIC_FORCEINLINE int32_t arm_nn_one_over_one_plus_x_for_x_in_0_1(int32_t val)
942 {
943     const int64_t sum = (int64_t)val + (int64_t)Q31_MAX;
944     const int32_t half_denominator = (int32_t)((sum + (sum >= 0 ? 1 : -1)) / 2L);
945     int32_t x = 1515870810 + MUL_SAT(half_denominator, -1010580540);
946 
947     const int32_t shift = (1 << 29);
948     x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
949     x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
950     x += MUL_POW2(MUL_SAT(x, shift - MUL_SAT(half_denominator, x)), 2);
951 
952     return MUL_POW2(x, 1);
953 }
954 
955 /**
956   @brief         Write 2 q15 elements and post increment pointer.
957   @param[in]     dest_q15  Pointer to pointer that holds address of destination.
958   @param[in]     src_q31   Input value to be written.
959   @return        none
960  */
arm_nn_write_q15x2_ia(q15_t ** dest_q15,q31_t src_q31)961 __STATIC_FORCEINLINE void arm_nn_write_q15x2_ia(q15_t **dest_q15, q31_t src_q31)
962 {
963     q31_t val = src_q31;
964 
965     memcpy(*dest_q15, &val, 4);
966     *dest_q15 += 2;
967 }
968 
969 #ifdef __cplusplus
970 }
971 #endif
972 
973 #endif
974