1#!/usr/bin/env python3
2"""Generate test data for PSA cryptographic mechanisms.
3
4With no arguments, generate all test data. With non-option arguments,
5generate only the specified files.
6"""
7
8# Copyright The Mbed TLS Contributors
9# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
10
11import enum
12import re
13import sys
14from typing import Callable, Dict, FrozenSet, Iterable, Iterator, List, Optional
15
16import scripts_path # pylint: disable=unused-import
17from mbedtls_dev import crypto_data_tests
18from mbedtls_dev import crypto_knowledge
19from mbedtls_dev import macro_collector #pylint: disable=unused-import
20from mbedtls_dev import psa_information
21from mbedtls_dev import psa_storage
22from mbedtls_dev import test_case
23from mbedtls_dev import test_data_generation
24
25
26
27def test_case_for_key_type_not_supported(
28        verb: str, key_type: str, bits: int,
29        dependencies: List[str],
30        *args: str,
31        param_descr: str = ''
32) -> test_case.TestCase:
33    """Return one test case exercising a key creation method
34    for an unsupported key type or size.
35    """
36    psa_information.hack_dependencies_not_implemented(dependencies)
37    tc = test_case.TestCase()
38    short_key_type = crypto_knowledge.short_expression(key_type)
39    adverb = 'not' if dependencies else 'never'
40    if param_descr:
41        adverb = param_descr + ' ' + adverb
42    tc.set_description('PSA {} {} {}-bit {} supported'
43                       .format(verb, short_key_type, bits, adverb))
44    tc.set_dependencies(dependencies)
45    tc.set_function(verb + '_not_supported')
46    tc.set_arguments([key_type] + list(args))
47    return tc
48
49class KeyTypeNotSupported:
50    """Generate test cases for when a key type is not supported."""
51
52    def __init__(self, info: psa_information.Information) -> None:
53        self.constructors = info.constructors
54
55    ALWAYS_SUPPORTED = frozenset([
56        'PSA_KEY_TYPE_DERIVE',
57        'PSA_KEY_TYPE_PASSWORD',
58        'PSA_KEY_TYPE_PASSWORD_HASH',
59        'PSA_KEY_TYPE_RAW_DATA',
60        'PSA_KEY_TYPE_HMAC'
61    ])
62    def test_cases_for_key_type_not_supported(
63            self,
64            kt: crypto_knowledge.KeyType,
65            param: Optional[int] = None,
66            param_descr: str = '',
67    ) -> Iterator[test_case.TestCase]:
68        """Return test cases exercising key creation when the given type is unsupported.
69
70        If param is present and not None, emit test cases conditioned on this
71        parameter not being supported. If it is absent or None, emit test cases
72        conditioned on the base type not being supported.
73        """
74        if kt.name in self.ALWAYS_SUPPORTED:
75            # Don't generate test cases for key types that are always supported.
76            # They would be skipped in all configurations, which is noise.
77            return
78        import_dependencies = [('!' if param is None else '') +
79                               psa_information.psa_want_symbol(kt.name)]
80        if kt.params is not None:
81            import_dependencies += [('!' if param == i else '') +
82                                    psa_information.psa_want_symbol(sym)
83                                    for i, sym in enumerate(kt.params)]
84        if kt.name.endswith('_PUBLIC_KEY'):
85            generate_dependencies = []
86        else:
87            generate_dependencies = \
88                psa_information.fix_key_pair_dependencies(import_dependencies, 'GENERATE')
89            import_dependencies = \
90                psa_information.fix_key_pair_dependencies(import_dependencies, 'BASIC')
91        for bits in kt.sizes_to_test():
92            yield test_case_for_key_type_not_supported(
93                'import', kt.expression, bits,
94                psa_information.finish_family_dependencies(import_dependencies, bits),
95                test_case.hex_string(kt.key_material(bits)),
96                param_descr=param_descr,
97            )
98            if not generate_dependencies and param is not None:
99                # If generation is impossible for this key type, rather than
100                # supported or not depending on implementation capabilities,
101                # only generate the test case once.
102                continue
103                # For public key we expect that key generation fails with
104                # INVALID_ARGUMENT. It is handled by KeyGenerate class.
105            if not kt.is_public():
106                yield test_case_for_key_type_not_supported(
107                    'generate', kt.expression, bits,
108                    psa_information.finish_family_dependencies(generate_dependencies, bits),
109                    str(bits),
110                    param_descr=param_descr,
111                )
112            # To be added: derive
113
114    ECC_KEY_TYPES = ('PSA_KEY_TYPE_ECC_KEY_PAIR',
115                     'PSA_KEY_TYPE_ECC_PUBLIC_KEY')
116    DH_KEY_TYPES = ('PSA_KEY_TYPE_DH_KEY_PAIR',
117                    'PSA_KEY_TYPE_DH_PUBLIC_KEY')
118
119    def test_cases_for_not_supported(self) -> Iterator[test_case.TestCase]:
120        """Generate test cases that exercise the creation of keys of unsupported types."""
121        for key_type in sorted(self.constructors.key_types):
122            if key_type in self.ECC_KEY_TYPES:
123                continue
124            if key_type in self.DH_KEY_TYPES:
125                continue
126            kt = crypto_knowledge.KeyType(key_type)
127            yield from self.test_cases_for_key_type_not_supported(kt)
128        for curve_family in sorted(self.constructors.ecc_curves):
129            for constr in self.ECC_KEY_TYPES:
130                kt = crypto_knowledge.KeyType(constr, [curve_family])
131                yield from self.test_cases_for_key_type_not_supported(
132                    kt, param_descr='type')
133                yield from self.test_cases_for_key_type_not_supported(
134                    kt, 0, param_descr='curve')
135        for dh_family in sorted(self.constructors.dh_groups):
136            for constr in self.DH_KEY_TYPES:
137                kt = crypto_knowledge.KeyType(constr, [dh_family])
138                yield from self.test_cases_for_key_type_not_supported(
139                    kt, param_descr='type')
140                yield from self.test_cases_for_key_type_not_supported(
141                    kt, 0, param_descr='group')
142
143def test_case_for_key_generation(
144        key_type: str, bits: int,
145        dependencies: List[str],
146        *args: str,
147        result: str = ''
148) -> test_case.TestCase:
149    """Return one test case exercising a key generation.
150    """
151    psa_information.hack_dependencies_not_implemented(dependencies)
152    tc = test_case.TestCase()
153    short_key_type = crypto_knowledge.short_expression(key_type)
154    tc.set_description('PSA {} {}-bit'
155                       .format(short_key_type, bits))
156    tc.set_dependencies(dependencies)
157    tc.set_function('generate_key')
158    tc.set_arguments([key_type] + list(args) + [result])
159
160    return tc
161
162class KeyGenerate:
163    """Generate positive and negative (invalid argument) test cases for key generation."""
164
165    def __init__(self, info: psa_information.Information) -> None:
166        self.constructors = info.constructors
167
168    ECC_KEY_TYPES = ('PSA_KEY_TYPE_ECC_KEY_PAIR',
169                     'PSA_KEY_TYPE_ECC_PUBLIC_KEY')
170    DH_KEY_TYPES = ('PSA_KEY_TYPE_DH_KEY_PAIR',
171                    'PSA_KEY_TYPE_DH_PUBLIC_KEY')
172
173    @staticmethod
174    def test_cases_for_key_type_key_generation(
175            kt: crypto_knowledge.KeyType
176    ) -> Iterator[test_case.TestCase]:
177        """Return test cases exercising key generation.
178
179        All key types can be generated except for public keys. For public key
180        PSA_ERROR_INVALID_ARGUMENT status is expected.
181        """
182        result = 'PSA_SUCCESS'
183
184        import_dependencies = [psa_information.psa_want_symbol(kt.name)]
185        if kt.params is not None:
186            import_dependencies += [psa_information.psa_want_symbol(sym)
187                                    for i, sym in enumerate(kt.params)]
188        if kt.name.endswith('_PUBLIC_KEY'):
189            # The library checks whether the key type is a public key generically,
190            # before it reaches a point where it needs support for the specific key
191            # type, so it returns INVALID_ARGUMENT for unsupported public key types.
192            generate_dependencies = []
193            result = 'PSA_ERROR_INVALID_ARGUMENT'
194        else:
195            generate_dependencies = \
196                psa_information.fix_key_pair_dependencies(import_dependencies, 'GENERATE')
197        for bits in kt.sizes_to_test():
198            if kt.name == 'PSA_KEY_TYPE_RSA_KEY_PAIR':
199                size_dependency = "PSA_VENDOR_RSA_GENERATE_MIN_KEY_BITS <= " +  str(bits)
200                test_dependencies = generate_dependencies + [size_dependency]
201            else:
202                test_dependencies = generate_dependencies
203            yield test_case_for_key_generation(
204                kt.expression, bits,
205                psa_information.finish_family_dependencies(test_dependencies, bits),
206                str(bits),
207                result
208            )
209
210    def test_cases_for_key_generation(self) -> Iterator[test_case.TestCase]:
211        """Generate test cases that exercise the generation of keys."""
212        for key_type in sorted(self.constructors.key_types):
213            if key_type in self.ECC_KEY_TYPES:
214                continue
215            if key_type in self.DH_KEY_TYPES:
216                continue
217            kt = crypto_knowledge.KeyType(key_type)
218            yield from self.test_cases_for_key_type_key_generation(kt)
219        for curve_family in sorted(self.constructors.ecc_curves):
220            for constr in self.ECC_KEY_TYPES:
221                kt = crypto_knowledge.KeyType(constr, [curve_family])
222                yield from self.test_cases_for_key_type_key_generation(kt)
223        for dh_family in sorted(self.constructors.dh_groups):
224            for constr in self.DH_KEY_TYPES:
225                kt = crypto_knowledge.KeyType(constr, [dh_family])
226                yield from self.test_cases_for_key_type_key_generation(kt)
227
228class OpFail:
229    """Generate test cases for operations that must fail."""
230    #pylint: disable=too-few-public-methods
231
232    class Reason(enum.Enum):
233        NOT_SUPPORTED = 0
234        INVALID = 1
235        INCOMPATIBLE = 2
236        PUBLIC = 3
237
238    def __init__(self, info: psa_information.Information) -> None:
239        self.constructors = info.constructors
240        key_type_expressions = self.constructors.generate_expressions(
241            sorted(self.constructors.key_types)
242        )
243        self.key_types = [crypto_knowledge.KeyType(kt_expr)
244                          for kt_expr in key_type_expressions]
245
246    def make_test_case(
247            self,
248            alg: crypto_knowledge.Algorithm,
249            category: crypto_knowledge.AlgorithmCategory,
250            reason: 'Reason',
251            kt: Optional[crypto_knowledge.KeyType] = None,
252            not_deps: FrozenSet[str] = frozenset(),
253    ) -> test_case.TestCase:
254        """Construct a failure test case for a one-key or keyless operation."""
255        #pylint: disable=too-many-arguments,too-many-locals
256        tc = test_case.TestCase()
257        pretty_alg = alg.short_expression()
258        if reason == self.Reason.NOT_SUPPORTED:
259            short_deps = [re.sub(r'PSA_WANT_ALG_', r'', dep)
260                          for dep in not_deps]
261            pretty_reason = '!' + '&'.join(sorted(short_deps))
262        else:
263            pretty_reason = reason.name.lower()
264        if kt:
265            key_type = kt.expression
266            pretty_type = kt.short_expression()
267        else:
268            key_type = ''
269            pretty_type = ''
270        tc.set_description('PSA {} {}: {}{}'
271                           .format(category.name.lower(),
272                                   pretty_alg,
273                                   pretty_reason,
274                                   ' with ' + pretty_type if pretty_type else ''))
275        dependencies = psa_information.automatic_dependencies(alg.base_expression, key_type)
276        dependencies = psa_information.fix_key_pair_dependencies(dependencies, 'BASIC')
277        for i, dep in enumerate(dependencies):
278            if dep in not_deps:
279                dependencies[i] = '!' + dep
280        tc.set_dependencies(dependencies)
281        tc.set_function(category.name.lower() + '_fail')
282        arguments = [] # type: List[str]
283        if kt:
284            key_material = kt.key_material(kt.sizes_to_test()[0])
285            arguments += [key_type, test_case.hex_string(key_material)]
286        arguments.append(alg.expression)
287        if category.is_asymmetric():
288            arguments.append('1' if reason == self.Reason.PUBLIC else '0')
289        error = ('NOT_SUPPORTED' if reason == self.Reason.NOT_SUPPORTED else
290                 'INVALID_ARGUMENT')
291        arguments.append('PSA_ERROR_' + error)
292        tc.set_arguments(arguments)
293        return tc
294
295    def no_key_test_cases(
296            self,
297            alg: crypto_knowledge.Algorithm,
298            category: crypto_knowledge.AlgorithmCategory,
299    ) -> Iterator[test_case.TestCase]:
300        """Generate failure test cases for keyless operations with the specified algorithm."""
301        if alg.can_do(category):
302            # Compatible operation, unsupported algorithm
303            for dep in psa_information.automatic_dependencies(alg.base_expression):
304                yield self.make_test_case(alg, category,
305                                          self.Reason.NOT_SUPPORTED,
306                                          not_deps=frozenset([dep]))
307        else:
308            # Incompatible operation, supported algorithm
309            yield self.make_test_case(alg, category, self.Reason.INVALID)
310
311    def one_key_test_cases(
312            self,
313            alg: crypto_knowledge.Algorithm,
314            category: crypto_knowledge.AlgorithmCategory,
315    ) -> Iterator[test_case.TestCase]:
316        """Generate failure test cases for one-key operations with the specified algorithm."""
317        for kt in self.key_types:
318            key_is_compatible = kt.can_do(alg)
319            if key_is_compatible and alg.can_do(category):
320                # Compatible key and operation, unsupported algorithm
321                for dep in psa_information.automatic_dependencies(alg.base_expression):
322                    yield self.make_test_case(alg, category,
323                                              self.Reason.NOT_SUPPORTED,
324                                              kt=kt, not_deps=frozenset([dep]))
325                # Public key for a private-key operation
326                if category.is_asymmetric() and kt.is_public():
327                    yield self.make_test_case(alg, category,
328                                              self.Reason.PUBLIC,
329                                              kt=kt)
330            elif key_is_compatible:
331                # Compatible key, incompatible operation, supported algorithm
332                yield self.make_test_case(alg, category,
333                                          self.Reason.INVALID,
334                                          kt=kt)
335            elif alg.can_do(category):
336                # Incompatible key, compatible operation, supported algorithm
337                yield self.make_test_case(alg, category,
338                                          self.Reason.INCOMPATIBLE,
339                                          kt=kt)
340            else:
341                # Incompatible key and operation. Don't test cases where
342                # multiple things are wrong, to keep the number of test
343                # cases reasonable.
344                pass
345
346    def test_cases_for_algorithm(
347            self,
348            alg: crypto_knowledge.Algorithm,
349    ) -> Iterator[test_case.TestCase]:
350        """Generate operation failure test cases for the specified algorithm."""
351        for category in crypto_knowledge.AlgorithmCategory:
352            if category == crypto_knowledge.AlgorithmCategory.PAKE:
353                # PAKE operations are not implemented yet
354                pass
355            elif category.requires_key():
356                yield from self.one_key_test_cases(alg, category)
357            else:
358                yield from self.no_key_test_cases(alg, category)
359
360    def all_test_cases(self) -> Iterator[test_case.TestCase]:
361        """Generate all test cases for operations that must fail."""
362        algorithms = sorted(self.constructors.algorithms)
363        for expr in self.constructors.generate_expressions(algorithms):
364            alg = crypto_knowledge.Algorithm(expr)
365            yield from self.test_cases_for_algorithm(alg)
366
367
368class StorageKey(psa_storage.Key):
369    """Representation of a key for storage format testing."""
370
371    IMPLICIT_USAGE_FLAGS = {
372        'PSA_KEY_USAGE_SIGN_HASH': 'PSA_KEY_USAGE_SIGN_MESSAGE',
373        'PSA_KEY_USAGE_VERIFY_HASH': 'PSA_KEY_USAGE_VERIFY_MESSAGE'
374    } #type: Dict[str, str]
375    """Mapping of usage flags to the flags that they imply."""
376
377    def __init__(
378            self,
379            usage: Iterable[str],
380            without_implicit_usage: Optional[bool] = False,
381            **kwargs
382    ) -> None:
383        """Prepare to generate a key.
384
385        * `usage`                 : The usage flags used for the key.
386        * `without_implicit_usage`: Flag to define to apply the usage extension
387        """
388        usage_flags = set(usage)
389        if not without_implicit_usage:
390            for flag in sorted(usage_flags):
391                if flag in self.IMPLICIT_USAGE_FLAGS:
392                    usage_flags.add(self.IMPLICIT_USAGE_FLAGS[flag])
393        if usage_flags:
394            usage_expression = ' | '.join(sorted(usage_flags))
395        else:
396            usage_expression = '0'
397        super().__init__(usage=usage_expression, **kwargs)
398
399class StorageTestData(StorageKey):
400    """Representation of test case data for storage format testing."""
401
402    def __init__(
403            self,
404            description: str,
405            expected_usage: Optional[List[str]] = None,
406            **kwargs
407    ) -> None:
408        """Prepare to generate test data
409
410        * `description`   : used for the test case names
411        * `expected_usage`: the usage flags generated as the expected usage flags
412                            in the test cases. CAn differ from the usage flags
413                            stored in the keys because of the usage flags extension.
414        """
415        super().__init__(**kwargs)
416        self.description = description #type: str
417        if expected_usage is None:
418            self.expected_usage = self.usage #type: psa_storage.Expr
419        elif expected_usage:
420            self.expected_usage = psa_storage.Expr(' | '.join(expected_usage))
421        else:
422            self.expected_usage = psa_storage.Expr(0)
423
424class StorageFormat:
425    """Storage format stability test cases."""
426
427    def __init__(self, info: psa_information.Information, version: int, forward: bool) -> None:
428        """Prepare to generate test cases for storage format stability.
429
430        * `info`: information about the API. See the `Information` class.
431        * `version`: the storage format version to generate test cases for.
432        * `forward`: if true, generate forward compatibility test cases which
433          save a key and check that its representation is as intended. Otherwise
434          generate backward compatibility test cases which inject a key
435          representation and check that it can be read and used.
436        """
437        self.constructors = info.constructors #type: macro_collector.PSAMacroEnumerator
438        self.version = version #type: int
439        self.forward = forward #type: bool
440
441    RSA_OAEP_RE = re.compile(r'PSA_ALG_RSA_OAEP\((.*)\)\Z')
442    BRAINPOOL_RE = re.compile(r'PSA_KEY_TYPE_\w+\(PSA_ECC_FAMILY_BRAINPOOL_\w+\)\Z')
443    @classmethod
444    def exercise_key_with_algorithm(
445            cls,
446            key_type: psa_storage.Expr, bits: int,
447            alg: psa_storage.Expr
448    ) -> bool:
449        """Whether to exercise the given key with the given algorithm.
450
451        Normally only the type and algorithm matter for compatibility, and
452        this is handled in crypto_knowledge.KeyType.can_do(). This function
453        exists to detect exceptional cases. Exceptional cases detected here
454        are not tested in OpFail and should therefore have manually written
455        test cases.
456        """
457        # Some test keys have the RAW_DATA type and attributes that don't
458        # necessarily make sense. We do this to validate numerical
459        # encodings of the attributes.
460        # Raw data keys have no useful exercise anyway so there is no
461        # loss of test coverage.
462        if key_type.string == 'PSA_KEY_TYPE_RAW_DATA':
463            return False
464        # OAEP requires room for two hashes plus wrapping
465        m = cls.RSA_OAEP_RE.match(alg.string)
466        if m:
467            hash_alg = m.group(1)
468            hash_length = crypto_knowledge.Algorithm.hash_length(hash_alg)
469            key_length = (bits + 7) // 8
470            # Leave enough room for at least one byte of plaintext
471            return key_length > 2 * hash_length + 2
472        # There's nothing wrong with ECC keys on Brainpool curves,
473        # but operations with them are very slow. So we only exercise them
474        # with a single algorithm, not with all possible hashes. We do
475        # exercise other curves with all algorithms so test coverage is
476        # perfectly adequate like this.
477        m = cls.BRAINPOOL_RE.match(key_type.string)
478        if m and alg.string != 'PSA_ALG_ECDSA_ANY':
479            return False
480        return True
481
482    def make_test_case(self, key: StorageTestData) -> test_case.TestCase:
483        """Construct a storage format test case for the given key.
484
485        If ``forward`` is true, generate a forward compatibility test case:
486        create a key and validate that it has the expected representation.
487        Otherwise generate a backward compatibility test case: inject the
488        key representation into storage and validate that it can be read
489        correctly.
490        """
491        verb = 'save' if self.forward else 'read'
492        tc = test_case.TestCase()
493        tc.set_description(verb + ' ' + key.description)
494        dependencies = psa_information.automatic_dependencies(
495            key.lifetime.string, key.type.string,
496            key.alg.string, key.alg2.string,
497        )
498        dependencies = psa_information.finish_family_dependencies(dependencies, key.bits)
499        dependencies += psa_information.generate_deps_from_description(key.description)
500        dependencies = psa_information.fix_key_pair_dependencies(dependencies, 'BASIC')
501        tc.set_dependencies(dependencies)
502        tc.set_function('key_storage_' + verb)
503        if self.forward:
504            extra_arguments = []
505        else:
506            flags = []
507            if self.exercise_key_with_algorithm(key.type, key.bits, key.alg):
508                flags.append('TEST_FLAG_EXERCISE')
509            if 'READ_ONLY' in key.lifetime.string:
510                flags.append('TEST_FLAG_READ_ONLY')
511            extra_arguments = [' | '.join(flags) if flags else '0']
512        tc.set_arguments([key.lifetime.string,
513                          key.type.string, str(key.bits),
514                          key.expected_usage.string,
515                          key.alg.string, key.alg2.string,
516                          '"' + key.material.hex() + '"',
517                          '"' + key.hex() + '"',
518                          *extra_arguments])
519        return tc
520
521    def key_for_lifetime(
522            self,
523            lifetime: str,
524    ) -> StorageTestData:
525        """Construct a test key for the given lifetime."""
526        short = lifetime
527        short = re.sub(r'PSA_KEY_LIFETIME_FROM_PERSISTENCE_AND_LOCATION',
528                       r'', short)
529        short = crypto_knowledge.short_expression(short)
530        description = 'lifetime: ' + short
531        key = StorageTestData(version=self.version,
532                              id=1, lifetime=lifetime,
533                              type='PSA_KEY_TYPE_RAW_DATA', bits=8,
534                              usage=['PSA_KEY_USAGE_EXPORT'], alg=0, alg2=0,
535                              material=b'L',
536                              description=description)
537        return key
538
539    def all_keys_for_lifetimes(self) -> Iterator[StorageTestData]:
540        """Generate test keys covering lifetimes."""
541        lifetimes = sorted(self.constructors.lifetimes)
542        expressions = self.constructors.generate_expressions(lifetimes)
543        for lifetime in expressions:
544            # Don't attempt to create or load a volatile key in storage
545            if 'VOLATILE' in lifetime:
546                continue
547            # Don't attempt to create a read-only key in storage,
548            # but do attempt to load one.
549            if 'READ_ONLY' in lifetime and self.forward:
550                continue
551            yield self.key_for_lifetime(lifetime)
552
553    def key_for_usage_flags(
554            self,
555            usage_flags: List[str],
556            short: Optional[str] = None,
557            test_implicit_usage: Optional[bool] = True
558    ) -> StorageTestData:
559        """Construct a test key for the given key usage."""
560        extra_desc = ' without implication' if test_implicit_usage else ''
561        description = 'usage' + extra_desc + ': '
562        key1 = StorageTestData(version=self.version,
563                               id=1, lifetime=0x00000001,
564                               type='PSA_KEY_TYPE_RAW_DATA', bits=8,
565                               expected_usage=usage_flags,
566                               without_implicit_usage=not test_implicit_usage,
567                               usage=usage_flags, alg=0, alg2=0,
568                               material=b'K',
569                               description=description)
570        if short is None:
571            usage_expr = key1.expected_usage.string
572            key1.description += crypto_knowledge.short_expression(usage_expr)
573        else:
574            key1.description += short
575        return key1
576
577    def generate_keys_for_usage_flags(self, **kwargs) -> Iterator[StorageTestData]:
578        """Generate test keys covering usage flags."""
579        known_flags = sorted(self.constructors.key_usage_flags)
580        yield self.key_for_usage_flags(['0'], **kwargs)
581        for usage_flag in known_flags:
582            yield self.key_for_usage_flags([usage_flag], **kwargs)
583        for flag1, flag2 in zip(known_flags,
584                                known_flags[1:] + [known_flags[0]]):
585            yield self.key_for_usage_flags([flag1, flag2], **kwargs)
586
587    def generate_key_for_all_usage_flags(self) -> Iterator[StorageTestData]:
588        known_flags = sorted(self.constructors.key_usage_flags)
589        yield self.key_for_usage_flags(known_flags, short='all known')
590
591    def all_keys_for_usage_flags(self) -> Iterator[StorageTestData]:
592        yield from self.generate_keys_for_usage_flags()
593        yield from self.generate_key_for_all_usage_flags()
594
595    def key_for_type_and_alg(
596            self,
597            kt: crypto_knowledge.KeyType,
598            bits: int,
599            alg: Optional[crypto_knowledge.Algorithm] = None,
600    ) -> StorageTestData:
601        """Construct a test key of the given type.
602
603        If alg is not None, this key allows it.
604        """
605        usage_flags = ['PSA_KEY_USAGE_EXPORT']
606        alg1 = 0 #type: psa_storage.Exprable
607        alg2 = 0
608        if alg is not None:
609            alg1 = alg.expression
610            usage_flags += alg.usage_flags(public=kt.is_public())
611        key_material = kt.key_material(bits)
612        description = 'type: {} {}-bit'.format(kt.short_expression(1), bits)
613        if alg is not None:
614            description += ', ' + alg.short_expression(1)
615        key = StorageTestData(version=self.version,
616                              id=1, lifetime=0x00000001,
617                              type=kt.expression, bits=bits,
618                              usage=usage_flags, alg=alg1, alg2=alg2,
619                              material=key_material,
620                              description=description)
621        return key
622
623    def keys_for_type(
624            self,
625            key_type: str,
626            all_algorithms: List[crypto_knowledge.Algorithm],
627    ) -> Iterator[StorageTestData]:
628        """Generate test keys for the given key type."""
629        kt = crypto_knowledge.KeyType(key_type)
630        for bits in kt.sizes_to_test():
631            # Test a non-exercisable key, as well as exercisable keys for
632            # each compatible algorithm.
633            # To do: test reading a key from storage with an incompatible
634            # or unsupported algorithm.
635            yield self.key_for_type_and_alg(kt, bits)
636            compatible_algorithms = [alg for alg in all_algorithms
637                                     if kt.can_do(alg)]
638            for alg in compatible_algorithms:
639                yield self.key_for_type_and_alg(kt, bits, alg)
640
641    def all_keys_for_types(self) -> Iterator[StorageTestData]:
642        """Generate test keys covering key types and their representations."""
643        key_types = sorted(self.constructors.key_types)
644        all_algorithms = [crypto_knowledge.Algorithm(alg)
645                          for alg in self.constructors.generate_expressions(
646                              sorted(self.constructors.algorithms)
647                          )]
648        for key_type in self.constructors.generate_expressions(key_types):
649            yield from self.keys_for_type(key_type, all_algorithms)
650
651    def keys_for_algorithm(self, alg: str) -> Iterator[StorageTestData]:
652        """Generate test keys for the encoding of the specified algorithm."""
653        # These test cases only validate the encoding of algorithms, not
654        # whether the key read from storage is suitable for an operation.
655        # `keys_for_types` generate read tests with an algorithm and a
656        # compatible key.
657        descr = crypto_knowledge.short_expression(alg, 1)
658        usage = ['PSA_KEY_USAGE_EXPORT']
659        key1 = StorageTestData(version=self.version,
660                               id=1, lifetime=0x00000001,
661                               type='PSA_KEY_TYPE_RAW_DATA', bits=8,
662                               usage=usage, alg=alg, alg2=0,
663                               material=b'K',
664                               description='alg: ' + descr)
665        yield key1
666        key2 = StorageTestData(version=self.version,
667                               id=1, lifetime=0x00000001,
668                               type='PSA_KEY_TYPE_RAW_DATA', bits=8,
669                               usage=usage, alg=0, alg2=alg,
670                               material=b'L',
671                               description='alg2: ' + descr)
672        yield key2
673
674    def all_keys_for_algorithms(self) -> Iterator[StorageTestData]:
675        """Generate test keys covering algorithm encodings."""
676        algorithms = sorted(self.constructors.algorithms)
677        for alg in self.constructors.generate_expressions(algorithms):
678            yield from self.keys_for_algorithm(alg)
679
680    def generate_all_keys(self) -> Iterator[StorageTestData]:
681        """Generate all keys for the test cases."""
682        yield from self.all_keys_for_lifetimes()
683        yield from self.all_keys_for_usage_flags()
684        yield from self.all_keys_for_types()
685        yield from self.all_keys_for_algorithms()
686
687    def all_test_cases(self) -> Iterator[test_case.TestCase]:
688        """Generate all storage format test cases."""
689        # First build a list of all keys, then construct all the corresponding
690        # test cases. This allows all required information to be obtained in
691        # one go, which is a significant performance gain as the information
692        # includes numerical values obtained by compiling a C program.
693        all_keys = list(self.generate_all_keys())
694        for key in all_keys:
695            if key.location_value() != 0:
696                # Skip keys with a non-default location, because they
697                # require a driver and we currently have no mechanism to
698                # determine whether a driver is available.
699                continue
700            yield self.make_test_case(key)
701
702class StorageFormatForward(StorageFormat):
703    """Storage format stability test cases for forward compatibility."""
704
705    def __init__(self, info: psa_information.Information, version: int) -> None:
706        super().__init__(info, version, True)
707
708class StorageFormatV0(StorageFormat):
709    """Storage format stability test cases for version 0 compatibility."""
710
711    def __init__(self, info: psa_information.Information) -> None:
712        super().__init__(info, 0, False)
713
714    def all_keys_for_usage_flags(self) -> Iterator[StorageTestData]:
715        """Generate test keys covering usage flags."""
716        yield from super().all_keys_for_usage_flags()
717        yield from self.generate_keys_for_usage_flags(test_implicit_usage=False)
718
719    def keys_for_implicit_usage(
720            self,
721            implyer_usage: str,
722            alg: str,
723            key_type: crypto_knowledge.KeyType
724    ) -> StorageTestData:
725        # pylint: disable=too-many-locals
726        """Generate test keys for the specified implicit usage flag,
727           algorithm and key type combination.
728        """
729        bits = key_type.sizes_to_test()[0]
730        implicit_usage = StorageKey.IMPLICIT_USAGE_FLAGS[implyer_usage]
731        usage_flags = ['PSA_KEY_USAGE_EXPORT']
732        material_usage_flags = usage_flags + [implyer_usage]
733        expected_usage_flags = material_usage_flags + [implicit_usage]
734        alg2 = 0
735        key_material = key_type.key_material(bits)
736        usage_expression = crypto_knowledge.short_expression(implyer_usage, 1)
737        alg_expression = crypto_knowledge.short_expression(alg, 1)
738        key_type_expression = key_type.short_expression(1)
739        description = 'implied by {}: {} {} {}-bit'.format(
740            usage_expression, alg_expression, key_type_expression, bits)
741        key = StorageTestData(version=self.version,
742                              id=1, lifetime=0x00000001,
743                              type=key_type.expression, bits=bits,
744                              usage=material_usage_flags,
745                              expected_usage=expected_usage_flags,
746                              without_implicit_usage=True,
747                              alg=alg, alg2=alg2,
748                              material=key_material,
749                              description=description)
750        return key
751
752    def gather_key_types_for_sign_alg(self) -> Dict[str, List[str]]:
753        # pylint: disable=too-many-locals
754        """Match possible key types for sign algorithms."""
755        # To create a valid combination both the algorithms and key types
756        # must be filtered. Pair them with keywords created from its names.
757        incompatible_alg_keyword = frozenset(['RAW', 'ANY', 'PURE'])
758        incompatible_key_type_keywords = frozenset(['MONTGOMERY'])
759        keyword_translation = {
760            'ECDSA': 'ECC',
761            'ED[0-9]*.*' : 'EDWARDS'
762        }
763        exclusive_keywords = {
764            'EDWARDS': 'ECC'
765        }
766        key_types = set(self.constructors.generate_expressions(self.constructors.key_types))
767        algorithms = set(self.constructors.generate_expressions(self.constructors.sign_algorithms))
768        alg_with_keys = {} #type: Dict[str, List[str]]
769        translation_table = str.maketrans('(', '_', ')')
770        for alg in algorithms:
771            # Generate keywords from the name of the algorithm
772            alg_keywords = set(alg.partition('(')[0].split(sep='_')[2:])
773            # Translate keywords for better matching with the key types
774            for keyword in alg_keywords.copy():
775                for pattern, replace in keyword_translation.items():
776                    if re.match(pattern, keyword):
777                        alg_keywords.remove(keyword)
778                        alg_keywords.add(replace)
779            # Filter out incompatible algorithms
780            if not alg_keywords.isdisjoint(incompatible_alg_keyword):
781                continue
782
783            for key_type in key_types:
784                # Generate keywords from the of the key type
785                key_type_keywords = set(key_type.translate(translation_table).split(sep='_')[3:])
786
787                # Remove ambiguous keywords
788                for keyword1, keyword2 in exclusive_keywords.items():
789                    if keyword1 in key_type_keywords:
790                        key_type_keywords.remove(keyword2)
791
792                if key_type_keywords.isdisjoint(incompatible_key_type_keywords) and\
793                   not key_type_keywords.isdisjoint(alg_keywords):
794                    if alg in alg_with_keys:
795                        alg_with_keys[alg].append(key_type)
796                    else:
797                        alg_with_keys[alg] = [key_type]
798        return alg_with_keys
799
800    def all_keys_for_implicit_usage(self) -> Iterator[StorageTestData]:
801        """Generate test keys for usage flag extensions."""
802        # Generate a key type and algorithm pair for each extendable usage
803        # flag to generate a valid key for exercising. The key is generated
804        # without usage extension to check the extension compatibility.
805        alg_with_keys = self.gather_key_types_for_sign_alg()
806
807        for usage in sorted(StorageKey.IMPLICIT_USAGE_FLAGS, key=str):
808            for alg in sorted(alg_with_keys):
809                for key_type in sorted(alg_with_keys[alg]):
810                    # The key types must be filtered to fit the specific usage flag.
811                    kt = crypto_knowledge.KeyType(key_type)
812                    if kt.is_public() and '_SIGN_' in usage:
813                        # Can't sign with a public key
814                        continue
815                    yield self.keys_for_implicit_usage(usage, alg, kt)
816
817    def generate_all_keys(self) -> Iterator[StorageTestData]:
818        yield from super().generate_all_keys()
819        yield from self.all_keys_for_implicit_usage()
820
821
822class PSATestGenerator(test_data_generation.TestGenerator):
823    """Test generator subclass including PSA targets and info."""
824    # Note that targets whose names contain 'test_format' have their content
825    # validated by `abi_check.py`.
826    targets = {
827        'test_suite_psa_crypto_generate_key.generated':
828        lambda info: KeyGenerate(info).test_cases_for_key_generation(),
829        'test_suite_psa_crypto_not_supported.generated':
830        lambda info: KeyTypeNotSupported(info).test_cases_for_not_supported(),
831        'test_suite_psa_crypto_low_hash.generated':
832        lambda info: crypto_data_tests.HashPSALowLevel(info).all_test_cases(),
833        'test_suite_psa_crypto_op_fail.generated':
834        lambda info: OpFail(info).all_test_cases(),
835        'test_suite_psa_crypto_storage_format.current':
836        lambda info: StorageFormatForward(info, 0).all_test_cases(),
837        'test_suite_psa_crypto_storage_format.v0':
838        lambda info: StorageFormatV0(info).all_test_cases(),
839    } #type: Dict[str, Callable[[psa_information.Information], Iterable[test_case.TestCase]]]
840
841    def __init__(self, options):
842        super().__init__(options)
843        self.info = psa_information.Information()
844
845    def generate_target(self, name: str, *target_args) -> None:
846        super().generate_target(name, self.info)
847
848
849if __name__ == '__main__':
850    test_data_generation.main(sys.argv[1:], __doc__, PSATestGenerator)
851