1#!/usr/bin/env python
2#
3# Copyright 2020-2021 Espressif Systems (Shanghai) CO LTD
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17# This program creates archives compatible with ESP32-S* ROM DFU implementation.
18#
19# The archives are in CPIO format. Each file which needs to be flashed is added to the archive
20# as a separate file. In addition to that, a special index file, 'dfuinfo0.dat', is created.
21# This file must be the first one in the archive. It contains binary structures describing each
22# subsequent file (for example, where the file needs to be flashed/loaded).
23
24from __future__ import print_function, unicode_literals
25
26import argparse
27import hashlib
28import json
29import os
30import struct
31import zlib
32from collections import namedtuple
33from functools import partial
34
35from future.utils import iteritems
36
37try:
38    import typing
39except ImportError:
40    # Only used for type annotations
41    pass
42
43try:
44    from itertools import izip as zip  # type: ignore
45except ImportError:
46    # Python 3
47    pass
48
49# CPIO ("new ASCII") format related things
50CPIO_MAGIC = b'070701'
51CPIO_STRUCT = b'=6s' + b'8s' * 13
52CPIOHeader = namedtuple(
53    'CPIOHeader',
54    [
55        'magic',
56        'ino',
57        'mode',
58        'uid',
59        'gid',
60        'nlink',
61        'mtime',
62        'filesize',
63        'devmajor',
64        'devminor',
65        'rdevmajor',
66        'rdevminor',
67        'namesize',
68        'check',
69    ],
70)
71CPIO_TRAILER = 'TRAILER!!!'
72
73
74def make_cpio_header(
75    filename_len, file_len, is_trailer=False
76):  # type: (int, int, bool) -> CPIOHeader
77    """ Returns CPIOHeader for the given file name and file size """
78
79    def as_hex(val):  # type: (int) -> bytes
80        return '{:08x}'.format(val).encode('ascii')
81
82    hex_0 = as_hex(0)
83    mode = hex_0 if is_trailer else as_hex(0o0100644)
84    nlink = as_hex(1) if is_trailer else hex_0
85    return CPIOHeader(
86        magic=CPIO_MAGIC,
87        ino=hex_0,
88        mode=mode,
89        uid=hex_0,
90        gid=hex_0,
91        nlink=nlink,
92        mtime=hex_0,
93        filesize=as_hex(file_len),
94        devmajor=hex_0,
95        devminor=hex_0,
96        rdevmajor=hex_0,
97        rdevminor=hex_0,
98        namesize=as_hex(filename_len),
99        check=hex_0,
100    )
101
102
103# DFU format related things
104# Structure of one entry in dfuinfo0.dat
105DFUINFO_STRUCT = b'<I I 64s 16s'
106DFUInfo = namedtuple('DFUInfo', ['address', 'flags', 'name', 'md5'])
107DFUINFO_FILE = 'dfuinfo0.dat'
108# Structure which gets added at the end of the entire DFU file
109DFUSUFFIX_STRUCT = b'<H H H H 3s B'
110DFUSuffix = namedtuple(
111    'DFUSuffix', ['bcd_device', 'pid', 'vid', 'bcd_dfu', 'sig', 'len']
112)
113ESPRESSIF_VID = 12346
114# This CRC32 gets added after DFUSUFFIX_STRUCT
115DFUCRC_STRUCT = b'<I'
116
117
118def dfu_crc(data, crc=0):  # type: (bytes, int) -> int
119    """ Calculate CRC32/JAMCRC of data, with an optional initial value """
120    uint32_max = 0xFFFFFFFF
121    return uint32_max - (zlib.crc32(data, crc) & uint32_max)
122
123
124def pad_bytes(b, multiple, padding=b'\x00'):  # type: (bytes, int, bytes) -> bytes
125    """ Pad 'b' to a length divisible by 'multiple' """
126    padded_len = (len(b) + multiple - 1) // multiple * multiple
127    return b + padding * (padded_len - len(b))
128
129
130class EspDfuWriter(object):
131    def __init__(self, dest_file, pid, part_size):  # type: (typing.BinaryIO, int, int) -> None
132        self.dest = dest_file
133        self.pid = pid
134        self.part_size = part_size
135        self.entries = []  # type: typing.List[bytes]
136        self.index = []  # type: typing.List[DFUInfo]
137
138    def add_file(self, flash_addr, path):  # type: (int, str) -> None
139        """
140        Add file to be written into flash at given address
141
142        Files are split up into chunks in order avoid timing-out during erasing large regions. Instead of adding
143        "app.bin" at flash_addr it will add:
144        1. app.bin   at flash_addr  # sizeof(app.bin) == self.part_size
145        2. app.bin.1 at flash_addr + self.part_size
146        3. app.bin.2 at flash_addr + 2 * self.part_size
147        ...
148
149        """
150        f_name = os.path.basename(path)
151        with open(path, 'rb') as f:
152            for i, chunk in enumerate(iter(partial(f.read, self.part_size), b'')):
153                n = f_name if i == 0 else '.'.join([f_name, str(i)])
154                self._add_cpio_flash_entry(n, flash_addr, chunk)
155                flash_addr += len(chunk)
156
157    def finish(self):  # type: () -> None
158        """ Write DFU file """
159        # Prepare and add dfuinfo0.dat file
160        dfuinfo = b''.join([struct.pack(DFUINFO_STRUCT, *item) for item in self.index])
161        self._add_cpio_entry(DFUINFO_FILE, dfuinfo, first=True)
162
163        # Add CPIO archive trailer
164        self._add_cpio_entry(CPIO_TRAILER, b'', trailer=True)
165
166        # Combine all the entries and pad the file
167        out_data = b''.join(self.entries)
168        cpio_block_size = 10240
169        out_data = pad_bytes(out_data, cpio_block_size)
170
171        # Add DFU suffix and CRC
172        dfu_suffix = DFUSuffix(0xFFFF, self.pid, ESPRESSIF_VID, 0x0100, b'UFD', 16)
173        out_data += struct.pack(DFUSUFFIX_STRUCT, *dfu_suffix)
174        out_data += struct.pack(DFUCRC_STRUCT, dfu_crc(out_data))
175
176        # Finally write the entire binary
177        self.dest.write(out_data)
178
179    def _add_cpio_flash_entry(
180        self, filename, flash_addr, data
181    ):  # type: (str, int, bytes) -> None
182        md5 = hashlib.md5()
183        md5.update(data)
184        self.index.append(
185            DFUInfo(
186                address=flash_addr,
187                flags=0,
188                name=filename.encode('utf-8'),
189                md5=md5.digest(),
190            )
191        )
192        self._add_cpio_entry(filename, data)
193
194    def _add_cpio_entry(
195        self, filename, data, first=False, trailer=False
196    ):  # type: (str, bytes, bool, bool) -> None
197        filename_b = filename.encode('utf-8') + b'\x00'
198        cpio_header = make_cpio_header(len(filename_b), len(data), is_trailer=trailer)
199        entry = pad_bytes(
200            struct.pack(CPIO_STRUCT, *cpio_header) + filename_b, 4
201        ) + pad_bytes(data, 4)
202        if not first:
203            self.entries.append(entry)
204        else:
205            self.entries.insert(0, entry)
206
207
208def action_write(args):  # type: (typing.Mapping[str, typing.Any]) -> None
209    writer = EspDfuWriter(args['output_file'], args['pid'], args['part_size'])
210    for addr, f in args['files']:
211        print('Adding {} at {:#x}'.format(f, addr))
212        writer.add_file(addr, f)
213    writer.finish()
214    print('"{}" has been written. You may proceed with DFU flashing.'.format(args['output_file'].name))
215    if args['part_size'] % (4 * 1024) != 0:
216        print('WARNING: Partition size of DFU is not multiple of 4k (4096). You might get unexpected behavior.')
217
218
219def main():  # type: () -> None
220    parser = argparse.ArgumentParser()
221
222    # Provision to add "info" command
223    subparsers = parser.add_subparsers(dest='command')
224    write_parser = subparsers.add_parser('write')
225    write_parser.add_argument('-o', '--output-file',
226                              help='Filename for storing the output DFU image',
227                              required=True,
228                              type=argparse.FileType('wb'))
229    write_parser.add_argument('--pid',
230                              required=True,
231                              type=lambda h: int(h, 16),
232                              help='Hexa-decimal product indentificator')
233    write_parser.add_argument('--json',
234                              help='Optional file for loading "flash_files" dictionary with <address> <file> items')
235    write_parser.add_argument('--part-size',
236                              default=os.environ.get('ESP_DFU_PART_SIZE', 512 * 1024),
237                              type=lambda x: int(x, 0),
238                              help='Larger files are split-up into smaller partitions of this size')
239    write_parser.add_argument('files',
240                              metavar='<address> <file>', help='Add <file> at <address>',
241                              nargs='*')
242
243    args = parser.parse_args()
244
245    def check_file(file_name):  # type: (str) -> str
246        if not os.path.isfile(file_name):
247            raise RuntimeError('{} is not a regular file!'.format(file_name))
248        return file_name
249
250    files = []
251    if args.files:
252        files += [(int(addr, 0), check_file(f_name)) for addr, f_name in zip(args.files[::2], args.files[1::2])]
253
254    if args.json:
255        json_dir = os.path.dirname(os.path.abspath(args.json))
256
257        def process_json_file(path):  # type: (str) -> str
258            '''
259            The input path is relative to json_dir. This function makes it relative to the current working
260            directory.
261            '''
262            return check_file(os.path.relpath(os.path.join(json_dir, path), start=os.curdir))
263
264        with open(args.json) as f:
265            files += [(int(addr, 0),
266                       process_json_file(f_name)) for addr, f_name in iteritems(json.load(f)['flash_files'])]
267
268    files = sorted([(addr, f_name.decode('utf-8') if isinstance(f_name, type(b'')) else f_name) for addr, f_name in iteritems(dict(files))],
269                   key=lambda x: x[0])  # remove possible duplicates and sort based on the address
270
271    cmd_args = {'output_file': args.output_file,
272                'files': files,
273                'pid': args.pid,
274                'part_size': args.part_size,
275                }
276
277    {'write': action_write
278     }[args.command](cmd_args)
279
280
281if __name__ == '__main__':
282    main()
283