1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_solve_upper_triangular_f16.c
4  * Description:  Solve linear system UT X = A with UT upper triangular matrix
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 
30 #include "dsp/matrix_functions_f16.h"
31 
32 #if defined(ARM_FLOAT16_SUPPORTED)
33 
34 
35 /**
36   @ingroup groupMatrix
37  */
38 
39 
40 /**
41   @addtogroup MatrixInv
42   @{
43  */
44 
45 /**
46    * @brief Solve UT . X = A where UT is an upper triangular matrix
47    * @param[in]  ut  The upper triangular matrix
48    * @param[in]  a  The matrix a
49    * @param[out] dst The solution X of UT . X = A
50    * @return The function returns ARM_MATH_SINGULAR, if the system can't be solved.
51   */
52 
53 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
54 
55 #include "arm_helium_utils.h"
56 
arm_mat_solve_upper_triangular_f16(const arm_matrix_instance_f16 * ut,const arm_matrix_instance_f16 * a,arm_matrix_instance_f16 * dst)57   arm_status arm_mat_solve_upper_triangular_f16(
58   const arm_matrix_instance_f16 * ut,
59   const arm_matrix_instance_f16 * a,
60   arm_matrix_instance_f16 * dst)
61   {
62 arm_status status;                             /* status of matrix inverse */
63 
64 
65 #ifdef ARM_MATH_MATRIX_CHECK
66 
67   /* Check for matrix mismatch condition */
68   if ((ut->numRows != ut->numCols) ||
69       (ut->numRows != a->numRows)   )
70   {
71     /* Set status as ARM_MATH_SIZE_MISMATCH */
72     status = ARM_MATH_SIZE_MISMATCH;
73   }
74   else
75 
76 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
77 
78   {
79 
80     int i,j,k,n,cols;
81 
82     n = dst->numRows;
83     cols = dst->numCols;
84 
85     float16_t *pX = dst->pData;
86     float16_t *pUT = ut->pData;
87     float16_t *pA = a->pData;
88 
89     float16_t *ut_row;
90     float16_t *a_col;
91 
92     _Float16 invUT;
93 
94     f16x8_t vecA;
95     f16x8_t vecX;
96 
97     for(i=n-1; i >= 0 ; i--)
98     {
99       for(j=0; j+7 < cols; j +=8)
100       {
101             vecA = vld1q_f16(&pA[i * cols + j]);
102 
103             for(k=n-1; k > i; k--)
104             {
105                 vecX = vld1q_f16(&pX[cols*k+j]);
106                 vecA = vfmsq(vecA,vdupq_n_f16(pUT[n*i + k]),vecX);
107             }
108 
109             if ((_Float16)pUT[n*i + i]==0.0f16)
110             {
111               return(ARM_MATH_SINGULAR);
112             }
113 
114             invUT = 1.0f16 / (_Float16)pUT[n*i + i];
115             vecA = vmulq(vecA,vdupq_n_f16(invUT));
116 
117 
118             vst1q(&pX[i*cols+j],vecA);
119       }
120 
121       for(; j < cols; j ++)
122       {
123             a_col = &pA[j];
124 
125             ut_row = &pUT[n*i];
126 
127             _Float16 tmp=a_col[i * cols];
128 
129             for(k=n-1; k > i; k--)
130             {
131                 tmp -= (_Float16)ut_row[k] * (_Float16)pX[cols*k+j];
132             }
133 
134             if ((_Float16)ut_row[i]==0.0f16)
135             {
136               return(ARM_MATH_SINGULAR);
137             }
138             tmp = tmp / (_Float16)ut_row[i];
139             pX[i*cols+j] = tmp;
140        }
141 
142     }
143     status = ARM_MATH_SUCCESS;
144 
145   }
146 
147 
148   /* Return to application */
149   return (status);
150 }
151 
152 #else
arm_mat_solve_upper_triangular_f16(const arm_matrix_instance_f16 * ut,const arm_matrix_instance_f16 * a,arm_matrix_instance_f16 * dst)153   arm_status arm_mat_solve_upper_triangular_f16(
154   const arm_matrix_instance_f16 * ut,
155   const arm_matrix_instance_f16 * a,
156   arm_matrix_instance_f16 * dst)
157   {
158 arm_status status;                             /* status of matrix inverse */
159 
160 
161 #ifdef ARM_MATH_MATRIX_CHECK
162 
163   /* Check for matrix mismatch condition */
164   if ((ut->numRows != ut->numCols) ||
165       (ut->numRows != a->numRows)   )
166   {
167     /* Set status as ARM_MATH_SIZE_MISMATCH */
168     status = ARM_MATH_SIZE_MISMATCH;
169   }
170   else
171 
172 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
173 
174   {
175 
176     int i,j,k,n,cols;
177 
178     n = dst->numRows;
179     cols = dst->numCols;
180 
181     float16_t *pX = dst->pData;
182     float16_t *pUT = ut->pData;
183     float16_t *pA = a->pData;
184 
185     float16_t *ut_row;
186     float16_t *a_col;
187 
188     for(j=0; j < cols; j ++)
189     {
190        a_col = &pA[j];
191 
192        for(i=n-1; i >= 0 ; i--)
193        {
194             ut_row = &pUT[n*i];
195 
196             float16_t tmp=a_col[i * cols];
197 
198             for(k=n-1; k > i; k--)
199             {
200                 tmp -= (_Float16)ut_row[k] * (_Float16)pX[cols*k+j];
201             }
202 
203             if ((_Float16)ut_row[i]==0.0f16)
204             {
205               return(ARM_MATH_SINGULAR);
206             }
207             tmp = (_Float16)tmp / (_Float16)ut_row[i];
208             pX[i*cols+j] = tmp;
209        }
210 
211     }
212     status = ARM_MATH_SUCCESS;
213 
214   }
215 
216 
217   /* Return to application */
218   return (status);
219 }
220 
221 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
222 
223 /**
224   @} end of MatrixInv group
225  */
226 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */
227