1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_cholesky_f32.c
4  * Description:  Floating-point Cholesky decomposition
5  *
6  * $Date:        05 October 2021
7  * $Revision:    V1.9.1
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/matrix_functions.h"
30 #include "dsp/matrix_utils.h"
31 
32 /**
33   @ingroup groupMatrix
34  */
35 
36 /**
37   @defgroup MatrixChol Cholesky and LDLT decompositions
38 
39   Computes the Cholesky or LL^t decomposition of a matrix.
40 
41 
42   If the input matrix does not have a decomposition, then the
43   algorithm terminates and returns error status ARM_MATH_DECOMPOSITION_FAILURE.
44  */
45 
46 /**
47   @addtogroup MatrixChol
48   @{
49  */
50 
51 /**
52    * @brief Floating-point Cholesky decomposition of positive-definite matrix.
53    * @param[in]  pSrc   points to the instance of the input floating-point matrix structure.
54    * @param[out] pDst   points to the instance of the output floating-point matrix structure.
55    * @return The function returns ARM_MATH_SIZE_MISMATCH, if the dimensions do not match.
56    * @return        execution status
57                    - \ref ARM_MATH_SUCCESS       : Operation successful
58                    - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
59                    - \ref ARM_MATH_DECOMPOSITION_FAILURE      : Input matrix cannot be decomposed
60    * @par
61    * If the matrix is ill conditioned or only semi-definite, then it is better using the LDL^t decomposition.
62    * The decomposition of A is returning a lower triangular matrix L such that A = L L^t
63    *
64    * @par
65    * The destination matrix should be set to 0 before calling the functions because
66    * the function may not overwrite all output elements.
67    */
68 
69 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
70 
71 #include "arm_helium_utils.h"
72 
arm_mat_cholesky_f32(const arm_matrix_instance_f32 * pSrc,arm_matrix_instance_f32 * pDst)73 ARM_DSP_ATTRIBUTE arm_status arm_mat_cholesky_f32(
74   const arm_matrix_instance_f32 * pSrc,
75         arm_matrix_instance_f32 * pDst)
76 {
77 
78   arm_status status;                             /* status of matrix inverse */
79 
80 
81 #ifdef ARM_MATH_MATRIX_CHECK
82 
83   /* Check for matrix mismatch condition */
84   if ((pSrc->numRows != pSrc->numCols) ||
85       (pDst->numRows != pDst->numCols) ||
86       (pSrc->numRows != pDst->numRows)   )
87   {
88     /* Set status as ARM_MATH_SIZE_MISMATCH */
89     status = ARM_MATH_SIZE_MISMATCH;
90   }
91   else
92 
93 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
94 
95   {
96     int i,j,k;
97     int n = pSrc->numRows;
98     float32_t invSqrtVj;
99     float32_t *pA,*pG;
100     int kCnt;
101 
102     mve_pred16_t p0;
103 
104     f32x4_t acc, acc0, acc1, acc2, acc3;
105     f32x4_t vecGi;
106     f32x4_t vecGj,vecGj0,vecGj1,vecGj2,vecGj3;
107 
108 
109     pA = pSrc->pData;
110     pG = pDst->pData;
111 
112     for(i=0 ;i < n ; i++)
113     {
114        for(j=i ; j+3 < n ; j+=4)
115        {
116           pG[(j + 0) * n + i] = pA[(j + 0) * n + i];
117           pG[(j + 1) * n + i] = pA[(j + 1) * n + i];
118           pG[(j + 2) * n + i] = pA[(j + 2) * n + i];
119           pG[(j + 3) * n + i] = pA[(j + 3) * n + i];
120 
121           kCnt = i;
122           acc0 = vdupq_n_f32(0.0f);
123           acc1 = vdupq_n_f32(0.0f);
124           acc2 = vdupq_n_f32(0.0f);
125           acc3 = vdupq_n_f32(0.0f);
126 
127           for(k=0; k < i ; k+=4)
128           {
129              p0 = vctp32q(kCnt);
130 
131              vecGi=vldrwq_z_f32(&pG[i * n + k],p0);
132 
133              vecGj0=vldrwq_z_f32(&pG[(j + 0) * n + k],p0);
134              vecGj1=vldrwq_z_f32(&pG[(j + 1) * n + k],p0);
135              vecGj2=vldrwq_z_f32(&pG[(j + 2) * n + k],p0);
136              vecGj3=vldrwq_z_f32(&pG[(j + 3) * n + k],p0);
137 
138              acc0 = vfmaq_m(acc0, vecGi, vecGj0, p0);
139              acc1 = vfmaq_m(acc1, vecGi, vecGj1, p0);
140              acc2 = vfmaq_m(acc2, vecGi, vecGj2, p0);
141              acc3 = vfmaq_m(acc3, vecGi, vecGj3, p0);
142 
143              kCnt -= 4;
144           }
145           pG[(j + 0) * n + i] -= vecAddAcrossF32Mve(acc0);
146           pG[(j + 1) * n + i] -= vecAddAcrossF32Mve(acc1);
147           pG[(j + 2) * n + i] -= vecAddAcrossF32Mve(acc2);
148           pG[(j + 3) * n + i] -= vecAddAcrossF32Mve(acc3);
149        }
150 
151        for(; j < n ; j++)
152        {
153           pG[j * n + i] = pA[j * n + i];
154 
155           kCnt = i;
156           acc = vdupq_n_f32(0.0f);
157 
158           for(k=0; k < i ; k+=4)
159           {
160              p0 = vctp32q(kCnt);
161 
162              vecGi=vldrwq_z_f32(&pG[i * n + k],p0);
163              vecGj=vldrwq_z_f32(&pG[j * n + k],p0);
164 
165              acc = vfmaq_m(acc, vecGi, vecGj,p0);
166 
167              kCnt -= 4;
168           }
169           pG[j * n + i] -= vecAddAcrossF32Mve(acc);
170        }
171 
172        if (pG[i * n + i] <= 0.0f)
173        {
174          return(ARM_MATH_DECOMPOSITION_FAILURE);
175        }
176 
177        invSqrtVj = 1.0f/sqrtf(pG[i * n + i]);
178        SCALE_COL_F32(pDst,i,invSqrtVj,i);
179     }
180 
181     status = ARM_MATH_SUCCESS;
182 
183   }
184 
185 
186   /* Return to application */
187   return (status);
188 }
189 
190 #else
191 #if defined(ARM_MATH_NEON) && !defined(ARM_MATH_AUTOVECTORIZE)
192 
arm_mat_cholesky_f32(const arm_matrix_instance_f32 * pSrc,arm_matrix_instance_f32 * pDst)193 ARM_DSP_ATTRIBUTE arm_status arm_mat_cholesky_f32(
194   const arm_matrix_instance_f32 * pSrc,
195         arm_matrix_instance_f32 * pDst)
196 {
197 
198   arm_status status;                             /* status of matrix inverse */
199 
200 
201 #ifdef ARM_MATH_MATRIX_CHECK
202 
203   /* Check for matrix mismatch condition */
204   if ((pSrc->numRows != pSrc->numCols) ||
205       (pDst->numRows != pDst->numCols) ||
206       (pSrc->numRows != pDst->numRows)   )
207   {
208     /* Set status as ARM_MATH_SIZE_MISMATCH */
209     status = ARM_MATH_SIZE_MISMATCH;
210   }
211   else
212 
213 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
214 
215   {
216     int i,j,k;
217     int n = pSrc->numRows;
218     float32_t invSqrtVj;
219     float32_t *pA,*pG;
220     int kCnt;
221 
222 
223     f32x4_t acc, acc0, acc1, acc2, acc3;
224     f32x4_t vecGi;
225     f32x4_t vecGj,vecGj0,vecGj1,vecGj2,vecGj3;
226 #if !defined(__aarch64__)
227     f32x2_t tmp = vdup_n_f32(0);
228 #endif
229     float32_t sum=0.0f;
230     float32_t sum0=0.0f,sum1=0.0f,sum2=0.0f,sum3=0.0f;
231 
232 
233     pA = pSrc->pData;
234     pG = pDst->pData;
235 
236     for(i=0 ;i < n ; i++)
237     {
238        for(j=i ; j+3 < n ; j+=4)
239        {
240           pG[(j + 0) * n + i] = pA[(j + 0) * n + i];
241           pG[(j + 1) * n + i] = pA[(j + 1) * n + i];
242           pG[(j + 2) * n + i] = pA[(j + 2) * n + i];
243           pG[(j + 3) * n + i] = pA[(j + 3) * n + i];
244 
245           acc0 = vdupq_n_f32(0.0f);
246           acc1 = vdupq_n_f32(0.0f);
247           acc2 = vdupq_n_f32(0.0f);
248           acc3 = vdupq_n_f32(0.0f);
249 
250           kCnt = i >> 2;
251           k=0;
252           while(kCnt > 0)
253           {
254 
255              vecGi=vld1q_f32(&pG[i * n + k]);
256 
257              vecGj0=vld1q_f32(&pG[(j + 0) * n + k]);
258              vecGj1=vld1q_f32(&pG[(j + 1) * n + k]);
259              vecGj2=vld1q_f32(&pG[(j + 2) * n + k]);
260              vecGj3=vld1q_f32(&pG[(j + 3) * n + k]);
261 
262              acc0 = vfmaq_f32(acc0, vecGi, vecGj0);
263              acc1 = vfmaq_f32(acc1, vecGi, vecGj1);
264              acc2 = vfmaq_f32(acc2, vecGi, vecGj2);
265              acc3 = vfmaq_f32(acc3, vecGi, vecGj3);
266 
267              kCnt--;
268              k+=4;
269           }
270 
271 #if defined(__aarch64__)
272           sum0 = vpadds_f32(vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0)));
273           sum1 = vpadds_f32(vpadd_f32(vget_low_f32(acc1), vget_high_f32(acc1)));
274           sum2 = vpadds_f32(vpadd_f32(vget_low_f32(acc2), vget_high_f32(acc2)));
275           sum3 = vpadds_f32(vpadd_f32(vget_low_f32(acc3), vget_high_f32(acc3)));
276 
277 #else
278           tmp = vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
279           sum0 = vget_lane_f32(tmp, 0) + vget_lane_f32(tmp, 1);
280 
281           tmp = vpadd_f32(vget_low_f32(acc1), vget_high_f32(acc1));
282           sum1 = vget_lane_f32(tmp, 0) + vget_lane_f32(tmp, 1);
283 
284           tmp = vpadd_f32(vget_low_f32(acc2), vget_high_f32(acc2));
285           sum2 = vget_lane_f32(tmp, 0) + vget_lane_f32(tmp, 1);
286 
287           tmp = vpadd_f32(vget_low_f32(acc3), vget_high_f32(acc3));
288           sum3 = vget_lane_f32(tmp, 0) + vget_lane_f32(tmp, 1);
289 #endif
290 
291           kCnt = i & 3;
292           while(kCnt > 0)
293           {
294 
295              sum0 = sum0 + pG[i * n + k] * pG[(j + 0) * n + k];
296              sum1 = sum1 + pG[i * n + k] * pG[(j + 1) * n + k];
297              sum2 = sum2 + pG[i * n + k] * pG[(j + 2) * n + k];
298              sum3 = sum3 + pG[i * n + k] * pG[(j + 3) * n + k];
299              kCnt--;
300              k++;
301           }
302 
303           pG[(j + 0) * n + i] -= sum0;
304           pG[(j + 1) * n + i] -= sum1;
305           pG[(j + 2) * n + i] -= sum2;
306           pG[(j + 3) * n + i] -= sum3;
307        }
308 
309        for(; j < n ; j++)
310        {
311           pG[j * n + i] = pA[j * n + i];
312 
313           acc = vdupq_n_f32(0.0f);
314 
315           kCnt = i >> 2;
316           k=0;
317           while(kCnt > 0)
318           {
319 
320              vecGi=vld1q_f32(&pG[i * n + k]);
321              vecGj=vld1q_f32(&pG[j * n + k]);
322 
323              acc = vfmaq_f32(acc, vecGi, vecGj);
324 
325              kCnt--;
326              k+=4;
327           }
328 
329 #if defined(__aarch64__)
330           sum = vpadds_f32(vpadd_f32(vget_low_f32(acc), vget_high_f32(acc)));
331 #else
332           tmp = vpadd_f32(vget_low_f32(acc), vget_high_f32(acc));
333           sum = vget_lane_f32(tmp, 0) + vget_lane_f32(tmp, 1);
334 #endif
335 
336           kCnt = i & 3;
337           while(kCnt > 0)
338           {
339              sum = sum + pG[i * n + k] * pG[(j + 0) * n + k];
340 
341 
342              kCnt--;
343              k++;
344           }
345 
346           pG[j * n + i] -= sum;
347        }
348 
349        if (pG[i * n + i] <= 0.0f)
350        {
351          return(ARM_MATH_DECOMPOSITION_FAILURE);
352        }
353 
354        invSqrtVj = 1.0f/sqrtf(pG[i * n + i]);
355        SCALE_COL_F32(pDst,i,invSqrtVj,i);
356     }
357 
358     status = ARM_MATH_SUCCESS;
359 
360   }
361 
362 
363   /* Return to application */
364   return (status);
365 }
366 
367 #else
arm_mat_cholesky_f32(const arm_matrix_instance_f32 * pSrc,arm_matrix_instance_f32 * pDst)368 ARM_DSP_ATTRIBUTE arm_status arm_mat_cholesky_f32(
369   const arm_matrix_instance_f32 * pSrc,
370         arm_matrix_instance_f32 * pDst)
371 {
372 
373   arm_status status;                             /* status of matrix inverse */
374 
375 
376 #ifdef ARM_MATH_MATRIX_CHECK
377 
378   /* Check for matrix mismatch condition */
379   if ((pSrc->numRows != pSrc->numCols) ||
380       (pDst->numRows != pDst->numCols) ||
381       (pSrc->numRows != pDst->numRows)   )
382   {
383     /* Set status as ARM_MATH_SIZE_MISMATCH */
384     status = ARM_MATH_SIZE_MISMATCH;
385   }
386   else
387 
388 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
389 
390   {
391     int i,j,k;
392     int n = pSrc->numRows;
393     float32_t invSqrtVj;
394     float32_t *pA,*pG;
395 
396     pA = pSrc->pData;
397     pG = pDst->pData;
398 
399 
400     for(i=0 ; i < n ; i++)
401     {
402        for(j=i ; j < n ; j++)
403        {
404           pG[j * n + i] = pA[j * n + i];
405 
406           for(k=0; k < i ; k++)
407           {
408              pG[j * n + i] = pG[j * n + i] - pG[i * n + k] * pG[j * n + k];
409           }
410        }
411 
412        if (pG[i * n + i] <= 0.0f)
413        {
414          return(ARM_MATH_DECOMPOSITION_FAILURE);
415        }
416 
417        invSqrtVj = 1.0f/sqrtf(pG[i * n + i]);
418        SCALE_COL_F32(pDst,i,invSqrtVj,i);
419 
420     }
421 
422     status = ARM_MATH_SUCCESS;
423 
424   }
425 
426 
427   /* Return to application */
428   return (status);
429 }
430 #endif /* #if defined(ARM_MATH_NEON) */
431 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
432 
433 /**
434   @} end of MatrixChol group
435  */
436