1#
2# Copyright (c) 2022 Nordic Semiconductor ASA
3#
4# SPDX-License-Identifier: Apache-2.0
5#
6
7import numpy
8import sys
9import os
10
11def decode_test():
12	num_start = 0
13	num_end = 0x10000
14
15	a = numpy.frombuffer(numpy.arange(num_start, num_end, dtype=numpy.ushort).tobytes(), dtype=numpy.float16)
16	with open(os.path.join(sys.argv[1], "fp_bytes_decode.bin"), 'wb') as f:
17		f.write(a.astype("<f").tobytes() + a.astype(">f").tobytes())
18
19def encode_test():
20	num_start = 0x33000001
21	num_end = 0x477ff000
22
23	a = numpy.arange(num_start, num_end, dtype=numpy.uintc)
24	b = numpy.frombuffer(a.tobytes(), dtype=numpy.float32).astype("<e") # <e is little endian float16
25	c = numpy.where(b[1:] != b[:-1])[0].astype(numpy.uintc) + 1
26	assert all(numpy.frombuffer((b[c].tobytes()), dtype=numpy.ushort) == numpy.arange(2, 31744))
27
28	with open(os.path.join(sys.argv[1], "fp_bytes_encode.bin"), 'wb') as f:
29		f.write(c.astype("<I").tobytes() + c.astype(">I").tobytes())
30
31def print_help():
32	print("Generate bin files with results from converting between float16 and float32 (both ways)")
33	print()
34	print(f"Usage: {sys.argv[0]} <directory to place bin files in>")
35
36if __name__ == "__main__":
37	if "--help" in sys.argv or "-h" in sys.argv or len(sys.argv) < 2:
38		print_help()
39	elif len(sys.argv) < 3:
40		decode_test()
41		encode_test()
42	elif sys.argv[2] == "decode":
43		decode_test()
44	elif sys.argv[2] == "encode":
45		encode_test()
46