1#!/usr/bin/env python3 2# 3# Copyright The Mbed TLS Contributors 4# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later 5 6"""Audit validity date of X509 crt/crl/csr. 7 8This script is used to audit the validity date of crt/crl/csr used for testing. 9It prints the information about X.509 objects excluding the objects that 10are valid throughout the desired validity period. The data are collected 11from tests/data_files/ and tests/suites/*.data files by default. 12""" 13 14import os 15import re 16import typing 17import argparse 18import datetime 19import glob 20import logging 21import hashlib 22from enum import Enum 23 24# The script requires cryptography >= 35.0.0 which is only available 25# for Python >= 3.6. 26import cryptography 27from cryptography import x509 28 29from generate_test_code import FileWrapper 30 31import scripts_path # pylint: disable=unused-import 32from mbedtls_dev import build_tree 33from mbedtls_dev import logging_util 34 35def check_cryptography_version(): 36 match = re.match(r'^[0-9]+', cryptography.__version__) 37 if match is None or int(match.group(0)) < 35: 38 raise Exception("audit-validity-dates requires cryptography >= 35.0.0" 39 + "({} is too old)".format(cryptography.__version__)) 40 41class DataType(Enum): 42 CRT = 1 # Certificate 43 CRL = 2 # Certificate Revocation List 44 CSR = 3 # Certificate Signing Request 45 46 47class DataFormat(Enum): 48 PEM = 1 # Privacy-Enhanced Mail 49 DER = 2 # Distinguished Encoding Rules 50 51 52class AuditData: 53 """Store data location, type and validity period of X.509 objects.""" 54 #pylint: disable=too-few-public-methods 55 def __init__(self, data_type: DataType, x509_obj): 56 self.data_type = data_type 57 # the locations that the x509 object could be found 58 self.locations = [] # type: typing.List[str] 59 self.fill_validity_duration(x509_obj) 60 self._obj = x509_obj 61 encoding = cryptography.hazmat.primitives.serialization.Encoding.DER 62 self._identifier = hashlib.sha1(self._obj.public_bytes(encoding)).hexdigest() 63 64 @property 65 def identifier(self): 66 """ 67 Identifier of the underlying X.509 object, which is consistent across 68 different runs. 69 """ 70 return self._identifier 71 72 def fill_validity_duration(self, x509_obj): 73 """Read validity period from an X.509 object.""" 74 # Certificate expires after "not_valid_after" 75 # Certificate is invalid before "not_valid_before" 76 if self.data_type == DataType.CRT: 77 self.not_valid_after = x509_obj.not_valid_after 78 self.not_valid_before = x509_obj.not_valid_before 79 # CertificateRevocationList expires after "next_update" 80 # CertificateRevocationList is invalid before "last_update" 81 elif self.data_type == DataType.CRL: 82 self.not_valid_after = x509_obj.next_update 83 self.not_valid_before = x509_obj.last_update 84 # CertificateSigningRequest is always valid. 85 elif self.data_type == DataType.CSR: 86 self.not_valid_after = datetime.datetime.max 87 self.not_valid_before = datetime.datetime.min 88 else: 89 raise ValueError("Unsupported file_type: {}".format(self.data_type)) 90 91 92class X509Parser: 93 """A parser class to parse crt/crl/csr file or data in PEM/DER format.""" 94 PEM_REGEX = br'-{5}BEGIN (?P<type>.*?)-{5}(?P<data>.*?)-{5}END (?P=type)-{5}' 95 PEM_TAG_REGEX = br'-{5}BEGIN (?P<type>.*?)-{5}\n' 96 PEM_TAGS = { 97 DataType.CRT: 'CERTIFICATE', 98 DataType.CRL: 'X509 CRL', 99 DataType.CSR: 'CERTIFICATE REQUEST' 100 } 101 102 def __init__(self, 103 backends: 104 typing.Dict[DataType, 105 typing.Dict[DataFormat, 106 typing.Callable[[bytes], object]]]) \ 107 -> None: 108 self.backends = backends 109 self.__generate_parsers() 110 111 def __generate_parser(self, data_type: DataType): 112 """Parser generator for a specific DataType""" 113 tag = self.PEM_TAGS[data_type] 114 pem_loader = self.backends[data_type][DataFormat.PEM] 115 der_loader = self.backends[data_type][DataFormat.DER] 116 def wrapper(data: bytes): 117 pem_type = X509Parser.pem_data_type(data) 118 # It is in PEM format with target tag 119 if pem_type == tag: 120 return pem_loader(data) 121 # It is in PEM format without target tag 122 if pem_type: 123 return None 124 # It might be in DER format 125 try: 126 result = der_loader(data) 127 except ValueError: 128 result = None 129 return result 130 wrapper.__name__ = "{}.parser[{}]".format(type(self).__name__, tag) 131 return wrapper 132 133 def __generate_parsers(self): 134 """Generate parsers for all support DataType""" 135 self.parsers = {} 136 for data_type, _ in self.PEM_TAGS.items(): 137 self.parsers[data_type] = self.__generate_parser(data_type) 138 139 def __getitem__(self, item): 140 return self.parsers[item] 141 142 @staticmethod 143 def pem_data_type(data: bytes) -> typing.Optional[str]: 144 """Get the tag from the data in PEM format 145 146 :param data: data to be checked in binary mode. 147 :return: PEM tag or "" when no tag detected. 148 """ 149 m = re.search(X509Parser.PEM_TAG_REGEX, data) 150 if m is not None: 151 return m.group('type').decode('UTF-8') 152 else: 153 return None 154 155 @staticmethod 156 def check_hex_string(hex_str: str) -> bool: 157 """Check if the hex string is possibly DER data.""" 158 hex_len = len(hex_str) 159 # At least 6 hex char for 3 bytes: Type + Length + Content 160 if hex_len < 6: 161 return False 162 # Check if Type (1 byte) is SEQUENCE. 163 if hex_str[0:2] != '30': 164 return False 165 # Check LENGTH (1 byte) value 166 content_len = int(hex_str[2:4], base=16) 167 consumed = 4 168 if content_len in (128, 255): 169 # Indefinite or Reserved 170 return False 171 elif content_len > 127: 172 # Definite, Long 173 length_len = (content_len - 128) * 2 174 content_len = int(hex_str[consumed:consumed+length_len], base=16) 175 consumed += length_len 176 # Check LENGTH 177 if hex_len != content_len * 2 + consumed: 178 return False 179 return True 180 181 182class Auditor: 183 """ 184 A base class that uses X509Parser to parse files to a list of AuditData. 185 186 A subclass must implement the following methods: 187 - collect_default_files: Return a list of file names that are defaultly 188 used for parsing (auditing). The list will be stored in 189 Auditor.default_files. 190 - parse_file: Method that parses a single file to a list of AuditData. 191 192 A subclass may override the following methods: 193 - parse_bytes: Defaultly, it parses `bytes` that contains only one valid 194 X.509 data(DER/PEM format) to an X.509 object. 195 - walk_all: Defaultly, it iterates over all the files in the provided 196 file name list, calls `parse_file` for each file and stores the results 197 by extending the `results` passed to the function. 198 """ 199 def __init__(self, logger): 200 self.logger = logger 201 self.default_files = self.collect_default_files() 202 self.parser = X509Parser({ 203 DataType.CRT: { 204 DataFormat.PEM: x509.load_pem_x509_certificate, 205 DataFormat.DER: x509.load_der_x509_certificate 206 }, 207 DataType.CRL: { 208 DataFormat.PEM: x509.load_pem_x509_crl, 209 DataFormat.DER: x509.load_der_x509_crl 210 }, 211 DataType.CSR: { 212 DataFormat.PEM: x509.load_pem_x509_csr, 213 DataFormat.DER: x509.load_der_x509_csr 214 }, 215 }) 216 217 def collect_default_files(self) -> typing.List[str]: 218 """Collect the default files for parsing.""" 219 raise NotImplementedError 220 221 def parse_file(self, filename: str) -> typing.List[AuditData]: 222 """ 223 Parse a list of AuditData from file. 224 225 :param filename: name of the file to parse. 226 :return list of AuditData parsed from the file. 227 """ 228 raise NotImplementedError 229 230 def parse_bytes(self, data: bytes): 231 """Parse AuditData from bytes.""" 232 for data_type in list(DataType): 233 try: 234 result = self.parser[data_type](data) 235 except ValueError as val_error: 236 result = None 237 self.logger.warning(val_error) 238 if result is not None: 239 audit_data = AuditData(data_type, result) 240 return audit_data 241 return None 242 243 def walk_all(self, 244 results: typing.Dict[str, AuditData], 245 file_list: typing.Optional[typing.List[str]] = None) \ 246 -> None: 247 """ 248 Iterate over all the files in the list and get audit data. The 249 results will be written to `results` passed to this function. 250 251 :param results: The dictionary used to store the parsed 252 AuditData. The keys of this dictionary should 253 be the identifier of the AuditData. 254 """ 255 if file_list is None: 256 file_list = self.default_files 257 for filename in file_list: 258 data_list = self.parse_file(filename) 259 for d in data_list: 260 if d.identifier in results: 261 results[d.identifier].locations.extend(d.locations) 262 else: 263 results[d.identifier] = d 264 265 @staticmethod 266 def find_test_dir(): 267 """Get the relative path for the Mbed TLS test directory.""" 268 return os.path.relpath(build_tree.guess_mbedtls_root() + '/tests') 269 270 271class TestDataAuditor(Auditor): 272 """Class for auditing files in `tests/data_files/`""" 273 274 def collect_default_files(self): 275 """Collect all files in `tests/data_files/`""" 276 test_dir = self.find_test_dir() 277 test_data_glob = os.path.join(test_dir, 'data_files/**') 278 data_files = [f for f in glob.glob(test_data_glob, recursive=True) 279 if os.path.isfile(f)] 280 return data_files 281 282 def parse_file(self, filename: str) -> typing.List[AuditData]: 283 """ 284 Parse a list of AuditData from data file. 285 286 :param filename: name of the file to parse. 287 :return list of AuditData parsed from the file. 288 """ 289 with open(filename, 'rb') as f: 290 data = f.read() 291 292 results = [] 293 # Try to parse all PEM blocks. 294 is_pem = False 295 for idx, m in enumerate(re.finditer(X509Parser.PEM_REGEX, data, flags=re.S), 1): 296 is_pem = True 297 result = self.parse_bytes(data[m.start():m.end()]) 298 if result is not None: 299 result.locations.append("{}#{}".format(filename, idx)) 300 results.append(result) 301 302 # Might be DER format. 303 if not is_pem: 304 result = self.parse_bytes(data) 305 if result is not None: 306 result.locations.append("{}".format(filename)) 307 results.append(result) 308 309 return results 310 311 312def parse_suite_data(data_f): 313 """ 314 Parses .data file for test arguments that possiblly have a 315 valid X.509 data. If you need a more precise parser, please 316 use generate_test_code.parse_test_data instead. 317 318 :param data_f: file object of the data file. 319 :return: Generator that yields test function argument list. 320 """ 321 for line in data_f: 322 line = line.strip() 323 # Skip comments 324 if line.startswith('#'): 325 continue 326 327 # Check parameters line 328 match = re.search(r'\A\w+(.*:)?\"', line) 329 if match: 330 # Read test vectors 331 parts = re.split(r'(?<!\\):', line) 332 parts = [x for x in parts if x] 333 args = parts[1:] 334 yield args 335 336 337class SuiteDataAuditor(Auditor): 338 """Class for auditing files in `tests/suites/*.data`""" 339 340 def collect_default_files(self): 341 """Collect all files in `tests/suites/*.data`""" 342 test_dir = self.find_test_dir() 343 suites_data_folder = os.path.join(test_dir, 'suites') 344 data_files = glob.glob(os.path.join(suites_data_folder, '*.data')) 345 return data_files 346 347 def parse_file(self, filename: str): 348 """ 349 Parse a list of AuditData from test suite data file. 350 351 :param filename: name of the file to parse. 352 :return list of AuditData parsed from the file. 353 """ 354 audit_data_list = [] 355 data_f = FileWrapper(filename) 356 for test_args in parse_suite_data(data_f): 357 for idx, test_arg in enumerate(test_args): 358 match = re.match(r'"(?P<data>[0-9a-fA-F]+)"', test_arg) 359 if not match: 360 continue 361 if not X509Parser.check_hex_string(match.group('data')): 362 continue 363 audit_data = self.parse_bytes(bytes.fromhex(match.group('data'))) 364 if audit_data is None: 365 continue 366 audit_data.locations.append("{}:{}:#{}".format(filename, 367 data_f.line_no, 368 idx + 1)) 369 audit_data_list.append(audit_data) 370 371 return audit_data_list 372 373 374def list_all(audit_data: AuditData): 375 for loc in audit_data.locations: 376 print("{}\t{:20}\t{:20}\t{:3}\t{}".format( 377 audit_data.identifier, 378 audit_data.not_valid_before.isoformat(timespec='seconds'), 379 audit_data.not_valid_after.isoformat(timespec='seconds'), 380 audit_data.data_type.name, 381 loc)) 382 383 384def main(): 385 """ 386 Perform argument parsing. 387 """ 388 parser = argparse.ArgumentParser(description=__doc__) 389 390 parser.add_argument('-a', '--all', 391 action='store_true', 392 help='list the information of all the files') 393 parser.add_argument('-v', '--verbose', 394 action='store_true', dest='verbose', 395 help='show logs') 396 parser.add_argument('--from', dest='start_date', 397 help=('Start of desired validity period (UTC, YYYY-MM-DD). ' 398 'Default: today'), 399 metavar='DATE') 400 parser.add_argument('--to', dest='end_date', 401 help=('End of desired validity period (UTC, YYYY-MM-DD). ' 402 'Default: --from'), 403 metavar='DATE') 404 parser.add_argument('--data-files', action='append', nargs='*', 405 help='data files to audit', 406 metavar='FILE') 407 parser.add_argument('--suite-data-files', action='append', nargs='*', 408 help='suite data files to audit', 409 metavar='FILE') 410 411 args = parser.parse_args() 412 413 # start main routine 414 # setup logger 415 logger = logging.getLogger() 416 logging_util.configure_logger(logger) 417 logger.setLevel(logging.DEBUG if args.verbose else logging.ERROR) 418 419 td_auditor = TestDataAuditor(logger) 420 sd_auditor = SuiteDataAuditor(logger) 421 422 data_files = [] 423 suite_data_files = [] 424 if args.data_files is None and args.suite_data_files is None: 425 data_files = td_auditor.default_files 426 suite_data_files = sd_auditor.default_files 427 else: 428 if args.data_files is not None: 429 data_files = [x for l in args.data_files for x in l] 430 if args.suite_data_files is not None: 431 suite_data_files = [x for l in args.suite_data_files for x in l] 432 433 # validity period start date 434 if args.start_date: 435 start_date = datetime.datetime.fromisoformat(args.start_date) 436 else: 437 start_date = datetime.datetime.today() 438 # validity period end date 439 if args.end_date: 440 end_date = datetime.datetime.fromisoformat(args.end_date) 441 else: 442 end_date = start_date 443 444 # go through all the files 445 audit_results = {} 446 td_auditor.walk_all(audit_results, data_files) 447 sd_auditor.walk_all(audit_results, suite_data_files) 448 449 logger.info("Total: {} objects found!".format(len(audit_results))) 450 451 # we filter out the files whose validity duration covers the provided 452 # duration. 453 filter_func = lambda d: (start_date < d.not_valid_before) or \ 454 (d.not_valid_after < end_date) 455 456 sortby_end = lambda d: d.not_valid_after 457 458 if args.all: 459 filter_func = None 460 461 # filter and output the results 462 for d in sorted(filter(filter_func, audit_results.values()), key=sortby_end): 463 list_all(d) 464 465 logger.debug("Done!") 466 467check_cryptography_version() 468if __name__ == "__main__": 469 main() 470