1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_ldl_f32.c
4  * Description:  Floating-point LDL decomposition
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/matrix_functions.h"
30 #include "dsp/matrix_utils.h"
31 
32 
33 
34 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
35 
36 /**
37   @ingroup groupMatrix
38  */
39 
40 /**
41   @addtogroup MatrixChol
42   @{
43  */
44 
45 /**
46    * @brief Floating-point LDL^t decomposition of positive semi-definite matrix.
47    * @param[in]  pSrc   points to the instance of the input floating-point matrix structure.
48    * @param[out] pl   points to the instance of the output floating-point triangular matrix structure.
49    * @param[out] pd   points to the instance of the output floating-point diagonal matrix structure.
50    * @param[out] pp   points to the instance of the output floating-point permutation vector.
51    * @return The function returns ARM_MATH_SIZE_MISMATCH, if the dimensions do not match.
52    * @return        execution status
53                    - \ref ARM_MATH_SUCCESS       : Operation successful
54                    - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
55                    - \ref ARM_MATH_DECOMPOSITION_FAILURE      : Input matrix cannot be decomposed
56    * @par
57    *  Computes the LDL^t decomposition of a matrix A such that P A P^t = L D L^t.
58    */
arm_mat_ldlt_f32(const arm_matrix_instance_f32 * pSrc,arm_matrix_instance_f32 * pl,arm_matrix_instance_f32 * pd,uint16_t * pp)59 ARM_DSP_ATTRIBUTE arm_status arm_mat_ldlt_f32(
60   const arm_matrix_instance_f32 * pSrc,
61   arm_matrix_instance_f32 * pl,
62   arm_matrix_instance_f32 * pd,
63   uint16_t * pp)
64 {
65 
66   arm_status status;                             /* status of matrix inverse */
67 
68 
69 #ifdef ARM_MATH_MATRIX_CHECK
70 
71   /* Check for matrix mismatch condition */
72   if ((pSrc->numRows != pSrc->numCols) ||
73       (pl->numRows != pl->numCols) ||
74       (pd->numRows != pd->numCols) ||
75       (pl->numRows != pd->numRows)   )
76   {
77     /* Set status as ARM_MATH_SIZE_MISMATCH */
78     status = ARM_MATH_SIZE_MISMATCH;
79   }
80   else
81 
82 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
83 
84   {
85 
86     const int n=pSrc->numRows;
87     int fullRank = 1, diag,k;
88     float32_t *pA;
89 
90     memset(pd->pData,0,sizeof(float32_t)*n*n);
91     memcpy(pl->pData,pSrc->pData,n*n*sizeof(float32_t));
92     pA = pl->pData;
93 
94     int cnt = n;
95     uint16x8_t vecP;
96 
97     for(int k=0;k < n; k+=8)
98     {
99       mve_pred16_t p0;
100       p0 = vctp16q(cnt);
101 
102       vecP = vidupq_u16((uint16_t)k, 1);
103 
104       vstrhq_p(&pp[k], vecP, p0);
105 
106       cnt -= 8;
107     }
108 
109 
110     for(k=0;k < n; k++)
111     {
112         /* Find pivot */
113         float32_t m=F32_MIN,a;
114         int j=k;
115 
116 
117         for(int r=k;r<n;r++)
118         {
119            if (pA[r*n+r] > m)
120            {
121              m = pA[r*n+r];
122              j = r;
123            }
124         }
125 
126         if(j != k)
127         {
128           SWAP_ROWS_F32(pl,0,k,j);
129           SWAP_COLS_F32(pl,0,k,j);
130         }
131 
132 
133         pp[k] = j;
134 
135         a = pA[k*n+k];
136 
137         if (fabsf(a) < 1.0e-8f)
138         {
139 
140             fullRank = 0;
141             break;
142         }
143 
144         float32_t invA;
145 
146         invA = 1.0f / a;
147 
148         int32x4_t vecOffs;
149         int w;
150         vecOffs = vidupq_u32((uint32_t)0, 1);
151         vecOffs = vmulq_n_s32(vecOffs,n);
152 
153         for(w=k+1; w<n; w+=4)
154         {
155           int cnt = n - k - 1;
156 
157           f32x4_t vecX;
158 
159           f32x4_t vecA;
160           f32x4_t vecW0,vecW1, vecW2, vecW3;
161 
162           mve_pred16_t p0;
163 
164           vecW0 = vdupq_n_f32(pA[(w + 0)*n+k]);
165           vecW1 = vdupq_n_f32(pA[(w + 1)*n+k]);
166           vecW2 = vdupq_n_f32(pA[(w + 2)*n+k]);
167           vecW3 = vdupq_n_f32(pA[(w + 3)*n+k]);
168 
169           for(int x=k+1;x<n;x += 4)
170           {
171              p0 = vctp32q(cnt);
172 
173              //pA[w*n+x] = pA[w*n+x] - pA[w*n+k] * (pA[x*n+k] * invA);
174 
175 
176              vecX = vldrwq_gather_shifted_offset_z_f32(&pA[x*n+k], (uint32x4_t)vecOffs, p0);
177              vecX = vmulq_m_n_f32(vuninitializedq_f32(),vecX,invA,p0);
178 
179 
180              vecA = vldrwq_z_f32(&pA[(w + 0)*n+x],p0);
181              vecA = vfmsq_m(vecA, vecW0, vecX, p0);
182              vstrwq_p(&pA[(w + 0)*n+x], vecA, p0);
183 
184              vecA = vldrwq_z_f32(&pA[(w + 1)*n+x],p0);
185              vecA = vfmsq_m(vecA, vecW1, vecX, p0);
186              vstrwq_p(&pA[(w + 1)*n+x], vecA, p0);
187 
188              vecA = vldrwq_z_f32(&pA[(w + 2)*n+x],p0);
189              vecA = vfmsq_m(vecA, vecW2, vecX, p0);
190              vstrwq_p(&pA[(w + 2)*n+x], vecA, p0);
191 
192              vecA = vldrwq_z_f32(&pA[(w + 3)*n+x],p0);
193              vecA = vfmsq_m(vecA, vecW3, vecX, p0);
194              vstrwq_p(&pA[(w + 3)*n+x], vecA, p0);
195 
196              cnt -= 4;
197           }
198         }
199 
200         for(; w<n; w++)
201         {
202           int cnt = n - k - 1;
203 
204           f32x4_t vecA,vecX,vecW;
205 
206 
207           mve_pred16_t p0;
208 
209           vecW = vdupq_n_f32(pA[w*n+k]);
210 
211           for(int x=k+1;x<n;x += 4)
212           {
213              p0 = vctp32q(cnt);
214 
215              //pA[w*n+x] = pA[w*n+x] - pA[w*n+k] * (pA[x*n+k] * invA);
216 
217              vecA = vldrwq_z_f32(&pA[w*n+x],p0);
218 
219              vecX = vldrwq_gather_shifted_offset_z_f32(&pA[x*n+k], (uint32x4_t)vecOffs, p0);
220              vecX = vmulq_m_n_f32(vuninitializedq_f32(),vecX,invA,p0);
221 
222              vecA = vfmsq_m(vecA, vecW, vecX, p0);
223 
224              vstrwq_p(&pA[w*n+x], vecA, p0);
225 
226              cnt -= 4;
227           }
228         }
229 
230         for(int w=k+1;w<n;w++)
231         {
232                pA[w*n+k] = pA[w*n+k] * invA;
233         }
234 
235 
236 
237     }
238 
239 
240 
241     diag=k;
242     if (!fullRank)
243     {
244       diag--;
245       for(int row=0; row < n;row++)
246       {
247         mve_pred16_t p0;
248         int cnt= n-k;
249         f32x4_t zero=vdupq_n_f32(0.0f);
250 
251         for(int col=k; col < n;col += 4)
252         {
253            p0 = vctp32q(cnt);
254 
255            vstrwq_p(&pl->pData[row*n+col], zero, p0);
256 
257            cnt -= 4;
258         }
259       }
260     }
261 
262     for(int row=0; row < n;row++)
263     {
264        mve_pred16_t p0;
265        int cnt= n-row-1;
266        f32x4_t zero=vdupq_n_f32(0.0f);
267 
268        for(int col=row+1; col < n;col+=4)
269        {
270          p0 = vctp32q(cnt);
271 
272          vstrwq_p(&pl->pData[row*n+col], zero, p0);
273 
274          cnt -= 4;
275        }
276     }
277 
278     for(int d=0; d < diag;d++)
279     {
280       pd->pData[d*n+d] = pl->pData[d*n+d];
281       pl->pData[d*n+d] = 1.0;
282     }
283 
284     status = ARM_MATH_SUCCESS;
285 
286   }
287 
288 
289   /* Return to application */
290   return (status);
291 }
292 #else
293 
294 
295 /**
296   @ingroup groupMatrix
297  */
298 
299 /**
300   @addtogroup MatrixChol
301   @{
302  */
303 
304 /**
305    * @brief Floating-point LDL^t decomposition of positive semi-definite matrix.
306    * @param[in]  pSrc   points to the instance of the input floating-point matrix structure.
307    * @param[out] pl   points to the instance of the output floating-point triangular matrix structure.
308    * @param[out] pd   points to the instance of the output floating-point diagonal matrix structure.
309    * @param[out] pp   points to the instance of the output floating-point permutation vector.
310    * @return The function returns ARM_MATH_SIZE_MISMATCH, if the dimensions do not match.
311    * @return        execution status
312                    - \ref ARM_MATH_SUCCESS       : Operation successful
313                    - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
314                    - \ref ARM_MATH_DECOMPOSITION_FAILURE      : Input matrix cannot be decomposed
315    * @par
316    *  Computes the LDL^t decomposition of a matrix A such that P A P^t = L D L^t.
317    */
arm_mat_ldlt_f32(const arm_matrix_instance_f32 * pSrc,arm_matrix_instance_f32 * pl,arm_matrix_instance_f32 * pd,uint16_t * pp)318 ARM_DSP_ATTRIBUTE arm_status arm_mat_ldlt_f32(
319   const arm_matrix_instance_f32 * pSrc,
320   arm_matrix_instance_f32 * pl,
321   arm_matrix_instance_f32 * pd,
322   uint16_t * pp)
323 {
324 
325   arm_status status;                             /* status of matrix inverse */
326 
327 
328 #ifdef ARM_MATH_MATRIX_CHECK
329 
330   /* Check for matrix mismatch condition */
331   if ((pSrc->numRows != pSrc->numCols) ||
332       (pl->numRows != pl->numCols) ||
333       (pd->numRows != pd->numCols) ||
334       (pl->numRows != pd->numRows)   )
335   {
336     /* Set status as ARM_MATH_SIZE_MISMATCH */
337     status = ARM_MATH_SIZE_MISMATCH;
338   }
339   else
340 
341 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
342 
343   {
344 
345     const int n=pSrc->numRows;
346     int fullRank = 1, diag,k;
347     float32_t *pA;
348     int row,d;
349 
350     memset(pd->pData,0,sizeof(float32_t)*n*n);
351     memcpy(pl->pData,pSrc->pData,n*n*sizeof(float32_t));
352     pA = pl->pData;
353 
354     for(k=0;k < n; k++)
355     {
356       pp[k] = k;
357     }
358 
359 
360     for(k=0;k < n; k++)
361     {
362         /* Find pivot */
363         float32_t m=F32_MIN,a;
364         int j=k;
365 
366 
367         int r;
368 
369         for(r=k;r<n;r++)
370         {
371            if (pA[r*n+r] > m)
372            {
373              m = pA[r*n+r];
374              j = r;
375            }
376         }
377 
378         if(j != k)
379         {
380           SWAP_ROWS_F32(pl,0,k,j);
381           SWAP_COLS_F32(pl,0,k,j);
382         }
383 
384 
385         pp[k] = j;
386 
387         a = pA[k*n+k];
388 
389         if (fabsf(a) < 1.0e-8f)
390         {
391 
392             fullRank = 0;
393             break;
394         }
395 
396         for(int w=k+1;w<n;w++)
397         {
398           int x;
399           for(x=k+1;x<n;x++)
400           {
401              pA[w*n+x] = pA[w*n+x] - pA[w*n+k] * pA[x*n+k] / a;
402           }
403         }
404 
405         for(int w=k+1;w<n;w++)
406         {
407                pA[w*n+k] = pA[w*n+k] / a;
408         }
409 
410 
411 
412     }
413 
414 
415 
416     diag=k;
417     if (!fullRank)
418     {
419       diag--;
420       for(row=0; row < n;row++)
421       {
422         int col;
423         for(col=k; col < n;col++)
424         {
425            pl->pData[row*n+col]=0.0;
426         }
427       }
428     }
429 
430     for(row=0; row < n;row++)
431     {
432        int col;
433        for(col=row+1; col < n;col++)
434        {
435          pl->pData[row*n+col] = 0.0;
436        }
437     }
438 
439     for(d=0; d < diag;d++)
440     {
441       pd->pData[d*n+d] = pl->pData[d*n+d];
442       pl->pData[d*n+d] = 1.0;
443     }
444 
445     status = ARM_MATH_SUCCESS;
446 
447   }
448 
449 
450   /* Return to application */
451   return (status);
452 }
453 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
454 
455 /**
456   @} end of MatrixChol group
457  */
458