1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_solve_lower_triangular_f32.c
4  * Description:  Solve linear system LT X = A with LT lower triangular matrix
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/matrix_functions.h"
30 
31 /**
32   @ingroup groupMatrix
33  */
34 
35 
36 /**
37   @addtogroup MatrixInv
38   @{
39  */
40 
41 
42    /**
43    * @brief Solve LT . X = A where LT is a lower triangular matrix
44    * @param[in]  lt  The lower triangular matrix
45    * @param[in]  a  The matrix a
46    * @param[out] dst The solution X of LT . X = A
47    * @return The function returns ARM_MATH_SINGULAR, if the system can't be solved.
48    */
49 
50 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
51 
52 #include "arm_helium_utils.h"
53 
arm_mat_solve_lower_triangular_f32(const arm_matrix_instance_f32 * lt,const arm_matrix_instance_f32 * a,arm_matrix_instance_f32 * dst)54   arm_status arm_mat_solve_lower_triangular_f32(
55   const arm_matrix_instance_f32 * lt,
56   const arm_matrix_instance_f32 * a,
57   arm_matrix_instance_f32 * dst)
58   {
59   arm_status status;                             /* status of matrix inverse */
60 
61 
62 #ifdef ARM_MATH_MATRIX_CHECK
63 
64   /* Check for matrix mismatch condition */
65   if ((lt->numRows != lt->numCols) ||
66       (lt->numRows != a->numRows)   )
67   {
68     /* Set status as ARM_MATH_SIZE_MISMATCH */
69     status = ARM_MATH_SIZE_MISMATCH;
70   }
71   else
72 
73 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
74 
75   {
76     /* a1 b1 c1   x1 = a1
77           b2 c2   x2   a2
78              c3   x3   a3
79 
80     x3 = a3 / c3
81     x2 = (a2 - c2 x3) / b2
82 
83     */
84     int i,j,k,n,cols;
85 
86     n = dst->numRows;
87     cols = dst->numCols;
88 
89     float32_t *pX = dst->pData;
90     float32_t *pLT = lt->pData;
91     float32_t *pA = a->pData;
92 
93     float32_t *lt_row;
94     float32_t *a_col;
95 
96     float32_t invLT;
97 
98     f32x4_t vecA;
99     f32x4_t vecX;
100 
101     for(i=0; i < n ; i++)
102     {
103 
104       for(j=0; j+3 < cols; j += 4)
105       {
106             vecA = vld1q_f32(&pA[i * cols + j]);
107 
108             for(k=0; k < i; k++)
109             {
110                 vecX = vld1q_f32(&pX[cols*k+j]);
111                 vecA = vfmsq(vecA,vdupq_n_f32(pLT[n*i + k]),vecX);
112             }
113 
114             if (pLT[n*i + i]==0.0f)
115             {
116               return(ARM_MATH_SINGULAR);
117             }
118 
119             invLT = 1.0f / pLT[n*i + i];
120             vecA = vmulq(vecA,vdupq_n_f32(invLT));
121             vst1q(&pX[i*cols+j],vecA);
122 
123        }
124 
125        for(; j < cols; j ++)
126        {
127             a_col = &pA[j];
128             lt_row = &pLT[n*i];
129 
130             float32_t tmp=a_col[i * cols];
131 
132             for(k=0; k < i; k++)
133             {
134                 tmp -= lt_row[k] * pX[cols*k+j];
135             }
136 
137             if (lt_row[i]==0.0f)
138             {
139               return(ARM_MATH_SINGULAR);
140             }
141             tmp = tmp / lt_row[i];
142             pX[i*cols+j] = tmp;
143         }
144 
145     }
146     status = ARM_MATH_SUCCESS;
147 
148   }
149 
150   /* Return to application */
151   return (status);
152 }
153 #else
154 #if defined(ARM_MATH_NEON) && !defined(ARM_MATH_AUTOVECTORIZE)
arm_mat_solve_lower_triangular_f32(const arm_matrix_instance_f32 * lt,const arm_matrix_instance_f32 * a,arm_matrix_instance_f32 * dst)155   arm_status arm_mat_solve_lower_triangular_f32(
156   const arm_matrix_instance_f32 * lt,
157   const arm_matrix_instance_f32 * a,
158   arm_matrix_instance_f32 * dst)
159   {
160   arm_status status;                             /* status of matrix inverse */
161 
162 
163 #ifdef ARM_MATH_MATRIX_CHECK
164 
165   /* Check for matrix mismatch condition */
166   if ((lt->numRows != lt->numCols) ||
167       (lt->numRows != a->numRows)   )
168   {
169     /* Set status as ARM_MATH_SIZE_MISMATCH */
170     status = ARM_MATH_SIZE_MISMATCH;
171   }
172   else
173 
174 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
175 
176   {
177     /* a1 b1 c1   x1 = a1
178           b2 c2   x2   a2
179              c3   x3   a3
180 
181     x3 = a3 / c3
182     x2 = (a2 - c2 x3) / b2
183 
184     */
185     int i,j,k,n,cols;
186 
187     n = dst->numRows;
188     cols = dst->numCols;
189 
190     float32_t *pX = dst->pData;
191     float32_t *pLT = lt->pData;
192     float32_t *pA = a->pData;
193 
194     float32_t *lt_row;
195     float32_t *a_col;
196 
197     float32_t invLT;
198 
199     f32x4_t vecA;
200     f32x4_t vecX;
201 
202     for(i=0; i < n ; i++)
203     {
204 
205       for(j=0; j+3 < cols; j += 4)
206       {
207             vecA = vld1q_f32(&pA[i * cols + j]);
208 
209             for(k=0; k < i; k++)
210             {
211                 vecX = vld1q_f32(&pX[cols*k+j]);
212                 vecA = vfmsq_f32(vecA,vdupq_n_f32(pLT[n*i + k]),vecX);
213             }
214 
215             if (pLT[n*i + i]==0.0f)
216             {
217               return(ARM_MATH_SINGULAR);
218             }
219 
220             invLT = 1.0f / pLT[n*i + i];
221             vecA = vmulq_f32(vecA,vdupq_n_f32(invLT));
222             vst1q_f32(&pX[i*cols+j],vecA);
223 
224        }
225 
226        for(; j < cols; j ++)
227        {
228             a_col = &pA[j];
229             lt_row = &pLT[n*i];
230 
231             float32_t tmp=a_col[i * cols];
232 
233             for(k=0; k < i; k++)
234             {
235                 tmp -= lt_row[k] * pX[cols*k+j];
236             }
237 
238             if (lt_row[i]==0.0f)
239             {
240               return(ARM_MATH_SINGULAR);
241             }
242             tmp = tmp / lt_row[i];
243             pX[i*cols+j] = tmp;
244         }
245 
246     }
247     status = ARM_MATH_SUCCESS;
248 
249   }
250 
251   /* Return to application */
252   return (status);
253 }
254 #else
arm_mat_solve_lower_triangular_f32(const arm_matrix_instance_f32 * lt,const arm_matrix_instance_f32 * a,arm_matrix_instance_f32 * dst)255   arm_status arm_mat_solve_lower_triangular_f32(
256   const arm_matrix_instance_f32 * lt,
257   const arm_matrix_instance_f32 * a,
258   arm_matrix_instance_f32 * dst)
259   {
260   arm_status status;                             /* status of matrix inverse */
261 
262 
263 #ifdef ARM_MATH_MATRIX_CHECK
264   /* Check for matrix mismatch condition */
265   if ((lt->numRows != lt->numCols) ||
266       (lt->numRows != a->numRows)   )
267   {
268     /* Set status as ARM_MATH_SIZE_MISMATCH */
269     status = ARM_MATH_SIZE_MISMATCH;
270   }
271   else
272 
273 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
274 
275   {
276     /* a1 b1 c1   x1 = a1
277           b2 c2   x2   a2
278              c3   x3   a3
279 
280     x3 = a3 / c3
281     x2 = (a2 - c2 x3) / b2
282 
283     */
284     int i,j,k,n,cols;
285 
286     float32_t *pX = dst->pData;
287     float32_t *pLT = lt->pData;
288     float32_t *pA = a->pData;
289 
290     float32_t *lt_row;
291     float32_t *a_col;
292 
293     n = dst->numRows;
294     cols = dst -> numCols;
295 
296 
297     for(j=0; j < cols; j ++)
298     {
299        a_col = &pA[j];
300 
301        for(i=0; i < n ; i++)
302        {
303             float32_t tmp=a_col[i * cols];
304 
305             lt_row = &pLT[n*i];
306 
307             for(k=0; k < i; k++)
308             {
309                 tmp -= lt_row[k] * pX[cols*k+j];
310             }
311 
312             if (lt_row[i]==0.0f)
313             {
314               return(ARM_MATH_SINGULAR);
315             }
316             tmp = tmp / lt_row[i];
317             pX[i*cols+j] = tmp;
318        }
319 
320     }
321     status = ARM_MATH_SUCCESS;
322 
323   }
324 
325   /* Return to application */
326   return (status);
327 }
328 #endif /* #if defined(ARM_MATH_NEON) */
329 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
330 
331 /**
332   @} end of MatrixInv group
333  */
334