1# Copyright (c) 2020 Arm Limited 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import argparse 16import logging 17import struct 18import sys 19 20from imgtool.image import (IMAGE_HEADER_SIZE, IMAGE_MAGIC, 21 TLV_INFO_MAGIC, TLV_PROT_INFO_MAGIC, TLV_VALUES) 22from shutil import copyfile 23 24 25def get_tlv_type_string(tlv_type): 26 tlvs = {v: f"IMAGE_TLV_{k}" for k, v in TLV_VALUES.items()} 27 return tlvs.get(tlv_type, "UNKNOWN({:d})".format(tlv_type)) 28 29 30class ImageHeader: 31 32 def __init__(self): 33 self.ih_magic = 0 34 self.ih_load_addr = 0 35 self.ih_hdr_size = 0 36 self.ih_protect_tlv_size = 0 37 self.ih_img_size = 0 38 self.ih_flags = 0 39 self.iv_major = 0 40 self.iv_minor = 0 41 self.iv_revision = 0 42 self.iv_build_num = 0 43 self._pad1 = 0 44 45 @staticmethod 46 def read_from_binary(in_file): 47 h = ImageHeader() 48 49 (h.ih_magic, h.ih_load_addr, h.ih_hdr_size, h.ih_protect_tlv_size, h.ih_img_size, 50 h.ih_flags, h.iv_major, h.iv_minor, h.iv_revision, h.iv_build_num, h._pad1 51 ) = struct.unpack('<IIHHIIBBHII', in_file.read(IMAGE_HEADER_SIZE)) 52 return h 53 54 def __repr__(self): 55 return "\n".join([ 56 " ih_magic = 0x{:X}".format(self.ih_magic), 57 " ih_load_addr = " + str(self.ih_load_addr), 58 " ih_hdr_size = " + str(self.ih_hdr_size), 59 " ih_protect_tlv_size = " + str(self.ih_protect_tlv_size), 60 " ih_img_size = " + str(self.ih_img_size), 61 " ih_flags = " + str(self.ih_flags), 62 " iv_major = " + str(self.iv_major), 63 " iv_minor = " + str(self.iv_minor), 64 " iv_revision = " + str(self.iv_revision), 65 " iv_build_num = " + str(self.iv_build_num), 66 " _pad1 = " + str(self._pad1)]) 67 68 69class ImageTLVInfo: 70 def __init__(self): 71 self.format_string = '<HH' 72 73 self.it_magic = 0 74 self.it_tlv_tot = 0 75 76 @staticmethod 77 def read_from_binary(in_file): 78 i = ImageTLVInfo() 79 80 (i.it_magic, i.it_tlv_tot) = struct.unpack('<HH', in_file.read(4)) 81 return i 82 83 def __repr__(self): 84 return "\n".join([ 85 " it_magic = 0x{:X}".format(self.it_magic), 86 " it_tlv_tot = " + str(self.it_tlv_tot)]) 87 88 def __len__(self): 89 return struct.calcsize(self.format_string) 90 91 92class ImageTLV: 93 def __init__(self): 94 self.it_value = 0 95 self.it_type = 0 96 self.it_len = 0 97 98 @staticmethod 99 def read_from_binary(in_file): 100 tlv = ImageTLV() 101 (tlv.it_type, _, tlv.it_len) = struct.unpack('<BBH', in_file.read(4)) 102 (tlv.it_value) = struct.unpack('<{:d}s'.format(tlv.it_len), in_file.read(tlv.it_len)) 103 return tlv 104 105 def __len__(self): 106 round_to = 1 107 return int((4 + self.it_len + round_to - 1) // round_to) * round_to 108 109 110def get_arguments(): 111 parser = argparse.ArgumentParser(description='Corrupt an MCUBoot image') 112 parser.add_argument("-i", "--in-file", required=True, help='The input image to be corrupted (read only)') 113 parser.add_argument("-o", "--out-file", required=True, help='the corrupted image') 114 parser.add_argument('-a', '--image-hash', 115 default=False, 116 action="store_true", 117 required=False, 118 help='Corrupt the image hash') 119 parser.add_argument('-s', '--signature', 120 default=False, 121 action="store_true", 122 required=False, 123 help='Corrupt the signature of the image') 124 return parser.parse_args() 125 126 127def damage_tlv(image_offset, tlv_off, tlv, out_file_content): 128 damage_offset = image_offset + tlv_off + 4 129 logging.info(" Damaging TLV at offset 0x{:X}...".format(damage_offset)) 130 value = bytearray(tlv.it_value[0]) 131 value[0] = (value[0] + 1) % 256 132 out_file_content[damage_offset] = value[0] 133 134 135def is_valid_signature(tlv): 136 return tlv.it_type == TLV_VALUES['RSA2048'] or tlv.it_type == TLV_VALUES['RSA3072'] 137 138 139def damage_image(args, in_file, out_file_content, image_offset): 140 in_file.seek(image_offset, 0) 141 142 # Find the Image header 143 image_header = ImageHeader.read_from_binary(in_file) 144 if image_header.ih_magic != IMAGE_MAGIC: 145 raise Exception("Invalid magic in image_header: 0x{:X} instead of 0x{:X}".format(image_header.ih_magic, IMAGE_MAGIC)) 146 147 # Find the TLV header 148 tlv_info_offset = image_header.ih_hdr_size + image_header.ih_img_size 149 in_file.seek(image_offset + tlv_info_offset, 0) 150 151 tlv_info = ImageTLVInfo.read_from_binary(in_file) 152 if tlv_info.it_magic == TLV_PROT_INFO_MAGIC: 153 logging.debug("Protected TLV found at offset 0x{:X}".format(tlv_info_offset)) 154 if image_header.ih_protect_tlv_size != tlv_info.it_tlv_tot: 155 raise Exception("Invalid prot TLV len ({:d} vs. {:d})".format(image_header.ih_protect_tlv_size, tlv_info.it_tlv_tot)) 156 157 # seek to unprotected TLV 158 tlv_info_offset += tlv_info.it_tlv_tot 159 in_file.seek(image_offset + tlv_info_offset) 160 tlv_info = ImageTLVInfo.read_from_binary(in_file) 161 162 else: 163 if image_header.ih_protect_tlv_size != 0: 164 raise Exception("No prot TLV was found.") 165 166 logging.debug("Unprotected TLV found at offset 0x{:X}".format(tlv_info_offset)) 167 if tlv_info.it_magic != TLV_INFO_MAGIC: 168 raise Exception("Invalid magic in tlv info: 0x{:X} instead of 0x{:X}".format(tlv_info.it_magic, TLV_INFO_MAGIC)) 169 170 tlv_off = tlv_info_offset + len(ImageTLVInfo()) 171 tlv_end = tlv_info_offset + tlv_info.it_tlv_tot 172 173 # iterate over the TLV entries 174 while tlv_off < tlv_end: 175 in_file.seek(image_offset + tlv_off, 0) 176 tlv = ImageTLV.read_from_binary(in_file) 177 178 logging.debug(" tlv {:24s} len = {:4d}, len = {:4d}".format(get_tlv_type_string(tlv.it_type), tlv.it_len, len(tlv))) 179 180 if is_valid_signature(tlv) and args.signature: 181 damage_tlv(image_offset, tlv_off, tlv, out_file_content) 182 elif tlv.it_type == TLV_VALUES['SHA256'] and args.image_hash: 183 damage_tlv(image_offset, tlv_off, tlv, out_file_content) 184 185 tlv_off += len(tlv) 186 187 188def main(): 189 args = get_arguments() 190 191 logging.debug("The script was started") 192 193 copyfile(args.in_file, args.out_file) 194 in_file = open(args.in_file, 'rb') 195 196 out_file_content = bytearray(in_file.read()) 197 198 damage_image(args, in_file, out_file_content, 0) 199 200 in_file.close() 201 202 file_to_damage = open(args.out_file, 'wb') 203 file_to_damage.write(out_file_content) 204 file_to_damage.close() 205 206 207if __name__ == "__main__": 208 logging.basicConfig(format='%(levelname)5s: %(message)s', 209 level=logging.DEBUG, stream=sys.stdout) 210 211 main() 212