1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_mat_ldl_f32.c
4 * Description: Floating-point LDL decomposition
5 *
6 * $Date: 23 April 2021
7 * $Revision: V1.9.0
8 *
9 * Target Processor: Cortex-M and Cortex-A cores
10 * -------------------------------------------------------------------- */
11 /*
12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13 *
14 * SPDX-License-Identifier: Apache-2.0
15 *
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
19 *
20 * www.apache.org/licenses/LICENSE-2.0
21 *
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
27 */
28
29 #include "dsp/matrix_functions.h"
30
31
32
33
34
35 #if defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE)
36
37
38 /// @private
39 #define SWAP_ROWS_F32(A,i,j) \
40 { \
41 int cnt = n; \
42 \
43 for(int w=0;w < n; w+=4) \
44 { \
45 f32x4_t tmpa,tmpb; \
46 mve_pred16_t p0 = vctp32q(cnt); \
47 \
48 tmpa=vldrwq_z_f32(&A[i*n + w],p0);\
49 tmpb=vldrwq_z_f32(&A[j*n + w],p0);\
50 \
51 vstrwq_p(&A[i*n + w], tmpb, p0); \
52 vstrwq_p(&A[j*n + w], tmpa, p0); \
53 \
54 cnt -= 4; \
55 } \
56 }
57
58 /// @private
59 #define SWAP_COLS_F32(A,i,j) \
60 for(int w=0;w < n; w++) \
61 { \
62 float32_t tmp; \
63 tmp = A[w*n + i]; \
64 A[w*n + i] = A[w*n + j];\
65 A[w*n + j] = tmp; \
66 }
67
68 /**
69 @ingroup groupMatrix
70 */
71
72 /**
73 @addtogroup MatrixChol
74 @{
75 */
76
77 /**
78 * @brief Floating-point LDL^t decomposition of positive semi-definite matrix.
79 * @param[in] pSrc points to the instance of the input floating-point matrix structure.
80 * @param[out] pl points to the instance of the output floating-point triangular matrix structure.
81 * @param[out] pd points to the instance of the output floating-point diagonal matrix structure.
82 * @param[out] pp points to the instance of the output floating-point permutation vector.
83 * @return The function returns ARM_MATH_SIZE_MISMATCH, if the dimensions do not match.
84 * @return execution status
85 - \ref ARM_MATH_SUCCESS : Operation successful
86 - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
87 - \ref ARM_MATH_DECOMPOSITION_FAILURE : Input matrix cannot be decomposed
88 * @par
89 * Computes the LDL^t decomposition of a matrix A such that P A P^t = L D L^t.
90 */
arm_mat_ldlt_f32(const arm_matrix_instance_f32 * pSrc,arm_matrix_instance_f32 * pl,arm_matrix_instance_f32 * pd,uint16_t * pp)91 arm_status arm_mat_ldlt_f32(
92 const arm_matrix_instance_f32 * pSrc,
93 arm_matrix_instance_f32 * pl,
94 arm_matrix_instance_f32 * pd,
95 uint16_t * pp)
96 {
97
98 arm_status status; /* status of matrix inverse */
99
100
101 #ifdef ARM_MATH_MATRIX_CHECK
102
103 /* Check for matrix mismatch condition */
104 if ((pSrc->numRows != pSrc->numCols) ||
105 (pl->numRows != pl->numCols) ||
106 (pd->numRows != pd->numCols) ||
107 (pl->numRows != pd->numRows) )
108 {
109 /* Set status as ARM_MATH_SIZE_MISMATCH */
110 status = ARM_MATH_SIZE_MISMATCH;
111 }
112 else
113
114 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
115
116 {
117
118 const int n=pSrc->numRows;
119 int fullRank = 1, diag,k;
120 float32_t *pA;
121
122 memcpy(pl->pData,pSrc->pData,n*n*sizeof(float32_t));
123 pA = pl->pData;
124
125 int cnt = n;
126 uint16x8_t vecP;
127
128 for(int k=0;k < n; k+=8)
129 {
130 mve_pred16_t p0;
131 p0 = vctp16q(cnt);
132
133 vecP = vidupq_u16((uint16_t)k, 1);
134
135 vstrhq_p(&pp[k], vecP, p0);
136
137 cnt -= 8;
138 }
139
140
141 for(k=0;k < n; k++)
142 {
143 /* Find pivot */
144 float32_t m=F32_MIN,a;
145 int j=k;
146
147
148 for(int r=k;r<n;r++)
149 {
150 if (pA[r*n+r] > m)
151 {
152 m = pA[r*n+r];
153 j = r;
154 }
155 }
156
157 if(j != k)
158 {
159 SWAP_ROWS_F32(pA,k,j);
160 SWAP_COLS_F32(pA,k,j);
161 }
162
163
164 pp[k] = j;
165
166 a = pA[k*n+k];
167
168 if (fabs(a) < 1.0e-8)
169 {
170
171 fullRank = 0;
172 break;
173 }
174
175 float32_t invA;
176
177 invA = 1.0f / a;
178
179 int32x4_t vecOffs;
180 int w;
181 vecOffs = vidupq_u32((uint32_t)0, 1);
182 vecOffs = vmulq_n_s32(vecOffs,n);
183
184 for(w=k+1; w<n; w+=4)
185 {
186 int cnt = n - k - 1;
187
188 f32x4_t vecX;
189
190 f32x4_t vecA;
191 f32x4_t vecW0,vecW1, vecW2, vecW3;
192
193 mve_pred16_t p0;
194
195 vecW0 = vdupq_n_f32(pA[(w + 0)*n+k]);
196 vecW1 = vdupq_n_f32(pA[(w + 1)*n+k]);
197 vecW2 = vdupq_n_f32(pA[(w + 2)*n+k]);
198 vecW3 = vdupq_n_f32(pA[(w + 3)*n+k]);
199
200 for(int x=k+1;x<n;x += 4)
201 {
202 p0 = vctp32q(cnt);
203
204 //pA[w*n+x] = pA[w*n+x] - pA[w*n+k] * (pA[x*n+k] * invA);
205
206
207 vecX = vldrwq_gather_shifted_offset_z_f32(&pA[x*n+k], vecOffs, p0);
208 vecX = vmulq_m_n_f32(vuninitializedq_f32(),vecX,invA,p0);
209
210
211 vecA = vldrwq_z_f32(&pA[(w + 0)*n+x],p0);
212 vecA = vfmsq_m(vecA, vecW0, vecX, p0);
213 vstrwq_p(&pA[(w + 0)*n+x], vecA, p0);
214
215 vecA = vldrwq_z_f32(&pA[(w + 1)*n+x],p0);
216 vecA = vfmsq_m(vecA, vecW1, vecX, p0);
217 vstrwq_p(&pA[(w + 1)*n+x], vecA, p0);
218
219 vecA = vldrwq_z_f32(&pA[(w + 2)*n+x],p0);
220 vecA = vfmsq_m(vecA, vecW2, vecX, p0);
221 vstrwq_p(&pA[(w + 2)*n+x], vecA, p0);
222
223 vecA = vldrwq_z_f32(&pA[(w + 3)*n+x],p0);
224 vecA = vfmsq_m(vecA, vecW3, vecX, p0);
225 vstrwq_p(&pA[(w + 3)*n+x], vecA, p0);
226
227 cnt -= 4;
228 }
229 }
230
231 for(; w<n; w++)
232 {
233 int cnt = n - k - 1;
234
235 f32x4_t vecA,vecX,vecW;
236
237
238 mve_pred16_t p0;
239
240 vecW = vdupq_n_f32(pA[w*n+k]);
241
242 for(int x=k+1;x<n;x += 4)
243 {
244 p0 = vctp32q(cnt);
245
246 //pA[w*n+x] = pA[w*n+x] - pA[w*n+k] * (pA[x*n+k] * invA);
247
248 vecA = vldrwq_z_f32(&pA[w*n+x],p0);
249
250 vecX = vldrwq_gather_shifted_offset_z_f32(&pA[x*n+k], vecOffs, p0);
251 vecX = vmulq_m_n_f32(vuninitializedq_f32(),vecX,invA,p0);
252
253 vecA = vfmsq_m(vecA, vecW, vecX, p0);
254
255 vstrwq_p(&pA[w*n+x], vecA, p0);
256
257 cnt -= 4;
258 }
259 }
260
261 for(int w=k+1;w<n;w++)
262 {
263 pA[w*n+k] = pA[w*n+k] * invA;
264 }
265
266
267
268 }
269
270
271
272 diag=k;
273 if (!fullRank)
274 {
275 diag--;
276 for(int row=0; row < n;row++)
277 {
278 mve_pred16_t p0;
279 int cnt= n-k;
280 f32x4_t zero=vdupq_n_f32(0.0f);
281
282 for(int col=k; col < n;col += 4)
283 {
284 p0 = vctp32q(cnt);
285
286 vstrwq_p(&pl->pData[row*n+col], zero, p0);
287
288 cnt -= 4;
289 }
290 }
291 }
292
293 for(int row=0; row < n;row++)
294 {
295 mve_pred16_t p0;
296 int cnt= n-row-1;
297 f32x4_t zero=vdupq_n_f32(0.0f);
298
299 for(int col=row+1; col < n;col+=4)
300 {
301 p0 = vctp32q(cnt);
302
303 vstrwq_p(&pl->pData[row*n+col], zero, p0);
304
305 cnt -= 4;
306 }
307 }
308
309 for(int d=0; d < diag;d++)
310 {
311 pd->pData[d*n+d] = pl->pData[d*n+d];
312 pl->pData[d*n+d] = 1.0;
313 }
314
315 status = ARM_MATH_SUCCESS;
316
317 }
318
319
320 /* Return to application */
321 return (status);
322 }
323 #else
324
325 /// @private
326 #define SWAP_ROWS_F32(A,i,j) \
327 for(int w=0;w < n; w++) \
328 { \
329 float32_t tmp; \
330 tmp = A[i*n + w]; \
331 A[i*n + w] = A[j*n + w];\
332 A[j*n + w] = tmp; \
333 }
334
335 /// @private
336 #define SWAP_COLS_F32(A,i,j) \
337 for(int w=0;w < n; w++) \
338 { \
339 float32_t tmp; \
340 tmp = A[w*n + i]; \
341 A[w*n + i] = A[w*n + j];\
342 A[w*n + j] = tmp; \
343 }
344
345 /**
346 @ingroup groupMatrix
347 */
348
349 /**
350 @addtogroup MatrixChol
351 @{
352 */
353
354 /**
355 * @brief Floating-point LDL^t decomposition of positive semi-definite matrix.
356 * @param[in] pSrc points to the instance of the input floating-point matrix structure.
357 * @param[out] pl points to the instance of the output floating-point triangular matrix structure.
358 * @param[out] pd points to the instance of the output floating-point diagonal matrix structure.
359 * @param[out] pp points to the instance of the output floating-point permutation vector.
360 * @return The function returns ARM_MATH_SIZE_MISMATCH, if the dimensions do not match.
361 * @return execution status
362 - \ref ARM_MATH_SUCCESS : Operation successful
363 - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
364 - \ref ARM_MATH_DECOMPOSITION_FAILURE : Input matrix cannot be decomposed
365 * @par
366 * Computes the LDL^t decomposition of a matrix A such that P A P^t = L D L^t.
367 */
arm_mat_ldlt_f32(const arm_matrix_instance_f32 * pSrc,arm_matrix_instance_f32 * pl,arm_matrix_instance_f32 * pd,uint16_t * pp)368 arm_status arm_mat_ldlt_f32(
369 const arm_matrix_instance_f32 * pSrc,
370 arm_matrix_instance_f32 * pl,
371 arm_matrix_instance_f32 * pd,
372 uint16_t * pp)
373 {
374
375 arm_status status; /* status of matrix inverse */
376
377
378 #ifdef ARM_MATH_MATRIX_CHECK
379
380 /* Check for matrix mismatch condition */
381 if ((pSrc->numRows != pSrc->numCols) ||
382 (pl->numRows != pl->numCols) ||
383 (pd->numRows != pd->numCols) ||
384 (pl->numRows != pd->numRows) )
385 {
386 /* Set status as ARM_MATH_SIZE_MISMATCH */
387 status = ARM_MATH_SIZE_MISMATCH;
388 }
389 else
390
391 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
392
393 {
394
395 const int n=pSrc->numRows;
396 int fullRank = 1, diag,k;
397 float32_t *pA;
398
399 memcpy(pl->pData,pSrc->pData,n*n*sizeof(float32_t));
400 pA = pl->pData;
401
402 for(int k=0;k < n; k++)
403 {
404 pp[k] = k;
405 }
406
407
408 for(k=0;k < n; k++)
409 {
410 /* Find pivot */
411 float32_t m=F32_MIN,a;
412 int j=k;
413
414
415 for(int r=k;r<n;r++)
416 {
417 if (pA[r*n+r] > m)
418 {
419 m = pA[r*n+r];
420 j = r;
421 }
422 }
423
424 if(j != k)
425 {
426 SWAP_ROWS_F32(pA,k,j);
427 SWAP_COLS_F32(pA,k,j);
428 }
429
430
431 pp[k] = j;
432
433 a = pA[k*n+k];
434
435 if (fabs(a) < 1.0e-8)
436 {
437
438 fullRank = 0;
439 break;
440 }
441
442 for(int w=k+1;w<n;w++)
443 {
444 for(int x=k+1;x<n;x++)
445 {
446 pA[w*n+x] = pA[w*n+x] - pA[w*n+k] * pA[x*n+k] / a;
447 }
448 }
449
450 for(int w=k+1;w<n;w++)
451 {
452 pA[w*n+k] = pA[w*n+k] / a;
453 }
454
455
456
457 }
458
459
460
461 diag=k;
462 if (!fullRank)
463 {
464 diag--;
465 for(int row=0; row < n;row++)
466 {
467 for(int col=k; col < n;col++)
468 {
469 pl->pData[row*n+col]=0.0;
470 }
471 }
472 }
473
474 for(int row=0; row < n;row++)
475 {
476 for(int col=row+1; col < n;col++)
477 {
478 pl->pData[row*n+col] = 0.0;
479 }
480 }
481
482 for(int d=0; d < diag;d++)
483 {
484 pd->pData[d*n+d] = pl->pData[d*n+d];
485 pl->pData[d*n+d] = 1.0;
486 }
487
488 status = ARM_MATH_SUCCESS;
489
490 }
491
492
493 /* Return to application */
494 return (status);
495 }
496 #endif /* defined(ARM_MATH_MVEF) && !defined(ARM_MATH_AUTOVECTORIZE) */
497
498 /**
499 @} end of MatrixChol group
500 */
501