1#!/usr/bin/env python3
2
3"""Generate psa_constant_names_generated.c
4which is included by programs/psa/psa_constant_names.c.
5The code generated by this module is only meant to be used in the context
6of that program.
7
8An argument passed to this script will modify the output directory where the
9file is written:
10* by default (no arguments passed): writes to programs/psa/
11* OUTPUT_FILE_DIR passed: writes to OUTPUT_FILE_DIR/
12"""
13
14# Copyright The Mbed TLS Contributors
15# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
16
17import os
18import sys
19
20import framework_scripts_path # pylint: disable=unused-import
21from mbedtls_framework import build_tree
22from mbedtls_framework import macro_collector
23
24OUTPUT_TEMPLATE = '''\
25/* Automatically generated by generate_psa_constant.py. DO NOT EDIT. */
26
27static const char *psa_strerror(psa_status_t status)
28{
29    switch (status) {
30    %(status_cases)s
31    default: return NULL;
32    }
33}
34
35static const char *psa_ecc_family_name(psa_ecc_family_t curve)
36{
37    switch (curve) {
38    %(ecc_curve_cases)s
39    default: return NULL;
40    }
41}
42
43static const char *psa_dh_family_name(psa_dh_family_t group)
44{
45    switch (group) {
46    %(dh_group_cases)s
47    default: return NULL;
48    }
49}
50
51static const char *psa_hash_algorithm_name(psa_algorithm_t hash_alg)
52{
53    switch (hash_alg) {
54    %(hash_algorithm_cases)s
55    default: return NULL;
56    }
57}
58
59static const char *psa_ka_algorithm_name(psa_algorithm_t ka_alg)
60{
61    switch (ka_alg) {
62    %(ka_algorithm_cases)s
63    default: return NULL;
64    }
65}
66
67static int psa_snprint_key_type(char *buffer, size_t buffer_size,
68                                psa_key_type_t type)
69{
70    size_t required_size = 0;
71    switch (type) {
72    %(key_type_cases)s
73    default:
74        %(key_type_code)s{
75            return snprintf(buffer, buffer_size,
76                            "0x%%04x", (unsigned) type);
77        }
78        break;
79    }
80    buffer[0] = 0;
81    return (int) required_size;
82}
83
84#define NO_LENGTH_MODIFIER 0xfffffffflu
85static int psa_snprint_algorithm(char *buffer, size_t buffer_size,
86                                 psa_algorithm_t alg)
87{
88    size_t required_size = 0;
89    psa_algorithm_t core_alg = alg;
90    unsigned long length_modifier = NO_LENGTH_MODIFIER;
91    if (PSA_ALG_IS_MAC(alg)) {
92        core_alg = PSA_ALG_TRUNCATED_MAC(alg, 0);
93        if (alg & PSA_ALG_MAC_AT_LEAST_THIS_LENGTH_FLAG) {
94            append(&buffer, buffer_size, &required_size,
95                   "PSA_ALG_AT_LEAST_THIS_LENGTH_MAC(", 33);
96            length_modifier = PSA_MAC_TRUNCATED_LENGTH(alg);
97        } else if (core_alg != alg) {
98            append(&buffer, buffer_size, &required_size,
99                   "PSA_ALG_TRUNCATED_MAC(", 22);
100            length_modifier = PSA_MAC_TRUNCATED_LENGTH(alg);
101        }
102    } else if (PSA_ALG_IS_AEAD(alg)) {
103        core_alg = PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG(alg);
104        if (core_alg == 0) {
105            /* For unknown AEAD algorithms, there is no "default tag length". */
106            core_alg = alg;
107        } else if (alg & PSA_ALG_AEAD_AT_LEAST_THIS_LENGTH_FLAG) {
108            append(&buffer, buffer_size, &required_size,
109                   "PSA_ALG_AEAD_WITH_AT_LEAST_THIS_LENGTH_TAG(", 43);
110            length_modifier = PSA_ALG_AEAD_GET_TAG_LENGTH(alg);
111        } else if (core_alg != alg) {
112            append(&buffer, buffer_size, &required_size,
113                   "PSA_ALG_AEAD_WITH_SHORTENED_TAG(", 32);
114            length_modifier = PSA_ALG_AEAD_GET_TAG_LENGTH(alg);
115        }
116    } else if (PSA_ALG_IS_KEY_AGREEMENT(alg) &&
117               !PSA_ALG_IS_RAW_KEY_AGREEMENT(alg)) {
118        core_alg = PSA_ALG_KEY_AGREEMENT_GET_KDF(alg);
119        append(&buffer, buffer_size, &required_size,
120               "PSA_ALG_KEY_AGREEMENT(", 22);
121        append_with_alg(&buffer, buffer_size, &required_size,
122                        psa_ka_algorithm_name,
123                        PSA_ALG_KEY_AGREEMENT_GET_BASE(alg));
124        append(&buffer, buffer_size, &required_size, ", ", 2);
125    }
126    switch (core_alg) {
127    %(algorithm_cases)s
128    default:
129        %(algorithm_code)s{
130            append_integer(&buffer, buffer_size, &required_size,
131                           "0x%%08lx", (unsigned long) core_alg);
132        }
133        break;
134    }
135    if (core_alg != alg) {
136        if (length_modifier != NO_LENGTH_MODIFIER) {
137            append(&buffer, buffer_size, &required_size, ", ", 2);
138            append_integer(&buffer, buffer_size, &required_size,
139                           "%%lu", length_modifier);
140        }
141        append(&buffer, buffer_size, &required_size, ")", 1);
142    }
143    buffer[0] = 0;
144    return (int) required_size;
145}
146
147static int psa_snprint_key_usage(char *buffer, size_t buffer_size,
148                                 psa_key_usage_t usage)
149{
150    size_t required_size = 0;
151    if (usage == 0) {
152        if (buffer_size > 1) {
153            buffer[0] = '0';
154            buffer[1] = 0;
155        } else if (buffer_size == 1) {
156            buffer[0] = 0;
157        }
158        return 1;
159    }
160%(key_usage_code)s
161    if (usage != 0) {
162        if (required_size != 0) {
163            append(&buffer, buffer_size, &required_size, " | ", 3);
164        }
165        append_integer(&buffer, buffer_size, &required_size,
166                       "0x%%08lx", (unsigned long) usage);
167    } else {
168        buffer[0] = 0;
169    }
170    return (int) required_size;
171}
172
173/* End of automatically generated file. */
174'''
175
176KEY_TYPE_FROM_CURVE_TEMPLATE = '''if (%(tester)s(type)) {
177            append_with_curve(&buffer, buffer_size, &required_size,
178                              "%(builder)s", %(builder_length)s,
179                              PSA_KEY_TYPE_ECC_GET_FAMILY(type));
180        } else '''
181
182KEY_TYPE_FROM_GROUP_TEMPLATE = '''if (%(tester)s(type)) {
183            append_with_group(&buffer, buffer_size, &required_size,
184                              "%(builder)s", %(builder_length)s,
185                              PSA_KEY_TYPE_DH_GET_FAMILY(type));
186        } else '''
187
188ALGORITHM_FROM_HASH_TEMPLATE = '''if (%(tester)s(core_alg)) {
189            append(&buffer, buffer_size, &required_size,
190                   "%(builder)s(", %(builder_length)s + 1);
191            append_with_alg(&buffer, buffer_size, &required_size,
192                            psa_hash_algorithm_name,
193                            PSA_ALG_GET_HASH(core_alg));
194            append(&buffer, buffer_size, &required_size, ")", 1);
195        } else '''
196
197BIT_TEST_TEMPLATE = '''\
198    if (%(var)s & %(flag)s) {
199        if (required_size != 0) {
200            append(&buffer, buffer_size, &required_size, " | ", 3);
201        }
202        append(&buffer, buffer_size, &required_size, "%(flag)s", %(length)d);
203        %(var)s ^= %(flag)s;
204    }\
205'''
206
207class CaseBuilder(macro_collector.PSAMacroCollector):
208    """Collect PSA crypto macro definitions and write value recognition functions.
209
210    1. Call `read_file` on the input header file(s).
211    2. Call `write_file` to write ``psa_constant_names_generated.c``.
212    """
213
214    def __init__(self):
215        super().__init__(include_intermediate=True)
216
217    @staticmethod
218    def _make_return_case(name):
219        return 'case %(name)s: return "%(name)s";' % {'name': name}
220
221    @staticmethod
222    def _make_append_case(name):
223        template = ('case %(name)s: '
224                    'append(&buffer, buffer_size, &required_size, "%(name)s", %(length)d); '
225                    'break;')
226        return template % {'name': name, 'length': len(name)}
227
228    @staticmethod
229    def _make_bit_test(var, flag):
230        return BIT_TEST_TEMPLATE % {'var': var,
231                                    'flag': flag,
232                                    'length': len(flag)}
233
234    def _make_status_cases(self):
235        return '\n    '.join(map(self._make_return_case,
236                                 sorted(self.statuses)))
237
238    def _make_ecc_curve_cases(self):
239        return '\n    '.join(map(self._make_return_case,
240                                 sorted(self.ecc_curves)))
241
242    def _make_dh_group_cases(self):
243        return '\n    '.join(map(self._make_return_case,
244                                 sorted(self.dh_groups)))
245
246    def _make_key_type_cases(self):
247        return '\n    '.join(map(self._make_append_case,
248                                 sorted(self.key_types)))
249
250    @staticmethod
251    def _make_key_type_from_curve_code(builder, tester):
252        return KEY_TYPE_FROM_CURVE_TEMPLATE % {'builder': builder,
253                                               'builder_length': len(builder),
254                                               'tester': tester}
255
256    @staticmethod
257    def _make_key_type_from_group_code(builder, tester):
258        return KEY_TYPE_FROM_GROUP_TEMPLATE % {'builder': builder,
259                                               'builder_length': len(builder),
260                                               'tester': tester}
261
262    def _make_ecc_key_type_code(self):
263        d = self.key_types_from_curve
264        make = self._make_key_type_from_curve_code
265        return ''.join([make(k, d[k]) for k in sorted(d.keys())])
266
267    def _make_dh_key_type_code(self):
268        d = self.key_types_from_group
269        make = self._make_key_type_from_group_code
270        return ''.join([make(k, d[k]) for k in sorted(d.keys())])
271
272    def _make_hash_algorithm_cases(self):
273        return '\n    '.join(map(self._make_return_case,
274                                 sorted(self.hash_algorithms)))
275
276    def _make_ka_algorithm_cases(self):
277        return '\n    '.join(map(self._make_return_case,
278                                 sorted(self.ka_algorithms)))
279
280    def _make_algorithm_cases(self):
281        return '\n    '.join(map(self._make_append_case,
282                                 sorted(self.algorithms)))
283
284    @staticmethod
285    def _make_algorithm_from_hash_code(builder, tester):
286        return ALGORITHM_FROM_HASH_TEMPLATE % {'builder': builder,
287                                               'builder_length': len(builder),
288                                               'tester': tester}
289
290    def _make_algorithm_code(self):
291        d = self.algorithms_from_hash
292        make = self._make_algorithm_from_hash_code
293        return ''.join([make(k, d[k]) for k in sorted(d.keys())])
294
295    def _make_key_usage_code(self):
296        return '\n'.join([self._make_bit_test('usage', bit)
297                          for bit in sorted(self.key_usage_flags)])
298
299    def write_file(self, output_file):
300        """Generate the pretty-printer function code from the gathered
301        constant definitions.
302        """
303        data = {}
304        data['status_cases'] = self._make_status_cases()
305        data['ecc_curve_cases'] = self._make_ecc_curve_cases()
306        data['dh_group_cases'] = self._make_dh_group_cases()
307        data['key_type_cases'] = self._make_key_type_cases()
308        data['key_type_code'] = (self._make_ecc_key_type_code() +
309                                 self._make_dh_key_type_code())
310        data['hash_algorithm_cases'] = self._make_hash_algorithm_cases()
311        data['ka_algorithm_cases'] = self._make_ka_algorithm_cases()
312        data['algorithm_cases'] = self._make_algorithm_cases()
313        data['algorithm_code'] = self._make_algorithm_code()
314        data['key_usage_code'] = self._make_key_usage_code()
315        output_file.write(OUTPUT_TEMPLATE % data)
316
317def generate_psa_constants(header_file_names, output_file_name):
318    collector = CaseBuilder()
319    for header_file_name in header_file_names:
320        with open(header_file_name, 'rb') as header_file:
321            collector.read_file(header_file)
322    temp_file_name = output_file_name + '.tmp'
323    with open(temp_file_name, 'w') as output_file:
324        collector.write_file(output_file)
325    os.replace(temp_file_name, output_file_name)
326
327if __name__ == '__main__':
328    build_tree.chdir_to_root()
329    # Allow to change the directory where psa_constant_names_generated.c is written to.
330    OUTPUT_FILE_DIR = sys.argv[1] if len(sys.argv) == 2 else "programs/psa"
331    generate_psa_constants(['include/psa/crypto_values.h',
332                            'include/psa/crypto_extra.h'],
333                           OUTPUT_FILE_DIR + '/psa_constant_names_generated.c')
334