1#!/usr/bin/env python3
2"""Generate test data for PSA cryptographic mechanisms.
3
4With no arguments, generate all test data. With non-option arguments,
5generate only the specified files.
6"""
7
8# Copyright The Mbed TLS Contributors
9# SPDX-License-Identifier: Apache-2.0
10#
11# Licensed under the Apache License, Version 2.0 (the "License"); you may
12# not use this file except in compliance with the License.
13# You may obtain a copy of the License at
14#
15# http://www.apache.org/licenses/LICENSE-2.0
16#
17# Unless required by applicable law or agreed to in writing, software
18# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
19# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20# See the License for the specific language governing permissions and
21# limitations under the License.
22
23import enum
24import re
25import sys
26from typing import Callable, Dict, FrozenSet, Iterable, Iterator, List, Optional
27
28import scripts_path # pylint: disable=unused-import
29from mbedtls_dev import crypto_knowledge
30from mbedtls_dev import macro_collector
31from mbedtls_dev import psa_storage
32from mbedtls_dev import test_case
33from mbedtls_dev import test_data_generation
34
35
36def psa_want_symbol(name: str) -> str:
37    """Return the PSA_WANT_xxx symbol associated with a PSA crypto feature."""
38    if name.startswith('PSA_'):
39        return name[:4] + 'WANT_' + name[4:]
40    else:
41        raise ValueError('Unable to determine the PSA_WANT_ symbol for ' + name)
42
43def finish_family_dependency(dep: str, bits: int) -> str:
44    """Finish dep if it's a family dependency symbol prefix.
45
46    A family dependency symbol prefix is a PSA_WANT_ symbol that needs to be
47    qualified by the key size. If dep is such a symbol, finish it by adjusting
48    the prefix and appending the key size. Other symbols are left unchanged.
49    """
50    return re.sub(r'_FAMILY_(.*)', r'_\1_' + str(bits), dep)
51
52def finish_family_dependencies(dependencies: List[str], bits: int) -> List[str]:
53    """Finish any family dependency symbol prefixes.
54
55    Apply `finish_family_dependency` to each element of `dependencies`.
56    """
57    return [finish_family_dependency(dep, bits) for dep in dependencies]
58
59SYMBOLS_WITHOUT_DEPENDENCY = frozenset([
60    'PSA_ALG_AEAD_WITH_AT_LEAST_THIS_LENGTH_TAG', # modifier, only in policies
61    'PSA_ALG_AEAD_WITH_SHORTENED_TAG', # modifier
62    'PSA_ALG_ANY_HASH', # only in policies
63    'PSA_ALG_AT_LEAST_THIS_LENGTH_MAC', # modifier, only in policies
64    'PSA_ALG_KEY_AGREEMENT', # chaining
65    'PSA_ALG_TRUNCATED_MAC', # modifier
66])
67def automatic_dependencies(*expressions: str) -> List[str]:
68    """Infer dependencies of a test case by looking for PSA_xxx symbols.
69
70    The arguments are strings which should be C expressions. Do not use
71    string literals or comments as this function is not smart enough to
72    skip them.
73    """
74    used = set()
75    for expr in expressions:
76        used.update(re.findall(r'PSA_(?:ALG|ECC_FAMILY|KEY_TYPE)_\w+', expr))
77    used.difference_update(SYMBOLS_WITHOUT_DEPENDENCY)
78    return sorted(psa_want_symbol(name) for name in used)
79
80# A temporary hack: at the time of writing, not all dependency symbols
81# are implemented yet. Skip test cases for which the dependency symbols are
82# not available. Once all dependency symbols are available, this hack must
83# be removed so that a bug in the dependency symbols properly leads to a test
84# failure.
85def read_implemented_dependencies(filename: str) -> FrozenSet[str]:
86    return frozenset(symbol
87                     for line in open(filename)
88                     for symbol in re.findall(r'\bPSA_WANT_\w+\b', line))
89_implemented_dependencies = None #type: Optional[FrozenSet[str]] #pylint: disable=invalid-name
90def hack_dependencies_not_implemented(dependencies: List[str]) -> None:
91    global _implemented_dependencies #pylint: disable=global-statement,invalid-name
92    if _implemented_dependencies is None:
93        _implemented_dependencies = \
94            read_implemented_dependencies('include/psa/crypto_config.h')
95    if not all((dep.lstrip('!') in _implemented_dependencies or 'PSA_WANT' not in dep)
96               for dep in dependencies):
97        dependencies.append('DEPENDENCY_NOT_IMPLEMENTED_YET')
98
99
100class Information:
101    """Gather information about PSA constructors."""
102
103    def __init__(self) -> None:
104        self.constructors = self.read_psa_interface()
105
106    @staticmethod
107    def remove_unwanted_macros(
108            constructors: macro_collector.PSAMacroEnumerator
109    ) -> None:
110        # Mbed TLS doesn't support finite-field DH yet and will not support
111        # finite-field DSA. Don't attempt to generate any related test case.
112        constructors.key_types.discard('PSA_KEY_TYPE_DH_KEY_PAIR')
113        constructors.key_types.discard('PSA_KEY_TYPE_DH_PUBLIC_KEY')
114        constructors.key_types.discard('PSA_KEY_TYPE_DSA_KEY_PAIR')
115        constructors.key_types.discard('PSA_KEY_TYPE_DSA_PUBLIC_KEY')
116
117    def read_psa_interface(self) -> macro_collector.PSAMacroEnumerator:
118        """Return the list of known key types, algorithms, etc."""
119        constructors = macro_collector.InputsForTest()
120        header_file_names = ['include/psa/crypto_values.h',
121                             'include/psa/crypto_extra.h']
122        test_suites = ['tests/suites/test_suite_psa_crypto_metadata.data']
123        for header_file_name in header_file_names:
124            constructors.parse_header(header_file_name)
125        for test_cases in test_suites:
126            constructors.parse_test_cases(test_cases)
127        self.remove_unwanted_macros(constructors)
128        constructors.gather_arguments()
129        return constructors
130
131
132def test_case_for_key_type_not_supported(
133        verb: str, key_type: str, bits: int,
134        dependencies: List[str],
135        *args: str,
136        param_descr: str = ''
137) -> test_case.TestCase:
138    """Return one test case exercising a key creation method
139    for an unsupported key type or size.
140    """
141    hack_dependencies_not_implemented(dependencies)
142    tc = test_case.TestCase()
143    short_key_type = crypto_knowledge.short_expression(key_type)
144    adverb = 'not' if dependencies else 'never'
145    if param_descr:
146        adverb = param_descr + ' ' + adverb
147    tc.set_description('PSA {} {} {}-bit {} supported'
148                       .format(verb, short_key_type, bits, adverb))
149    tc.set_dependencies(dependencies)
150    tc.set_function(verb + '_not_supported')
151    tc.set_arguments([key_type] + list(args))
152    return tc
153
154class KeyTypeNotSupported:
155    """Generate test cases for when a key type is not supported."""
156
157    def __init__(self, info: Information) -> None:
158        self.constructors = info.constructors
159
160    ALWAYS_SUPPORTED = frozenset([
161        'PSA_KEY_TYPE_DERIVE',
162        'PSA_KEY_TYPE_PASSWORD',
163        'PSA_KEY_TYPE_PASSWORD_HASH',
164        'PSA_KEY_TYPE_RAW_DATA',
165        'PSA_KEY_TYPE_HMAC'
166    ])
167    def test_cases_for_key_type_not_supported(
168            self,
169            kt: crypto_knowledge.KeyType,
170            param: Optional[int] = None,
171            param_descr: str = '',
172    ) -> Iterator[test_case.TestCase]:
173        """Return test cases exercising key creation when the given type is unsupported.
174
175        If param is present and not None, emit test cases conditioned on this
176        parameter not being supported. If it is absent or None, emit test cases
177        conditioned on the base type not being supported.
178        """
179        if kt.name in self.ALWAYS_SUPPORTED:
180            # Don't generate test cases for key types that are always supported.
181            # They would be skipped in all configurations, which is noise.
182            return
183        import_dependencies = [('!' if param is None else '') +
184                               psa_want_symbol(kt.name)]
185        if kt.params is not None:
186            import_dependencies += [('!' if param == i else '') +
187                                    psa_want_symbol(sym)
188                                    for i, sym in enumerate(kt.params)]
189        if kt.name.endswith('_PUBLIC_KEY'):
190            generate_dependencies = []
191        else:
192            generate_dependencies = import_dependencies
193        for bits in kt.sizes_to_test():
194            yield test_case_for_key_type_not_supported(
195                'import', kt.expression, bits,
196                finish_family_dependencies(import_dependencies, bits),
197                test_case.hex_string(kt.key_material(bits)),
198                param_descr=param_descr,
199            )
200            if not generate_dependencies and param is not None:
201                # If generation is impossible for this key type, rather than
202                # supported or not depending on implementation capabilities,
203                # only generate the test case once.
204                continue
205                # For public key we expect that key generation fails with
206                # INVALID_ARGUMENT. It is handled by KeyGenerate class.
207            if not kt.is_public():
208                yield test_case_for_key_type_not_supported(
209                    'generate', kt.expression, bits,
210                    finish_family_dependencies(generate_dependencies, bits),
211                    str(bits),
212                    param_descr=param_descr,
213                )
214            # To be added: derive
215
216    ECC_KEY_TYPES = ('PSA_KEY_TYPE_ECC_KEY_PAIR',
217                     'PSA_KEY_TYPE_ECC_PUBLIC_KEY')
218
219    def test_cases_for_not_supported(self) -> Iterator[test_case.TestCase]:
220        """Generate test cases that exercise the creation of keys of unsupported types."""
221        for key_type in sorted(self.constructors.key_types):
222            if key_type in self.ECC_KEY_TYPES:
223                continue
224            kt = crypto_knowledge.KeyType(key_type)
225            yield from self.test_cases_for_key_type_not_supported(kt)
226        for curve_family in sorted(self.constructors.ecc_curves):
227            for constr in self.ECC_KEY_TYPES:
228                kt = crypto_knowledge.KeyType(constr, [curve_family])
229                yield from self.test_cases_for_key_type_not_supported(
230                    kt, param_descr='type')
231                yield from self.test_cases_for_key_type_not_supported(
232                    kt, 0, param_descr='curve')
233
234def test_case_for_key_generation(
235        key_type: str, bits: int,
236        dependencies: List[str],
237        *args: str,
238        result: str = ''
239) -> test_case.TestCase:
240    """Return one test case exercising a key generation.
241    """
242    hack_dependencies_not_implemented(dependencies)
243    tc = test_case.TestCase()
244    short_key_type = crypto_knowledge.short_expression(key_type)
245    tc.set_description('PSA {} {}-bit'
246                       .format(short_key_type, bits))
247    tc.set_dependencies(dependencies)
248    tc.set_function('generate_key')
249    tc.set_arguments([key_type] + list(args) + [result])
250
251    return tc
252
253class KeyGenerate:
254    """Generate positive and negative (invalid argument) test cases for key generation."""
255
256    def __init__(self, info: Information) -> None:
257        self.constructors = info.constructors
258
259    ECC_KEY_TYPES = ('PSA_KEY_TYPE_ECC_KEY_PAIR',
260                     'PSA_KEY_TYPE_ECC_PUBLIC_KEY')
261
262    @staticmethod
263    def test_cases_for_key_type_key_generation(
264            kt: crypto_knowledge.KeyType
265    ) -> Iterator[test_case.TestCase]:
266        """Return test cases exercising key generation.
267
268        All key types can be generated except for public keys. For public key
269        PSA_ERROR_INVALID_ARGUMENT status is expected.
270        """
271        result = 'PSA_SUCCESS'
272
273        import_dependencies = [psa_want_symbol(kt.name)]
274        if kt.params is not None:
275            import_dependencies += [psa_want_symbol(sym)
276                                    for i, sym in enumerate(kt.params)]
277        if kt.name.endswith('_PUBLIC_KEY'):
278            # The library checks whether the key type is a public key generically,
279            # before it reaches a point where it needs support for the specific key
280            # type, so it returns INVALID_ARGUMENT for unsupported public key types.
281            generate_dependencies = []
282            result = 'PSA_ERROR_INVALID_ARGUMENT'
283        else:
284            generate_dependencies = import_dependencies
285            if kt.name == 'PSA_KEY_TYPE_RSA_KEY_PAIR':
286                generate_dependencies.append("MBEDTLS_GENPRIME")
287        for bits in kt.sizes_to_test():
288            yield test_case_for_key_generation(
289                kt.expression, bits,
290                finish_family_dependencies(generate_dependencies, bits),
291                str(bits),
292                result
293            )
294
295    def test_cases_for_key_generation(self) -> Iterator[test_case.TestCase]:
296        """Generate test cases that exercise the generation of keys."""
297        for key_type in sorted(self.constructors.key_types):
298            if key_type in self.ECC_KEY_TYPES:
299                continue
300            kt = crypto_knowledge.KeyType(key_type)
301            yield from self.test_cases_for_key_type_key_generation(kt)
302        for curve_family in sorted(self.constructors.ecc_curves):
303            for constr in self.ECC_KEY_TYPES:
304                kt = crypto_knowledge.KeyType(constr, [curve_family])
305                yield from self.test_cases_for_key_type_key_generation(kt)
306
307class OpFail:
308    """Generate test cases for operations that must fail."""
309    #pylint: disable=too-few-public-methods
310
311    class Reason(enum.Enum):
312        NOT_SUPPORTED = 0
313        INVALID = 1
314        INCOMPATIBLE = 2
315        PUBLIC = 3
316
317    def __init__(self, info: Information) -> None:
318        self.constructors = info.constructors
319        key_type_expressions = self.constructors.generate_expressions(
320            sorted(self.constructors.key_types)
321        )
322        self.key_types = [crypto_knowledge.KeyType(kt_expr)
323                          for kt_expr in key_type_expressions]
324
325    def make_test_case(
326            self,
327            alg: crypto_knowledge.Algorithm,
328            category: crypto_knowledge.AlgorithmCategory,
329            reason: 'Reason',
330            kt: Optional[crypto_knowledge.KeyType] = None,
331            not_deps: FrozenSet[str] = frozenset(),
332    ) -> test_case.TestCase:
333        """Construct a failure test case for a one-key or keyless operation."""
334        #pylint: disable=too-many-arguments,too-many-locals
335        tc = test_case.TestCase()
336        pretty_alg = alg.short_expression()
337        if reason == self.Reason.NOT_SUPPORTED:
338            short_deps = [re.sub(r'PSA_WANT_ALG_', r'', dep)
339                          for dep in not_deps]
340            pretty_reason = '!' + '&'.join(sorted(short_deps))
341        else:
342            pretty_reason = reason.name.lower()
343        if kt:
344            key_type = kt.expression
345            pretty_type = kt.short_expression()
346        else:
347            key_type = ''
348            pretty_type = ''
349        tc.set_description('PSA {} {}: {}{}'
350                           .format(category.name.lower(),
351                                   pretty_alg,
352                                   pretty_reason,
353                                   ' with ' + pretty_type if pretty_type else ''))
354        dependencies = automatic_dependencies(alg.base_expression, key_type)
355        for i, dep in enumerate(dependencies):
356            if dep in not_deps:
357                dependencies[i] = '!' + dep
358        tc.set_dependencies(dependencies)
359        tc.set_function(category.name.lower() + '_fail')
360        arguments = [] # type: List[str]
361        if kt:
362            key_material = kt.key_material(kt.sizes_to_test()[0])
363            arguments += [key_type, test_case.hex_string(key_material)]
364        arguments.append(alg.expression)
365        if category.is_asymmetric():
366            arguments.append('1' if reason == self.Reason.PUBLIC else '0')
367        error = ('NOT_SUPPORTED' if reason == self.Reason.NOT_SUPPORTED else
368                 'INVALID_ARGUMENT')
369        arguments.append('PSA_ERROR_' + error)
370        tc.set_arguments(arguments)
371        return tc
372
373    def no_key_test_cases(
374            self,
375            alg: crypto_knowledge.Algorithm,
376            category: crypto_knowledge.AlgorithmCategory,
377    ) -> Iterator[test_case.TestCase]:
378        """Generate failure test cases for keyless operations with the specified algorithm."""
379        if alg.can_do(category):
380            # Compatible operation, unsupported algorithm
381            for dep in automatic_dependencies(alg.base_expression):
382                yield self.make_test_case(alg, category,
383                                          self.Reason.NOT_SUPPORTED,
384                                          not_deps=frozenset([dep]))
385        else:
386            # Incompatible operation, supported algorithm
387            yield self.make_test_case(alg, category, self.Reason.INVALID)
388
389    def one_key_test_cases(
390            self,
391            alg: crypto_knowledge.Algorithm,
392            category: crypto_knowledge.AlgorithmCategory,
393    ) -> Iterator[test_case.TestCase]:
394        """Generate failure test cases for one-key operations with the specified algorithm."""
395        for kt in self.key_types:
396            key_is_compatible = kt.can_do(alg)
397            if key_is_compatible and alg.can_do(category):
398                # Compatible key and operation, unsupported algorithm
399                for dep in automatic_dependencies(alg.base_expression):
400                    yield self.make_test_case(alg, category,
401                                              self.Reason.NOT_SUPPORTED,
402                                              kt=kt, not_deps=frozenset([dep]))
403                # Public key for a private-key operation
404                if category.is_asymmetric() and kt.is_public():
405                    yield self.make_test_case(alg, category,
406                                              self.Reason.PUBLIC,
407                                              kt=kt)
408            elif key_is_compatible:
409                # Compatible key, incompatible operation, supported algorithm
410                yield self.make_test_case(alg, category,
411                                          self.Reason.INVALID,
412                                          kt=kt)
413            elif alg.can_do(category):
414                # Incompatible key, compatible operation, supported algorithm
415                yield self.make_test_case(alg, category,
416                                          self.Reason.INCOMPATIBLE,
417                                          kt=kt)
418            else:
419                # Incompatible key and operation. Don't test cases where
420                # multiple things are wrong, to keep the number of test
421                # cases reasonable.
422                pass
423
424    def test_cases_for_algorithm(
425            self,
426            alg: crypto_knowledge.Algorithm,
427    ) -> Iterator[test_case.TestCase]:
428        """Generate operation failure test cases for the specified algorithm."""
429        for category in crypto_knowledge.AlgorithmCategory:
430            if category == crypto_knowledge.AlgorithmCategory.PAKE:
431                # PAKE operations are not implemented yet
432                pass
433            elif category.requires_key():
434                yield from self.one_key_test_cases(alg, category)
435            else:
436                yield from self.no_key_test_cases(alg, category)
437
438    def all_test_cases(self) -> Iterator[test_case.TestCase]:
439        """Generate all test cases for operations that must fail."""
440        algorithms = sorted(self.constructors.algorithms)
441        for expr in self.constructors.generate_expressions(algorithms):
442            alg = crypto_knowledge.Algorithm(expr)
443            yield from self.test_cases_for_algorithm(alg)
444
445
446class StorageKey(psa_storage.Key):
447    """Representation of a key for storage format testing."""
448
449    IMPLICIT_USAGE_FLAGS = {
450        'PSA_KEY_USAGE_SIGN_HASH': 'PSA_KEY_USAGE_SIGN_MESSAGE',
451        'PSA_KEY_USAGE_VERIFY_HASH': 'PSA_KEY_USAGE_VERIFY_MESSAGE'
452    } #type: Dict[str, str]
453    """Mapping of usage flags to the flags that they imply."""
454
455    def __init__(
456            self,
457            usage: Iterable[str],
458            without_implicit_usage: Optional[bool] = False,
459            **kwargs
460    ) -> None:
461        """Prepare to generate a key.
462
463        * `usage`                 : The usage flags used for the key.
464        * `without_implicit_usage`: Flag to define to apply the usage extension
465        """
466        usage_flags = set(usage)
467        if not without_implicit_usage:
468            for flag in sorted(usage_flags):
469                if flag in self.IMPLICIT_USAGE_FLAGS:
470                    usage_flags.add(self.IMPLICIT_USAGE_FLAGS[flag])
471        if usage_flags:
472            usage_expression = ' | '.join(sorted(usage_flags))
473        else:
474            usage_expression = '0'
475        super().__init__(usage=usage_expression, **kwargs)
476
477class StorageTestData(StorageKey):
478    """Representation of test case data for storage format testing."""
479
480    def __init__(
481            self,
482            description: str,
483            expected_usage: Optional[List[str]] = None,
484            **kwargs
485    ) -> None:
486        """Prepare to generate test data
487
488        * `description`   : used for the test case names
489        * `expected_usage`: the usage flags generated as the expected usage flags
490                            in the test cases. CAn differ from the usage flags
491                            stored in the keys because of the usage flags extension.
492        """
493        super().__init__(**kwargs)
494        self.description = description #type: str
495        if expected_usage is None:
496            self.expected_usage = self.usage #type: psa_storage.Expr
497        elif expected_usage:
498            self.expected_usage = psa_storage.Expr(' | '.join(expected_usage))
499        else:
500            self.expected_usage = psa_storage.Expr(0)
501
502class StorageFormat:
503    """Storage format stability test cases."""
504
505    def __init__(self, info: Information, version: int, forward: bool) -> None:
506        """Prepare to generate test cases for storage format stability.
507
508        * `info`: information about the API. See the `Information` class.
509        * `version`: the storage format version to generate test cases for.
510        * `forward`: if true, generate forward compatibility test cases which
511          save a key and check that its representation is as intended. Otherwise
512          generate backward compatibility test cases which inject a key
513          representation and check that it can be read and used.
514        """
515        self.constructors = info.constructors #type: macro_collector.PSAMacroEnumerator
516        self.version = version #type: int
517        self.forward = forward #type: bool
518
519    RSA_OAEP_RE = re.compile(r'PSA_ALG_RSA_OAEP\((.*)\)\Z')
520    BRAINPOOL_RE = re.compile(r'PSA_KEY_TYPE_\w+\(PSA_ECC_FAMILY_BRAINPOOL_\w+\)\Z')
521    @classmethod
522    def exercise_key_with_algorithm(
523            cls,
524            key_type: psa_storage.Expr, bits: int,
525            alg: psa_storage.Expr
526    ) -> bool:
527        """Whether to exercise the given key with the given algorithm.
528
529        Normally only the type and algorithm matter for compatibility, and
530        this is handled in crypto_knowledge.KeyType.can_do(). This function
531        exists to detect exceptional cases. Exceptional cases detected here
532        are not tested in OpFail and should therefore have manually written
533        test cases.
534        """
535        # Some test keys have the RAW_DATA type and attributes that don't
536        # necessarily make sense. We do this to validate numerical
537        # encodings of the attributes.
538        # Raw data keys have no useful exercise anyway so there is no
539        # loss of test coverage.
540        if key_type.string == 'PSA_KEY_TYPE_RAW_DATA':
541            return False
542        # OAEP requires room for two hashes plus wrapping
543        m = cls.RSA_OAEP_RE.match(alg.string)
544        if m:
545            hash_alg = m.group(1)
546            hash_length = crypto_knowledge.Algorithm.hash_length(hash_alg)
547            key_length = (bits + 7) // 8
548            # Leave enough room for at least one byte of plaintext
549            return key_length > 2 * hash_length + 2
550        # There's nothing wrong with ECC keys on Brainpool curves,
551        # but operations with them are very slow. So we only exercise them
552        # with a single algorithm, not with all possible hashes. We do
553        # exercise other curves with all algorithms so test coverage is
554        # perfectly adequate like this.
555        m = cls.BRAINPOOL_RE.match(key_type.string)
556        if m and alg.string != 'PSA_ALG_ECDSA_ANY':
557            return False
558        return True
559
560    def make_test_case(self, key: StorageTestData) -> test_case.TestCase:
561        """Construct a storage format test case for the given key.
562
563        If ``forward`` is true, generate a forward compatibility test case:
564        create a key and validate that it has the expected representation.
565        Otherwise generate a backward compatibility test case: inject the
566        key representation into storage and validate that it can be read
567        correctly.
568        """
569        verb = 'save' if self.forward else 'read'
570        tc = test_case.TestCase()
571        tc.set_description(verb + ' ' + key.description)
572        dependencies = automatic_dependencies(
573            key.lifetime.string, key.type.string,
574            key.alg.string, key.alg2.string,
575        )
576        dependencies = finish_family_dependencies(dependencies, key.bits)
577        tc.set_dependencies(dependencies)
578        tc.set_function('key_storage_' + verb)
579        if self.forward:
580            extra_arguments = []
581        else:
582            flags = []
583            if self.exercise_key_with_algorithm(key.type, key.bits, key.alg):
584                flags.append('TEST_FLAG_EXERCISE')
585            if 'READ_ONLY' in key.lifetime.string:
586                flags.append('TEST_FLAG_READ_ONLY')
587            extra_arguments = [' | '.join(flags) if flags else '0']
588        tc.set_arguments([key.lifetime.string,
589                          key.type.string, str(key.bits),
590                          key.expected_usage.string,
591                          key.alg.string, key.alg2.string,
592                          '"' + key.material.hex() + '"',
593                          '"' + key.hex() + '"',
594                          *extra_arguments])
595        return tc
596
597    def key_for_lifetime(
598            self,
599            lifetime: str,
600    ) -> StorageTestData:
601        """Construct a test key for the given lifetime."""
602        short = lifetime
603        short = re.sub(r'PSA_KEY_LIFETIME_FROM_PERSISTENCE_AND_LOCATION',
604                       r'', short)
605        short = crypto_knowledge.short_expression(short)
606        description = 'lifetime: ' + short
607        key = StorageTestData(version=self.version,
608                              id=1, lifetime=lifetime,
609                              type='PSA_KEY_TYPE_RAW_DATA', bits=8,
610                              usage=['PSA_KEY_USAGE_EXPORT'], alg=0, alg2=0,
611                              material=b'L',
612                              description=description)
613        return key
614
615    def all_keys_for_lifetimes(self) -> Iterator[StorageTestData]:
616        """Generate test keys covering lifetimes."""
617        lifetimes = sorted(self.constructors.lifetimes)
618        expressions = self.constructors.generate_expressions(lifetimes)
619        for lifetime in expressions:
620            # Don't attempt to create or load a volatile key in storage
621            if 'VOLATILE' in lifetime:
622                continue
623            # Don't attempt to create a read-only key in storage,
624            # but do attempt to load one.
625            if 'READ_ONLY' in lifetime and self.forward:
626                continue
627            yield self.key_for_lifetime(lifetime)
628
629    def key_for_usage_flags(
630            self,
631            usage_flags: List[str],
632            short: Optional[str] = None,
633            test_implicit_usage: Optional[bool] = True
634    ) -> StorageTestData:
635        """Construct a test key for the given key usage."""
636        extra_desc = ' without implication' if test_implicit_usage else ''
637        description = 'usage' + extra_desc + ': '
638        key1 = StorageTestData(version=self.version,
639                               id=1, lifetime=0x00000001,
640                               type='PSA_KEY_TYPE_RAW_DATA', bits=8,
641                               expected_usage=usage_flags,
642                               without_implicit_usage=not test_implicit_usage,
643                               usage=usage_flags, alg=0, alg2=0,
644                               material=b'K',
645                               description=description)
646        if short is None:
647            usage_expr = key1.expected_usage.string
648            key1.description += crypto_knowledge.short_expression(usage_expr)
649        else:
650            key1.description += short
651        return key1
652
653    def generate_keys_for_usage_flags(self, **kwargs) -> Iterator[StorageTestData]:
654        """Generate test keys covering usage flags."""
655        known_flags = sorted(self.constructors.key_usage_flags)
656        yield self.key_for_usage_flags(['0'], **kwargs)
657        for usage_flag in known_flags:
658            yield self.key_for_usage_flags([usage_flag], **kwargs)
659        for flag1, flag2 in zip(known_flags,
660                                known_flags[1:] + [known_flags[0]]):
661            yield self.key_for_usage_flags([flag1, flag2], **kwargs)
662
663    def generate_key_for_all_usage_flags(self) -> Iterator[StorageTestData]:
664        known_flags = sorted(self.constructors.key_usage_flags)
665        yield self.key_for_usage_flags(known_flags, short='all known')
666
667    def all_keys_for_usage_flags(self) -> Iterator[StorageTestData]:
668        yield from self.generate_keys_for_usage_flags()
669        yield from self.generate_key_for_all_usage_flags()
670
671    def key_for_type_and_alg(
672            self,
673            kt: crypto_knowledge.KeyType,
674            bits: int,
675            alg: Optional[crypto_knowledge.Algorithm] = None,
676    ) -> StorageTestData:
677        """Construct a test key of the given type.
678
679        If alg is not None, this key allows it.
680        """
681        usage_flags = ['PSA_KEY_USAGE_EXPORT']
682        alg1 = 0 #type: psa_storage.Exprable
683        alg2 = 0
684        if alg is not None:
685            alg1 = alg.expression
686            usage_flags += alg.usage_flags(public=kt.is_public())
687        key_material = kt.key_material(bits)
688        description = 'type: {} {}-bit'.format(kt.short_expression(1), bits)
689        if alg is not None:
690            description += ', ' + alg.short_expression(1)
691        key = StorageTestData(version=self.version,
692                              id=1, lifetime=0x00000001,
693                              type=kt.expression, bits=bits,
694                              usage=usage_flags, alg=alg1, alg2=alg2,
695                              material=key_material,
696                              description=description)
697        return key
698
699    def keys_for_type(
700            self,
701            key_type: str,
702            all_algorithms: List[crypto_knowledge.Algorithm],
703    ) -> Iterator[StorageTestData]:
704        """Generate test keys for the given key type."""
705        kt = crypto_knowledge.KeyType(key_type)
706        for bits in kt.sizes_to_test():
707            # Test a non-exercisable key, as well as exercisable keys for
708            # each compatible algorithm.
709            # To do: test reading a key from storage with an incompatible
710            # or unsupported algorithm.
711            yield self.key_for_type_and_alg(kt, bits)
712            compatible_algorithms = [alg for alg in all_algorithms
713                                     if kt.can_do(alg)]
714            for alg in compatible_algorithms:
715                yield self.key_for_type_and_alg(kt, bits, alg)
716
717    def all_keys_for_types(self) -> Iterator[StorageTestData]:
718        """Generate test keys covering key types and their representations."""
719        key_types = sorted(self.constructors.key_types)
720        all_algorithms = [crypto_knowledge.Algorithm(alg)
721                          for alg in self.constructors.generate_expressions(
722                              sorted(self.constructors.algorithms)
723                          )]
724        for key_type in self.constructors.generate_expressions(key_types):
725            yield from self.keys_for_type(key_type, all_algorithms)
726
727    def keys_for_algorithm(self, alg: str) -> Iterator[StorageTestData]:
728        """Generate test keys for the encoding of the specified algorithm."""
729        # These test cases only validate the encoding of algorithms, not
730        # whether the key read from storage is suitable for an operation.
731        # `keys_for_types` generate read tests with an algorithm and a
732        # compatible key.
733        descr = crypto_knowledge.short_expression(alg, 1)
734        usage = ['PSA_KEY_USAGE_EXPORT']
735        key1 = StorageTestData(version=self.version,
736                               id=1, lifetime=0x00000001,
737                               type='PSA_KEY_TYPE_RAW_DATA', bits=8,
738                               usage=usage, alg=alg, alg2=0,
739                               material=b'K',
740                               description='alg: ' + descr)
741        yield key1
742        key2 = StorageTestData(version=self.version,
743                               id=1, lifetime=0x00000001,
744                               type='PSA_KEY_TYPE_RAW_DATA', bits=8,
745                               usage=usage, alg=0, alg2=alg,
746                               material=b'L',
747                               description='alg2: ' + descr)
748        yield key2
749
750    def all_keys_for_algorithms(self) -> Iterator[StorageTestData]:
751        """Generate test keys covering algorithm encodings."""
752        algorithms = sorted(self.constructors.algorithms)
753        for alg in self.constructors.generate_expressions(algorithms):
754            yield from self.keys_for_algorithm(alg)
755
756    def generate_all_keys(self) -> Iterator[StorageTestData]:
757        """Generate all keys for the test cases."""
758        yield from self.all_keys_for_lifetimes()
759        yield from self.all_keys_for_usage_flags()
760        yield from self.all_keys_for_types()
761        yield from self.all_keys_for_algorithms()
762
763    def all_test_cases(self) -> Iterator[test_case.TestCase]:
764        """Generate all storage format test cases."""
765        # First build a list of all keys, then construct all the corresponding
766        # test cases. This allows all required information to be obtained in
767        # one go, which is a significant performance gain as the information
768        # includes numerical values obtained by compiling a C program.
769        all_keys = list(self.generate_all_keys())
770        for key in all_keys:
771            if key.location_value() != 0:
772                # Skip keys with a non-default location, because they
773                # require a driver and we currently have no mechanism to
774                # determine whether a driver is available.
775                continue
776            yield self.make_test_case(key)
777
778class StorageFormatForward(StorageFormat):
779    """Storage format stability test cases for forward compatibility."""
780
781    def __init__(self, info: Information, version: int) -> None:
782        super().__init__(info, version, True)
783
784class StorageFormatV0(StorageFormat):
785    """Storage format stability test cases for version 0 compatibility."""
786
787    def __init__(self, info: Information) -> None:
788        super().__init__(info, 0, False)
789
790    def all_keys_for_usage_flags(self) -> Iterator[StorageTestData]:
791        """Generate test keys covering usage flags."""
792        yield from super().all_keys_for_usage_flags()
793        yield from self.generate_keys_for_usage_flags(test_implicit_usage=False)
794
795    def keys_for_implicit_usage(
796            self,
797            implyer_usage: str,
798            alg: str,
799            key_type: crypto_knowledge.KeyType
800    ) -> StorageTestData:
801        # pylint: disable=too-many-locals
802        """Generate test keys for the specified implicit usage flag,
803           algorithm and key type combination.
804        """
805        bits = key_type.sizes_to_test()[0]
806        implicit_usage = StorageKey.IMPLICIT_USAGE_FLAGS[implyer_usage]
807        usage_flags = ['PSA_KEY_USAGE_EXPORT']
808        material_usage_flags = usage_flags + [implyer_usage]
809        expected_usage_flags = material_usage_flags + [implicit_usage]
810        alg2 = 0
811        key_material = key_type.key_material(bits)
812        usage_expression = crypto_knowledge.short_expression(implyer_usage, 1)
813        alg_expression = crypto_knowledge.short_expression(alg, 1)
814        key_type_expression = key_type.short_expression(1)
815        description = 'implied by {}: {} {} {}-bit'.format(
816            usage_expression, alg_expression, key_type_expression, bits)
817        key = StorageTestData(version=self.version,
818                              id=1, lifetime=0x00000001,
819                              type=key_type.expression, bits=bits,
820                              usage=material_usage_flags,
821                              expected_usage=expected_usage_flags,
822                              without_implicit_usage=True,
823                              alg=alg, alg2=alg2,
824                              material=key_material,
825                              description=description)
826        return key
827
828    def gather_key_types_for_sign_alg(self) -> Dict[str, List[str]]:
829        # pylint: disable=too-many-locals
830        """Match possible key types for sign algorithms."""
831        # To create a valid combination both the algorithms and key types
832        # must be filtered. Pair them with keywords created from its names.
833        incompatible_alg_keyword = frozenset(['RAW', 'ANY', 'PURE'])
834        incompatible_key_type_keywords = frozenset(['MONTGOMERY'])
835        keyword_translation = {
836            'ECDSA': 'ECC',
837            'ED[0-9]*.*' : 'EDWARDS'
838        }
839        exclusive_keywords = {
840            'EDWARDS': 'ECC'
841        }
842        key_types = set(self.constructors.generate_expressions(self.constructors.key_types))
843        algorithms = set(self.constructors.generate_expressions(self.constructors.sign_algorithms))
844        alg_with_keys = {} #type: Dict[str, List[str]]
845        translation_table = str.maketrans('(', '_', ')')
846        for alg in algorithms:
847            # Generate keywords from the name of the algorithm
848            alg_keywords = set(alg.partition('(')[0].split(sep='_')[2:])
849            # Translate keywords for better matching with the key types
850            for keyword in alg_keywords.copy():
851                for pattern, replace in keyword_translation.items():
852                    if re.match(pattern, keyword):
853                        alg_keywords.remove(keyword)
854                        alg_keywords.add(replace)
855            # Filter out incompatible algorithms
856            if not alg_keywords.isdisjoint(incompatible_alg_keyword):
857                continue
858
859            for key_type in key_types:
860                # Generate keywords from the of the key type
861                key_type_keywords = set(key_type.translate(translation_table).split(sep='_')[3:])
862
863                # Remove ambiguous keywords
864                for keyword1, keyword2 in exclusive_keywords.items():
865                    if keyword1 in key_type_keywords:
866                        key_type_keywords.remove(keyword2)
867
868                if key_type_keywords.isdisjoint(incompatible_key_type_keywords) and\
869                   not key_type_keywords.isdisjoint(alg_keywords):
870                    if alg in alg_with_keys:
871                        alg_with_keys[alg].append(key_type)
872                    else:
873                        alg_with_keys[alg] = [key_type]
874        return alg_with_keys
875
876    def all_keys_for_implicit_usage(self) -> Iterator[StorageTestData]:
877        """Generate test keys for usage flag extensions."""
878        # Generate a key type and algorithm pair for each extendable usage
879        # flag to generate a valid key for exercising. The key is generated
880        # without usage extension to check the extension compatibility.
881        alg_with_keys = self.gather_key_types_for_sign_alg()
882
883        for usage in sorted(StorageKey.IMPLICIT_USAGE_FLAGS, key=str):
884            for alg in sorted(alg_with_keys):
885                for key_type in sorted(alg_with_keys[alg]):
886                    # The key types must be filtered to fit the specific usage flag.
887                    kt = crypto_knowledge.KeyType(key_type)
888                    if kt.is_public() and '_SIGN_' in usage:
889                        # Can't sign with a public key
890                        continue
891                    yield self.keys_for_implicit_usage(usage, alg, kt)
892
893    def generate_all_keys(self) -> Iterator[StorageTestData]:
894        yield from super().generate_all_keys()
895        yield from self.all_keys_for_implicit_usage()
896
897class PSATestGenerator(test_data_generation.TestGenerator):
898    """Test generator subclass including PSA targets and info."""
899    # Note that targets whose names contain 'test_format' have their content
900    # validated by `abi_check.py`.
901    targets = {
902        'test_suite_psa_crypto_generate_key.generated':
903        lambda info: KeyGenerate(info).test_cases_for_key_generation(),
904        'test_suite_psa_crypto_not_supported.generated':
905        lambda info: KeyTypeNotSupported(info).test_cases_for_not_supported(),
906        'test_suite_psa_crypto_op_fail.generated':
907        lambda info: OpFail(info).all_test_cases(),
908        'test_suite_psa_crypto_storage_format.current':
909        lambda info: StorageFormatForward(info, 0).all_test_cases(),
910        'test_suite_psa_crypto_storage_format.v0':
911        lambda info: StorageFormatV0(info).all_test_cases(),
912    } #type: Dict[str, Callable[[Information], Iterable[test_case.TestCase]]]
913
914    def __init__(self, options):
915        super().__init__(options)
916        self.info = Information()
917
918    def generate_target(self, name: str, *target_args) -> None:
919        super().generate_target(name, self.info)
920
921if __name__ == '__main__':
922    test_data_generation.main(sys.argv[1:], __doc__, PSATestGenerator)
923