1 /*
2  * Copyright (c) 2019 Kevin Townsend (KTOWN)
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <errno.h>
8 #include <stdio.h>
9 #include <stdbool.h>
10 #include <string.h>
11 #include <zsl/zsl.h>
12 #include <zsl/matrices.h>
13 
14 /*
15  * WARNING: Work in progress!
16  *
17  * The code in this module is very 'naive' in the sense that no attempt
18  * has been made at efficiency. It is written from the perspective
19  * that code should be written to be 'reliable, elegant, efficient' in that
20  * order.
21  *
22  * Clarity and reliability have been absolutely prioritized in this
23  * early stage, with the key goal being good unit test coverage before
24  * moving on to any form of general-purpose or architecture-specific
25  * optimisation.
26  */
27 
28 // TODO: Introduce local macros for bounds/shape checks to avoid duplication!
29 
30 int
zsl_mtx_entry_fn_empty(struct zsl_mtx * m,size_t i,size_t j)31 zsl_mtx_entry_fn_empty(struct zsl_mtx *m, size_t i, size_t j)
32 {
33 	return zsl_mtx_set(m, i, j, 0);
34 }
35 
36 int
zsl_mtx_entry_fn_identity(struct zsl_mtx * m,size_t i,size_t j)37 zsl_mtx_entry_fn_identity(struct zsl_mtx *m, size_t i, size_t j)
38 {
39 	return zsl_mtx_set(m, i, j, i == j ? 1.0 : 0);
40 }
41 
42 int
zsl_mtx_entry_fn_random(struct zsl_mtx * m,size_t i,size_t j)43 zsl_mtx_entry_fn_random(struct zsl_mtx *m, size_t i, size_t j)
44 {
45 	/* TODO: Determine an appropriate random number generator. */
46 	return zsl_mtx_set(m, i, j, 0);
47 }
48 
49 int
zsl_mtx_init(struct zsl_mtx * m,zsl_mtx_init_entry_fn_t entry_fn)50 zsl_mtx_init(struct zsl_mtx *m, zsl_mtx_init_entry_fn_t entry_fn)
51 {
52 	int rc;
53 
54 	for (size_t i = 0; i < m->sz_rows; i++) {
55 		for (size_t j = 0; j < m->sz_cols; j++) {
56 			/* If entry_fn is NULL, assign 0.0 values. */
57 			if (entry_fn == NULL) {
58 				rc = zsl_mtx_entry_fn_empty(m, i, j);
59 			} else {
60 				rc = entry_fn(m, i, j);
61 			}
62 			/* Abort if entry_fn returned an error code. */
63 			if (rc) {
64 				return rc;
65 			}
66 		}
67 	}
68 
69 	return 0;
70 }
71 
72 int
zsl_mtx_from_arr(struct zsl_mtx * m,zsl_real_t * a)73 zsl_mtx_from_arr(struct zsl_mtx *m, zsl_real_t *a)
74 {
75 	memcpy(m->data, a, (m->sz_rows * m->sz_cols) * sizeof(zsl_real_t));
76 
77 	return 0;
78 }
79 
80 int
zsl_mtx_copy(struct zsl_mtx * mdest,struct zsl_mtx * msrc)81 zsl_mtx_copy(struct zsl_mtx *mdest, struct zsl_mtx *msrc)
82 {
83 #if CONFIG_ZSL_BOUNDS_CHECKS
84 	/* Ensure that msrc and mdest have the same shape. */
85 	if ((mdest->sz_rows != msrc->sz_rows) ||
86 	    (mdest->sz_cols != msrc->sz_cols)) {
87 		return -EINVAL;
88 	}
89 #endif
90 
91 	/* Make a copy of matrix 'msrc'. */
92 	memcpy(mdest->data, msrc->data, sizeof(zsl_real_t) *
93 	       msrc->sz_rows * msrc->sz_cols);
94 
95 	return 0;
96 }
97 
98 int
zsl_mtx_get(struct zsl_mtx * m,size_t i,size_t j,zsl_real_t * x)99 zsl_mtx_get(struct zsl_mtx *m, size_t i, size_t j, zsl_real_t *x)
100 {
101 #if CONFIG_ZSL_BOUNDS_CHECKS
102 	if ((i >= m->sz_rows) || (j >= m->sz_cols)) {
103 		return -EINVAL;
104 	}
105 #endif
106 
107 	*x = m->data[(i * m->sz_cols) + j];
108 
109 	return 0;
110 }
111 
112 int
zsl_mtx_set(struct zsl_mtx * m,size_t i,size_t j,zsl_real_t x)113 zsl_mtx_set(struct zsl_mtx *m, size_t i, size_t j, zsl_real_t x)
114 {
115 #if CONFIG_ZSL_BOUNDS_CHECKS
116 	if ((i >= m->sz_rows) || (j >= m->sz_cols)) {
117 		return -EINVAL;
118 	}
119 #endif
120 
121 	m->data[(i * m->sz_cols) + j] = x;
122 
123 	return 0;
124 }
125 
126 int
zsl_mtx_get_row(struct zsl_mtx * m,size_t i,zsl_real_t * v)127 zsl_mtx_get_row(struct zsl_mtx *m, size_t i, zsl_real_t *v)
128 {
129 	int rc;
130 	zsl_real_t x;
131 
132 	for (size_t j = 0; j < m->sz_cols; j++) {
133 		rc = zsl_mtx_get(m, i, j, &x);
134 		if (rc) {
135 			return rc;
136 		}
137 		v[j] = x;
138 	}
139 
140 	return 0;
141 }
142 
143 int
zsl_mtx_set_row(struct zsl_mtx * m,size_t i,zsl_real_t * v)144 zsl_mtx_set_row(struct zsl_mtx *m, size_t i, zsl_real_t *v)
145 {
146 	int rc;
147 
148 	for (size_t j = 0; j < m->sz_cols; j++) {
149 		rc = zsl_mtx_set(m, i, j, v[j]);
150 		if (rc) {
151 			return rc;
152 		}
153 	}
154 
155 	return 0;
156 }
157 
158 int
zsl_mtx_get_col(struct zsl_mtx * m,size_t j,zsl_real_t * v)159 zsl_mtx_get_col(struct zsl_mtx *m, size_t j, zsl_real_t *v)
160 {
161 	int rc;
162 	zsl_real_t x;
163 
164 	for (size_t i = 0; i < m->sz_rows; i++) {
165 		rc = zsl_mtx_get(m, i, j, &x);
166 		if (rc) {
167 			return rc;
168 		}
169 		v[i] = x;
170 	}
171 
172 	return 0;
173 }
174 
175 int
zsl_mtx_set_col(struct zsl_mtx * m,size_t j,zsl_real_t * v)176 zsl_mtx_set_col(struct zsl_mtx *m, size_t j, zsl_real_t *v)
177 {
178 	int rc;
179 
180 	for (size_t i = 0; i < m->sz_rows; i++) {
181 		rc = zsl_mtx_set(m, i, j, v[i]);
182 		if (rc) {
183 			return rc;
184 		}
185 	}
186 
187 	return 0;
188 }
189 
190 int
zsl_mtx_unary_op(struct zsl_mtx * m,zsl_mtx_unary_op_t op)191 zsl_mtx_unary_op(struct zsl_mtx *m, zsl_mtx_unary_op_t op)
192 {
193 	/* Execute the unary operation component by component. */
194 	for (size_t i = 0; i < m->sz_cols * m->sz_rows; i++) {
195 		switch (op) {
196 		case ZSL_MTX_UNARY_OP_INCREMENT:
197 			m->data[i] += 1.0;
198 			break;
199 		case ZSL_MTX_UNARY_OP_DECREMENT:
200 			m->data[i] -= 1.0;
201 			break;
202 		case ZSL_MTX_UNARY_OP_NEGATIVE:
203 			m->data[i] = -m->data[i];
204 			break;
205 		case ZSL_MTX_UNARY_OP_ROUND:
206 			m->data[i] = ZSL_ROUND(m->data[i]);
207 			break;
208 		case ZSL_MTX_UNARY_OP_ABS:
209 			m->data[i] = ZSL_ABS(m->data[i]);
210 			break;
211 		case ZSL_MTX_UNARY_OP_FLOOR:
212 			m->data[i] = ZSL_FLOOR(m->data[i]);
213 			break;
214 		case ZSL_MTX_UNARY_OP_CEIL:
215 			m->data[i] = ZSL_CEIL(m->data[i]);
216 			break;
217 		case ZSL_MTX_UNARY_OP_EXP:
218 			m->data[i] = ZSL_EXP(m->data[i]);
219 			break;
220 		case ZSL_MTX_UNARY_OP_LOG:
221 			m->data[i] = ZSL_LOG(m->data[i]);
222 			break;
223 		case ZSL_MTX_UNARY_OP_LOG10:
224 			m->data[i] = ZSL_LOG10(m->data[i]);
225 			break;
226 		case ZSL_MTX_UNARY_OP_SQRT:
227 			m->data[i] = ZSL_SQRT(m->data[i]);
228 			break;
229 		case ZSL_MTX_UNARY_OP_SIN:
230 			m->data[i] = ZSL_SIN(m->data[i]);
231 			break;
232 		case ZSL_MTX_UNARY_OP_COS:
233 			m->data[i] = ZSL_COS(m->data[i]);
234 			break;
235 		case ZSL_MTX_UNARY_OP_TAN:
236 			m->data[i] = ZSL_TAN(m->data[i]);
237 			break;
238 		case ZSL_MTX_UNARY_OP_ASIN:
239 			m->data[i] = ZSL_ASIN(m->data[i]);
240 			break;
241 		case ZSL_MTX_UNARY_OP_ACOS:
242 			m->data[i] = ZSL_ACOS(m->data[i]);
243 			break;
244 		case ZSL_MTX_UNARY_OP_ATAN:
245 			m->data[i] = ZSL_ATAN(m->data[i]);
246 			break;
247 		case ZSL_MTX_UNARY_OP_SINH:
248 			m->data[i] = ZSL_SINH(m->data[i]);
249 			break;
250 		case ZSL_MTX_UNARY_OP_COSH:
251 			m->data[i] = ZSL_COSH(m->data[i]);
252 			break;
253 		case ZSL_MTX_UNARY_OP_TANH:
254 			m->data[i] = ZSL_TANH(m->data[i]);
255 			break;
256 		default:
257 			/* Not yet implemented! */
258 			return -ENOSYS;
259 		}
260 	}
261 
262 	return 0;
263 }
264 
265 int
zsl_mtx_unary_func(struct zsl_mtx * m,zsl_mtx_unary_fn_t fn)266 zsl_mtx_unary_func(struct zsl_mtx *m, zsl_mtx_unary_fn_t fn)
267 {
268 	int rc;
269 
270 	for (size_t i = 0; i < m->sz_rows; i++) {
271 		for (size_t j = 0; j < m->sz_cols; j++) {
272 			/* If fn is NULL, do nothing. */
273 			if (fn != NULL) {
274 				rc = fn(m, i, j);
275 				if (rc) {
276 					return rc;
277 				}
278 			}
279 		}
280 	}
281 
282 	return 0;
283 }
284 
285 int
zsl_mtx_binary_op(struct zsl_mtx * ma,struct zsl_mtx * mb,struct zsl_mtx * mc,zsl_mtx_binary_op_t op)286 zsl_mtx_binary_op(struct zsl_mtx *ma, struct zsl_mtx *mb, struct zsl_mtx *mc,
287 		  zsl_mtx_binary_op_t op)
288 {
289 #if CONFIG_ZSL_BOUNDS_CHECKS
290 	if ((ma->sz_rows != mb->sz_rows) || (mb->sz_rows != mc->sz_rows) ||
291 	    (ma->sz_cols != mb->sz_cols) || (mb->sz_cols != mc->sz_cols)) {
292 		return -EINVAL;
293 	}
294 #endif
295 
296 	/* Execute the binary operation component by component. */
297 	for (size_t i = 0; i < ma->sz_cols * ma->sz_rows; i++) {
298 		switch (op) {
299 		case ZSL_MTX_BINARY_OP_ADD:
300 			mc->data[i] = ma->data[i] + mb->data[i];
301 			break;
302 		case ZSL_MTX_BINARY_OP_SUB:
303 			mc->data[i] = ma->data[i] - mb->data[i];
304 			break;
305 		case ZSL_MTX_BINARY_OP_MULT:
306 			mc->data[i] = ma->data[i] * mb->data[i];
307 			break;
308 		case ZSL_MTX_BINARY_OP_DIV:
309 			if (mb->data[i] == 0.0) {
310 				mc->data[i] = 0.0;
311 			} else {
312 				mc->data[i] = ma->data[i] / mb->data[i];
313 			}
314 			break;
315 		case ZSL_MTX_BINARY_OP_MEAN:
316 			mc->data[i] = (ma->data[i] + mb->data[i]) / 2.0;
317 		case ZSL_MTX_BINARY_OP_EXPON:
318 			mc->data[i] = ZSL_POW(ma->data[i], mb->data[i]);
319 		case ZSL_MTX_BINARY_OP_MIN:
320 			mc->data[i] = ma->data[i] < mb->data[i] ?
321 				      ma->data[i] : mb->data[i];
322 		case ZSL_MTX_BINARY_OP_MAX:
323 			mc->data[i] = ma->data[i] > mb->data[i] ?
324 				      ma->data[i] : mb->data[i];
325 		case ZSL_MTX_BINARY_OP_EQUAL:
326 			mc->data[i] = ma->data[i] == mb->data[i] ? 1.0 : 0.0;
327 		case ZSL_MTX_BINARY_OP_NEQUAL:
328 			mc->data[i] = ma->data[i] != mb->data[i] ? 1.0 : 0.0;
329 		case ZSL_MTX_BINARY_OP_LESS:
330 			mc->data[i] = ma->data[i] < mb->data[i] ? 1.0 : 0.0;
331 		case ZSL_MTX_BINARY_OP_GREAT:
332 			mc->data[i] = ma->data[i] > mb->data[i] ? 1.0 : 0.0;
333 		case ZSL_MTX_BINARY_OP_LEQ:
334 			mc->data[i] = ma->data[i] <= mb->data[i] ? 1.0 : 0.0;
335 		case ZSL_MTX_BINARY_OP_GEQ:
336 			mc->data[i] = ma->data[i] >= mb->data[i] ? 1.0 : 0.0;
337 		default:
338 			/* Not yet implemented! */
339 			return -ENOSYS;
340 		}
341 	}
342 
343 	return 0;
344 }
345 
346 int
zsl_mtx_binary_func(struct zsl_mtx * ma,struct zsl_mtx * mb,struct zsl_mtx * mc,zsl_mtx_binary_fn_t fn)347 zsl_mtx_binary_func(struct zsl_mtx *ma, struct zsl_mtx *mb,
348 		    struct zsl_mtx *mc, zsl_mtx_binary_fn_t fn)
349 {
350 	int rc;
351 
352 #if CONFIG_ZSL_BOUNDS_CHECKS
353 	if ((ma->sz_rows != mb->sz_rows) || (mb->sz_rows != mc->sz_rows) ||
354 	    (ma->sz_cols != mb->sz_cols) || (mb->sz_cols != mc->sz_cols)) {
355 		return -EINVAL;
356 	}
357 #endif
358 
359 	for (size_t i = 0; i < ma->sz_rows; i++) {
360 		for (size_t j = 0; j < ma->sz_cols; j++) {
361 			/* If fn is NULL, do nothing. */
362 			if (fn != NULL) {
363 				rc = fn(ma, mb, mc, i, j);
364 				if (rc) {
365 					return rc;
366 				}
367 			}
368 		}
369 	}
370 
371 	return 0;
372 }
373 
374 int
zsl_mtx_add(struct zsl_mtx * ma,struct zsl_mtx * mb,struct zsl_mtx * mc)375 zsl_mtx_add(struct zsl_mtx *ma, struct zsl_mtx *mb, struct zsl_mtx *mc)
376 {
377 	return zsl_mtx_binary_op(ma, mb, mc, ZSL_MTX_BINARY_OP_ADD);
378 }
379 
380 int
zsl_mtx_add_d(struct zsl_mtx * ma,struct zsl_mtx * mb)381 zsl_mtx_add_d(struct zsl_mtx *ma, struct zsl_mtx *mb)
382 {
383 	return zsl_mtx_binary_op(ma, mb, ma, ZSL_MTX_BINARY_OP_ADD);
384 }
385 
386 int
zsl_mtx_sum_rows_d(struct zsl_mtx * m,size_t i,size_t j)387 zsl_mtx_sum_rows_d(struct zsl_mtx *m, size_t i, size_t j)
388 {
389 #if CONFIG_ZSL_BOUNDS_CHECKS
390 	if ((i >= m->sz_rows) || (j >= m->sz_rows)) {
391 		return -EINVAL;
392 	}
393 #endif
394 
395 	/* Add row j to row i, element by element. */
396 	for (size_t x = 0; x < m->sz_cols; x++) {
397 		m->data[(i * m->sz_cols) + x] += m->data[(j * m->sz_cols) + x];
398 	}
399 
400 	return 0;
401 }
402 
zsl_mtx_sum_rows_scaled_d(struct zsl_mtx * m,size_t i,size_t j,zsl_real_t s)403 int zsl_mtx_sum_rows_scaled_d(struct zsl_mtx *m,
404 			      size_t i, size_t j, zsl_real_t s)
405 {
406 #if CONFIG_ZSL_BOUNDS_CHECKS
407 	if ((i >= m->sz_rows) || (j >= m->sz_cols)) {
408 		return -EINVAL;
409 	}
410 #endif
411 
412 	/* Set the values in row 'i' to 'i[n] += j[n] * s' . */
413 	for (size_t x = 0; x < m->sz_cols; x++) {
414 		m->data[(i * m->sz_cols) + x] +=
415 			(m->data[(j * m->sz_cols) + x] * s);
416 	}
417 
418 	return 0;
419 }
420 
421 int
zsl_mtx_sub(struct zsl_mtx * ma,struct zsl_mtx * mb,struct zsl_mtx * mc)422 zsl_mtx_sub(struct zsl_mtx *ma, struct zsl_mtx *mb, struct zsl_mtx *mc)
423 {
424 	return zsl_mtx_binary_op(ma, mb, mc, ZSL_MTX_BINARY_OP_SUB);
425 }
426 
427 int
zsl_mtx_sub_d(struct zsl_mtx * ma,struct zsl_mtx * mb)428 zsl_mtx_sub_d(struct zsl_mtx *ma, struct zsl_mtx *mb)
429 {
430 	return zsl_mtx_binary_op(ma, mb, ma, ZSL_MTX_BINARY_OP_SUB);
431 }
432 
433 int
zsl_mtx_mult(struct zsl_mtx * ma,struct zsl_mtx * mb,struct zsl_mtx * mc)434 zsl_mtx_mult(struct zsl_mtx *ma, struct zsl_mtx *mb, struct zsl_mtx *mc)
435 {
436 #if CONFIG_ZSL_BOUNDS_CHECKS
437 	/* Ensure that ma has the same number as columns as mb has rows. */
438 	if (ma->sz_cols != mb->sz_rows) {
439 		return -EINVAL;
440 	}
441 
442 	/* Ensure that mc has ma rows and mb cols */
443 	if ((mc->sz_rows != ma->sz_rows) || (mc->sz_cols != mb->sz_cols)) {
444 		return -EINVAL;
445 	}
446 #endif
447 
448 	ZSL_MATRIX_DEF(ma_copy, ma->sz_rows, ma->sz_cols);
449 	ZSL_MATRIX_DEF(mb_copy, mb->sz_rows, mb->sz_cols);
450 	zsl_mtx_copy(&ma_copy, ma);
451 	zsl_mtx_copy(&mb_copy, mb);
452 
453 	for (size_t i = 0; i < ma_copy.sz_rows; i++) {
454 		for (size_t j = 0; j < mb_copy.sz_cols; j++) {
455 			mc->data[j + i * mb_copy.sz_cols] = 0;
456 			for (size_t k = 0; k < ma_copy.sz_cols; k++) {
457 				mc->data[j + i * mb_copy.sz_cols] +=
458 					ma_copy.data[k + i * ma_copy.sz_cols] *
459 					mb_copy.data[j + k * mb_copy.sz_cols];
460 			}
461 		}
462 	}
463 
464 	return 0;
465 }
466 
467 int
zsl_mtx_mult_d(struct zsl_mtx * ma,struct zsl_mtx * mb)468 zsl_mtx_mult_d(struct zsl_mtx *ma, struct zsl_mtx *mb)
469 {
470 #if CONFIG_ZSL_BOUNDS_CHECKS
471 	/* Ensure that ma has the same number as columns as mb has rows. */
472 	if (ma->sz_cols != mb->sz_rows) {
473 		return -EINVAL;
474 	}
475 
476 	/* Ensure that mb is a square matrix. */
477 	if (mb->sz_rows != mb->sz_cols) {
478 		return -EINVAL;
479 	}
480 #endif
481 
482 	zsl_mtx_mult(ma, mb, ma);
483 
484 	return 0;
485 }
486 
487 int
zsl_mtx_scalar_mult_d(struct zsl_mtx * m,zsl_real_t s)488 zsl_mtx_scalar_mult_d(struct zsl_mtx *m, zsl_real_t s)
489 {
490 	for (size_t i = 0; i < m->sz_rows * m->sz_cols; i++) {
491 		m->data[i] *= s;
492 	}
493 
494 	return 0;
495 }
496 
497 int
zsl_mtx_scalar_mult_row_d(struct zsl_mtx * m,size_t i,zsl_real_t s)498 zsl_mtx_scalar_mult_row_d(struct zsl_mtx *m, size_t i, zsl_real_t s)
499 {
500 #if CONFIG_ZSL_BOUNDS_CHECKS
501 	if (i >= m->sz_rows) {
502 		return -EINVAL;
503 	}
504 #endif
505 
506 	for (size_t k = 0; k < m->sz_cols; k++) {
507 		m->data[(i * m->sz_cols) + k] *= s;
508 	}
509 
510 	return 0;
511 }
512 
513 int
zsl_mtx_trans(struct zsl_mtx * ma,struct zsl_mtx * mb)514 zsl_mtx_trans(struct zsl_mtx *ma, struct zsl_mtx *mb)
515 {
516 #if CONFIG_ZSL_BOUNDS_CHECKS
517 	/* Ensure that ma and mb have the same shape. */
518 	if ((ma->sz_rows != mb->sz_cols) || (ma->sz_cols != mb->sz_rows)) {
519 		return -EINVAL;
520 	}
521 #endif
522 
523 	zsl_real_t d[ma->sz_cols];
524 
525 	for (size_t i = 0; i < ma->sz_rows; i++) {
526 		zsl_mtx_get_row(ma, i, d);
527 		zsl_mtx_set_col(mb, i, d);
528 	}
529 
530 	return 0;
531 }
532 
533 int
zsl_mtx_adjoint_3x3(struct zsl_mtx * m,struct zsl_mtx * ma)534 zsl_mtx_adjoint_3x3(struct zsl_mtx *m, struct zsl_mtx *ma)
535 {
536 	/* Make sure this is a square matrix. */
537 	if ((m->sz_rows != m->sz_cols) || (ma->sz_rows != ma->sz_cols)) {
538 		return -EINVAL;
539 	}
540 
541 #if CONFIG_ZSL_BOUNDS_CHECKS
542 	/* Make sure this is a 3x3 matrix. */
543 	if ((m->sz_rows != 3) || (ma->sz_rows != 3)) {
544 		return -EINVAL;
545 	}
546 #endif
547 
548 	/*
549 	 * 3x3 matrix element to array table:
550 	 *
551 	 * 1,1 = 0  1,2 = 1  1,3 = 2
552 	 * 2,1 = 3  2,2 = 4  2,3 = 5
553 	 * 3,1 = 6  3,2 = 7  3,3 = 8
554 	 */
555 
556 	ma->data[0] = m->data[4] * m->data[8] - m->data[7] * m->data[5];
557 	ma->data[1] = m->data[7] * m->data[2] - m->data[1] * m->data[8];
558 	ma->data[2] = m->data[1] * m->data[5] - m->data[4] * m->data[2];
559 
560 	ma->data[3] = m->data[6] * m->data[5] - m->data[3] * m->data[8];
561 	ma->data[4] = m->data[0] * m->data[8] - m->data[6] * m->data[2];
562 	ma->data[5] = m->data[3] * m->data[2] - m->data[0] * m->data[5];
563 
564 	ma->data[6] = m->data[3] * m->data[7] - m->data[6] * m->data[4];
565 	ma->data[7] = m->data[6] * m->data[1] - m->data[0] * m->data[7];
566 	ma->data[8] = m->data[0] * m->data[4] - m->data[3] * m->data[1];
567 
568 	return 0;
569 }
570 
571 int
zsl_mtx_adjoint(struct zsl_mtx * m,struct zsl_mtx * ma)572 zsl_mtx_adjoint(struct zsl_mtx *m, struct zsl_mtx *ma)
573 {
574 	/* Shortcut for 3x3 matrices. */
575 	if (m->sz_rows == 3) {
576 		return zsl_mtx_adjoint_3x3(m, ma);
577 	}
578 
579 #if CONFIG_ZSL_BOUNDS_CHECKS
580 	/* Make sure this is a square matrix. */
581 	if (m->sz_rows != m->sz_cols) {
582 		return -EINVAL;
583 	}
584 #endif
585 
586 	zsl_real_t sign;
587 	zsl_real_t d;
588 	ZSL_MATRIX_DEF(mr, (m->sz_cols - 1), (m->sz_cols - 1));
589 
590 	for (size_t i = 0; i < m->sz_cols; i++) {
591 		for (size_t j = 0; j < m->sz_cols; j++) {
592 			sign = 1.0;
593 			if ((i + j) % 2 != 0) {
594 				sign = -1.0;
595 			}
596 			zsl_mtx_reduce(m, &mr, i, j);
597 			zsl_mtx_deter(&mr, &d);
598 			d *= sign;
599 			zsl_mtx_set(ma, i, j, d);
600 		}
601 	}
602 
603 	return 0;
604 }
605 
606 #ifndef CONFIG_ZSL_SINGLE_PRECISION
zsl_mtx_vec_wedge(struct zsl_mtx * m,struct zsl_vec * v)607 int zsl_mtx_vec_wedge(struct zsl_mtx *m, struct zsl_vec *v)
608 {
609 #if CONFIG_ZSL_BOUNDS_CHECKS
610 	/* Make sure the dimensions of 'm' and 'v' match. */
611 	if (v->sz != m->sz_cols || v->sz < 4 || m->sz_rows != (m->sz_cols - 1)) {
612 		return -EINVAL;
613 	}
614 #endif
615 
616 	zsl_real_t d;
617 
618 	ZSL_MATRIX_DEF(A, m->sz_cols, m->sz_cols);
619 	ZSL_MATRIX_DEF(Ai, m->sz_cols, m->sz_cols);
620 	ZSL_VECTOR_DEF(Av, m->sz_cols);
621 	ZSL_MATRIX_DEF(b, m->sz_cols, 1);
622 
623 	zsl_mtx_init(&A, NULL);
624 	A.data[(m->sz_cols * m->sz_cols - 1)] = 1.0;
625 
626 	for (size_t i = 0; i < m->sz_rows; i++) {
627 		zsl_mtx_get_row(m, i, Av.data);
628 		zsl_mtx_set_row(&A, i, Av.data);
629 	}
630 
631 	zsl_mtx_deter(&A, &d);
632 	zsl_mtx_inv(&A, &Ai);
633 	zsl_mtx_init(&b, NULL);
634 	b.data[(m->sz_cols - 1)] = d;
635 
636 	zsl_mtx_mult(&Ai, &b, &b);
637 
638 	zsl_vec_from_arr(v, b.data);
639 
640 	return 0;
641 }
642 #endif
643 
644 int
zsl_mtx_reduce(struct zsl_mtx * m,struct zsl_mtx * mr,size_t i,size_t j)645 zsl_mtx_reduce(struct zsl_mtx *m, struct zsl_mtx *mr, size_t i, size_t j)
646 {
647 	size_t u = 0;
648 	zsl_real_t x;
649 	zsl_real_t v[mr->sz_rows * mr->sz_rows];
650 
651 #if CONFIG_ZSL_BOUNDS_CHECKS
652 	/* Make sure mr is 1 less than m. */
653 	if (mr->sz_rows != m->sz_rows - 1) {
654 		return -EINVAL;
655 	}
656 	if (mr->sz_cols != m->sz_cols - 1) {
657 		return -EINVAL;
658 	}
659 	if ((i >= m->sz_rows) || (j >= m->sz_cols)) {
660 		return -EINVAL;
661 	}
662 #endif
663 
664 	for (size_t k = 0; k < m->sz_rows; k++) {
665 		for (size_t g = 0; g < m->sz_rows; g++) {
666 			if (k != i && g != j) {
667 				zsl_mtx_get(m, k, g, &x);
668 				v[u] = x;
669 				u++;
670 			}
671 		}
672 	}
673 
674 	zsl_mtx_from_arr(mr, v);
675 
676 	return 0;
677 }
678 
679 int
zsl_mtx_reduce_iter(struct zsl_mtx * m,struct zsl_mtx * mred,struct zsl_mtx * place1,struct zsl_mtx * place2)680 zsl_mtx_reduce_iter(struct zsl_mtx *m, struct zsl_mtx *mred,
681 				struct zsl_mtx *place1, struct zsl_mtx *place2)
682 {
683 	/* TODO: Properly check if matrix is square. */
684 	if (m->sz_rows == place1->sz_rows) {
685 		zsl_mtx_copy(place1, m);
686 	}
687 
688 	if (place1->sz_rows == mred->sz_rows) {
689 		zsl_mtx_copy(mred, place1);
690 
691 		/* restore the original placeholder size */
692 		place1->sz_rows = m->sz_rows;
693 		place1->sz_cols = m->sz_cols;
694 		place2->sz_rows = m->sz_rows;
695 		place2->sz_cols = m->sz_cols;
696 		return 0;
697 	}
698 
699 	/* trick the iterative method by generating the inner
700 	 * call intermediate matrix, adjusting its size
701 	 */
702 	place2->sz_rows = place1->sz_rows - 1;
703 	place2->sz_cols = place1->sz_cols - 1;
704 	zsl_mtx_reduce(place1, place2, 0, 0);
705 
706 	/* Do the same with the second placeholder matrix */
707 	place1->sz_rows--;
708 	place1->sz_cols--;
709 	zsl_mtx_copy(place1, place2);
710 
711 	return -EAGAIN;
712 }
713 
714 int
zsl_mtx_augm_diag(struct zsl_mtx * m,struct zsl_mtx * maug)715 zsl_mtx_augm_diag(struct zsl_mtx *m, struct zsl_mtx *maug)
716 {
717 	zsl_real_t x;
718 	/* TODO: Properly check if matrix is square, and diff > 0. */
719 	size_t diff = (maug->sz_rows) - (m->sz_rows);
720 
721 	zsl_mtx_init(maug, zsl_mtx_entry_fn_identity);
722 	for (size_t i = 0; i < m->sz_rows; i++) {
723 		for (size_t j = 0; j < m->sz_rows; j++) {
724 			zsl_mtx_get(m, i, j, &x);
725 			zsl_mtx_set(maug, i + diff, j + diff, x);
726 		}
727 	}
728 
729 	return 0;
730 }
731 
732 int
zsl_mtx_deter_3x3(struct zsl_mtx * m,zsl_real_t * d)733 zsl_mtx_deter_3x3(struct zsl_mtx *m, zsl_real_t *d)
734 {
735 	/* Make sure this is a square matrix. */
736 	if (m->sz_rows != m->sz_cols) {
737 		return -EINVAL;
738 	}
739 
740 #if CONFIG_ZSL_BOUNDS_CHECKS
741 	/* Make sure this is a 3x3 matrix. */
742 	if (m->sz_rows != 3) {
743 		return -EINVAL;
744 	}
745 #endif
746 
747 	/*
748 	 * 3x3 matrix element to array table:
749 	 *
750 	 * 1,1 = 0  1,2 = 1  1,3 = 2
751 	 * 2,1 = 3  2,2 = 4  2,3 = 5
752 	 * 3,1 = 6  3,2 = 7  3,3 = 8
753 	 */
754 
755 	*d = m->data[0] * (m->data[4] * m->data[8] - m->data[7] * m->data[5]);
756 	*d -= m->data[3] * (m->data[1] * m->data[8] - m->data[7] * m->data[2]);
757 	*d += m->data[6] * (m->data[1] * m->data[5] - m->data[4] * m->data[2]);
758 
759 	return 0;
760 }
761 
762 int
zsl_mtx_deter(struct zsl_mtx * m,zsl_real_t * d)763 zsl_mtx_deter(struct zsl_mtx *m, zsl_real_t *d)
764 {
765 	/* Shortcut for 1x1 matrices. */
766 	if (m->sz_rows == 1) {
767 		*d = m->data[0];
768 		return 0;
769 	}
770 
771 	/* Shortcut for 2x2 matrices. */
772 	if (m->sz_rows == 2) {
773 		*d = m->data[0] * m->data[3] - m->data[2] * m->data[1];
774 		return 0;
775 	}
776 
777 	/* Shortcut for 3x3 matrices. */
778 	if (m->sz_rows == 3) {
779 		return zsl_mtx_deter_3x3(m, d);
780 	}
781 
782 #if CONFIG_ZSL_BOUNDS_CHECKS
783 	/* Make sure this is a square matrix. */
784 	if (m->sz_rows != m->sz_cols) {
785 		return -EINVAL;
786 	}
787 #endif
788 
789 	/* Full calculation required for non 3x3 matrices. */
790 	int rc;
791 	zsl_real_t dtmp;
792 	zsl_real_t cur;
793 	zsl_real_t sign;
794 	ZSL_MATRIX_DEF(mr, (m->sz_rows - 1), (m->sz_rows - 1));
795 
796 	/* Clear determinant output before starting. */
797 	*d = 0.0;
798 
799 	/*
800 	 * Iterate across row 0, removing columns one by one.
801 	 * Note that these calls are recursive until we reach a 3x3 matrix,
802 	 * which will be calculated using the shortcut at the top of this
803 	 * function.
804 	 */
805 	for (size_t g = 0; g < m->sz_cols; g++) {
806 		zsl_mtx_get(m, 0, g, &cur);     /* Get value at (0, g). */
807 		zsl_mtx_init(&mr, NULL);        /* Clear mr. */
808 		zsl_mtx_reduce(m, &mr, 0, g);   /* Remove row 0, column g. */
809 		rc = zsl_mtx_deter(&mr, &dtmp); /* Calc. determinant of mr. */
810 		sign = 1.0;
811 		if (rc) {
812 			return -EINVAL;
813 		}
814 
815 		/* Uneven elements are negative. */
816 		if (g % 2 != 0) {
817 			sign = -1.0;
818 		}
819 
820 		/* Add current determinant to final output value. */
821 		*d += dtmp * cur * sign;
822 	}
823 
824 	return 0;
825 }
826 
827 int
zsl_mtx_gauss_elim(struct zsl_mtx * m,struct zsl_mtx * mg,struct zsl_mtx * mi,size_t i,size_t j)828 zsl_mtx_gauss_elim(struct zsl_mtx *m, struct zsl_mtx *mg, struct zsl_mtx *mi,
829 		   size_t i, size_t j)
830 {
831 	int rc;
832 	zsl_real_t x, y;
833 	zsl_real_t epsilon = 1E-6;
834 
835 	/* Make a copy of matrix m. */
836 	rc = zsl_mtx_copy(mg, m);
837 	if (rc) {
838 		return -EINVAL;
839 	}
840 
841 	/* Get the value of the element at position (i, j). */
842 	rc = zsl_mtx_get(mg, i, j, &y);
843 	if (rc) {
844 		return rc;
845 	}
846 
847 	/* If this is a zero value, don't do anything. */
848 	if ((y >= 0 && y < epsilon) || (y <= 0 && y > -epsilon)) {
849 		return 0;
850 	}
851 
852 	/* Cycle through the matrix row by row. */
853 	for (size_t p = 0; p < mg->sz_rows; p++) {
854 		/* Skip row 'i'. */
855 		if (p == i) {
856 			p++;
857 		}
858 		if (p == mg->sz_rows) {
859 			break;
860 		}
861 		/* Get the value of (p, j), aborting if value is zero. */
862 		zsl_mtx_get(mg, p, j, &x);
863 		if ((x >= 1E-6) || (x <= -1E-6)) {
864 			rc = zsl_mtx_sum_rows_scaled_d(mg, p, i, -(x / y));
865 
866 			if (rc) {
867 				return -EINVAL;
868 			}
869 			rc = zsl_mtx_sum_rows_scaled_d(mi, p, i, -(x / y));
870 			if (rc) {
871 				return -EINVAL;
872 			}
873 		}
874 	}
875 
876 	return 0;
877 }
878 
879 int
zsl_mtx_gauss_elim_d(struct zsl_mtx * m,struct zsl_mtx * mi,size_t i,size_t j)880 zsl_mtx_gauss_elim_d(struct zsl_mtx *m, struct zsl_mtx *mi, size_t i, size_t j)
881 {
882 	return zsl_mtx_gauss_elim(m, m, mi, i, j);
883 }
884 
885 int
zsl_mtx_gauss_reduc(struct zsl_mtx * m,struct zsl_mtx * mi,struct zsl_mtx * mg)886 zsl_mtx_gauss_reduc(struct zsl_mtx *m, struct zsl_mtx *mi,
887 		    struct zsl_mtx *mg)
888 {
889 	zsl_real_t v[m->sz_rows];
890 	zsl_real_t epsilon = 1E-6;
891 	zsl_real_t x;
892 	zsl_real_t y;
893 
894 	/* Copy the input matrix into 'mg' so all the changes will be done to
895 	 * 'mg' and the input matrix will not be destroyed. */
896 	zsl_mtx_copy(mg, m);
897 
898 	for (size_t k = 0; k < m->sz_rows; k++) {
899 
900 		/* Get every element in the diagonal. */
901 		zsl_mtx_get(mg, k, k, &x);
902 
903 		/* If the diagonal element is zero, find another value in the
904 		 * same column that isn't zero and add the row containing
905 		 * the non-zero element to the diagonal element's row. */
906 		if ((x >= 0 && x < epsilon) || (x <= 0 && x > -epsilon)) {
907 			zsl_mtx_get_col(mg, k, v);
908 			for (size_t q = 0; q < m->sz_rows; q++) {
909 				zsl_mtx_get(mg, q, q, &y);
910 				if ((v[q] >= epsilon) || (v[q] <= -epsilon)) {
911 
912 					/* If the non-zero element found is
913 					 * above the diagonal, only add its row
914 					 * if the diagonal element in this row
915 					 * is zero, to avoid undoing previous
916 					 * steps. */
917 					if (q < k && ((y >= epsilon)
918 						      || (y <= -epsilon))) {
919 					} else {
920 						zsl_mtx_sum_rows_d(mg, k, q);
921 						zsl_mtx_sum_rows_d(mi, k, q);
922 						break;
923 					}
924 				}
925 			}
926 		}
927 
928 		/* Perform the gaussian elimination in the column of the
929 		 * diagonal element to get rid of all the values in the column
930 		 * except for the diagonal one. */
931 		zsl_mtx_gauss_elim_d(mg, mi, k, k);
932 
933 		/* Divide the diagonal element's row by the diagonal element. */
934 		zsl_mtx_norm_elem_d(mg, mi, k, k);
935 	}
936 
937 	return 0;
938 }
939 
940 int
zsl_mtx_gram_schmidt(struct zsl_mtx * m,struct zsl_mtx * mort)941 zsl_mtx_gram_schmidt(struct zsl_mtx *m, struct zsl_mtx *mort)
942 {
943 	ZSL_VECTOR_DEF(v, m->sz_rows);
944 	ZSL_VECTOR_DEF(w, m->sz_rows);
945 	ZSL_VECTOR_DEF(q, m->sz_rows);
946 
947 	for (size_t t = 0; t < m->sz_cols; t++) {
948 		zsl_vec_init(&q);
949 		zsl_mtx_get_col(m, t, v.data);
950 		for (size_t g = 0; g < t; g++) {
951 			zsl_mtx_get_col(mort, g, w.data);
952 
953 			/* Calculate the projection of every column vector
954 			 * before 'g' on the 't'th column. */
955 			zsl_vec_project(&w, &v, &w);
956 			zsl_vec_add(&q, &w, &q);
957 		}
958 
959 		/* Substract the sum of the projections on the 't'th column from
960 		 * the 't'th column and set this vector as the 't'th column of
961 		 * the output matrix. */
962 		zsl_vec_sub(&v, &q, &v);
963 		zsl_mtx_set_col(mort, t, v.data);
964 	}
965 
966 	return 0;
967 }
968 
969 int
zsl_mtx_cols_norm(struct zsl_mtx * m,struct zsl_mtx * mnorm)970 zsl_mtx_cols_norm(struct zsl_mtx *m, struct zsl_mtx *mnorm)
971 {
972 	ZSL_VECTOR_DEF(v, m->sz_rows);
973 
974 	for (size_t g = 0; g < m->sz_cols; g++) {
975 		zsl_mtx_get_col(m, g, v.data);
976 		zsl_vec_to_unit(&v);
977 		zsl_mtx_set_col(mnorm, g, v.data);
978 	}
979 
980 	return 0;
981 }
982 
983 int
zsl_mtx_norm_elem(struct zsl_mtx * m,struct zsl_mtx * mn,struct zsl_mtx * mi,size_t i,size_t j)984 zsl_mtx_norm_elem(struct zsl_mtx *m, struct zsl_mtx *mn, struct zsl_mtx *mi,
985 		  size_t i, size_t j)
986 {
987 	int rc;
988 	zsl_real_t x;
989 	zsl_real_t epsilon = 1E-6;
990 
991 	/* Make a copy of matrix m. */
992 	rc = zsl_mtx_copy(mn, m);
993 	if (rc) {
994 		return -EINVAL;
995 	}
996 
997 	/* Get the value to normalise. */
998 	rc = zsl_mtx_get(mn, i, j, &x);
999 	if (rc) {
1000 		return rc;
1001 	}
1002 
1003 	/* If the value is 0.0, abort. */
1004 	if ((x >= 0 && x < epsilon) || (x <= 0 && x > -epsilon)) {
1005 		return 0;
1006 	}
1007 
1008 	rc = zsl_mtx_scalar_mult_row_d(mn, i, (1.0 / x));
1009 	if (rc) {
1010 		return -EINVAL;
1011 	}
1012 
1013 	rc = zsl_mtx_scalar_mult_row_d(mi, i, (1.0 / x));
1014 	if (rc) {
1015 		return -EINVAL;
1016 	}
1017 
1018 	return 0;
1019 }
1020 
1021 int
zsl_mtx_norm_elem_d(struct zsl_mtx * m,struct zsl_mtx * mi,size_t i,size_t j)1022 zsl_mtx_norm_elem_d(struct zsl_mtx *m, struct zsl_mtx *mi, size_t i, size_t j)
1023 {
1024 	return zsl_mtx_norm_elem(m, m, mi, i, j);
1025 }
1026 
1027 int
zsl_mtx_inv_3x3(struct zsl_mtx * m,struct zsl_mtx * mi)1028 zsl_mtx_inv_3x3(struct zsl_mtx *m, struct zsl_mtx *mi)
1029 {
1030 	int rc;
1031 	zsl_real_t d;   /* Determinant. */
1032 	zsl_real_t s;   /* Scale factor. */
1033 
1034 	/* Make sure these are square matrices. */
1035 	if ((m->sz_rows != m->sz_cols) || (mi->sz_rows != mi->sz_cols)) {
1036 		return -EINVAL;
1037 	}
1038 
1039 #if CONFIG_ZSL_BOUNDS_CHECKS
1040 	/* Make sure 'm' and 'mi' have the same shape. */
1041 	if (m->sz_rows != mi->sz_rows) {
1042 		return -EINVAL;
1043 	}
1044 	if (m->sz_cols != mi->sz_cols) {
1045 		return -EINVAL;
1046 	}
1047 	/* Make sure these are 3x3 matrices. */
1048 	if ((m->sz_cols != 3) || (mi->sz_cols != 3)) {
1049 		return -EINVAL;
1050 	}
1051 #endif
1052 
1053 	/* Calculate the determinant. */
1054 	rc = zsl_mtx_deter_3x3(m, &d);
1055 	if (rc) {
1056 		goto err;
1057 	}
1058 
1059 	/* Calculate the adjoint matrix. */
1060 	rc = zsl_mtx_adjoint_3x3(m, mi);
1061 	if (rc) {
1062 		goto err;
1063 	}
1064 
1065 	/* Scale the output using the determinant. */
1066 	if (d != 0) {
1067 		s = 1.0 / d;
1068 		rc = zsl_mtx_scalar_mult_d(mi, s);
1069 	} else {
1070 		/* Provide an identity matrix if the determinant is zero. */
1071 		rc = zsl_mtx_init(mi, zsl_mtx_entry_fn_identity);
1072 		if (rc) {
1073 			return -EINVAL;
1074 		}
1075 	}
1076 
1077 	return 0;
1078 err:
1079 	return rc;
1080 }
1081 
1082 int
zsl_mtx_inv(struct zsl_mtx * m,struct zsl_mtx * mi)1083 zsl_mtx_inv(struct zsl_mtx *m, struct zsl_mtx *mi)
1084 {
1085 	int rc;
1086 	zsl_real_t d = 0.0;
1087 
1088 	/* Shortcut for 3x3 matrices. */
1089 	if (m->sz_rows == 3) {
1090 		return zsl_mtx_inv_3x3(m, mi);
1091 	}
1092 
1093 	/* Make sure we have square matrices. */
1094 	if ((m->sz_rows != m->sz_cols) || (mi->sz_rows != mi->sz_cols)) {
1095 		return -EINVAL;
1096 	}
1097 
1098 #if CONFIG_ZSL_BOUNDS_CHECKS
1099 	/* Make sure 'm' and 'mi' have the same shape. */
1100 	if (m->sz_rows != mi->sz_rows) {
1101 		return -EINVAL;
1102 	}
1103 	if (m->sz_cols != mi->sz_cols) {
1104 		return -EINVAL;
1105 	}
1106 #endif
1107 
1108 	/* Make a copy of matrix m on the stack to avoid modifying it. */
1109 	ZSL_MATRIX_DEF(m_tmp, mi->sz_rows, mi->sz_cols);
1110 	rc = zsl_mtx_copy(&m_tmp, m);
1111 	if (rc) {
1112 		return -EINVAL;
1113 	}
1114 
1115 	/* Initialise 'mi' as an identity matrix. */
1116 	rc = zsl_mtx_init(mi, zsl_mtx_entry_fn_identity);
1117 	if (rc) {
1118 		return -EINVAL;
1119 	}
1120 
1121 	/* Make sure the determinant of 'm' is not zero. */
1122 	zsl_mtx_deter(m, &d);
1123 
1124 	if (d == 0) {
1125 		return 0;
1126 	}
1127 
1128 	/* Use Gauss-Jordan elimination for nxn matrices. */
1129 	zsl_mtx_gauss_reduc(m, mi, &m_tmp);
1130 
1131 	return 0;
1132 }
1133 
1134 int
zsl_mtx_cholesky(struct zsl_mtx * m,struct zsl_mtx * l)1135 zsl_mtx_cholesky(struct zsl_mtx *m, struct zsl_mtx *l)
1136 {
1137 #if CONFIG_ZSL_BOUNDS_CHECKS
1138 	/* Make sure 'm' is square. */
1139 	if (m->sz_rows != m->sz_cols) {
1140 		return -EINVAL;
1141 	}
1142 
1143 	/* Make sure 'm' is symmetric. */
1144 	zsl_real_t a, b;
1145 	for (size_t i = 0; i < m->sz_rows; i++) {
1146 		for (size_t j = 0; j < m->sz_rows; j++) {
1147 			zsl_mtx_get(m, i, j, &a);
1148 			zsl_mtx_get(m, j, i, &b);
1149 			if (a != b) {
1150 				return -EINVAL;
1151 			}
1152 		}
1153 	}
1154 
1155 	/* Make sure 'm' and 'l' have the same shape. */
1156 	if (m->sz_rows != l->sz_rows) {
1157 		return -EINVAL;
1158 	}
1159 	if (m->sz_cols != l->sz_cols) {
1160 		return -EINVAL;
1161 	}
1162 #endif
1163 
1164 	zsl_real_t sum, x, y;
1165 	zsl_mtx_init(l, zsl_mtx_entry_fn_empty);
1166 	for (size_t j = 0; j < m->sz_cols; j++) {
1167 		sum = 0.0;
1168 		for (size_t k = 0; k < j; k++) {
1169 			zsl_mtx_get(l, j, k, &x);
1170 			sum += x * x;
1171 		}
1172 		zsl_mtx_get(m, j, j, &x);
1173 		zsl_mtx_set(l, j, j, ZSL_SQRT(x - sum));
1174 
1175 		for (size_t i = j + 1; i < m->sz_cols; i++) {
1176 			sum = 0.0;
1177 			for (size_t k = 0; k < j; k++) {
1178 				zsl_mtx_get(l, j, k, &x);
1179 				zsl_mtx_get(l, i, k, &y);
1180 				sum += y * x;
1181 			}
1182 			zsl_mtx_get(l, j, j, &x);
1183 			zsl_mtx_get(m, i, j, &y);
1184 			zsl_mtx_set(l, i, j, (y - sum) / x);
1185 		}
1186 	}
1187 
1188 	return 0;
1189 }
1190 
1191 int
zsl_mtx_balance(struct zsl_mtx * m,struct zsl_mtx * mout)1192 zsl_mtx_balance(struct zsl_mtx *m, struct zsl_mtx *mout)
1193 {
1194 	int rc;
1195 	bool done = false;
1196 	zsl_real_t sum;
1197 	zsl_real_t row, row2;
1198 	zsl_real_t col, col2;
1199 
1200 	/* Make sure we have square matrices. */
1201 	if ((m->sz_rows != m->sz_cols) || (mout->sz_rows != mout->sz_cols)) {
1202 		return -EINVAL;
1203 	}
1204 
1205 #if CONFIG_ZSL_BOUNDS_CHECKS
1206 	/* Make sure 'm' and 'mout' have the same shape. */
1207 	if (m->sz_rows != mout->sz_rows) {
1208 		return -EINVAL;
1209 	}
1210 	if (m->sz_cols != mout->sz_cols) {
1211 		return -EINVAL;
1212 	}
1213 #endif
1214 
1215 	rc = zsl_mtx_copy(mout, m);
1216 	if (rc) {
1217 		goto err;
1218 	}
1219 
1220 	while (!done) {
1221 		done = true;
1222 
1223 		for (size_t i = 0; i < m->sz_rows; i++) {
1224 			/* Calculate sum of components of each row, column. */
1225 			for (size_t j = 0; j < m->sz_cols; j++) {
1226 				row += ZSL_ABS(mout->data[(i * m->sz_rows) +
1227 							  j]);
1228 				col += ZSL_ABS(mout->data[(j * m->sz_rows) +
1229 							  i]);
1230 			}
1231 
1232 			/* TODO: Extend with a check against epsilon? */
1233 			if (col != 0.0 && row != 0.0) {
1234 				row2 = row / 2.0;
1235 				col2 = 1.0;
1236 				sum = col + row;
1237 
1238 				while (col < row2) {
1239 					col2 *= 2.0;
1240 					col *= 4.0;
1241 				}
1242 
1243 				row2 = row * 2.0;
1244 
1245 				while (col > row2) {
1246 					col2 /= 2.0;
1247 					col /= 4.0;
1248 				}
1249 
1250 				if ((col + row) / col2 < 0.95 * sum) {
1251 					done = false;
1252 					row2 = 1.0 / col2;
1253 
1254 					for (int k = 0; k < m->sz_rows; k++) {
1255 						mout->data[(i * m->sz_rows) + k]
1256 							*= row2;
1257 						mout->data[(k * m->sz_rows) + i]
1258 							*= col2;
1259 					}
1260 				}
1261 			}
1262 
1263 			row = 0.0;
1264 			col = 0.0;
1265 		}
1266 	}
1267 
1268 err:
1269 	return rc;
1270 }
1271 
1272 int
zsl_mtx_householder(struct zsl_mtx * m,struct zsl_mtx * h,bool hessenberg)1273 zsl_mtx_householder(struct zsl_mtx *m, struct zsl_mtx *h, bool hessenberg)
1274 {
1275 	size_t size = m->sz_rows;
1276 
1277 	if (hessenberg == true) {
1278 		size--;
1279 	}
1280 
1281 	ZSL_VECTOR_DEF(v, size);
1282 	ZSL_VECTOR_DEF(v2, m->sz_rows);
1283 	ZSL_VECTOR_DEF(e1, size);
1284 
1285 	ZSL_MATRIX_DEF(mv, size, 1);
1286 	ZSL_MATRIX_DEF(mvt, 1, size);
1287 	ZSL_MATRIX_DEF(id, size, size);
1288 	ZSL_MATRIX_DEF(vvt, size, size);
1289 	ZSL_MATRIX_DEF(h2, size, size);
1290 
1291 	/* Create the e1 vector, i.e. the vector (1, 0, 0, ...). */
1292 	zsl_vec_init(&e1);
1293 	e1.data[0] = 1.0;
1294 
1295 	/* Get the first column of the input matrix. */
1296 	zsl_mtx_get_col(m, 0, v2.data);
1297 	if (hessenberg == true) {
1298 		zsl_vec_get_subset(&v2, 1, size, &v);
1299 	} else {
1300 		zsl_vec_copy(&v, &v2);
1301 	}
1302 
1303 	/* Change the 'sign' value according to the sign of the first
1304 	 * coefficient of the matrix. */
1305 	zsl_real_t sign = 1.0;
1306 
1307 	if (v.data[0] < 0) {
1308 		sign = -1.0;
1309 	}
1310 
1311 	/* Calculate the vector 'v' that will later be used to calculate the
1312 	 * Householder matrix. */
1313 	zsl_vec_scalar_mult(&e1, -sign * zsl_vec_norm(&v));
1314 
1315 	zsl_vec_add(&v, &e1, &v);
1316 
1317 	zsl_vec_scalar_div(&v, zsl_vec_norm(&v));
1318 
1319 	/* Calculate the H householder matrix by doing:
1320 	 * H = IDENTITY - 2 * v * v^t. */
1321 	zsl_mtx_from_arr(&mv, v.data);
1322 	zsl_mtx_trans(&mv, &mvt);
1323 	zsl_mtx_mult(&mv, &mvt, &vvt);
1324 	zsl_mtx_init(&id, zsl_mtx_entry_fn_identity);
1325 	zsl_mtx_scalar_mult_d(&vvt, -2);
1326 	zsl_mtx_add(&id, &vvt, &h2);
1327 
1328 	/* If Hessenberg set to true, augment the output to the size of 'm'.
1329 	 * If Hessenberg set to false, this line of code will do nothing but
1330 	 * copy the matrix 'h2' into the output matrix 'h', */
1331 	zsl_mtx_augm_diag(&h2, h);
1332 
1333 	return 0;
1334 }
1335 
1336 int
zsl_mtx_qrd(struct zsl_mtx * m,struct zsl_mtx * q,struct zsl_mtx * r,bool hessenberg)1337 zsl_mtx_qrd(struct zsl_mtx *m, struct zsl_mtx *q, struct zsl_mtx *r,
1338 	    bool hessenberg)
1339 {
1340 	ZSL_MATRIX_DEF(r2, m->sz_rows, m->sz_cols);
1341 	ZSL_MATRIX_DEF(hess, m->sz_rows, m->sz_cols);
1342 	ZSL_MATRIX_DEF(h, m->sz_rows, m->sz_rows);
1343 	ZSL_MATRIX_DEF(h2, m->sz_rows, m->sz_rows);
1344 	ZSL_MATRIX_DEF(qt, m->sz_rows, m->sz_rows);
1345 
1346 	zsl_mtx_init(&h, NULL);
1347 	zsl_mtx_init(&qt, zsl_mtx_entry_fn_identity);
1348 	zsl_mtx_copy(r, m);
1349 
1350 	for (size_t g = 0; g < (m->sz_rows - 1); g++) {
1351 
1352 		/* Reduce the matrix by 'g' rows and columns each time. */
1353 		ZSL_MATRIX_DEF(mred, (m->sz_rows - g), (m->sz_cols - g));
1354 		ZSL_MATRIX_DEF(hred, (m->sz_rows - g), (m->sz_rows - g));
1355 
1356 		/* allocate the placeholder matrices for the reduction loop */
1357 		ZSL_MATRIX_DEF(place1, r->sz_rows, r->sz_cols);
1358 		ZSL_MATRIX_DEF(place2, r->sz_rows, r->sz_cols);
1359 
1360 		while(zsl_mtx_reduce_iter(r, &mred, &place1, &place2) != 0);
1361 
1362 		/* Calculate the reduced Householder matrix 'hred'. */
1363 		if (hessenberg == true) {
1364 			zsl_mtx_householder(&mred, &hred, true);
1365 		} else {
1366 			zsl_mtx_householder(&mred, &hred, false);
1367 		}
1368 
1369 		/* Augment the Householder matrix to the input matrix size. */
1370 		zsl_mtx_augm_diag(&hred, &h);
1371 		zsl_mtx_mult(&h, r, &r2);
1372 
1373 		/* Multiply this Householder matrix by the previous ones,
1374 		 * stacked in 'qt'. */
1375 		zsl_mtx_mult(&h, &qt, &h2);
1376 		zsl_mtx_copy(&qt, &h2);
1377 		if (hessenberg == true) {
1378 			zsl_mtx_mult(&r2, &h, &hess);
1379 			zsl_mtx_copy(r, &hess);
1380 		} else {
1381 			zsl_mtx_copy(r, &r2);
1382 		}
1383 	}
1384 
1385 	/* Calculate the 'q' matrix by transposing 'qt'. */
1386 	zsl_mtx_trans(&qt, q);
1387 
1388 	return 0;
1389 }
1390 
1391 #ifndef CONFIG_ZSL_SINGLE_PRECISION
1392 int
zsl_mtx_qrd_iter(struct zsl_mtx * m,struct zsl_mtx * mout,size_t iter)1393 zsl_mtx_qrd_iter(struct zsl_mtx *m, struct zsl_mtx *mout, size_t iter)
1394 {
1395 	int rc;
1396 
1397 	ZSL_MATRIX_DEF(q, m->sz_rows, m->sz_rows);
1398 	ZSL_MATRIX_DEF(r, m->sz_rows, m->sz_rows);
1399 
1400 
1401 	/* Make a copy of 'm'. */
1402 	rc = zsl_mtx_copy(mout, m);
1403 	if (rc) {
1404 		return -EINVAL;
1405 	}
1406 
1407 	for (size_t g = 1; g <= iter; g++) {
1408 		/* Perform the QR decomposition. */
1409 		zsl_mtx_qrd(mout, &q, &r, false);
1410 
1411 		/* Multiply the results of the QR decomposition together but
1412 		 * changing its order. */
1413 		zsl_mtx_mult(&r, &q, mout);
1414 	}
1415 
1416 	return 0;
1417 }
1418 #endif
1419 
1420 #ifndef CONFIG_ZSL_SINGLE_PRECISION
1421 int
zsl_mtx_eigenvalues(struct zsl_mtx * m,struct zsl_vec * v,size_t iter)1422 zsl_mtx_eigenvalues(struct zsl_mtx *m, struct zsl_vec *v, size_t iter)
1423 {
1424 	zsl_real_t diag;
1425 	zsl_real_t sdiag;
1426 	size_t real = 0;
1427 
1428 	/* Epsilon is used to check 0 values in the subdiagonal, to determine
1429 	 * if any coimplekx values were found. Increasing the number of
1430 	 * iterations will move these values closer to 0, but when using
1431 	 * single-precision floats the numbers can still be quite large, so
1432 	 * we need to set a delta of +/- 0.001 in this case. */
1433 
1434 	zsl_real_t epsilon = 1E-6;
1435 
1436 	ZSL_MATRIX_DEF(mout, m->sz_rows, m->sz_rows);
1437 	ZSL_MATRIX_DEF(mtemp, m->sz_rows, m->sz_rows);
1438 	ZSL_MATRIX_DEF(mtemp2, m->sz_rows, m->sz_rows);
1439 
1440 	/* Balance the matrix. */
1441 	zsl_mtx_balance(m, &mtemp);
1442 
1443 	/* Put the balanced matrix into hessenberg form. */
1444 	zsl_mtx_qrd(&mtemp, &mout, &mtemp2, true);
1445 
1446 	/* Calculate the upper triangular matrix by using the recursive QR
1447 	 * decomposition method. */
1448 	zsl_mtx_qrd_iter(&mtemp2, &mout, iter);
1449 
1450 	zsl_vec_init(v);
1451 
1452 	/* If the matrix is symmetric, then it will always have real
1453 	 * eigenvalues, so treat this case appart. */
1454 	if (zsl_mtx_is_sym(m) == true) {
1455 		for (size_t g = 0; g < m->sz_rows; g++) {
1456 			zsl_mtx_get(&mout, g, g, &diag);
1457 			v->data[g] = diag;
1458 		}
1459 
1460 		return 0;
1461 	}
1462 
1463 	/*
1464 	 * If any value just below the diagonal is non-zero, it means that the
1465 	 * numbers above and to the right of the non-zero value are a pair of
1466 	 * complex values, a complex number and its conjugate.
1467 	 *
1468 	 * SVD will always return real numbers so this can be ignored, but if
1469 	 * you are calculating eigenvalues outside the SVD method, you may
1470 	 * get complex numbers, which will be indicated with the return error
1471 	 * code '-ECOMPLEXVAL'.
1472 	 *
1473 	 * If the imput matrix has complex eigenvalues, then these will be
1474 	 * ignored and the output vector will not include them.
1475 	 *
1476 	 * NOTE: The real and imaginary parts of the complex numbers are not
1477 	 * available. This only checks if there are any complex eigenvalues and
1478 	 * returns an appropriate error code to alert the user that there are
1479 	 * non-real eigenvalues present.
1480 	 */
1481 
1482 	for (size_t g = 0; g < (m->sz_rows - 1); g++) {
1483 		/* Check if any element just below the diagonal isn't zero. */
1484 		zsl_mtx_get(&mout, g + 1, g, &sdiag);
1485 		if ((sdiag >= epsilon) || (sdiag <= -epsilon)) {
1486 			/* Skip two elements if the element below
1487 			 * is not zero. */
1488 			g++;
1489 		} else {
1490 			/* Get the diagonal element if the element below
1491 			 * is zero. */
1492 			zsl_mtx_get(&mout, g, g, &diag);
1493 			v->data[real] = diag;
1494 			real++;
1495 		}
1496 	}
1497 
1498 
1499 	/* Since it's not possible to check the coefficient below the last
1500 	 * diagonal element, then check the element to its left. */
1501 	zsl_mtx_get(&mout, (m->sz_rows - 1), (m->sz_rows - 2), &sdiag);
1502 	if ((sdiag >= epsilon) || (sdiag <= -epsilon)) {
1503 		/* Do nothing if the element to its left is not zero. */
1504 	} else {
1505 		/* Get the last diagonal element if the element to its left
1506 		 * is zero. */
1507 		zsl_mtx_get(&mout, (m->sz_rows - 1), (m->sz_rows - 1), &diag);
1508 		v->data[real] = diag;
1509 		real++;
1510 	}
1511 
1512 	/* If the number of real eigenvalues ('real' coefficient) is less than
1513 	 * the matrix dimensions, then there must be complex eigenvalues. */
1514 	v->sz = real;
1515 	if (real != m->sz_rows) {
1516 		return -ECOMPLEXVAL;
1517 	}
1518 
1519 	/* Put the zeros to the end. */
1520 	zsl_vec_zte(v);
1521 
1522 	return 0;
1523 }
1524 #endif
1525 
1526 #ifndef CONFIG_ZSL_SINGLE_PRECISION
1527 int
zsl_mtx_eigenvectors(struct zsl_mtx * m,struct zsl_mtx * mev,size_t iter,bool orthonormal)1528 zsl_mtx_eigenvectors(struct zsl_mtx *m, struct zsl_mtx *mev, size_t iter,
1529 		     bool orthonormal)
1530 {
1531 	size_t b = 0;           /* Total number of eigenvectors. */
1532 	size_t e_vals = 0;      /* Number of unique eigenvalues. */
1533 	size_t count = 0;       /* Number of eigenvectors for an eigenvalue. */
1534 	size_t ga = 0;
1535 
1536 	zsl_real_t epsilon = 1E-6;
1537 	zsl_real_t x;
1538 
1539 	/* The vector where all eigenvalues will be stored. */
1540 	ZSL_VECTOR_DEF(k, m->sz_rows);
1541 	/* Temp vector to store column data. */
1542 	ZSL_VECTOR_DEF(f, m->sz_rows);
1543 	/* The vector where all UNIQUE eigenvalues will be stored. */
1544 	ZSL_VECTOR_DEF(o, m->sz_rows);
1545 	/* Temporary mxm identity matrix placeholder. */
1546 	ZSL_MATRIX_DEF(id, m->sz_rows, m->sz_rows);
1547 	/* 'm' minus the eigenvalues * the identity matrix (id). */
1548 	ZSL_MATRIX_DEF(mi, m->sz_rows, m->sz_rows);
1549 	/* Placeholder for zsl_mtx_gauss_reduc calls (required param). */
1550 	ZSL_MATRIX_DEF(mid, m->sz_rows, m->sz_rows);
1551 	/* Matrix containing all column eigenvectors for an eigenvalue. */
1552 	ZSL_MATRIX_DEF(evec, m->sz_rows, m->sz_rows);
1553 	/* Matrix containing all column eigenvectors for an eigenvalue.
1554 	* Two matrices are required for the Gramm-Schmidt operation. */
1555 	ZSL_MATRIX_DEF(evec2, m->sz_rows, m->sz_rows);
1556 	/* Matrix containing all column eigenvectors. */
1557 	ZSL_MATRIX_DEF(mev2, m->sz_rows, m->sz_rows);
1558 
1559 	/* TODO: Check that we have a SQUARE matrix, etc. */
1560 	zsl_mtx_init(&mev2, NULL);
1561 	zsl_vec_init(&o);
1562 	zsl_mtx_eigenvalues(m, &k, iter);
1563 
1564 	/* Copy every non-zero eigenvalue ONCE in the 'o' vector to get rid of
1565 	 * repeated values. */
1566 	for (size_t q = 0; q < m->sz_rows; q++) {
1567 		if ((k.data[q] >= epsilon) || (k.data[q] <= -epsilon)) {
1568 			if (zsl_vec_contains(&o, k.data[q], epsilon) == 0) {
1569 				o.data[e_vals] = k.data[q];
1570 				/* Increment the unique eigenvalue counter. */
1571 				e_vals++;
1572 			}
1573 		}
1574 	}
1575 
1576 	/* If zero is also an eigenvalue, copy it once in 'o'. */
1577 	if (zsl_vec_contains(&k, 0.0, epsilon) > 0) {
1578 		e_vals++;
1579 	}
1580 
1581 	/* Calculates the null space of 'm' minus each eigenvalue times
1582 	 * the identity matrix by performing the gaussian reduction. */
1583 	for (size_t g = 0; g < e_vals; g++) {
1584 		count = 0;
1585 		ga = 0;
1586 
1587 		zsl_mtx_init(&id, zsl_mtx_entry_fn_identity);
1588 		zsl_mtx_scalar_mult_d(&id, -o.data[g]);
1589 		zsl_mtx_add_d(&id, m);
1590 		zsl_mtx_gauss_reduc(&id, &mid, &mi);
1591 
1592 		/* If 'orthonormal' is true, perform the following process. */
1593 		if (orthonormal == true) {
1594 			/* Count how many eigenvectors ('count' coefficient)
1595 			 * there are for each eigenvalue. */
1596 			for (size_t h = 0; h < m->sz_rows; h++) {
1597 				zsl_mtx_get(&mi, h, h, &x);
1598 				if ((x >= 0.0 && x < epsilon) ||
1599 				    (x <= 0.0 && x > -epsilon)) {
1600 					count++;
1601 				}
1602 			}
1603 
1604 			/* Resize evec* placeholders to have 'count' cols. */
1605 			evec.sz_cols = count;
1606 			evec2.sz_cols = count;
1607 
1608 			/* Get all the eigenvectors for each eigenvalue and set
1609 			 * them as the columns of 'evec'. */
1610 			for (size_t h = 0; h < m->sz_rows; h++) {
1611 				zsl_mtx_get(&mi, h, h, &x);
1612 				if ((x >= 0.0 && x < epsilon) ||
1613 				    (x <= 0.0 && x > -epsilon)) {
1614 					zsl_mtx_set(&mi, h, h, -1);
1615 					zsl_mtx_get_col(&mi, h, f.data);
1616 					zsl_vec_neg(&f);
1617 					zsl_mtx_set_col(&evec, ga, f.data);
1618 					ga++;
1619 				}
1620 			}
1621 			/* Orthonormalize the set of eigenvectors for each
1622 			 * eigenvalue using the Gram-Schmidt process. */
1623 			zsl_mtx_gram_schmidt(&evec, &evec2);
1624 			zsl_mtx_cols_norm(&evec2, &evec);
1625 
1626 			/* Place these eigenvectors in the 'mev2' matrix,
1627 			 * that will hold all the eigenvectors for different
1628 			 * eigenvalues. */
1629 			for (size_t gi = 0; gi < count; gi++) {
1630 				zsl_mtx_get_col(&evec, gi, f.data);
1631 				zsl_mtx_set_col(&mev2, b, f.data);
1632 				b++;
1633 			}
1634 
1635 		} else {
1636 			/* Orthonormal is false. */
1637 			/* Get the eigenvectors for every eigenvalue and place
1638 			 * them in 'mev2'. */
1639 			for (size_t h = 0; h < m->sz_rows; h++) {
1640 				zsl_mtx_get(&mi, h, h, &x);
1641 				if ((x >= 0.0 && x < epsilon) ||
1642 				    (x <= 0.0 && x > -epsilon)) {
1643 					zsl_mtx_set(&mi, h, h, -1);
1644 					zsl_mtx_get_col(&mi, h, f.data);
1645 					zsl_vec_neg(&f);
1646 					zsl_mtx_set_col(&mev2, b, f.data);
1647 					b++;
1648 				}
1649 			}
1650 		}
1651 	}
1652 
1653 	/* Since 'b' is the number of eigenvectors, reduce 'mev' (of size
1654 	 * m->sz_rows times b) to erase columns of zeros. */
1655 	mev->sz_cols = b;
1656 
1657 	for (size_t s = 0; s < b; s++) {
1658 		zsl_mtx_get_col(&mev2, s, f.data);
1659 		zsl_mtx_set_col(mev, s, f.data);
1660 	}
1661 
1662 	/* Checks if the number of eigenvectors is the same as the shape of
1663 	 * the input matrix. If the number of eigenvectors is less than
1664 	 * the number of columns in the input matrix 'm', this will be
1665 	 * indicated by EEIGENSIZE as a return code. */
1666 	if (b != m->sz_cols) {
1667 		return -EEIGENSIZE;
1668 	}
1669 
1670 	return 0;
1671 }
1672 #endif
1673 
1674 #ifndef CONFIG_ZSL_SINGLE_PRECISION
1675 int
zsl_mtx_svd(struct zsl_mtx * m,struct zsl_mtx * u,struct zsl_mtx * e,struct zsl_mtx * v,size_t iter)1676 zsl_mtx_svd(struct zsl_mtx *m, struct zsl_mtx *u, struct zsl_mtx *e,
1677 	    struct zsl_mtx *v, size_t iter)
1678 {
1679 	ZSL_MATRIX_DEF(aat, m->sz_rows, m->sz_rows);
1680 	ZSL_MATRIX_DEF(upri, m->sz_rows, m->sz_rows);
1681 	ZSL_MATRIX_DEF(ata, m->sz_cols, m->sz_cols);
1682 	ZSL_MATRIX_DEF(at, m->sz_cols, m->sz_rows);
1683 	ZSL_VECTOR_DEF(ui, m->sz_rows);
1684 	ZSL_MATRIX_DEF(ui2, m->sz_cols, 1);
1685 	ZSL_MATRIX_DEF(ui3, m->sz_rows, 1);
1686 	ZSL_VECTOR_DEF(hu, m->sz_rows);
1687 
1688 	zsl_real_t d;
1689 	size_t pu = 0;
1690 	size_t min = m->sz_cols;
1691 	zsl_real_t epsilon = 1E-6;
1692 
1693 	zsl_mtx_trans(m, &at);
1694 
1695 	/* Calculate 'm' times 'm' transposed and viceversa. */
1696 	zsl_mtx_mult(m, &at, &aat);
1697 	zsl_mtx_mult(&at, m, &ata);
1698 
1699 	/* Set the value 'min' as the minimum of number of columns and number
1700 	 * of rows. */
1701 	if (m->sz_rows <= m->sz_cols) {
1702 		min = m->sz_rows;
1703 	}
1704 
1705 	/* Calculate the eigenvalues of the square matrix 'm' times 'm'
1706 	 * transposed or the square matrix 'm' transposed times 'm', whichever
1707 	 * is smaller in dimensions. */
1708 	ZSL_VECTOR_DEF(ev, min);
1709 	if (min < m->sz_cols) {
1710 		zsl_mtx_eigenvalues(&aat, &ev, iter);
1711 	} else {
1712 		zsl_mtx_eigenvalues(&ata, &ev, iter);
1713 	}
1714 
1715 	/* Place the square root of these eigenvalues in the diagonal entries
1716 	 * of 'e', the sigma matrix. */
1717 	zsl_mtx_init(e, NULL);
1718 	for (size_t g = 0; g < min; g++) {
1719 		zsl_mtx_set(e, g, g, ZSL_SQRT(ev.data[g]));
1720 	}
1721 
1722 	/* Calculate the eigenvectors of 'm' times 'm' transposed and set them
1723 	 * as the columns of the 'v' matrix. */
1724 	zsl_mtx_eigenvectors(&ata, v, iter, true);
1725 	for (size_t i = 0; i < min; i++) {
1726 		zsl_mtx_get_col(v, i, ui.data);
1727 		zsl_mtx_from_arr(&ui2, ui.data);
1728 		zsl_mtx_get(e, i, i, &d);
1729 
1730 		/* Calculate the column vectors of 'u' by dividing these
1731 		 * eniegnvectors by the square root its eigenvalue and
1732 		 * multiplying them by the input matrix. */
1733 		zsl_mtx_mult(m, &ui2, &ui3);
1734 		if ((d >= 0.0 && d < epsilon) || (d <= 0.0 && d > -epsilon)) {
1735 			pu++;
1736 		} else {
1737 			zsl_mtx_scalar_mult_d(&ui3, (1 / d));
1738 			zsl_vec_from_arr(&ui, ui3.data);
1739 			zsl_mtx_set_col(u, i, ui.data);
1740 		}
1741 	}
1742 
1743 	/* Expand the columns of 'u' into an orthonormal basis if there are
1744 	 * zero eigenvalues or if the number of columns in 'm' is less than the
1745 	 * number of rows. */
1746 	zsl_mtx_eigenvectors(&aat, &upri, iter, true);
1747 	for (size_t f = min - pu; f < m->sz_rows; f++) {
1748 		zsl_mtx_get_col(&upri, f, hu.data);
1749 		zsl_mtx_set_col(u, f, hu.data);
1750 	}
1751 
1752 	return 0;
1753 }
1754 #endif
1755 
1756 #ifndef CONFIG_ZSL_SINGLE_PRECISION
1757 int
zsl_mtx_pinv(struct zsl_mtx * m,struct zsl_mtx * pinv,size_t iter)1758 zsl_mtx_pinv(struct zsl_mtx *m, struct zsl_mtx *pinv, size_t iter)
1759 {
1760 	zsl_real_t x;
1761 	size_t min = m->sz_cols;
1762 	zsl_real_t epsilon = 1E-6;
1763 
1764 	ZSL_MATRIX_DEF(u, m->sz_rows, m->sz_rows);
1765 	ZSL_MATRIX_DEF(e, m->sz_rows, m->sz_cols);
1766 	ZSL_MATRIX_DEF(v, m->sz_cols, m->sz_cols);
1767 	ZSL_MATRIX_DEF(et, m->sz_cols, m->sz_rows);
1768 	ZSL_MATRIX_DEF(ut, m->sz_rows, m->sz_rows);
1769 	ZSL_MATRIX_DEF(pas, m->sz_cols, m->sz_rows);
1770 
1771 	/* Determine the SVD decomposition of 'm'. */
1772 	zsl_mtx_svd(m, &u, &e, &v, iter);
1773 
1774 	/* Transpose the 'u' matrix. */
1775 	zsl_mtx_trans(&u, &ut);
1776 
1777 	/* Set the value 'min' as the minimum of number of columns and number
1778 	 * of rows. */
1779 	if (m->sz_rows <= m->sz_cols) {
1780 		min = m->sz_rows;
1781 	}
1782 
1783 	for (size_t g = 0; g < min; g++) {
1784 
1785 		/* Invert the diagonal values in 'e'. If a value is zero, do
1786 		 * nothing to it. */
1787 		zsl_mtx_get(&e, g, g, &x);
1788 		if ((x < epsilon) || (x > -epsilon)) {
1789 			x = 1 / x;
1790 			zsl_mtx_set(&e, g, g, x);
1791 		}
1792 	}
1793 
1794 	/* Transpose the sigma matrix. */
1795 	zsl_mtx_trans(&e, &et);
1796 
1797 	/* Multiply 'u' (transposed) times sigma (transposed and with inverted
1798 	 * eigenvalues) times 'v'. */
1799 	zsl_mtx_mult(&v, &et, &pas);
1800 	zsl_mtx_mult(&pas, &ut, pinv);
1801 
1802 	return 0;
1803 }
1804 #endif
1805 
1806 int
zsl_mtx_min(struct zsl_mtx * m,zsl_real_t * x)1807 zsl_mtx_min(struct zsl_mtx *m, zsl_real_t *x)
1808 {
1809 	zsl_real_t min = m->data[0];
1810 
1811 	for (size_t i = 0; i < m->sz_cols * m->sz_rows; i++) {
1812 		if (m->data[i] < min) {
1813 			min = m->data[i];
1814 		}
1815 	}
1816 
1817 	*x = min;
1818 
1819 	return 0;
1820 }
1821 
1822 int
zsl_mtx_max(struct zsl_mtx * m,zsl_real_t * x)1823 zsl_mtx_max(struct zsl_mtx *m, zsl_real_t *x)
1824 {
1825 	zsl_real_t max = m->data[0];
1826 
1827 	for (size_t i = 0; i < m->sz_cols * m->sz_rows; i++) {
1828 		if (m->data[i] > max) {
1829 			max = m->data[i];
1830 		}
1831 	}
1832 
1833 	*x = max;
1834 
1835 	return 0;
1836 }
1837 
1838 int
zsl_mtx_min_idx(struct zsl_mtx * m,size_t * i,size_t * j)1839 zsl_mtx_min_idx(struct zsl_mtx *m, size_t *i, size_t *j)
1840 {
1841 	zsl_real_t min = m->data[0];
1842 
1843 	*i = 0;
1844 	*j = 0;
1845 
1846 	for (size_t _i = 0; _i < m->sz_rows; _i++) {
1847 		for (size_t _j = 0; _j < m->sz_cols; _j++) {
1848 			if (m->data[_i * m->sz_cols + _j] < min) {
1849 				min = m->data[_i * m->sz_cols + _j];
1850 				*i = _i;
1851 				*j = _j;
1852 			}
1853 		}
1854 	}
1855 
1856 	return 0;
1857 }
1858 
1859 int
zsl_mtx_max_idx(struct zsl_mtx * m,size_t * i,size_t * j)1860 zsl_mtx_max_idx(struct zsl_mtx *m, size_t *i, size_t *j)
1861 {
1862 	zsl_real_t max = m->data[0];
1863 
1864 	*i = 0;
1865 	*j = 0;
1866 
1867 	for (size_t _i = 0; _i < m->sz_rows; _i++) {
1868 		for (size_t _j = 0; _j < m->sz_cols; _j++) {
1869 			if (m->data[_i * m->sz_cols + _j] > max) {
1870 				max = m->data[_i * m->sz_cols + _j];
1871 				*i = _i;
1872 				*j = _j;
1873 			}
1874 		}
1875 	}
1876 
1877 	return 0;
1878 }
1879 
1880 bool
zsl_mtx_is_equal(struct zsl_mtx * ma,struct zsl_mtx * mb)1881 zsl_mtx_is_equal(struct zsl_mtx *ma, struct zsl_mtx *mb)
1882 {
1883 	int res;
1884 
1885 	/* Make sure shape is the same. */
1886 	if ((ma->sz_rows != mb->sz_rows) || (ma->sz_cols != mb->sz_cols)) {
1887 		return false;
1888 	}
1889 
1890 	res = memcmp(ma->data, mb->data,
1891 		     sizeof(zsl_real_t) * (ma->sz_rows + ma->sz_cols));
1892 
1893 	return res == 0 ? true : false;
1894 }
1895 
1896 bool
zsl_mtx_is_notneg(struct zsl_mtx * m)1897 zsl_mtx_is_notneg(struct zsl_mtx *m)
1898 {
1899 	for (size_t i = 0; i < m->sz_rows * m->sz_cols; i++) {
1900 		if (m->data[i] < 0.0) {
1901 			return false;
1902 		}
1903 	}
1904 
1905 	return true;
1906 }
1907 
1908 bool
zsl_mtx_is_sym(struct zsl_mtx * m)1909 zsl_mtx_is_sym(struct zsl_mtx *m)
1910 {
1911 	zsl_real_t x;
1912 	zsl_real_t y;
1913 	zsl_real_t diff;
1914 	zsl_real_t epsilon = 1E-6;
1915 
1916 	for (size_t i = 0; i < m->sz_rows; i++) {
1917 		for (size_t j = 0; j < m->sz_cols; j++) {
1918 			zsl_mtx_get(m, i, j, &x);
1919 			zsl_mtx_get(m, j, i, &y);
1920 			diff = x - y;
1921 			if (diff >= epsilon || diff <= -epsilon) {
1922 				return false;
1923 			}
1924 		}
1925 	}
1926 
1927 	return true;
1928 }
1929 
1930 int
zsl_mtx_print(struct zsl_mtx * m)1931 zsl_mtx_print(struct zsl_mtx *m)
1932 {
1933 	int rc;
1934 	zsl_real_t x;
1935 
1936 	for (size_t i = 0; i < m->sz_rows; i++) {
1937 		for (size_t j = 0; j < m->sz_cols; j++) {
1938 			rc = zsl_mtx_get(m, i, j, &x);
1939 			if (rc) {
1940 				printf("Error reading (%zu,%zu)!\n", i, j);
1941 				return -EINVAL;
1942 			}
1943 			/* Print the current floating-point value. */
1944 			printf("%f ", x);
1945 		}
1946 		printf("\n");
1947 	}
1948 
1949 	printf("\n");
1950 
1951 	return 0;
1952 }
1953