1#!/usr/bin/env python3
2#
3# Copyright The Mbed TLS Contributors
4# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
5
6"""Audit validity date of X509 crt/crl/csr.
7
8This script is used to audit the validity date of crt/crl/csr used for testing.
9It prints the information about X.509 objects excluding the objects that
10are valid throughout the desired validity period. The data are collected
11from tests/data_files/ and tests/suites/*.data files by default.
12"""
13
14import os
15import re
16import typing
17import argparse
18import datetime
19import glob
20import logging
21import hashlib
22from enum import Enum
23
24# The script requires cryptography >= 35.0.0 which is only available
25# for Python >= 3.6.
26import cryptography
27from cryptography import x509
28
29from generate_test_code import FileWrapper
30
31import scripts_path # pylint: disable=unused-import
32from mbedtls_dev import build_tree
33from mbedtls_dev import logging_util
34
35def check_cryptography_version():
36    match = re.match(r'^[0-9]+', cryptography.__version__)
37    if match is None or int(match.group(0)) < 35:
38        raise Exception("audit-validity-dates requires cryptography >= 35.0.0"
39                        + "({} is too old)".format(cryptography.__version__))
40
41class DataType(Enum):
42    CRT = 1 # Certificate
43    CRL = 2 # Certificate Revocation List
44    CSR = 3 # Certificate Signing Request
45
46
47class DataFormat(Enum):
48    PEM = 1 # Privacy-Enhanced Mail
49    DER = 2 # Distinguished Encoding Rules
50
51
52class AuditData:
53    """Store data location, type and validity period of X.509 objects."""
54    #pylint: disable=too-few-public-methods
55    def __init__(self, data_type: DataType, x509_obj):
56        self.data_type = data_type
57        # the locations that the x509 object could be found
58        self.locations = [] # type: typing.List[str]
59        self.fill_validity_duration(x509_obj)
60        self._obj = x509_obj
61        encoding = cryptography.hazmat.primitives.serialization.Encoding.DER
62        self._identifier = hashlib.sha1(self._obj.public_bytes(encoding)).hexdigest()
63
64    @property
65    def identifier(self):
66        """
67        Identifier of the underlying X.509 object, which is consistent across
68        different runs.
69        """
70        return self._identifier
71
72    def fill_validity_duration(self, x509_obj):
73        """Read validity period from an X.509 object."""
74        # Certificate expires after "not_valid_after"
75        # Certificate is invalid before "not_valid_before"
76        if self.data_type == DataType.CRT:
77            self.not_valid_after = x509_obj.not_valid_after
78            self.not_valid_before = x509_obj.not_valid_before
79        # CertificateRevocationList expires after "next_update"
80        # CertificateRevocationList is invalid before "last_update"
81        elif self.data_type == DataType.CRL:
82            self.not_valid_after = x509_obj.next_update
83            self.not_valid_before = x509_obj.last_update
84        # CertificateSigningRequest is always valid.
85        elif self.data_type == DataType.CSR:
86            self.not_valid_after = datetime.datetime.max
87            self.not_valid_before = datetime.datetime.min
88        else:
89            raise ValueError("Unsupported file_type: {}".format(self.data_type))
90
91
92class X509Parser:
93    """A parser class to parse crt/crl/csr file or data in PEM/DER format."""
94    PEM_REGEX = br'-{5}BEGIN (?P<type>.*?)-{5}(?P<data>.*?)-{5}END (?P=type)-{5}'
95    PEM_TAG_REGEX = br'-{5}BEGIN (?P<type>.*?)-{5}\n'
96    PEM_TAGS = {
97        DataType.CRT: 'CERTIFICATE',
98        DataType.CRL: 'X509 CRL',
99        DataType.CSR: 'CERTIFICATE REQUEST'
100    }
101
102    def __init__(self,
103                 backends:
104                 typing.Dict[DataType,
105                             typing.Dict[DataFormat,
106                                         typing.Callable[[bytes], object]]]) \
107    -> None:
108        self.backends = backends
109        self.__generate_parsers()
110
111    def __generate_parser(self, data_type: DataType):
112        """Parser generator for a specific DataType"""
113        tag = self.PEM_TAGS[data_type]
114        pem_loader = self.backends[data_type][DataFormat.PEM]
115        der_loader = self.backends[data_type][DataFormat.DER]
116        def wrapper(data: bytes):
117            pem_type = X509Parser.pem_data_type(data)
118            # It is in PEM format with target tag
119            if pem_type == tag:
120                return pem_loader(data)
121            # It is in PEM format without target tag
122            if pem_type:
123                return None
124            # It might be in DER format
125            try:
126                result = der_loader(data)
127            except ValueError:
128                result = None
129            return result
130        wrapper.__name__ = "{}.parser[{}]".format(type(self).__name__, tag)
131        return wrapper
132
133    def __generate_parsers(self):
134        """Generate parsers for all support DataType"""
135        self.parsers = {}
136        for data_type, _ in self.PEM_TAGS.items():
137            self.parsers[data_type] = self.__generate_parser(data_type)
138
139    def __getitem__(self, item):
140        return self.parsers[item]
141
142    @staticmethod
143    def pem_data_type(data: bytes) -> typing.Optional[str]:
144        """Get the tag from the data in PEM format
145
146        :param data: data to be checked in binary mode.
147        :return: PEM tag or "" when no tag detected.
148        """
149        m = re.search(X509Parser.PEM_TAG_REGEX, data)
150        if m is not None:
151            return m.group('type').decode('UTF-8')
152        else:
153            return None
154
155    @staticmethod
156    def check_hex_string(hex_str: str) -> bool:
157        """Check if the hex string is possibly DER data."""
158        hex_len = len(hex_str)
159        # At least 6 hex char for 3 bytes: Type + Length + Content
160        if hex_len < 6:
161            return False
162        # Check if Type (1 byte) is SEQUENCE.
163        if hex_str[0:2] != '30':
164            return False
165        # Check LENGTH (1 byte) value
166        content_len = int(hex_str[2:4], base=16)
167        consumed = 4
168        if content_len in (128, 255):
169            # Indefinite or Reserved
170            return False
171        elif content_len > 127:
172            # Definite, Long
173            length_len = (content_len - 128) * 2
174            content_len = int(hex_str[consumed:consumed+length_len], base=16)
175            consumed += length_len
176        # Check LENGTH
177        if hex_len != content_len * 2 + consumed:
178            return False
179        return True
180
181
182class Auditor:
183    """
184    A base class that uses X509Parser to parse files to a list of AuditData.
185
186    A subclass must implement the following methods:
187      - collect_default_files: Return a list of file names that are defaultly
188        used for parsing (auditing). The list will be stored in
189        Auditor.default_files.
190      - parse_file: Method that parses a single file to a list of AuditData.
191
192    A subclass may override the following methods:
193      - parse_bytes: Defaultly, it parses `bytes` that contains only one valid
194        X.509 data(DER/PEM format) to an X.509 object.
195      - walk_all: Defaultly, it iterates over all the files in the provided
196        file name list, calls `parse_file` for each file and stores the results
197        by extending the `results` passed to the function.
198    """
199    def __init__(self, logger):
200        self.logger = logger
201        self.default_files = self.collect_default_files()
202        self.parser = X509Parser({
203            DataType.CRT: {
204                DataFormat.PEM: x509.load_pem_x509_certificate,
205                DataFormat.DER: x509.load_der_x509_certificate
206            },
207            DataType.CRL: {
208                DataFormat.PEM: x509.load_pem_x509_crl,
209                DataFormat.DER: x509.load_der_x509_crl
210            },
211            DataType.CSR: {
212                DataFormat.PEM: x509.load_pem_x509_csr,
213                DataFormat.DER: x509.load_der_x509_csr
214            },
215        })
216
217    def collect_default_files(self) -> typing.List[str]:
218        """Collect the default files for parsing."""
219        raise NotImplementedError
220
221    def parse_file(self, filename: str) -> typing.List[AuditData]:
222        """
223        Parse a list of AuditData from file.
224
225        :param filename: name of the file to parse.
226        :return list of AuditData parsed from the file.
227        """
228        raise NotImplementedError
229
230    def parse_bytes(self, data: bytes):
231        """Parse AuditData from bytes."""
232        for data_type in list(DataType):
233            try:
234                result = self.parser[data_type](data)
235            except ValueError as val_error:
236                result = None
237                self.logger.warning(val_error)
238            if result is not None:
239                audit_data = AuditData(data_type, result)
240                return audit_data
241        return None
242
243    def walk_all(self,
244                 results: typing.Dict[str, AuditData],
245                 file_list: typing.Optional[typing.List[str]] = None) \
246        -> None:
247        """
248        Iterate over all the files in the list and get audit data. The
249        results will be written to `results` passed to this function.
250
251        :param results: The dictionary used to store the parsed
252                        AuditData. The keys of this dictionary should
253                        be the identifier of the AuditData.
254        """
255        if file_list is None:
256            file_list = self.default_files
257        for filename in file_list:
258            data_list = self.parse_file(filename)
259            for d in data_list:
260                if d.identifier in results:
261                    results[d.identifier].locations.extend(d.locations)
262                else:
263                    results[d.identifier] = d
264
265    @staticmethod
266    def find_test_dir():
267        """Get the relative path for the MbedTLS test directory."""
268        return os.path.relpath(build_tree.guess_mbedtls_root() + '/tests')
269
270
271class TestDataAuditor(Auditor):
272    """Class for auditing files in `tests/data_files/`"""
273
274    def collect_default_files(self):
275        """Collect all files in `tests/data_files/`"""
276        test_dir = self.find_test_dir()
277        test_data_glob = os.path.join(test_dir, 'data_files/**')
278        data_files = [f for f in glob.glob(test_data_glob, recursive=True)
279                      if os.path.isfile(f)]
280        return data_files
281
282    def parse_file(self, filename: str) -> typing.List[AuditData]:
283        """
284        Parse a list of AuditData from data file.
285
286        :param filename: name of the file to parse.
287        :return list of AuditData parsed from the file.
288        """
289        with open(filename, 'rb') as f:
290            data = f.read()
291
292        results = []
293        # Try to parse all PEM blocks.
294        is_pem = False
295        for idx, m in enumerate(re.finditer(X509Parser.PEM_REGEX, data, flags=re.S), 1):
296            is_pem = True
297            result = self.parse_bytes(data[m.start():m.end()])
298            if result is not None:
299                result.locations.append("{}#{}".format(filename, idx))
300                results.append(result)
301
302        # Might be DER format.
303        if not is_pem:
304            result = self.parse_bytes(data)
305            if result is not None:
306                result.locations.append("{}".format(filename))
307                results.append(result)
308
309        return results
310
311
312def parse_suite_data(data_f):
313    """
314    Parses .data file for test arguments that possiblly have a
315    valid X.509 data. If you need a more precise parser, please
316    use generate_test_code.parse_test_data instead.
317
318    :param data_f: file object of the data file.
319    :return: Generator that yields test function argument list.
320    """
321    for line in data_f:
322        line = line.strip()
323        # Skip comments
324        if line.startswith('#'):
325            continue
326
327        # Check parameters line
328        match = re.search(r'\A\w+(.*:)?\"', line)
329        if match:
330            # Read test vectors
331            parts = re.split(r'(?<!\\):', line)
332            parts = [x for x in parts if x]
333            args = parts[1:]
334            yield args
335
336
337class SuiteDataAuditor(Auditor):
338    """Class for auditing files in `tests/suites/*.data`"""
339
340    def collect_default_files(self):
341        """Collect all files in `tests/suites/*.data`"""
342        test_dir = self.find_test_dir()
343        suites_data_folder = os.path.join(test_dir, 'suites')
344        data_files = glob.glob(os.path.join(suites_data_folder, '*.data'))
345        return data_files
346
347    def parse_file(self, filename: str):
348        """
349        Parse a list of AuditData from test suite data file.
350
351        :param filename: name of the file to parse.
352        :return list of AuditData parsed from the file.
353        """
354        audit_data_list = []
355        data_f = FileWrapper(filename)
356        for test_args in parse_suite_data(data_f):
357            for idx, test_arg in enumerate(test_args):
358                match = re.match(r'"(?P<data>[0-9a-fA-F]+)"', test_arg)
359                if not match:
360                    continue
361                if not X509Parser.check_hex_string(match.group('data')):
362                    continue
363                audit_data = self.parse_bytes(bytes.fromhex(match.group('data')))
364                if audit_data is None:
365                    continue
366                audit_data.locations.append("{}:{}:#{}".format(filename,
367                                                               data_f.line_no,
368                                                               idx + 1))
369                audit_data_list.append(audit_data)
370
371        return audit_data_list
372
373
374def list_all(audit_data: AuditData):
375    for loc in audit_data.locations:
376        print("{}\t{:20}\t{:20}\t{:3}\t{}".format(
377            audit_data.identifier,
378            audit_data.not_valid_before.isoformat(timespec='seconds'),
379            audit_data.not_valid_after.isoformat(timespec='seconds'),
380            audit_data.data_type.name,
381            loc))
382
383
384def main():
385    """
386    Perform argument parsing.
387    """
388    parser = argparse.ArgumentParser(description=__doc__)
389
390    parser.add_argument('-a', '--all',
391                        action='store_true',
392                        help='list the information of all the files')
393    parser.add_argument('-v', '--verbose',
394                        action='store_true', dest='verbose',
395                        help='show logs')
396    parser.add_argument('--from', dest='start_date',
397                        help=('Start of desired validity period (UTC, YYYY-MM-DD). '
398                              'Default: today'),
399                        metavar='DATE')
400    parser.add_argument('--to', dest='end_date',
401                        help=('End of desired validity period (UTC, YYYY-MM-DD). '
402                              'Default: --from'),
403                        metavar='DATE')
404    parser.add_argument('--data-files', action='append', nargs='*',
405                        help='data files to audit',
406                        metavar='FILE')
407    parser.add_argument('--suite-data-files', action='append', nargs='*',
408                        help='suite data files to audit',
409                        metavar='FILE')
410
411    args = parser.parse_args()
412
413    # start main routine
414    # setup logger
415    logger = logging.getLogger()
416    logging_util.configure_logger(logger)
417    logger.setLevel(logging.DEBUG if args.verbose else logging.ERROR)
418
419    td_auditor = TestDataAuditor(logger)
420    sd_auditor = SuiteDataAuditor(logger)
421
422    data_files = []
423    suite_data_files = []
424    if args.data_files is None and args.suite_data_files is None:
425        data_files = td_auditor.default_files
426        suite_data_files = sd_auditor.default_files
427    else:
428        if args.data_files is not None:
429            data_files = [x for l in args.data_files for x in l]
430        if args.suite_data_files is not None:
431            suite_data_files = [x for l in args.suite_data_files for x in l]
432
433    # validity period start date
434    if args.start_date:
435        start_date = datetime.datetime.fromisoformat(args.start_date)
436    else:
437        start_date = datetime.datetime.today()
438    # validity period end date
439    if args.end_date:
440        end_date = datetime.datetime.fromisoformat(args.end_date)
441    else:
442        end_date = start_date
443
444    # go through all the files
445    audit_results = {}
446    td_auditor.walk_all(audit_results, data_files)
447    sd_auditor.walk_all(audit_results, suite_data_files)
448
449    logger.info("Total: {} objects found!".format(len(audit_results)))
450
451    # we filter out the files whose validity duration covers the provided
452    # duration.
453    filter_func = lambda d: (start_date < d.not_valid_before) or \
454                            (d.not_valid_after < end_date)
455
456    sortby_end = lambda d: d.not_valid_after
457
458    if args.all:
459        filter_func = None
460
461    # filter and output the results
462    for d in sorted(filter(filter_func, audit_results.values()), key=sortby_end):
463        list_all(d)
464
465    logger.debug("Done!")
466
467check_cryptography_version()
468if __name__ == "__main__":
469    main()
470