/*
 * SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
 *
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the License); you may
 * not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/* ----------------------------------------------------------------------
 * Project:      CMSIS NN Library
 * Title:        arm_nn_vec_mat_mul_result_acc_s16
 * Description:  s16 vector by matrix (transposed) multiplication
 *
 * $Date:        26 March 2023
 * $Revision:    V.1.0.0
 *
 * Target :  Arm(R) M-Profile Architecture
 *
 * -------------------------------------------------------------------- */

#include "arm_nnsupportfunctions.h"

/**
 * @ingroup groupSupport
 */

/**
 * @addtogroup supportFC
 * @{
 */

/*
 * s16 vector(lhs) by matrix (transposed) multiplication with result accumulation
 *
 * Refer header file for details.
 *
 */
arm_cmsis_nn_status arm_nn_vec_mat_mul_result_acc_s16(const int16_t *lhs,
                                                      const int8_t *rhs,
                                                      const int64_t *effective_bias,
                                                      int16_t *dst,
                                                      const int32_t dst_multiplier,
                                                      const int32_t dst_shift,
                                                      const int32_t rhs_cols,
                                                      const int32_t rhs_rows,
                                                      const int32_t batches,
                                                      const int32_t batch_offset)
{

    int32_t reduced_multiplier = REDUCE_MULTIPLIER(dst_multiplier);

    for (int batch = 0; batch < batches; batch++)
    {

        const int8_t *rhs_ptr = &rhs[0];
        const int64_t *effective_bias_ptr = &effective_bias[0];

#if defined(ARM_MATH_DSP)

        int32_t rhs_cols_fast = rhs_cols;

        if (rhs_cols > MAX_COL_COUNT)
        {
            rhs_cols_fast = MAX_COL_COUNT;
        }

    #if defined(ARM_MATH_MVEI)
        int32_t row_loop_cnt = rhs_rows / 4;
        const int32_t col_loop_cnt = (rhs_cols_fast + 7) / 8;

        for (int32_t i_row_loop_count = 0; i_row_loop_count < row_loop_cnt; i_row_loop_count++)
        {
            int32_t col_cnt = rhs_cols_fast;

            const int16_t *lhs_ptr = lhs;
            const int8_t *rhs_ptr_0 = rhs_ptr;
            const int8_t *rhs_ptr_1 = rhs_ptr + rhs_cols;
            const int8_t *rhs_ptr_2 = rhs_ptr + rhs_cols * 2;
            const int8_t *rhs_ptr_3 = rhs_ptr + rhs_cols * 3;

            int32_t result_0 = *effective_bias_ptr++;
            int32_t result_1 = *effective_bias_ptr++;
            int32_t result_2 = *effective_bias_ptr++;
            int32_t result_3 = *effective_bias_ptr++;

            for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
            {
                mve_pred16_t pred = vctp16q(col_cnt);
                col_cnt -= 8;

                int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);

                int16x8_t rhs_input_0 = vldrbq_z_s16(rhs_ptr_0, pred);
                int16x8_t rhs_input_1 = vldrbq_z_s16(rhs_ptr_1, pred);
                int16x8_t rhs_input_2 = vldrbq_z_s16(rhs_ptr_2, pred);
                int16x8_t rhs_input_3 = vldrbq_z_s16(rhs_ptr_3, pred);

                result_0 = vmladavaq_s16(result_0, lhs_input, rhs_input_0);
                result_1 = vmladavaq_s16(result_1, lhs_input, rhs_input_1);
                result_2 = vmladavaq_s16(result_2, lhs_input, rhs_input_2);
                result_3 = vmladavaq_s16(result_3, lhs_input, rhs_input_3);

                lhs_ptr += 8;

                rhs_ptr_0 += 8;
                rhs_ptr_1 += 8;
                rhs_ptr_2 += 8;
                rhs_ptr_3 += 8;
            }

            int64_t result_64_0 = result_0;
            int64_t result_64_1 = result_1;
            int64_t result_64_2 = result_2;
            int64_t result_64_3 = result_3;

            if (rhs_cols > MAX_COL_COUNT)
            {
                for (int i_rhs_cols = MAX_COL_COUNT; i_rhs_cols < rhs_cols; i_rhs_cols++)
                {
                    const int16_t lhs_temp = *lhs_ptr++;

                    result_64_0 += *rhs_ptr_0++ * lhs_temp;
                    result_64_1 += *rhs_ptr_1++ * lhs_temp;
                    result_64_2 += *rhs_ptr_2++ * lhs_temp;
                    result_64_3 += *rhs_ptr_3++ * lhs_temp;
                }
            }

            int32_t tmp;
            tmp = arm_nn_requantize_s64(result_64_0, reduced_multiplier, dst_shift);
            tmp += (int64_t)*dst;
            tmp = MAX(tmp, NN_Q15_MIN);
            tmp = MIN(tmp, NN_Q15_MAX);
            *dst++ = (int16_t)tmp;

            tmp = 0;
            tmp = arm_nn_requantize_s64(result_64_1, reduced_multiplier, dst_shift);
            tmp += (int64_t)*dst;
            tmp = MAX(tmp, NN_Q15_MIN);
            tmp = MIN(tmp, NN_Q15_MAX);
            *dst++ = (int16_t)tmp;

            tmp = 0;
            tmp = arm_nn_requantize_s64(result_64_2, reduced_multiplier, dst_shift);
            tmp += (int64_t)*dst;
            tmp = MAX(tmp, NN_Q15_MIN);
            tmp = MIN(tmp, NN_Q15_MAX);
            *dst++ = (int16_t)tmp;

            tmp = 0;
            tmp = arm_nn_requantize_s64(result_64_3, reduced_multiplier, dst_shift);
            tmp += (int64_t)*dst;
            tmp = MAX(tmp, NN_Q15_MIN);
            tmp = MIN(tmp, NN_Q15_MAX);
            *dst++ = (int16_t)tmp;

            rhs_ptr += 4 * rhs_cols;
        }

        for (int8_t rows_left = rhs_rows & 0x3; rows_left > 0; rows_left--)
        {
            int32_t result = *effective_bias_ptr++;

            const int16_t *lhs_ptr = lhs;
            const int8_t *rhs_ptr0 = rhs_ptr;

            int32_t col_cnt = (int32_t)rhs_cols_fast;

            for (int i_col_loop_cnt = 0; i_col_loop_cnt < col_loop_cnt; i_col_loop_cnt++)
            {
                mve_pred16_t pred = vctp16q(col_cnt);
                col_cnt -= 8;

                int16x8_t lhs_input = vldrhq_z_s16(lhs_ptr, pred);
                int16x8_t rhs_input = vldrbq_z_s16(rhs_ptr0, pred);

                result = vmladavaq_p_s16(result, lhs_input, rhs_input, pred);

                lhs_ptr += 8;
                rhs_ptr0 += 8;
            }

            int64_t result_64 = result;

            if (rhs_cols > MAX_COL_COUNT)
            {
                for (int i_rhs_cols = MAX_COL_COUNT; i_rhs_cols < rhs_cols; i_rhs_cols++)
                {
                    const int16_t lhs_temp = *lhs_ptr++;

                    result_64 += *rhs_ptr0++ * lhs_temp;
                }
            }

            int32_t tmp = 0;
            tmp = arm_nn_requantize_s64(result_64, reduced_multiplier, dst_shift);
            tmp += (int64_t)*dst;
            tmp = MAX(tmp, NN_Q15_MIN);
            tmp = MIN(tmp, NN_Q15_MAX);
            *dst++ = (int16_t)tmp;

            rhs_ptr += rhs_cols;
        }

    #else // ARM_MATH_MVEI

        const int32_t row_loop_cnt = rhs_rows / 2;

        for (int32_t i = 0; i < row_loop_cnt; i++)
        {

            int64_t acc_64_0 = 0;
            int64_t acc_64_1 = 0;
            int32_t acc_0 = 0;
            int32_t acc_1 = 0;

            const int32_t col_loop_cnt = rhs_cols_fast / 4;

            const int16_t *lhs_vec = lhs;
            const int8_t *rhs_0 = rhs_ptr;
            rhs_ptr += rhs_cols;
            const int8_t *rhs_1 = rhs_ptr;
            rhs_ptr += rhs_cols;

            for (int j = col_loop_cnt; j != 0; j--)
            {
                int32_t ker_0, ker_1, vec_part_0, vec_part_1;

                vec_part_0 = arm_nn_read_q15x2_ia(&lhs_vec);
                vec_part_1 = arm_nn_read_q15x2_ia(&lhs_vec);

                rhs_0 = read_and_pad(rhs_0, &ker_0, &ker_1);

                acc_0 = SMLAD(ker_0, vec_part_0, acc_0);
                acc_0 = SMLAD(ker_1, vec_part_1, acc_0);

                rhs_1 = read_and_pad(rhs_1, &ker_0, &ker_1);

                acc_1 = SMLAD(ker_0, vec_part_0, acc_1);
                acc_1 = SMLAD(ker_1, vec_part_1, acc_1);
            }

            acc_64_0 += acc_0;
            acc_64_1 += acc_1;

            for (int k = col_loop_cnt * 4; k < rhs_cols; k++)
            {
                const int32_t lhs_temp = (*lhs_vec);
                lhs_vec++;
                acc_64_0 += lhs_temp * (*rhs_0);
                rhs_0++;
                acc_64_1 += lhs_temp * (*rhs_1);
                rhs_1++;
            }

            acc_64_0 += *effective_bias_ptr++;
            acc_64_1 += *effective_bias_ptr++;
            int32_t tmp;

            tmp = arm_nn_requantize_s64(acc_64_0, reduced_multiplier, dst_shift);
            tmp += (int64_t)*dst;
            tmp = MAX(tmp, NN_Q15_MIN);
            tmp = MIN(tmp, NN_Q15_MAX);
            *dst++ = (int16_t)tmp;

            tmp = arm_nn_requantize_s64(acc_64_1, reduced_multiplier, dst_shift);
            tmp += (int64_t)*dst;
            tmp = MAX(tmp, NN_Q15_MIN);
            tmp = MIN(tmp, NN_Q15_MAX);
            *dst++ = (int16_t)tmp;
        }

        if (rhs_rows & 0x1)
        {
            int64_t acc_64_0 = 0;
            int32_t acc_0 = 0;
            const int32_t col_loop_cnt = rhs_cols_fast / 4;

            const int16_t *lhs_vec = lhs;
            const int8_t *rhs_0 = rhs_ptr;

            for (int i = col_loop_cnt; i != 0; i--)
            {
                int32_t ker_0, ker_1, vec;
                rhs_0 = read_and_pad(rhs_0, &ker_0, &ker_1);

                vec = arm_nn_read_q15x2_ia(&lhs_vec);
                acc_0 = SMLAD(ker_0, vec, acc_0);

                vec = arm_nn_read_q15x2_ia(&lhs_vec);
                acc_0 = SMLAD(ker_1, vec, acc_0);
            }

            acc_64_0 += acc_0;

            for (int j = col_loop_cnt * 4; j < rhs_cols; j++)
            {
                const int32_t lhs_temp = (*lhs_vec);
                lhs_vec++;
                acc_64_0 += lhs_temp * (*rhs_0);
                rhs_0++;
            }

            acc_64_0 += *effective_bias_ptr++;

            int32_t tmp;
            tmp = arm_nn_requantize_s64(acc_64_0, reduced_multiplier, dst_shift);
            tmp += (int64_t)*dst;
            tmp = MAX(tmp, NN_Q15_MIN);
            tmp = MIN(tmp, NN_Q15_MAX);
            *dst++ = (int16_t)tmp;
        }

    #endif // ARM_MATH_MVEI
#else      // ARM_MATH_DSP
        for (int i_row_loop_cnt = 0; i_row_loop_cnt < rhs_rows; i_row_loop_cnt++)
        {
            const int16_t *lhs_ptr = lhs;

            int64_t result = *effective_bias_ptr++;

            for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
            {
                const int64_t rhs_value0 = (int8_t)*rhs_ptr;
                const int64_t lhs_value = *lhs_ptr;

                result += lhs_value * rhs_value0;
                ++rhs_ptr;
                ++lhs_ptr;
            }

            // Quantize down
            result = arm_nn_requantize_s64(result, reduced_multiplier, dst_shift);
            result += (int64_t)*dst;

            // Clamp the result
            result = ((result) > (NN_Q15_MIN) ? (result) : (NN_Q15_MIN));
            result = ((result) < (NN_Q15_MAX) ? (result) : (NN_Q15_MAX));

            *dst++ = (int16_t)result;
        }
#endif     // ARM_MATH_DSP

        lhs += rhs_cols * batch_offset;
    }

    return ARM_CMSIS_NN_SUCCESS;
}

/**
 * @} end of Doxygen group
 */