1/*
2 * SPDX-License-Identifier: BSD-3-Clause
3 *
4 * Copyright © 2023 Keith Packard
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 *
10 * 1. Redistributions of source code must retain the above copyright
11 *    notice, this list of conditions and the following disclaimer.
12 *
13 * 2. Redistributions in binary form must reproduce the above
14 *    copyright notice, this list of conditions and the following
15 *    disclaimer in the documentation and/or other materials provided
16 *    with the distribution.
17 *
18 * 3. Neither the name of the copyright holder nor the names of its
19 *    contributors may be used to endorse or promote products derived
20 *    from this software without specific prior written permission.
21 *
22 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
25 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
26 * COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
27 * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
28 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
29 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
30 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
31 * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
32 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
33 * OF THE POSSIBILITY OF SUCH DAMAGE.
34 */
35
36/* List of all IEEE rounding modes */
37typedef enum {
38	TONEAREST, UPWARD, DOWNWARD, TOWARDZERO
39} rounding_mode_t;
40
41/*
42 * IEEE-style numbers with explicit sign (for signed 0), along with
43 * NaN and infinity values
44 */
45
46typedef union {
47	real	num;
48	void	nan;
49	void	inf;
50} value_t;
51
52typedef struct {
53	bool	sign;
54	value_t	u;
55} float_t;
56
57/*
58 * Construct our fancy float_t from a number
59 */
60float_t
61make_float(real num)
62{
63	return (float_t) {
64		.sign = num < 0,
65		.u = { .num = num },
66	};
67}
68
69typedef struct {
70	int	bits;
71	int	exp_bits;
72	int	min_exp;
73	int	max_exp;
74
75	int	first_exp;
76	int	last_exp;
77} format_t;
78
79typedef union {
80	format_t	format;
81	void		none;
82} format_or_none_t;
83
84format_or_none_t none_format = { .none = <> };
85
86/*
87 * Round the given float to the specified sizes under the given
88 * rounding mode.
89 */
90
91float_t
92round(float_t f, format_t format, format_or_none_t i_format, rounding_mode_t rm)
93{
94	union switch (i_format) {
95	case format format:
96		f = round(f, format, none_format, rm);
97		break;
98	case none:
99		break;
100	}
101
102	int bits = format.bits;
103
104	union switch(f.u) {
105	case num x:
106		if (f.sign)
107			x = -x;
108		int exp;
109		if (x == 0)
110			exp = 0;
111		else
112			exp = ceil(log2(x));
113		int denorm = format.min_exp - exp;
114
115		/* Denorm means our available precision is reduced */
116		if (denorm > 0) {
117#			printf("denorm %d\n", denorm);
118			bits -= denorm;
119		}
120
121		/*
122		 * Compute the significand. This could use the
123		 * mantissa built-in function if 'x' was a float, but
124		 * we usually want to use rationals instead to
125		 * preserve all of the bits until rounding happens
126		 */
127		real mant = abs(x / (2**(exp-bits)));
128
129		/*
130		 * Split into integer and fractional portions. The
131		 * integer portion holds the number of bits in the
132		 * eventual result, the fractional portion is used in
133		 * rounding decisions.
134		 */
135		int ipart = floor(mant);
136		real fpart = mant - ipart;
137
138#		printf("%a: mant %e ipart %d fpart %f\n", x, mant, ipart, fpart);
139
140		union switch(rm) {
141		case TONEAREST:
142			if (fpart == 0.5) {
143				/* round even when the fraction is exactly 1/2 */
144				if ((ipart & 1) != 0)
145					ipart = ipart + 1;
146			} else if (fpart > 0.5) {
147				ipart = ipart + 1;
148			}
149			break;
150		case UPWARD:
151			if (!f.sign) {
152				if (fpart > 0)
153					ipart = ipart + 1;
154			} else {
155				/*
156				 * Large negative values round
157				 * up to the negative finite value
158				 * of greatest magnitude instead of
159				 * rounding down to -infinity
160				 */
161				if (exp > format.max_exp) {
162					exp = format.max_exp;
163					ipart = (2**bits) - 1;
164				}
165			}
166			break;
167		case DOWNWARD:
168			if (f.sign) {
169				if (fpart > 0)
170					ipart = ipart + 1;
171			} else {
172				/*
173				 * Large positive values round
174				 * down to the positive finite value
175				 * of greatest magnitude instead of
176				 * rounding up to infinity
177				 */
178				if (exp > format.max_exp) {
179					exp = format.max_exp;
180					ipart = (2**bits) - 1;
181				}
182			}
183			break;
184		case TOWARDZERO:
185			/*
186			 * Large magnitude values round to the value
187			 * of largest magnitude of the appropriate
188			 * sign instead of away from zero to
189			 * +/-infinity.
190			 */
191			if (exp > format.max_exp) {
192				exp = format.max_exp;
193				ipart = (2**bits) - 1;
194			}
195			break;
196		}
197
198		/*
199		 * Handle underflow in a way that preserves rounding
200		 * to a value of smallest magnitude.
201		 */
202		if (bits < 0) {
203			exp -= bits;
204			bits = 0;
205		}
206
207#		printf("rounded ipart %d exp %d bits %d\n", ipart, exp, bits);
208
209		/*
210		 * Compute the final significand, which
211		 * is always >= 0.5 and < 1
212		 */
213		mant = ipart / (2 ** bits);
214		if (mant >= 1) {
215			exp++;
216			mant /= 2;
217		}
218
219		/* Overflow to infinity */
220		if (exp > format.max_exp) {
221			f.u.inf = <>;
222		} else {
223			f.u.num = mant * 2 ** exp;
224			if (f.sign)
225				f.u.num = -f.u.num;
226		}
227		break;
228	case nan:
229	case inf:
230		break;
231	}
232	return f;
233}
234
235string
236strfromfloat(float_t f, string suffix)
237{
238	union switch (f.u) {
239	case num x:
240		if (x == 0)
241			/*
242			 * Make sure zero is printed as 0.0 so that
243			 * the suffix works
244			 */
245			return sprintf("%.1f%s", x, suffix);
246		else
247			/*
248			 * %a format involves conversion to float; the
249			 * default has 256 bits of significand, which
250			 * is expected to be sufficient for any ieee
251			 * target
252			 */
253			return sprintf("%a%s", x, suffix);
254	case nan:
255		return sprintf("%s%s", f.sign ? "          -nan" : "           nan", suffix);
256	case inf:
257		return sprintf("%s%s", f.sign ? "          -inf" : "           inf", suffix);
258	}
259}
260
261bool
262isfinite(float_t f)
263{
264	union switch (f.u) {
265	case num:
266		return true;
267	default:
268		return false;
269	}
270}
271
272bool
273isnan(float_t f)
274{
275	union switch (f.u) {
276	case nan:
277		return true;
278	default:
279		return false;
280	}
281}
282
283bool
284isinf(float_t f)
285{
286	union switch (f.u) {
287	case inf:
288		return true;
289	default:
290		return false;
291	}
292}
293
294float_t
295times(float_t a, float_t b)
296{
297	if (isnan(a))
298		return a;
299	if (isnan(b))
300		return b;
301
302	bool sign = !(a.sign == b.sign);
303
304	/* Special case inf values -- inf * 0 is nan, but inf * other is inf */
305	if (isinf(a)) {
306		if (b.u == (value_t.num) 0)
307			return (float_t) { .sign = sign, .u = { .nan = <> } };
308		return (float_t) { .sign = sign, .u = a.u };
309	}
310	if (isinf(b)) {
311		if (a.u == (value_t.num) 0)
312			return (float_t) { .sign = sign, .u = { .nan = <> } };
313		return (float_t) { .sign = sign, .u = b.u };
314	}
315	return (float_t) { .sign = sign, .u = { .num = a.u.num * b.u.num } };
316}
317
318float_t
319plus(float_t a, float_t b)
320{
321	if (isnan(a))
322		return a;
323	if (isnan(b))
324		return b;
325
326	if (isinf(a)) {
327		/* inf + -inf is NaN */
328		if (isinf(b) && a.sign != b.sign)
329			return (float_t) { .sign = true, .u = { .nan = <> } };
330		return a;
331	}
332	if (isinf(b)) {
333		return b;
334	}
335	real v = a.u.num + b.u.num;
336	bool sign = v < 0;
337	if (v == 0)
338		sign = a.sign;
339	return (float_t) { .sign = sign, .u = { .num = v } };
340}
341
342/*
343 * Now that we have all of our support functions, the actual fma
344 * implementation is pretty simple
345 */
346float_t
347fma(float_t x, float_t y, float_t z)
348{
349	return plus(times(x, y), z);
350}
351
352int next_exp(int e, format_t format)
353{
354	switch (e) {
355	case format.first_exp + 1:
356		return format.min_exp - 2;
357	case format.min_exp:
358		return -1;
359	case 1:
360		return format.last_exp - 2;
361	default:
362		return e + 1;
363	}
364}
365
366/*
367 * Usual Exponent range for the specified number of bits in the
368 * exponent. Note that the minimum is off-by one for the 80-bit m68k
369 * format, which uses a slightly different form for denorm.
370 */
371
372int
373min_exp(int exp_bits)
374{
375	return -(2**(exp_bits-1) - 3);
376}
377
378int
379max_exp(int exp_bits)
380{
381	return (2**(exp_bits-1));
382}
383
384/*
385 * Generate a set of test vectors for the specified floating point
386 * format
387 */
388void generate(string suf, format_t format, format_or_none_t i_format)
389{
390	int bits = format.bits;
391
392	format.first_exp = (format.min_exp - bits - 2);
393	format.last_exp = (format.max_exp);
394
395	real val = 1 + 2**-(bits-1);
396
397	int i = 0;
398
399	/* Check +/- z */
400	for (int zs = -1; zs <= 1; zs += 2) {
401		for (int ze = format.first_exp; ze <= format.last_exp; ze = next_exp(ze, format)) {
402			float_t z = round(make_float(zs * val * (2 ** ze)), format, none_format, rounding_mode_t.TONEAREST);
403			for (int ye = format.first_exp; ye <= format.last_exp; ye = next_exp(ye, format)) {
404				float_t y = round(make_float(val * (2 ** ye)), format, none_format, rounding_mode_t.TONEAREST);
405				for (int xs = -1; xs <= 1; xs += 2) {
406					for (int xe = format.first_exp; xe <= format.last_exp; xe = next_exp(xe, format)) {
407						float_t x = round(make_float(xs * val * (2 ** xe)), format, none_format, rounding_mode_t.TONEAREST);
408						printf(" /* %4d */ { %-17s, %-17s, %-17s, {", i, strfromfloat(x, suf), strfromfloat(y, suf), strfromfloat(z, suf));
409						float_t r = plus(times(x, y), z);
410						printf(" %s,", strfromfloat(round(r, format, i_format, rounding_mode_t.TONEAREST), suf));
411						printf(" %s,", strfromfloat(round(r, format, i_format, rounding_mode_t.UPWARD), suf));
412						printf(" %s,", strfromfloat(round(r, format, i_format, rounding_mode_t.DOWNWARD), suf));
413						printf(" %s" , strfromfloat(round(r, format, i_format, rounding_mode_t.TOWARDZERO), suf));
414						printf(" } },\n");
415						i++;
416					}
417				}
418			}
419		}
420	}
421}
422
423format_t ieee_32 = {
424	.bits = 24,
425	.exp_bits = 8,
426	.min_exp = min_exp(8),
427	.max_exp = max_exp(8),
428};
429
430format_t ieee_64 = {
431	.bits = 53,
432	.exp_bits = 11,
433	.min_exp = min_exp(11),
434	.max_exp = max_exp(11),
435};
436
437format_t ieee_128 = {
438	.bits = 113,
439	.exp_bits = 15,
440	.min_exp = min_exp(15),
441	.max_exp = max_exp(15),
442};
443
444format_t intel_80 = {
445	.bits = 64,
446	.exp_bits = 15,
447	.min_exp = min_exp(15),
448	.max_exp = max_exp(15),
449};
450
451format_t moto_80 = {
452	.bits = 64,
453	.exp_bits = 15,
454	.min_exp = -16382,
455	.max_exp = max_exp(15),
456};
457
458format_or_none_t intel_80_optional = { .format = intel_80 };
459format_or_none_t moto_80_optional = { .format = moto_80 };
460
461void main()
462{
463	printf("/* This file is automatically generated with fma_gen.5c */\n");
464	printf("\n");
465
466	printf("#if __FLT_EVAL_METHOD__ == 0 && FLT_MANT_DIG == 24\n");
467	printf("#define HAVE_FLOAT_FMA_VEC\n");
468	printf("static const struct fmaf_vec fmaf_vec[] = {\n");
469	generate("f", ieee_32, none_format);
470	printf("};\n");
471	printf("#endif\n");
472	printf("\n");
473
474	printf("#if __FLT_EVAL_METHOD__ == 2 && FLT_MANT_DIG == 24 && LDBL_MANT_DIG == 64 && LDBL_MIN_EXP == -16381\n");
475	printf("#define HAVE_FLOAT_FMA_VEC\n");
476	printf("static const struct fmaf_vec fmaf_vec[] = {\n");
477	generate("f", ieee_32, intel_80_optional);
478	printf("};\n");
479	printf("#endif\n");
480	printf("\n");
481
482	printf("#if __FLT_EVAL_METHOD__ == 2 && FLT_MANT_DIG == 24 && LDBL_MANT_DIG == 64 && LDBL_MIN_EXP == -16382\n");
483	printf("#define HAVE_FLOAT_FMA_VEC\n");
484	printf("static const struct fmaf_vec fmaf_vec[] = {\n");
485	generate("f", ieee_32, moto_80_optional);
486	printf("};\n");
487	printf("#endif\n");
488	printf("\n");
489
490	printf("#if __FLT_EVAL_METHOD__ <= 1 && DBL_MANT_DIG == 53\n");
491	printf("#define HAVE_DOUBLE_FMA_VEC\n");
492	printf("static const struct fma_vec fma_vec[] = {\n");
493	generate("", ieee_64, none_format);
494	printf("};\n");
495	printf("#endif\n");
496	printf("\n");
497
498	printf("#if __FLT_EVAL_METHOD__ == 2 && DBL_MANT_DIG == 53 && LDBL_MANT_DIG == 64 && LDBL_MIN_EXP == -16381\n");
499	printf("#define HAVE_DOUBLE_FMA_VEC\n");
500	printf("static const struct fma_vec fma_vec[] = {\n");
501	generate("", ieee_64, intel_80_optional);
502	printf("};\n");
503	printf("#endif\n");
504	printf("\n");
505
506	printf("#if __FLT_EVAL_METHOD__ == 2 && DBL_MANT_DIG == 53 && LDBL_MANT_DIG == 64 && LDBL_MIN_EXP == -16382\n");
507	printf("#define HAVE_DOUBLE_FMA_VEC\n");
508	printf("static const struct fma_vec fma_vec[] = {\n");
509	generate("", ieee_64, moto_80_optional);
510	printf("};\n");
511	printf("#endif\n");
512	printf("\n");
513
514	printf("#if LDBL_MANT_DIG == 64 && LDBL_MIN_EXP == -16381\n");
515	printf("#define HAVE_LONG_DOUBLE_FMA_VEC\n");
516	printf("static const struct fmal_vec fmal_vec[] = {\n");
517	generate("l", intel_80, none_format);
518	printf("};\n");
519	printf("#endif\n");
520	printf("\n");
521
522	printf("#if LDBL_MANT_DIG == 64 && LDBL_MIN_EXP == -16382\n");
523	printf("#define HAVE_LONG_DOUBLE_FMA_VEC\n");
524	printf("static const struct fmal_vec fmal_vec[] = {\n");
525	generate("l", moto_80, none_format);
526	printf("};\n");
527	printf("#endif\n");
528	printf("\n");
529
530	printf("#if LDBL_MANT_DIG == 113\n");
531	printf("#define HAVE_LONG_DOUBLE_FMA_VEC\n");
532	printf("static const struct fmal_vec fmal_vec[] = {\n");
533	generate("l", ieee_128, none_format);
534	printf("};\n");
535	printf("#endif\n");
536}
537
538main();
539