1import filecmp
2import hashlib
3import itertools
4import os
5import os.path
6import random
7import struct
8import subprocess
9import sys
10import tempfile
11from functools import partial
12
13IMAGES_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), "images")
14
15from conftest import need_to_install_package_err
16
17import pytest
18
19try:
20    from esptool.util import byte
21    from esptool.uf2_writer import UF2Writer
22    from esptool.targets import CHIP_DEFS
23except ImportError:
24    need_to_install_package_err()
25
26
27def read_image(filename):
28    with open(os.path.join(IMAGES_DIR, filename), "rb") as f:
29        return f.read()
30
31
32@pytest.mark.host_test
33class TestMergeBin:
34    def run_merge_bin(self, chip, offsets_names, options=[], allow_warnings=False):
35        """Run merge_bin on a list of (offset, filename) tuples
36        with output to a named temporary file.
37
38        Filenames are relative to the 'test/images' directory.
39
40        Returns the contents of the merged file if successful.
41        """
42        output_file = tempfile.NamedTemporaryFile(delete=False)
43        try:
44            output_file.close()
45
46            cmd = [
47                sys.executable,
48                "-m",
49                "esptool",
50                "--chip",
51                chip,
52                "merge_bin",
53                "-o",
54                output_file.name,
55            ] + options
56            for offset, name in offsets_names:
57                cmd += [hex(offset), name]
58            print("\nExecuting {}".format(" ".join(cmd)))
59
60            output = subprocess.check_output(
61                cmd, cwd=IMAGES_DIR, stderr=subprocess.STDOUT
62            )
63            output = output.decode("utf-8")
64            print(output)
65            if not allow_warnings:
66                assert (
67                    "warning" not in output.lower()
68                ), "merge_bin should not output warnings"
69
70            with open(output_file.name, "rb") as f:
71                return f.read()
72        except subprocess.CalledProcessError as e:
73            print(e.output)
74            raise
75        finally:
76            os.unlink(output_file.name)
77
78    def assertAllFF(self, some_bytes):
79        # this may need some improving as the failed assert messages may be
80        # very long and/or useless!
81        assert b"\xFF" * len(some_bytes) == some_bytes
82
83    def test_simple_merge(self):
84        merged = self.run_merge_bin(
85            "esp8266",
86            [(0x0, "one_kb.bin"), (0x1000, "one_kb.bin"), (0x10000, "one_kb.bin")],
87        )
88        one_kb = read_image("one_kb.bin")
89
90        assert len(one_kb) == 0x400
91
92        assert len(merged) == 0x10400
93        assert merged[:0x400] == one_kb
94        assert merged[0x1000:0x1400] == one_kb
95        assert merged[0x10000:] == one_kb
96
97        self.assertAllFF(merged[0x400:0x1000])
98        self.assertAllFF(merged[0x1400:0x10000])
99
100    def test_args_out_of_order(self):
101        # no matter which order we supply arguments, the output should be the same
102        args = [(0x0, "one_kb.bin"), (0x1000, "one_kb.bin"), (0x10000, "one_kb.bin")]
103        merged_orders = [
104            self.run_merge_bin("esp8266", perm_args)
105            for perm_args in itertools.permutations(args)
106        ]
107        for m in merged_orders:
108            assert m == merged_orders[0]
109
110    def test_error_overlap(self, capsys):
111        args = [(0x1000, "one_mb.bin"), (0x20000, "one_kb.bin")]
112        for perm_args in itertools.permutations(args):
113            with pytest.raises(subprocess.CalledProcessError):
114                self.run_merge_bin("esp32", perm_args)
115            output = capsys.readouterr().out
116            assert "overlap" in output
117
118    def test_leading_padding(self):
119        merged = self.run_merge_bin("esp32c3", [(0x100000, "one_mb.bin")])
120        self.assertAllFF(merged[:0x100000])
121        assert read_image("one_mb.bin") == merged[0x100000:]
122
123    def test_update_bootloader_params(self):
124        merged = self.run_merge_bin(
125            "esp32",
126            [
127                (0x1000, "bootloader_esp32.bin"),
128                (0x10000, "ram_helloworld/helloworld-esp32.bin"),
129            ],
130            ["--flash_size", "2MB", "--flash_mode", "dout"],
131        )
132        self.assertAllFF(merged[:0x1000])
133
134        bootloader = read_image("bootloader_esp32.bin")
135        helloworld = read_image("ram_helloworld/helloworld-esp32.bin")
136
137        # test the bootloader is unchanged apart from the header
138        # (updating the header doesn't change CRC,
139        # and doesn't update the SHA although it will invalidate it!)
140        assert merged[0x1010 : 0x1000 + len(bootloader)] == bootloader[0x10:]
141
142        # check the individual bytes in the header are as expected
143        merged_hdr = merged[0x1000:0x1010]
144        bootloader_hdr = bootloader[:0x10]
145        assert bootloader_hdr[:2] == merged_hdr[:2]
146        assert byte(merged_hdr, 2) == 3  # flash mode dout
147        assert byte(merged_hdr, 3) & 0xF0 == 0x10  # flash size 2MB (ESP32)
148        # flash freq is unchanged
149        assert byte(bootloader_hdr, 3) & 0x0F == byte(merged_hdr, 3) & 0x0F
150        assert bootloader_hdr[4:] == merged_hdr[4:]  # remaining field are unchanged
151
152        # check all the padding is as expected
153        self.assertAllFF(merged[0x1000 + len(bootloader) : 0x10000])
154        assert merged[0x10000 : 0x10000 + len(helloworld)], helloworld
155
156    def test_target_offset(self):
157        merged = self.run_merge_bin(
158            "esp32",
159            [
160                (0x1000, "bootloader_esp32.bin"),
161                (0x10000, "ram_helloworld/helloworld-esp32.bin"),
162            ],
163            ["--target-offset", "0x1000"],
164        )
165
166        bootloader = read_image("bootloader_esp32.bin")
167        helloworld = read_image("ram_helloworld/helloworld-esp32.bin")
168        assert bootloader == merged[: len(bootloader)]
169        assert helloworld == merged[0xF000 : 0xF000 + len(helloworld)]
170        self.assertAllFF(merged[0x1000 + len(bootloader) : 0xF000])
171
172    def test_fill_flash_size(self):
173        merged = self.run_merge_bin(
174            "esp32c3", [(0x0, "bootloader_esp32c3.bin")], ["--fill-flash-size", "4MB"]
175        )
176        bootloader = read_image("bootloader_esp32c3.bin")
177
178        assert len(merged) == 0x400000
179        assert bootloader == merged[: len(bootloader)]
180        self.assertAllFF(merged[len(bootloader) :])
181
182    def test_fill_flash_size_w_target_offset(self):
183        merged = self.run_merge_bin(
184            "esp32",
185            [
186                (0x1000, "bootloader_esp32.bin"),
187                (0x10000, "ram_helloworld/helloworld-esp32.bin"),
188            ],
189            ["--target-offset", "0x1000", "--fill-flash-size", "2MB"],
190        )
191
192        # full length is without target-offset arg
193        assert len(merged) == 0x200000 - 0x1000
194
195        bootloader = read_image("bootloader_esp32.bin")
196        helloworld = read_image("ram_helloworld/helloworld-esp32.bin")
197        assert bootloader == merged[: len(bootloader)]
198        assert helloworld == merged[0xF000 : 0xF000 + len(helloworld)]
199        self.assertAllFF(merged[0xF000 + len(helloworld) :])
200
201    def test_merge_mixed(self):
202        # convert bootloader to hex
203        hex = self.run_merge_bin(
204            "esp32",
205            [(0x1000, "bootloader_esp32.bin")],
206            options=["--format", "hex"],
207            allow_warnings=True,
208        )
209        # create a temp file with hex content
210        with tempfile.NamedTemporaryFile(suffix=".hex", delete=False) as f:
211            f.write(hex)
212        # merge hex file with bin file
213        # output to bin file should be the same as in merge bin + bin
214        try:
215            merged = self.run_merge_bin(
216                "esp32",
217                [(0x1000, f.name), (0x10000, "ram_helloworld/helloworld-esp32.bin")],
218                ["--target-offset", "0x1000", "--fill-flash-size", "2MB"],
219            )
220        finally:
221            os.unlink(f.name)
222        # full length is without target-offset arg
223        assert len(merged) == 0x200000 - 0x1000
224
225        bootloader = read_image("bootloader_esp32.bin")
226        helloworld = read_image("ram_helloworld/helloworld-esp32.bin")
227        assert bootloader == merged[: len(bootloader)]
228        assert helloworld == merged[0xF000 : 0xF000 + len(helloworld)]
229        self.assertAllFF(merged[0xF000 + len(helloworld) :])
230
231    def test_merge_bin2hex(self):
232        merged = self.run_merge_bin(
233            "esp32",
234            [
235                (0x1000, "bootloader_esp32.bin"),
236            ],
237            options=["--format", "hex"],
238            allow_warnings=True,
239        )
240        lines = merged.splitlines()
241        # hex format - :0300300002337A1E
242        # :03          0030  00    02337A 1E
243        #  ^data_cnt/2 ^addr ^type ^data  ^checksum
244
245        # check for starting address - 0x1000 passed from arg
246        assert lines[0][3:7] == b"1000"
247        # pick a random line for testing the format
248        line = lines[random.randrange(0, len(lines))]
249        assert line[0] == ord(":")
250        data_len = int(b"0x" + line[1:3], 16)
251        # : + len + addr + type + data + checksum
252        assert len(line) == 1 + 2 + 4 + 2 + data_len * 2 + 2
253        # last line is always :00000001FF
254        assert lines[-1] == b":00000001FF"
255        # convert back and verify the result against the source bin file
256        with tempfile.NamedTemporaryFile(suffix=".hex", delete=False) as hex:
257            hex.write(merged)
258        merged_bin = self.run_merge_bin(
259            "esp32",
260            [(0x1000, hex.name)],
261            options=["--format", "raw"],
262        )
263        source = read_image("bootloader_esp32.bin")
264        # verify that padding was done correctly
265        assert b"\xFF" * 0x1000 == merged_bin[:0x1000]
266        # verify the file itself
267        assert source == merged_bin[0x1000:]
268
269    def test_hex_header_raw_file(self):
270        # use raw binary file starting with colon
271        with tempfile.NamedTemporaryFile(delete=False) as f:
272            f.write(b":")
273        try:
274            merged = self.run_merge_bin("esp32", [(0x0, f.name)])
275            assert merged == b":"
276        finally:
277            os.unlink(f.name)
278
279
280class UF2Block(object):
281    def __init__(self, bs):
282        self.length = len(bs)
283
284        # See https://github.com/microsoft/uf2 for the format
285        first_part = "<" + "I" * 8
286        # payload is between
287        last_part = "<I"
288
289        first_part_len = struct.calcsize(first_part)
290        last_part_len = struct.calcsize(last_part)
291
292        (
293            self.magicStart0,
294            self.magicStart1,
295            self.flags,
296            self.targetAddr,
297            self.payloadSize,
298            self.blockNo,
299            self.numBlocks,
300            self.familyID,
301        ) = struct.unpack(first_part, bs[:first_part_len])
302
303        self.data = bs[first_part_len:-last_part_len]
304
305        (self.magicEnd,) = struct.unpack(last_part, bs[-last_part_len:])
306
307    def __len__(self):
308        return self.length
309
310
311class UF2BlockReader(object):
312    def __init__(self, f_name):
313        self.f_name = f_name
314
315    def get(self):
316        with open(self.f_name, "rb") as f:
317            for chunk in iter(partial(f.read, UF2Writer.UF2_BLOCK_SIZE), b""):
318                yield UF2Block(chunk)
319
320
321class BinaryWriter(object):
322    def __init__(self, f_name):
323        self.f_name = f_name
324
325    def append(self, data):
326        # File is reopened several times in order to make sure that won't left open
327        with open(self.f_name, "ab") as f:
328            f.write(data)
329
330
331@pytest.mark.host_test
332class TestUF2:
333    def generate_binary(self, size):
334        with tempfile.NamedTemporaryFile(delete=False) as f:
335            for _ in range(size):
336                f.write(struct.pack("B", random.randrange(0, 1 << 7)))
337            return f.name
338
339    @staticmethod
340    def generate_chipID():
341        chip, rom = random.choice(list(CHIP_DEFS.items()))
342        family_id = rom.UF2_FAMILY_ID
343        return chip, family_id
344
345    def generate_uf2(
346        self,
347        of_name,
348        chip_id,
349        iter_addr_offset_tuples,
350        chunk_size=None,
351        md5_enable=True,
352    ):
353        com_args = [
354            sys.executable,
355            "-m",
356            "esptool",
357            "--chip",
358            chip_id,
359            "merge_bin",
360            "--format",
361            "uf2",
362            "-o",
363            of_name,
364        ]
365        if not md5_enable:
366            com_args.append("--md5-disable")
367        com_args += [] if chunk_size is None else ["--chunk-size", str(chunk_size)]
368        file_args = list(
369            itertools.chain(*[(hex(addr), f) for addr, f in iter_addr_offset_tuples])
370        )
371
372        output = subprocess.check_output(com_args + file_args, stderr=subprocess.STDOUT)
373        output = output.decode("utf-8")
374        print(output)
375        assert "warning" not in output.lower(), "merge_bin should not output warnings"
376
377        exp_list = [f"Adding {f} at {hex(addr)}" for addr, f in iter_addr_offset_tuples]
378        exp_list += [
379            f"bytes to file {of_name}, ready to be flashed with any ESP USB Bridge"
380        ]
381        for e in exp_list:
382            assert e in output
383
384        return of_name
385
386    def process_blocks(self, uf2block, expected_chip_id, md5_enable=True):
387        flags = UF2Writer.UF2_FLAG_FAMILYID_PRESENT
388        if md5_enable:
389            flags |= UF2Writer.UF2_FLAG_MD5_PRESENT
390
391        parsed_binaries = []
392
393        block_list = []  # collect block numbers here
394        total_blocks = set()  # collect total block numbers here
395        for block in UF2BlockReader(uf2block).get():
396            if block.blockNo == 0:
397                # new file has been detected
398                base_addr = block.targetAddr
399                current_addr = base_addr
400                binary_writer = BinaryWriter(self.generate_binary(0))
401
402            assert len(block) == UF2Writer.UF2_BLOCK_SIZE
403            assert block.magicStart0 == UF2Writer.UF2_FIRST_MAGIC
404            assert block.magicStart1 == UF2Writer.UF2_SECOND_MAGIC
405            assert block.flags & flags == flags
406
407            assert len(block.data) == UF2Writer.UF2_DATA_SIZE
408            payload = block.data[: block.payloadSize]
409            if md5_enable:
410                md5_obj = hashlib.md5(payload)
411                md5_part = block.data[
412                    block.payloadSize : block.payloadSize + UF2Writer.UF2_MD5_PART_SIZE
413                ]
414                address, length = struct.unpack("<II", md5_part[: -md5_obj.digest_size])
415                md5sum = md5_part[-md5_obj.digest_size :]
416                assert address == block.targetAddr
417                assert length == block.payloadSize
418                assert md5sum == md5_obj.digest()
419
420            assert block.familyID == expected_chip_id
421            assert block.magicEnd == UF2Writer.UF2_FINAL_MAGIC
422
423            assert current_addr == block.targetAddr
424            binary_writer.append(payload)
425
426            block_list.append(block.blockNo)
427            total_blocks.add(block.numBlocks)
428            if block.blockNo == block.numBlocks - 1:
429                assert block_list == list(range(block.numBlocks))
430                # we have found all blocks and in the right order
431                assert total_blocks == {
432                    block.numBlocks
433                }  # numBlocks are the same in all the blocks
434                del block_list[:]
435                total_blocks.clear()
436
437                parsed_binaries += [(base_addr, binary_writer.f_name)]
438
439            current_addr += block.payloadSize
440        return parsed_binaries
441
442    def common(self, t, chunk_size=None, md5_enable=True):
443        of_name = self.generate_binary(0)
444        try:
445            chip_name, chip_id = self.generate_chipID()
446            self.generate_uf2(of_name, chip_name, t, chunk_size, md5_enable)
447            parsed_t = self.process_blocks(of_name, chip_id, md5_enable)
448
449            assert len(t) == len(parsed_t)
450            for (orig_addr, orig_fname), (addr, fname) in zip(t, parsed_t):
451                assert orig_addr == addr
452                assert filecmp.cmp(orig_fname, fname)
453        finally:
454            os.unlink(of_name)
455            for _, file_name in t:
456                os.unlink(file_name)
457
458    def test_simple(self):
459        self.common([(0, self.generate_binary(1))])
460
461    def test_more_files(self):
462        self.common(
463            [(0x100, self.generate_binary(1)), (0x1000, self.generate_binary(1))]
464        )
465
466    def test_larger_files(self):
467        self.common(
468            [(0x100, self.generate_binary(6)), (0x1000, self.generate_binary(8))]
469        )
470
471    def test_boundaries(self):
472        self.common(
473            [
474                (0x100, self.generate_binary(UF2Writer.UF2_DATA_SIZE)),
475                (0x2000, self.generate_binary(UF2Writer.UF2_DATA_SIZE + 1)),
476                (0x3000, self.generate_binary(UF2Writer.UF2_DATA_SIZE - 1)),
477            ]
478        )
479
480    def test_files_with_more_blocks(self):
481        self.common(
482            [
483                (0x100, self.generate_binary(3 * UF2Writer.UF2_DATA_SIZE)),
484                (0x2000, self.generate_binary(2 * UF2Writer.UF2_DATA_SIZE + 1)),
485                (0x3000, self.generate_binary(2 * UF2Writer.UF2_DATA_SIZE - 1)),
486            ]
487        )
488
489    def test_very_large_files(self):
490        self.common(
491            [
492                (0x100, self.generate_binary(20 * UF2Writer.UF2_DATA_SIZE + 5)),
493                (0x10000, self.generate_binary(50 * UF2Writer.UF2_DATA_SIZE + 100)),
494                (0x100000, self.generate_binary(100 * UF2Writer.UF2_DATA_SIZE)),
495            ]
496        )
497
498    def test_chunk_size(self):
499        chunk_size = 256
500        self.common(
501            [
502                (0x1000, self.generate_binary(chunk_size)),
503                (0x2000, self.generate_binary(chunk_size + 1)),
504                (0x3000, self.generate_binary(chunk_size - 1)),
505            ],
506            chunk_size,
507        )
508
509    def test_md5_disable(self):
510        self.common(
511            [(0x100, self.generate_binary(1)), (0x2000, self.generate_binary(1))],
512            md5_enable=False,
513        )
514