1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_solve_lower_triangular_f16.c
4  * Description:  Solve linear system LT X = A with LT lower triangular matrix
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/matrix_functions_f16.h"
30 
31 #if defined(ARM_FLOAT16_SUPPORTED)
32 /**
33   @ingroup groupMatrix
34  */
35 
36 
37 /**
38   @addtogroup MatrixInv
39   @{
40  */
41 
42 
43    /**
44    * @brief Solve LT . X = A where LT is a lower triangular matrix
45    * @param[in]  lt  The lower triangular matrix
46    * @param[in]  a  The matrix a
47    * @param[out] dst The solution X of LT . X = A
48    * @return The function returns ARM_MATH_SINGULAR, if the system can't be solved.
49    */
50 
51 #if defined(ARM_MATH_MVE_FLOAT16) && !defined(ARM_MATH_AUTOVECTORIZE)
52 
53 #include "arm_helium_utils.h"
54 
arm_mat_solve_lower_triangular_f16(const arm_matrix_instance_f16 * lt,const arm_matrix_instance_f16 * a,arm_matrix_instance_f16 * dst)55   arm_status arm_mat_solve_lower_triangular_f16(
56   const arm_matrix_instance_f16 * lt,
57   const arm_matrix_instance_f16 * a,
58   arm_matrix_instance_f16 * dst)
59   {
60   arm_status status;                             /* status of matrix inverse */
61 
62 
63 #ifdef ARM_MATH_MATRIX_CHECK
64 
65   /* Check for matrix mismatch condition */
66   if ((lt->numRows != lt->numCols) ||
67       (a->numRows != a->numCols) ||
68       (lt->numRows != a->numRows)   )
69   {
70     /* Set status as ARM_MATH_SIZE_MISMATCH */
71     status = ARM_MATH_SIZE_MISMATCH;
72   }
73   else
74 
75 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
76 
77   {
78     /* a1 b1 c1   x1 = a1
79           b2 c2   x2   a2
80              c3   x3   a3
81 
82     x3 = a3 / c3
83     x2 = (a2 - c2 x3) / b2
84 
85     */
86     int i,j,k,n;
87 
88     n = dst->numRows;
89 
90     float16_t *pX = dst->pData;
91     float16_t *pLT = lt->pData;
92     float16_t *pA = a->pData;
93 
94     float16_t *lt_row;
95     float16_t *a_col;
96 
97     _Float16 invLT;
98 
99     f16x8_t vecA;
100     f16x8_t vecX;
101 
102     for(i=0; i < n ; i++)
103     {
104 
105       for(j=0; j+7 < n; j += 8)
106       {
107             vecA = vld1q_f16(&pA[i * n + j]);
108 
109             for(k=0; k < i; k++)
110             {
111                 vecX = vld1q_f16(&pX[n*k+j]);
112                 vecA = vfmsq(vecA,vdupq_n_f16(pLT[n*i + k]),vecX);
113             }
114 
115             if (pLT[n*i + i]==0.0f16)
116             {
117               return(ARM_MATH_SINGULAR);
118             }
119 
120             invLT = 1.0f16 / (_Float16)pLT[n*i + i];
121             vecA = vmulq(vecA,vdupq_n_f16(invLT));
122             vst1q(&pX[i*n+j],vecA);
123 
124        }
125 
126        for(; j < n; j ++)
127        {
128             a_col = &pA[j];
129             lt_row = &pLT[n*i];
130 
131             _Float16 tmp=a_col[i * n];
132 
133             for(k=0; k < i; k++)
134             {
135                 tmp -= (_Float16)lt_row[k] * (_Float16)pX[n*k+j];
136             }
137 
138             if (lt_row[i]==0.0f16)
139             {
140               return(ARM_MATH_SINGULAR);
141             }
142             tmp = tmp / (_Float16)lt_row[i];
143             pX[i*n+j] = tmp;
144         }
145 
146     }
147     status = ARM_MATH_SUCCESS;
148 
149   }
150 
151   /* Return to application */
152   return (status);
153 }
154 
155 #else
arm_mat_solve_lower_triangular_f16(const arm_matrix_instance_f16 * lt,const arm_matrix_instance_f16 * a,arm_matrix_instance_f16 * dst)156   arm_status arm_mat_solve_lower_triangular_f16(
157   const arm_matrix_instance_f16 * lt,
158   const arm_matrix_instance_f16 * a,
159   arm_matrix_instance_f16 * dst)
160   {
161   arm_status status;                             /* status of matrix inverse */
162 
163 
164 #ifdef ARM_MATH_MATRIX_CHECK
165 
166   /* Check for matrix mismatch condition */
167   if ((lt->numRows != lt->numCols) ||
168       (a->numRows != a->numCols) ||
169       (lt->numRows != a->numRows)   )
170   {
171     /* Set status as ARM_MATH_SIZE_MISMATCH */
172     status = ARM_MATH_SIZE_MISMATCH;
173   }
174   else
175 
176 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
177 
178   {
179     /* a1 b1 c1   x1 = a1
180           b2 c2   x2   a2
181              c3   x3   a3
182 
183     x3 = a3 / c3
184     x2 = (a2 - c2 x3) / b2
185 
186     */
187     int i,j,k,n;
188 
189     n = dst->numRows;
190 
191     float16_t *pX = dst->pData;
192     float16_t *pLT = lt->pData;
193     float16_t *pA = a->pData;
194 
195     float16_t *lt_row;
196     float16_t *a_col;
197 
198     for(j=0; j < n; j ++)
199     {
200        a_col = &pA[j];
201 
202        for(i=0; i < n ; i++)
203        {
204             lt_row = &pLT[n*i];
205 
206             float16_t tmp=a_col[i * n];
207 
208             for(k=0; k < i; k++)
209             {
210                 tmp -= lt_row[k] * pX[n*k+j];
211             }
212 
213             if (lt_row[i]==0.0f)
214             {
215               return(ARM_MATH_SINGULAR);
216             }
217             tmp = tmp / lt_row[i];
218             pX[i*n+j] = tmp;
219        }
220 
221     }
222     status = ARM_MATH_SUCCESS;
223 
224   }
225 
226   /* Return to application */
227   return (status);
228 }
229 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
230 
231 /**
232   @} end of MatrixInv group
233  */
234 #endif /* #if defined(ARM_FLOAT16_SUPPORTED) */
235