/* * SPDX-FileCopyrightText: Copyright 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com> * * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the License); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* ---------------------------------------------------------------------- * Project: CMSIS NN Library * Title: arm_nn_mat_mult_nt_t_s4 * Description: Matrix multiplication support function with the right-hand-side (rhs) matrix transposed, and 4 bit rhs. * * $Date: 10 April 2024 * $Revision: V.1.1.0 * * Target : Arm(R) M-Profile Architecture * * -------------------------------------------------------------------- */ #include "arm_nnsupportfunctions.h" /** * @ingroup groupSupport */ /** * @addtogroup supportConvolution * @{ */ /* * s4 matrix multiplication with the right-hand-side matrix transposed * * Refer header file for details. * */ arm_cmsis_nn_status arm_nn_mat_mult_nt_t_s4(const int8_t *lhs, const int8_t *packed_rhs, const int32_t *bias, int8_t *dst, const int32_t *dst_multipliers, const int32_t *dst_shifts, const int32_t lhs_rows, const int32_t rhs_rows, const int32_t rhs_cols, const int32_t lhs_offset, const int32_t dst_offset, const int32_t activation_min, const int32_t activation_max, const int32_t lhs_cols_offset) { #if defined(ARM_MATH_MVEI) int i_items = 0; const int rhs_cols_offset = rhs_cols % 16; const int32_t blk_cnt = rhs_cols >> 4; const mve_pred16_t lower_nibble_mask = 21845; // 0101010101010101 const uint8x16_t gather_offset = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7}; const uint32x4_t scatter_offset = {0, (uint32_t)rhs_rows, (uint32_t)rhs_rows * 2, (uint32_t)rhs_rows * 3}; for (; i_items <= (lhs_rows - 4); i_items += 4) { int8_t const *col_base = packed_rhs; for (int i = 0; i < rhs_rows; i++) { int32_t acc_n0 = 0; int32_t acc_n1 = 0; int32_t acc_n2 = 0; int32_t acc_n3 = 0; int8_t const *ip_row_0 = lhs; int8_t const *ip_row_1 = lhs + lhs_cols_offset; int8_t const *ip_row_2 = lhs + (2 * lhs_cols_offset); int8_t const *ip_row_3 = lhs + (3 * lhs_cols_offset); int32_t sum_tmp = 0; mve_pred16_t rmdr_mask = vctp8q(rhs_cols_offset); if ((rhs_cols & 0x1) & (i & 0x1)) { rmdr_mask >>= 1; int32_t col = col_base[0] >> 4; sum_tmp = col; acc_n0 += ip_row_0[0] * col; acc_n1 += ip_row_1[0] * col; acc_n2 += ip_row_2[0] * col; acc_n3 += ip_row_3[0] * col; ++col_base; ++ip_row_0; ++ip_row_1; ++ip_row_2; ++ip_row_3; } for (int j = blk_cnt; j > 0; --j) { int8x16_t col_vec = vldrbq_gather_offset_s8(col_base, gather_offset); col_base += 8; col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask); col_vec = vshrq_n_s8(col_vec, 4); sum_tmp = vaddvaq_s8(sum_tmp, col_vec); int8x16_t lhs_vec = vldrbq_s8(ip_row_0); ip_row_0 += 16; acc_n0 = vmladavaq_s8(acc_n0, col_vec, lhs_vec); lhs_vec = vldrbq_s8(ip_row_1); ip_row_1 += 16; acc_n1 = vmladavaq_s8(acc_n1, col_vec, lhs_vec); lhs_vec = vldrbq_s8(ip_row_2); ip_row_2 += 16; acc_n2 = vmladavaq_s8(acc_n2, col_vec, lhs_vec); lhs_vec = vldrbq_s8(ip_row_3); ip_row_3 += 16; acc_n3 = vmladavaq_s8(acc_n3, col_vec, lhs_vec); } if (rmdr_mask) { int8x16_t col_vec = vldrbq_gather_offset_z_s8(col_base, gather_offset, rmdr_mask); col_base += rhs_cols_offset >> 1; col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask); col_vec = vshrq_n_s8(col_vec, 4); sum_tmp = vaddvaq_p_s8(sum_tmp, col_vec, rmdr_mask); int8x16_t lhs_vec = vldrbq_z_s8(ip_row_0, rmdr_mask); acc_n0 = vmladavaq_p_s8(acc_n0, col_vec, lhs_vec, rmdr_mask); lhs_vec = vldrbq_z_s8(ip_row_1, rmdr_mask); acc_n1 = vmladavaq_p_s8(acc_n1, col_vec, lhs_vec, rmdr_mask); lhs_vec = vldrbq_z_s8(ip_row_2, rmdr_mask); acc_n2 = vmladavaq_p_s8(acc_n2, col_vec, lhs_vec, rmdr_mask); lhs_vec = vldrbq_z_s8(ip_row_3, rmdr_mask); acc_n3 = vmladavaq_p_s8(acc_n3, col_vec, lhs_vec, rmdr_mask); } int32x4_t res = {acc_n0, acc_n1, acc_n2, acc_n3}; sum_tmp *= lhs_offset; if (bias) { sum_tmp += bias[i]; } res = vaddq_n_s32(res, sum_tmp); res = arm_requantize_mve(res, dst_multipliers[i], dst_shifts[i]); res = vaddq_n_s32(res, dst_offset); res = vmaxq_s32(res, vdupq_n_s32(activation_min)); res = vminq_s32(res, vdupq_n_s32(activation_max)); vstrbq_scatter_offset_s32(dst, scatter_offset, res); dst++; } lhs += 4 * lhs_cols_offset; dst += (3 * rhs_rows); } if (lhs_rows % 4 == 3) { int8_t const *col_base = packed_rhs; const mve_pred16_t requant_mask = vctp32q(3); for (int i = 0; i < rhs_rows; i++) { int32_t acc_n0 = 0; int32_t acc_n1 = 0; int32_t acc_n2 = 0; int8_t const *ip_row_0 = lhs; int8_t const *ip_row_1 = lhs + lhs_cols_offset; int8_t const *ip_row_2 = lhs + (2 * lhs_cols_offset); int32_t sum_tmp = 0; mve_pred16_t rmdr_mask = vctp8q(rhs_cols_offset); if ((rhs_cols & 0x1) & (i & 0x1)) { rmdr_mask >>= 1; int32_t col = col_base[0] >> 4; sum_tmp = col; acc_n0 += ip_row_0[0] * col; acc_n1 += ip_row_1[0] * col; acc_n2 += ip_row_2[0] * col; ++col_base; ++ip_row_0; ++ip_row_1; ++ip_row_2; } for (int j = blk_cnt; j > 0; --j) { int8x16_t col_vec = vldrbq_gather_offset_s8(col_base, gather_offset); col_base += 8; col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask); col_vec = vshrq_n_s8(col_vec, 4); sum_tmp = vaddvaq_s8(sum_tmp, col_vec); int8x16_t lhs_vec = vldrbq_s8(ip_row_0); ip_row_0 += 16; acc_n0 = vmladavaq_s8(acc_n0, col_vec, lhs_vec); lhs_vec = vldrbq_s8(ip_row_1); ip_row_1 += 16; acc_n1 = vmladavaq_s8(acc_n1, col_vec, lhs_vec); lhs_vec = vldrbq_s8(ip_row_2); ip_row_2 += 16; acc_n2 = vmladavaq_s8(acc_n2, col_vec, lhs_vec); } if (rmdr_mask) { int8x16_t col_vec = vldrbq_gather_offset_z_s8(col_base, gather_offset, rmdr_mask); col_base += rhs_cols_offset >> 1; col_vec = vrshlq_m_n_s8(col_vec, 4, (lower_nibble_mask & rmdr_mask)); col_vec = vshrq_n_s8(col_vec, 4); sum_tmp = vaddvaq_p_s8(sum_tmp, col_vec, rmdr_mask); int8x16_t lhs_vec = vldrbq_z_s8(ip_row_0, rmdr_mask); acc_n0 = vmladavaq_p_s8(acc_n0, col_vec, lhs_vec, rmdr_mask); lhs_vec = vldrbq_z_s8(ip_row_1, rmdr_mask); acc_n1 = vmladavaq_p_s8(acc_n1, col_vec, lhs_vec, rmdr_mask); lhs_vec = vldrbq_z_s8(ip_row_2, rmdr_mask); acc_n2 = vmladavaq_p_s8(acc_n2, col_vec, lhs_vec, rmdr_mask); } int32x4_t res = {acc_n0, acc_n1, acc_n2, 0}; sum_tmp *= lhs_offset; if (bias) { sum_tmp += bias[i]; } res = vaddq_x_n_s32(res, sum_tmp, requant_mask); res = arm_requantize_mve_pred(res, dst_multipliers[i], dst_shifts[i], requant_mask); res = vaddq_x_n_s32(res, dst_offset, requant_mask); res = vmaxq_x_s32(res, vdupq_n_s32(activation_min), requant_mask); res = vminq_x_s32(res, vdupq_n_s32(activation_max), requant_mask); vstrbq_scatter_offset_p_s32(dst, scatter_offset, res, requant_mask); dst++; } lhs += 3 * lhs_cols_offset; dst += (2 * rhs_rows); } else { for (; i_items < lhs_rows; i_items++) { int32_t acc[4]; const int32_t *multipliers = dst_multipliers; const int32_t *shifts = dst_shifts; const int8_t *col_base = packed_rhs; int col_inc = rhs_cols_offset >> 1; for (int i = 0; i < rhs_rows; i++) { int32_t acc_n0 = 0; const int8_t *ip_row_0 = lhs; int32_t sum_tmp = 0; mve_pred16_t rmdr_mask = vctp8q(rhs_cols_offset); if ((rhs_cols & 0x1) & (i & 0x1)) { rmdr_mask >>= 1; int32_t col = col_base[0] >> 4; sum_tmp += col; acc_n0 += ip_row_0[0] * col; ++col_base; ++ip_row_0; } for (int j = blk_cnt; j > 0; --j) { int8x16_t col_vec = vldrbq_gather_offset_s8(col_base, gather_offset); col_base += 8; col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask); col_vec = vshrq_n_s8(col_vec, 4); sum_tmp = vaddvaq_s8(sum_tmp, col_vec); int8x16_t lhs_vec = vldrbq_s8(ip_row_0); ip_row_0 += 16; acc_n0 = vmladavaq_s8(acc_n0, col_vec, lhs_vec); } if (rmdr_mask) { int8x16_t col_vec = vldrbq_gather_offset_z_s8(col_base, gather_offset, rmdr_mask); col_base += col_inc; col_vec = vrshlq_m_n_s8(col_vec, 4, lower_nibble_mask); col_vec = vshrq_n_s8(col_vec, 4); sum_tmp = vaddvaq_p_s8(sum_tmp, col_vec, rmdr_mask); int8x16_t lhs_vec = vldrbq_z_s8(ip_row_0, rmdr_mask); acc_n0 = vmladavaq_p_s8(acc_n0, col_vec, lhs_vec, rmdr_mask); } sum_tmp *= lhs_offset; sum_tmp += acc_n0; if (bias) { sum_tmp += bias[i]; } const int32_t index = i & 0x3; acc[index] = sum_tmp; if (index == 3) { int32x4_t res = vldrwq_s32(acc); res = arm_requantize_mve_32x4(res, vldrwq_s32(multipliers), vldrwq_s32(shifts)); multipliers += 4; shifts += 4; res = vaddq_n_s32(res, dst_offset); res = vmaxq_s32(res, vdupq_n_s32(activation_min)); res = vminq_s32(res, vdupq_n_s32(activation_max)); vstrbq_s32(dst, res); dst += 4; } } lhs += lhs_cols_offset; const int32_t tail_rows = rhs_rows & 0x3; for (int i = 0; i < tail_rows; i++) { int32_t acc_n0 = acc[i]; acc_n0 = arm_nn_requantize(acc_n0, multipliers[i], shifts[i]); acc_n0 += dst_offset; acc_n0 = MAX(acc_n0, activation_min); acc_n0 = MIN(acc_n0, activation_max); *dst++ = (int8_t)acc_n0; } } } #elif defined(ARM_MATH_DSP) const int32_t lhs_cols_off1 = lhs_cols_offset - 4; const int16_t i16_lhs_offset = (int16_t)lhs_offset; const uint32_t ui32_lhs_offset_i16x2 = PKHBT(i16_lhs_offset, i16_lhs_offset, 16); const int32_t rhs_cols_int4 = rhs_cols >> 1; for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 4); rhs_rows_idx += 4) { const int8_t *lhs_ptr = &lhs[0]; int8_t *dst_ptr = &dst[0]; int32_t lhs_rows_idx = lhs_rows >> 1; while (lhs_rows_idx) { const int8_t *packed_rhs_ptr = &packed_rhs[0]; int32_t res00 = 0; int32_t res01 = 0; int32_t res10 = 0; int32_t res11 = 0; int32_t spillover00 = 0; int32_t spillover01 = 0; int32_t spillover10 = 0; int32_t spillover11 = 0; if (bias) { res00 = bias[rhs_rows_idx]; res01 = bias[rhs_rows_idx + 2]; res10 = bias[rhs_rows_idx]; res11 = bias[rhs_rows_idx + 2]; spillover00 = bias[rhs_rows_idx + 1]; spillover01 = bias[rhs_rows_idx + 3]; spillover10 = bias[rhs_rows_idx + 1]; spillover11 = bias[rhs_rows_idx + 3]; } int32_t rhs_cols_idx = 0; int32_t lhs_low, rhs_low0, rhs_high0, lhs_high, rhs_low1, rhs_high1; for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16) { read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 4 x MAC res00, res01 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); res01 = SMLAD(rhs_low1, lhs_low, res01); res01 = SMLAD(rhs_high1, lhs_high, res01); // 4 x MAC res10, res11 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); res10 = SMLAD(rhs_low0, lhs_low, res10); res11 = SMLAD(rhs_low1, lhs_low, res11); res10 = SMLAD(rhs_high0, lhs_high, res10); res11 = SMLAD(rhs_high1, lhs_high, res11); read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 4 x MAC res00, res01 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); res01 = SMLAD(rhs_low1, lhs_low, res01); res01 = SMLAD(rhs_high1, lhs_high, res01); // 4 x MAC res10, res11 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); res10 = SMLAD(rhs_low0, lhs_low, res10); res11 = SMLAD(rhs_low1, lhs_low, res11); res10 = SMLAD(rhs_high0, lhs_high, res10); res11 = SMLAD(rhs_high1, lhs_high, res11); read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 4 x MAC res00, res01 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); res01 = SMLAD(rhs_low1, lhs_low, res01); res01 = SMLAD(rhs_high1, lhs_high, res01); // 4 x MAC res10, res11 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); res10 = SMLAD(rhs_low0, lhs_low, res10); res11 = SMLAD(rhs_low1, lhs_low, res11); res10 = SMLAD(rhs_high0, lhs_high, res10); res11 = SMLAD(rhs_high1, lhs_high, res11); read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 4 x MAC res00, res01 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); res01 = SMLAD(rhs_low1, lhs_low, res01); res01 = SMLAD(rhs_high1, lhs_high, res01); // 4 x MAC res10, res11 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); res10 = SMLAD(rhs_low0, lhs_low, res10); res11 = SMLAD(rhs_low1, lhs_low, res11); res10 = SMLAD(rhs_high0, lhs_high, res10); res11 = SMLAD(rhs_high1, lhs_high, res11); } for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4) { read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 4 x MAC res00, res01 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); res01 = SMLAD(rhs_low1, lhs_low, res01); res01 = SMLAD(rhs_high1, lhs_high, res01); // 4 x MAC res10, res11 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); res10 = SMLAD(rhs_low0, lhs_low, res10); res11 = SMLAD(rhs_low1, lhs_low, res11); res10 = SMLAD(rhs_high0, lhs_high, res10); res11 = SMLAD(rhs_high1, lhs_high, res11); } for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2) { rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4; rhs_high0 = packed_rhs_ptr[0] >> 4; rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4; rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4; lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; lhs_high = (int8_t)lhs_ptr[1] + lhs_offset; res00 += lhs_low * rhs_low0; res00 += lhs_high * rhs_high0; res01 += lhs_low * rhs_low1; res01 += lhs_high * rhs_high1; lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset; lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset; res10 += lhs_low * rhs_low0; res10 += lhs_high * rhs_high0; res11 += lhs_low * rhs_low1; res11 += lhs_high * rhs_high1; ++packed_rhs_ptr; lhs_ptr += 2; } if (rhs_cols % 2) { rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4; rhs_high0 = packed_rhs_ptr[0] >> 4; rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4; rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4; lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset; res00 += lhs_low * rhs_low0; res01 += lhs_low * rhs_low1; res10 += lhs_high * rhs_low0; res11 += lhs_high * rhs_low1; lhs_ptr -= rhs_cols - 1; lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset; spillover00 += lhs_low * rhs_high0; spillover01 += lhs_low * rhs_high1; spillover10 += lhs_high * rhs_high0; spillover11 += lhs_high * rhs_high1; ++packed_rhs_ptr; ++lhs_ptr; } else { lhs_ptr -= rhs_cols; } // Quantize down res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]); res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]); res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]); res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]); // Add offset res00 += dst_offset; res01 += dst_offset; res10 += dst_offset; res11 += dst_offset; // Clamp the result res00 = MAX(res00, activation_min); res00 = MIN(res00, activation_max); res01 = MAX(res01, activation_min); res01 = MIN(res01, activation_max); res10 = MAX(res10, activation_min); res10 = MIN(res10, activation_max); res11 = MAX(res11, activation_min); res11 = MIN(res11, activation_max); dst_ptr[0] = (int8_t)res00; dst_ptr[2] = (int8_t)res01; dst_ptr += rhs_rows; dst_ptr[0] = (int8_t)res10; dst_ptr[2] = (int8_t)res11; dst_ptr -= rhs_rows; res00 = spillover00; res01 = spillover01; res10 = spillover10; res11 = spillover11; rhs_cols_idx = 0; for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16) { read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 4 x MAC res00, res01 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); res01 = SMLAD(rhs_low1, lhs_low, res01); res01 = SMLAD(rhs_high1, lhs_high, res01); // 4 x MAC res10, res11 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); res10 = SMLAD(rhs_low0, lhs_low, res10); res11 = SMLAD(rhs_low1, lhs_low, res11); res10 = SMLAD(rhs_high0, lhs_high, res10); res11 = SMLAD(rhs_high1, lhs_high, res11); read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 4 x MAC res00, res01 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); res01 = SMLAD(rhs_low1, lhs_low, res01); res01 = SMLAD(rhs_high1, lhs_high, res01); // 4 x MAC res10, res11 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); res10 = SMLAD(rhs_low0, lhs_low, res10); res11 = SMLAD(rhs_low1, lhs_low, res11); res10 = SMLAD(rhs_high0, lhs_high, res10); res11 = SMLAD(rhs_high1, lhs_high, res11); read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 4 x MAC res00, res01 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); res01 = SMLAD(rhs_low1, lhs_low, res01); res01 = SMLAD(rhs_high1, lhs_high, res01); // 4 x MAC res10, res11 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); res10 = SMLAD(rhs_low0, lhs_low, res10); res11 = SMLAD(rhs_low1, lhs_low, res11); res10 = SMLAD(rhs_high0, lhs_high, res10); res11 = SMLAD(rhs_high1, lhs_high, res11); read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 4 x MAC res00, res01 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); res01 = SMLAD(rhs_low1, lhs_low, res01); res01 = SMLAD(rhs_high1, lhs_high, res01); // 4 x MAC res10, res11 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); res10 = SMLAD(rhs_low0, lhs_low, res10); res11 = SMLAD(rhs_low1, lhs_low, res11); res10 = SMLAD(rhs_high0, lhs_high, res10); res11 = SMLAD(rhs_high1, lhs_high, res11); } for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4) { read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 4 x MAC res00, res01 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); res01 = SMLAD(rhs_low1, lhs_low, res01); res01 = SMLAD(rhs_high1, lhs_high, res01); // 4 x MAC res10, res11 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); res10 = SMLAD(rhs_low0, lhs_low, res10); res11 = SMLAD(rhs_low1, lhs_low, res11); res10 = SMLAD(rhs_high0, lhs_high, res10); res11 = SMLAD(rhs_high1, lhs_high, res11); } for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2) { rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4; rhs_high0 = packed_rhs_ptr[0] >> 4; rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4; rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4; lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; lhs_high = (int8_t)lhs_ptr[1] + lhs_offset; res00 += lhs_low * rhs_low0; res00 += lhs_high * rhs_high0; res01 += lhs_low * rhs_low1; res01 += lhs_high * rhs_high1; lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset; lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset; res10 += lhs_low * rhs_low0; res10 += lhs_high * rhs_high0; res11 += lhs_low * rhs_low1; res11 += lhs_high * rhs_high1; ++packed_rhs_ptr; lhs_ptr += 2; } // Quantize down res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]); res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]); res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]); res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]); // Add offset res00 += dst_offset; res01 += dst_offset; res10 += dst_offset; res11 += dst_offset; // Clamp the result res00 = MAX(res00, activation_min); res00 = MIN(res00, activation_max); res01 = MAX(res01, activation_min); res01 = MIN(res01, activation_max); res10 = MAX(res10, activation_min); res10 = MIN(res10, activation_max); res11 = MAX(res11, activation_min); res11 = MIN(res11, activation_max); dst_ptr[1] = (int8_t)res00; dst_ptr[3] = (int8_t)res01; dst_ptr += rhs_rows; dst_ptr[1] = (int8_t)res10; dst_ptr[3] = (int8_t)res11; dst_ptr += rhs_rows; lhs_ptr -= rhs_cols; lhs_ptr += 2 * lhs_cols_offset; lhs_rows_idx--; } // Left-over rows if (lhs_rows % 2) { const int8_t *packed_rhs_ptr = &packed_rhs[0]; int32_t res00 = 0; int32_t res01 = 0; int32_t spillover00 = 0; int32_t spillover01 = 0; if (bias) { res00 = bias[rhs_rows_idx]; spillover00 = bias[rhs_rows_idx + 1]; res01 = bias[rhs_rows_idx + 2]; spillover01 = bias[rhs_rows_idx + 3]; } int32_t rhs_cols_idx = 0; int32_t lhs_low, rhs_low0, rhs_high0, lhs_high, rhs_low1, rhs_high1; for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16) { read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 4 x MAC res00, res01 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); res01 = SMLAD(rhs_low1, lhs_low, res01); res01 = SMLAD(rhs_high1, lhs_high, res01); read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 4 x MAC res00, res01 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); res01 = SMLAD(rhs_low1, lhs_low, res01); res01 = SMLAD(rhs_high1, lhs_high, res01); read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 4 x MAC res00, res01 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); res01 = SMLAD(rhs_low1, lhs_low, res01); res01 = SMLAD(rhs_high1, lhs_high, res01); read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 4 x MAC res00, res01 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); res01 = SMLAD(rhs_low1, lhs_low, res01); res01 = SMLAD(rhs_high1, lhs_high, res01); } for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4) { read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 4 x MAC res00, res01 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); res01 = SMLAD(rhs_low1, lhs_low, res01); res01 = SMLAD(rhs_high1, lhs_high, res01); } for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2) { rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4; rhs_high0 = packed_rhs_ptr[0] >> 4; rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4; rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4; lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; lhs_high = (int8_t)lhs_ptr[1] + lhs_offset; res00 += lhs_low * rhs_low0; res00 += lhs_high * rhs_high0; res01 += lhs_low * rhs_low1; res01 += lhs_high * rhs_high1; ++packed_rhs_ptr; lhs_ptr += 2; } if (rhs_cols % 2) { rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4; rhs_high0 = packed_rhs_ptr[0] >> 4; rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4; rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4; lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; lhs_ptr -= rhs_cols - 1; lhs_high = (int8_t)lhs_ptr[0] + lhs_offset; res00 += lhs_low * rhs_low0; res01 += lhs_low * rhs_low1; spillover00 += lhs_high * rhs_high0; spillover01 += lhs_high * rhs_high1; ++packed_rhs_ptr; ++lhs_ptr; } else { lhs_ptr -= rhs_cols; } // Quantize down res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]); res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]); // Add offset res00 += dst_offset; res01 += dst_offset; // Clamp the result res00 = MAX(res00, activation_min); res00 = MIN(res00, activation_max); res01 = MAX(res01, activation_min); res01 = MIN(res01, activation_max); dst_ptr[0] = (int8_t)res00; dst_ptr[2] = (int8_t)res01; res00 = spillover00; res01 = spillover01; rhs_cols_idx = 0; for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16) { read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 4 x MAC res00, res01 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); res01 = SMLAD(rhs_low1, lhs_low, res01); res01 = SMLAD(rhs_high1, lhs_high, res01); read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 4 x MAC res00, res01 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); res01 = SMLAD(rhs_low1, lhs_low, res01); res01 = SMLAD(rhs_high1, lhs_high, res01); read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 4 x MAC res00, res01 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); res01 = SMLAD(rhs_low1, lhs_low, res01); res01 = SMLAD(rhs_high1, lhs_high, res01); read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 4 x MAC res00, res01 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); res01 = SMLAD(rhs_low1, lhs_low, res01); res01 = SMLAD(rhs_high1, lhs_high, res01); } for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4) { read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); read_and_pad_s4((const int8_t *)&packed_rhs_ptr[rhs_cols], &rhs_low1, &rhs_high1); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 4 x MAC res00, res01 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); res01 = SMLAD(rhs_low1, lhs_low, res01); res01 = SMLAD(rhs_high1, lhs_high, res01); } for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2) { rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4; rhs_high0 = packed_rhs_ptr[0] >> 4; rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4; rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4; lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; lhs_high = (int8_t)lhs_ptr[1] + lhs_offset; res00 += lhs_low * rhs_low0; res00 += lhs_high * rhs_high0; res01 += lhs_low * rhs_low1; res01 += lhs_high * rhs_high1; ++packed_rhs_ptr; lhs_ptr += 2; } // Quantize down res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]); res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]); // Add offset res00 += dst_offset; res01 += dst_offset; // Clamp the result res00 = MAX(res00, activation_min); res00 = MIN(res00, activation_max); res01 = MAX(res01, activation_min); res01 = MIN(res01, activation_max); dst_ptr[1] = (int8_t)res00; dst_ptr[3] = (int8_t)res01; } packed_rhs += 2 * rhs_cols; dst += 4; } int8_t rhs_spilled_col = 0; const int32_t rhs_rows_finished = rhs_rows - (rhs_rows % 4); // Left over rhs rows will be in the range 0 -> 3 for (int rhs_rows_idx = 0; rhs_rows_idx < rhs_rows % 4; ++rhs_rows_idx) { const int8_t *lhs_ptr = &lhs[0]; int8_t *dst_ptr = &dst[0]; int32_t lhs_rows_idx = lhs_rows >> 1; while (lhs_rows_idx) { const int8_t *packed_rhs_ptr = &packed_rhs[0]; int32_t res00 = 0; int32_t res10 = 0; if (bias) { res00 = bias[rhs_rows_finished + rhs_rows_idx]; res10 = bias[rhs_rows_finished + rhs_rows_idx]; } // Since there can only be 3 rhs rows here we only need treat rhs_row_idx[1] // differently by dealing with the leftover column from rhs_row_idx[0] if (rhs_cols % 2 && rhs_rows_idx == 1) { int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; res00 += lhs_low * rhs_spilled_col; lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset; res10 += lhs_low * rhs_spilled_col; ++lhs_ptr; } int32_t rhs_cols_idx = 0; int32_t lhs_low, rhs_low0, rhs_high0, lhs_high; for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16) { read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 2 x MAC res00 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); // 2 x MAC res10 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); res10 = SMLAD(rhs_low0, lhs_low, res10); res10 = SMLAD(rhs_high0, lhs_high, res10); read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 2 x MAC res00 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); // 2 x MAC res10 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); res10 = SMLAD(rhs_low0, lhs_low, res10); res10 = SMLAD(rhs_high0, lhs_high, res10); read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 2 x MAC res00 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); // 2 x MAC res10 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); res10 = SMLAD(rhs_low0, lhs_low, res10); res10 = SMLAD(rhs_high0, lhs_high, res10); read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 2 x MAC res00 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); // 2 x MAC res10 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); res10 = SMLAD(rhs_low0, lhs_low, res10); res10 = SMLAD(rhs_high0, lhs_high, res10); } for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4) { read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 2 x MAC res00 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); // 2 x MAC res10 lhs_high = arm_nn_read_s8x4((const int8_t *)&lhs_ptr[lhs_cols_off1]); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); res10 = SMLAD(rhs_low0, lhs_low, res10); res10 = SMLAD(rhs_high0, lhs_high, res10); } for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2) { rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4; rhs_high0 = packed_rhs_ptr[0] >> 4; lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; lhs_high = (int8_t)lhs_ptr[1] + lhs_offset; res00 += lhs_low * rhs_low0; res00 += lhs_high * rhs_high0; lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset; lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset; res10 += lhs_low * rhs_low0; res10 += lhs_high * rhs_high0; ++packed_rhs_ptr; lhs_ptr += 2; } if (rhs_cols % 2 && !(rhs_rows_idx % 2)) { rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4; lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset; res00 += lhs_low * rhs_low0; res10 += lhs_high * rhs_low0; ++lhs_ptr; } lhs_ptr -= rhs_cols; lhs_ptr += 2 * lhs_cols_offset; // Quantize down res00 = arm_nn_requantize( res00, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]); res10 = arm_nn_requantize( res10, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]); // Add offset res00 += dst_offset; res10 += dst_offset; // Clamp the result res00 = MAX(res00, activation_min); res00 = MIN(res00, activation_max); res10 = MAX(res10, activation_min); res10 = MIN(res10, activation_max); dst_ptr[0] = (int8_t)res00; dst_ptr += rhs_rows; dst_ptr[0] = (int8_t)res10; dst_ptr += rhs_rows; lhs_rows_idx--; } if (lhs_rows % 2) { const int8_t *packed_rhs_ptr = &packed_rhs[0]; int32_t res00 = 0; if (bias) { res00 = bias[rhs_rows_finished + rhs_rows_idx]; } // Since there can only be 3 rhs rows here we only need treat rhs_row_idx[1] // differently by dealing with the leftover column from rhs_row_idx[0] if (rhs_cols % 2 && rhs_rows_idx == 1) { int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; res00 += lhs_low * rhs_spilled_col; ++lhs_ptr; } int32_t rhs_cols_idx = 0; int32_t lhs_low, rhs_low0, rhs_high0, lhs_high; for (; rhs_cols_idx <= (rhs_cols - 16); rhs_cols_idx += 16) { read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 2 x MAC res00 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 2 x MAC res00 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 2 x MAC res00 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 2 x MAC res00 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); } for (; rhs_cols_idx <= (rhs_cols - 4); rhs_cols_idx += 4) { read_and_pad_s4(packed_rhs_ptr, &rhs_low0, &rhs_high0); packed_rhs_ptr += 2; lhs_high = arm_nn_read_s8x4_ia((const int8_t **)&lhs_ptr); lhs_low = SXTAB16(ui32_lhs_offset_i16x2, lhs_high); lhs_high = SXTAB16_RORn(ui32_lhs_offset_i16x2, lhs_high, 8); // 2 x MAC res00 res00 = SMLAD(rhs_low0, lhs_low, res00); res00 = SMLAD(rhs_high0, lhs_high, res00); } for (; rhs_cols_idx <= rhs_cols - 2; rhs_cols_idx += 2) { rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4; rhs_high0 = packed_rhs_ptr[0] >> 4; lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; lhs_high = (int8_t)lhs_ptr[1] + lhs_offset; res00 += lhs_low * rhs_low0; res00 += lhs_high * rhs_high0; ++packed_rhs_ptr; lhs_ptr += 2; } if (rhs_cols % 2 && !(rhs_rows_idx % 2)) { rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4; lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset; res00 += lhs_low * rhs_low0; ++lhs_ptr; } // Quantize down res00 = arm_nn_requantize( res00, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]); // Add offset res00 += dst_offset; // Clamp the result res00 = MAX(res00, activation_min); res00 = MIN(res00, activation_max); dst_ptr[0] = (int8_t)res00; } if (rhs_cols % 2 && !(rhs_rows_idx % 2)) { rhs_spilled_col = packed_rhs[rhs_cols_int4] >> 4; packed_rhs += rhs_cols_int4 + 1; } else { rhs_spilled_col = 0; packed_rhs += rhs_cols_int4; } ++dst; } #else const int32_t rhs_cols_int4 = rhs_cols >> 1; for (int32_t rhs_rows_idx = 0; rhs_rows_idx <= (rhs_rows - 4); rhs_rows_idx += 4) { const int8_t *lhs_ptr = &lhs[0]; int8_t *dst_ptr = &dst[0]; for (int32_t lhs_rows_idx = (lhs_rows >> 1); lhs_rows_idx > 0; --lhs_rows_idx) { // To avoid the issue of packed values leaking into the next rhs row // we instead evaluate the rhs rows in pairs like so: // rhs[0] and rhs[2], rhs[1] and rhs[3] // Start processing rhs_row[0] and rhs_row[2] const int8_t *packed_rhs_ptr = &packed_rhs[0]; int32_t res00 = 0; int32_t res01 = 0; int32_t res10 = 0; int32_t res11 = 0; int32_t spillover00 = 0; int32_t spillover01 = 0; int32_t spillover10 = 0; int32_t spillover11 = 0; if (bias) { res00 = bias[rhs_rows_idx]; res01 = bias[rhs_rows_idx + 2]; res10 = bias[rhs_rows_idx]; res11 = bias[rhs_rows_idx + 2]; } for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx) { int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4; int8_t rhs_high0 = packed_rhs_ptr[0] >> 4; int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4; int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4; int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset; res00 += lhs_low * rhs_low0; res00 += lhs_high * rhs_high0; res01 += lhs_low * rhs_low1; res01 += lhs_high * rhs_high1; lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset; lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset; res10 += lhs_low * rhs_low0; res10 += lhs_high * rhs_high0; res11 += lhs_low * rhs_low1; res11 += lhs_high * rhs_high1; ++packed_rhs_ptr; lhs_ptr += 2; } if (rhs_cols % 2) { int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4; int8_t rhs_high0 = packed_rhs_ptr[0] >> 4; int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4; int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4; int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; int32_t lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset; res00 += lhs_low * rhs_low0; res01 += lhs_low * rhs_low1; res10 += lhs_high * rhs_low0; res11 += lhs_high * rhs_low1; lhs_ptr -= rhs_cols - 1; lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; lhs_high = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset; spillover00 += lhs_low * rhs_high0; spillover01 += lhs_low * rhs_high1; spillover10 += lhs_high * rhs_high0; spillover11 += lhs_high * rhs_high1; ++packed_rhs_ptr; ++lhs_ptr; } else { lhs_ptr -= rhs_cols; } // Quantize down res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]); res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]); res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]); res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]); // Add offset res00 += dst_offset; res01 += dst_offset; res10 += dst_offset; res11 += dst_offset; // Clamp the result res00 = MAX(res00, activation_min); res00 = MIN(res00, activation_max); res01 = MAX(res01, activation_min); res01 = MIN(res01, activation_max); res10 = MAX(res10, activation_min); res10 = MIN(res10, activation_max); res11 = MAX(res11, activation_min); res11 = MIN(res11, activation_max); dst_ptr[0] = (int8_t)res00; dst_ptr[2] = (int8_t)res01; dst_ptr += rhs_rows; dst_ptr[0] = (int8_t)res10; dst_ptr[2] = (int8_t)res11; dst_ptr -= rhs_rows; // Start processing rhs_row[1] and rhs_row[3] res00 = spillover00; res01 = spillover01; res10 = spillover10; res11 = spillover11; if (bias) { res00 += bias[rhs_rows_idx + 1]; res01 += bias[rhs_rows_idx + 3]; res10 += bias[rhs_rows_idx + 1]; res11 += bias[rhs_rows_idx + 3]; } for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx) { int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4; int8_t rhs_high0 = packed_rhs_ptr[0] >> 4; int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4; int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4; int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset; res00 += lhs_low * rhs_low0; res00 += lhs_high * rhs_high0; res01 += lhs_low * rhs_low1; res01 += lhs_high * rhs_high1; lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset; lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset; res10 += lhs_low * rhs_low0; res10 += lhs_high * rhs_high0; res11 += lhs_low * rhs_low1; res11 += lhs_high * rhs_high1; ++packed_rhs_ptr; lhs_ptr += 2; } // Quantize down res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]); res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]); res10 = arm_nn_requantize(res10, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]); res11 = arm_nn_requantize(res11, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]); // Add offset res00 += dst_offset; res01 += dst_offset; res10 += dst_offset; res11 += dst_offset; // Clamp the result res00 = MAX(res00, activation_min); res00 = MIN(res00, activation_max); res01 = MAX(res01, activation_min); res01 = MIN(res01, activation_max); res10 = MAX(res10, activation_min); res10 = MIN(res10, activation_max); res11 = MAX(res11, activation_min); res11 = MIN(res11, activation_max); dst_ptr[1] = (int8_t)res00; dst_ptr[3] = (int8_t)res01; dst_ptr += rhs_rows; dst_ptr[1] = (int8_t)res10; dst_ptr[3] = (int8_t)res11; dst_ptr += rhs_rows; lhs_ptr -= rhs_cols; lhs_ptr += 2 * lhs_cols_offset; } // Left-over row if (lhs_rows % 2) { const int8_t *packed_rhs_ptr = &packed_rhs[0]; int32_t res00 = 0; int32_t res01 = 0; int32_t spillover00 = 0; int32_t spillover01 = 0; if (bias) { res00 += bias[rhs_rows_idx]; res01 += bias[rhs_rows_idx + 2]; } for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx) { int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4; int8_t rhs_high0 = packed_rhs_ptr[0] >> 4; int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4; int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4; int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset; res00 += lhs_low * rhs_low0; res00 += lhs_high * rhs_high0; res01 += lhs_low * rhs_low1; res01 += lhs_high * rhs_high1; ++packed_rhs_ptr; lhs_ptr += 2; } if (rhs_cols % 2) { int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4; int8_t rhs_high0 = packed_rhs_ptr[0] >> 4; int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4; int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4; int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; lhs_ptr -= rhs_cols - 1; int32_t lhs_high = (int8_t)lhs_ptr[0] + lhs_offset; res00 += lhs_low * rhs_low0; res01 += lhs_low * rhs_low1; spillover00 = lhs_high * rhs_high0; spillover01 = lhs_high * rhs_high1; ++packed_rhs_ptr; ++lhs_ptr; } else { lhs_ptr -= rhs_cols; } // Quantize down res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx], dst_shifts[rhs_rows_idx]); res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 2], dst_shifts[rhs_rows_idx + 2]); // Add offset res00 += dst_offset; res01 += dst_offset; // Clamp the result res00 = MAX(res00, activation_min); res00 = MIN(res00, activation_max); res01 = MAX(res01, activation_min); res01 = MIN(res01, activation_max); dst_ptr[0] = (int8_t)res00; dst_ptr[2] = (int8_t)res01; res00 = spillover00; res01 = spillover01; if (bias) { res00 += bias[rhs_rows_idx + 1]; res01 += bias[rhs_rows_idx + 3]; } for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx) { int8_t rhs_low0 = (int8_t)(packed_rhs_ptr[0] << 4) >> 4; int8_t rhs_high0 = packed_rhs_ptr[0] >> 4; int8_t rhs_low1 = (int8_t)(packed_rhs_ptr[rhs_cols] << 4) >> 4; int8_t rhs_high1 = packed_rhs_ptr[rhs_cols] >> 4; int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset; res00 += lhs_low * rhs_low0; res00 += lhs_high * rhs_high0; res01 += lhs_low * rhs_low1; res01 += lhs_high * rhs_high1; ++packed_rhs_ptr; lhs_ptr += 2; } // Quantize down res00 = arm_nn_requantize(res00, dst_multipliers[rhs_rows_idx + 1], dst_shifts[rhs_rows_idx + 1]); res01 = arm_nn_requantize(res01, dst_multipliers[rhs_rows_idx + 3], dst_shifts[rhs_rows_idx + 3]); // Add offset res00 += dst_offset; res01 += dst_offset; // Clamp the result res00 = MAX(res00, activation_min); res00 = MIN(res00, activation_max); res01 = MAX(res01, activation_min); res01 = MIN(res01, activation_max); dst_ptr[1] = (int8_t)res00; dst_ptr[3] = (int8_t)res01; } packed_rhs += 2 * rhs_cols; dst += 4; } int32_t spillover00 = 0; const int32_t rhs_rows_finished = rhs_rows - (rhs_rows % 4); // Left over rhs rows will be in the range 0 -> 3 for (int rhs_rows_idx = 0; rhs_rows_idx < rhs_rows % 4; ++rhs_rows_idx) { const int8_t *lhs_ptr = &lhs[0]; int8_t *dst_ptr = &dst[0]; for (int32_t lhs_rows_idx = (lhs_rows >> 1); lhs_rows_idx > 0; --lhs_rows_idx) { const int8_t *packed_rhs_ptr = &packed_rhs[0]; int32_t res00 = 0; int32_t res10 = 0; if (bias) { res00 = bias[rhs_rows_finished + rhs_rows_idx]; res10 = bias[rhs_rows_finished + rhs_rows_idx]; } // Since there can only be 3 rhs rows here we only need treat rhs_row_idx[1] // differently by dealing with the leftover column from rhs_row_idx[0] if (rhs_cols % 2 && rhs_rows_idx == 1) { int32_t lhs_low = lhs_ptr[0] + lhs_offset; res00 += lhs_low * spillover00; lhs_low = lhs_ptr[lhs_cols_offset] + lhs_offset; res10 += lhs_low * spillover00; ++lhs_ptr; } for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx) { int8_t rhs_low = (int8_t)(packed_rhs_ptr[0] << 4) >> 4; int8_t rhs_high = packed_rhs_ptr[0] >> 4; int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset; res00 += lhs_low * rhs_low; res00 += lhs_high * rhs_high; lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset; lhs_high = (int8_t)lhs_ptr[lhs_cols_offset + 1] + lhs_offset; res10 += lhs_low * rhs_low; res10 += lhs_high * rhs_high; ++packed_rhs_ptr; lhs_ptr += 2; } if (rhs_cols % 2 && !(rhs_rows_idx % 2)) { int8_t rhs_low = (int8_t)(packed_rhs_ptr[0] << 4) >> 4; int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; res00 += lhs_low * rhs_low; lhs_low = (int8_t)lhs_ptr[lhs_cols_offset] + lhs_offset; res10 += lhs_low * rhs_low; ++lhs_ptr; } lhs_ptr -= rhs_cols; lhs_ptr += 2 * lhs_cols_offset; // Quantize down res00 = arm_nn_requantize( res00, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]); res10 = arm_nn_requantize( res10, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]); // Add offset res00 += dst_offset; res10 += dst_offset; // Clamp the result res00 = MAX(res00, activation_min); res00 = MIN(res00, activation_max); res10 = MAX(res10, activation_min); res10 = MIN(res10, activation_max); dst_ptr[0] = (int8_t)res00; dst_ptr += rhs_rows; dst_ptr[0] = (int8_t)res10; dst_ptr += rhs_rows; } if (lhs_rows % 2) { const int8_t *packed_rhs_ptr = &packed_rhs[0]; int32_t res00 = 0; if (bias) { res00 = bias[rhs_rows_finished + rhs_rows_idx]; } // Since there can only be 3 rhs rows here we only need treat rhs_row_idx[1] // differently by dealing with the leftover column from rhs_row_idx[0] if (rhs_cols % 2 && rhs_rows_idx == 1) { int32_t lhs_low = lhs_ptr[0] + lhs_offset; res00 += lhs_low * spillover00; ++lhs_ptr; } for (int32_t rhs_cols_idx = rhs_cols_int4; rhs_cols_idx != 0; --rhs_cols_idx) { int8_t rhs_low = (int8_t)(packed_rhs_ptr[0] << 4) >> 4; int8_t rhs_high = packed_rhs_ptr[0] >> 4; int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; int32_t lhs_high = (int8_t)lhs_ptr[1] + lhs_offset; res00 += lhs_low * rhs_low; res00 += lhs_high * rhs_high; ++packed_rhs_ptr; lhs_ptr += 2; } if (rhs_cols % 2 && (rhs_rows_idx != 1)) { int8_t rhs_low = (int8_t)(packed_rhs_ptr[0] << 4) >> 4; int32_t lhs_low = (int8_t)lhs_ptr[0] + lhs_offset; res00 += lhs_low * rhs_low; ++lhs_ptr; } // Quantize down res00 = arm_nn_requantize( res00, dst_multipliers[rhs_rows_finished + rhs_rows_idx], dst_shifts[rhs_rows_finished + rhs_rows_idx]); // Add offset res00 += dst_offset; // Clamp the result res00 = MAX(res00, activation_min); res00 = MIN(res00, activation_max); dst_ptr[0] = (int8_t)res00; } if (rhs_cols % 2 && !(rhs_rows_idx % 2)) { spillover00 = packed_rhs[rhs_cols_int4] >> 4; packed_rhs += rhs_cols_int4 + (rhs_cols & 0x1); } else { spillover00 = 0; packed_rhs += rhs_cols_int4; } ++dst; } #endif return ARM_CMSIS_NN_SUCCESS; } /** * @} end of Doxygen group */