1# Copyright (c) 2020 Arm Limited
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import argparse
16import logging
17import struct
18import sys
19
20from imgtool.image import (IMAGE_HEADER_SIZE, IMAGE_MAGIC,
21                           TLV_INFO_MAGIC, TLV_PROT_INFO_MAGIC, TLV_VALUES)
22from shutil import copyfile
23
24
25def get_tlv_type_string(tlv_type):
26    tlvs = {v: f"IMAGE_TLV_{k}" for k, v in TLV_VALUES.items()}
27    return tlvs.get(tlv_type, "UNKNOWN({:d})".format(tlv_type))
28
29
30class ImageHeader:
31
32    def __init__(self):
33        self.ih_magic = 0
34        self.ih_load_addr = 0
35        self.ih_hdr_size = 0
36        self.ih_protect_tlv_size = 0
37        self.ih_img_size = 0
38        self.ih_flags = 0
39        self.iv_major = 0
40        self.iv_minor = 0
41        self.iv_revision = 0
42        self.iv_build_num = 0
43        self._pad1 = 0
44
45    @staticmethod
46    def read_from_binary(in_file):
47        h = ImageHeader()
48
49        (h.ih_magic, h.ih_load_addr, h.ih_hdr_size, h.ih_protect_tlv_size, h.ih_img_size,
50            h.ih_flags, h.iv_major, h.iv_minor, h.iv_revision, h.iv_build_num, h._pad1
51         ) = struct.unpack('<IIHHIIBBHII', in_file.read(IMAGE_HEADER_SIZE))
52        return h
53
54    def __repr__(self):
55        return "\n".join([
56            "    ih_magic = 0x{:X}".format(self.ih_magic),
57            "    ih_load_addr = " + str(self.ih_load_addr),
58            "    ih_hdr_size = " + str(self.ih_hdr_size),
59            "    ih_protect_tlv_size = " + str(self.ih_protect_tlv_size),
60            "    ih_img_size = " + str(self.ih_img_size),
61            "    ih_flags = " + str(self.ih_flags),
62            "    iv_major = " + str(self.iv_major),
63            "    iv_minor = " + str(self.iv_minor),
64            "    iv_revision = " + str(self.iv_revision),
65            "    iv_build_num = " + str(self.iv_build_num),
66            "    _pad1 = " + str(self._pad1)])
67
68
69class ImageTLVInfo:
70    def __init__(self):
71        self.format_string = '<HH'
72
73        self.it_magic = 0
74        self.it_tlv_tot = 0
75
76    @staticmethod
77    def read_from_binary(in_file):
78        i = ImageTLVInfo()
79
80        (i.it_magic, i.it_tlv_tot) = struct.unpack('<HH', in_file.read(4))
81        return i
82
83    def __repr__(self):
84        return "\n".join([
85            "    it_magic = 0x{:X}".format(self.it_magic),
86            "    it_tlv_tot = " + str(self.it_tlv_tot)])
87
88    def __len__(self):
89        return struct.calcsize(self.format_string)
90
91
92class ImageTLV:
93    def __init__(self):
94        self.it_value = 0
95        self.it_type = 0
96        self.it_len = 0
97
98    @staticmethod
99    def read_from_binary(in_file):
100        tlv = ImageTLV()
101        (tlv.it_type, _, tlv.it_len) = struct.unpack('<BBH', in_file.read(4))
102        (tlv.it_value) = struct.unpack('<{:d}s'.format(tlv.it_len), in_file.read(tlv.it_len))
103        return tlv
104
105    def __len__(self):
106        round_to = 1
107        return int((4 + self.it_len + round_to - 1) // round_to) * round_to
108
109
110def get_arguments():
111    parser = argparse.ArgumentParser(description='Corrupt an MCUBoot image')
112    parser.add_argument("-i", "--in-file", required=True, help='The input image to be corrupted (read only)')
113    parser.add_argument("-o", "--out-file", required=True, help='the corrupted image')
114    parser.add_argument('-a', '--image-hash',
115                        default=False,
116                        action="store_true",
117                        required=False,
118                        help='Corrupt the image hash')
119    parser.add_argument('-s', '--signature',
120                        default=False,
121                        action="store_true",
122                        required=False,
123                        help='Corrupt the signature of the image')
124    return parser.parse_args()
125
126
127def damage_tlv(image_offset, tlv_off, tlv, out_file_content):
128    damage_offset = image_offset + tlv_off + 4
129    logging.info("        Damaging TLV at offset 0x{:X}...".format(damage_offset))
130    value = bytearray(tlv.it_value[0])
131    value[0] = (value[0] + 1) % 256
132    out_file_content[damage_offset] = value[0]
133
134
135def is_valid_signature(tlv):
136    return tlv.it_type == TLV_VALUES['RSA2048'] or tlv.it_type == TLV_VALUES['RSA3072']
137
138
139def damage_image(args, in_file, out_file_content, image_offset):
140    in_file.seek(image_offset, 0)
141
142    # Find the Image header
143    image_header = ImageHeader.read_from_binary(in_file)
144    if image_header.ih_magic != IMAGE_MAGIC:
145        raise Exception("Invalid magic in image_header: 0x{:X} instead of 0x{:X}".format(image_header.ih_magic, IMAGE_MAGIC))
146
147    # Find the TLV header
148    tlv_info_offset = image_header.ih_hdr_size + image_header.ih_img_size
149    in_file.seek(image_offset + tlv_info_offset, 0)
150
151    tlv_info = ImageTLVInfo.read_from_binary(in_file)
152    if tlv_info.it_magic == TLV_PROT_INFO_MAGIC:
153        logging.debug("Protected TLV found at offset 0x{:X}".format(tlv_info_offset))
154        if image_header.ih_protect_tlv_size != tlv_info.it_tlv_tot:
155            raise Exception("Invalid prot TLV len ({:d} vs. {:d})".format(image_header.ih_protect_tlv_size, tlv_info.it_tlv_tot))
156
157        # seek to unprotected TLV
158        tlv_info_offset += tlv_info.it_tlv_tot
159        in_file.seek(image_offset + tlv_info_offset)
160        tlv_info = ImageTLVInfo.read_from_binary(in_file)
161
162    else:
163        if image_header.ih_protect_tlv_size != 0:
164            raise Exception("No prot TLV was found.")
165
166    logging.debug("Unprotected TLV found at offset 0x{:X}".format(tlv_info_offset))
167    if tlv_info.it_magic != TLV_INFO_MAGIC:
168        raise Exception("Invalid magic in tlv info: 0x{:X} instead of 0x{:X}".format(tlv_info.it_magic, TLV_INFO_MAGIC))
169
170    tlv_off = tlv_info_offset + len(ImageTLVInfo())
171    tlv_end = tlv_info_offset + tlv_info.it_tlv_tot
172
173    # iterate over the TLV entries
174    while tlv_off < tlv_end:
175        in_file.seek(image_offset + tlv_off, 0)
176        tlv = ImageTLV.read_from_binary(in_file)
177
178        logging.debug("    tlv {:24s} len = {:4d}, len = {:4d}".format(get_tlv_type_string(tlv.it_type), tlv.it_len, len(tlv)))
179
180        if is_valid_signature(tlv) and args.signature:
181            damage_tlv(image_offset, tlv_off, tlv, out_file_content)
182        elif tlv.it_type == TLV_VALUES['SHA256'] and args.image_hash:
183            damage_tlv(image_offset, tlv_off, tlv, out_file_content)
184
185        tlv_off += len(tlv)
186
187
188def main():
189    args = get_arguments()
190
191    logging.debug("The script was started")
192
193    copyfile(args.in_file, args.out_file)
194    in_file = open(args.in_file, 'rb')
195
196    out_file_content = bytearray(in_file.read())
197
198    damage_image(args, in_file, out_file_content, 0)
199
200    in_file.close()
201
202    file_to_damage = open(args.out_file, 'wb')
203    file_to_damage.write(out_file_content)
204    file_to_damage.close()
205
206
207if __name__ == "__main__":
208    logging.basicConfig(format='%(levelname)5s: %(message)s',
209                        level=logging.DEBUG, stream=sys.stdout)
210
211    main()
212