1#!/usr/bin/env python3
2
3"""Generate psa_constant_names_generated.c
4which is included by programs/psa/psa_constant_names.c.
5The code generated by this module is only meant to be used in the context
6of that program.
7
8An argument passed to this script will modify the output directory where the
9file is written:
10* by default (no arguments passed): writes to programs/psa/
11* OUTPUT_FILE_DIR passed: writes to OUTPUT_FILE_DIR/
12"""
13
14# Copyright The Mbed TLS Contributors
15# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
16
17import os
18import sys
19
20from mbedtls_dev import build_tree
21from mbedtls_dev import macro_collector
22
23OUTPUT_TEMPLATE = '''\
24/* Automatically generated by generate_psa_constant.py. DO NOT EDIT. */
25
26static const char *psa_strerror(psa_status_t status)
27{
28    switch (status) {
29    %(status_cases)s
30    default: return NULL;
31    }
32}
33
34static const char *psa_ecc_family_name(psa_ecc_family_t curve)
35{
36    switch (curve) {
37    %(ecc_curve_cases)s
38    default: return NULL;
39    }
40}
41
42static const char *psa_dh_family_name(psa_dh_family_t group)
43{
44    switch (group) {
45    %(dh_group_cases)s
46    default: return NULL;
47    }
48}
49
50static const char *psa_hash_algorithm_name(psa_algorithm_t hash_alg)
51{
52    switch (hash_alg) {
53    %(hash_algorithm_cases)s
54    default: return NULL;
55    }
56}
57
58static const char *psa_ka_algorithm_name(psa_algorithm_t ka_alg)
59{
60    switch (ka_alg) {
61    %(ka_algorithm_cases)s
62    default: return NULL;
63    }
64}
65
66static int psa_snprint_key_type(char *buffer, size_t buffer_size,
67                                psa_key_type_t type)
68{
69    size_t required_size = 0;
70    switch (type) {
71    %(key_type_cases)s
72    default:
73        %(key_type_code)s{
74            return snprintf(buffer, buffer_size,
75                            "0x%%04x", (unsigned) type);
76        }
77        break;
78    }
79    buffer[0] = 0;
80    return (int) required_size;
81}
82
83#define NO_LENGTH_MODIFIER 0xfffffffflu
84static int psa_snprint_algorithm(char *buffer, size_t buffer_size,
85                                 psa_algorithm_t alg)
86{
87    size_t required_size = 0;
88    psa_algorithm_t core_alg = alg;
89    unsigned long length_modifier = NO_LENGTH_MODIFIER;
90    if (PSA_ALG_IS_MAC(alg)) {
91        core_alg = PSA_ALG_TRUNCATED_MAC(alg, 0);
92        if (alg & PSA_ALG_MAC_AT_LEAST_THIS_LENGTH_FLAG) {
93            append(&buffer, buffer_size, &required_size,
94                   "PSA_ALG_AT_LEAST_THIS_LENGTH_MAC(", 33);
95            length_modifier = PSA_MAC_TRUNCATED_LENGTH(alg);
96        } else if (core_alg != alg) {
97            append(&buffer, buffer_size, &required_size,
98                   "PSA_ALG_TRUNCATED_MAC(", 22);
99            length_modifier = PSA_MAC_TRUNCATED_LENGTH(alg);
100        }
101    } else if (PSA_ALG_IS_AEAD(alg)) {
102        core_alg = PSA_ALG_AEAD_WITH_DEFAULT_LENGTH_TAG(alg);
103        if (core_alg == 0) {
104            /* For unknown AEAD algorithms, there is no "default tag length". */
105            core_alg = alg;
106        } else if (alg & PSA_ALG_AEAD_AT_LEAST_THIS_LENGTH_FLAG) {
107            append(&buffer, buffer_size, &required_size,
108                   "PSA_ALG_AEAD_WITH_AT_LEAST_THIS_LENGTH_TAG(", 43);
109            length_modifier = PSA_ALG_AEAD_GET_TAG_LENGTH(alg);
110        } else if (core_alg != alg) {
111            append(&buffer, buffer_size, &required_size,
112                   "PSA_ALG_AEAD_WITH_SHORTENED_TAG(", 32);
113            length_modifier = PSA_ALG_AEAD_GET_TAG_LENGTH(alg);
114        }
115    } else if (PSA_ALG_IS_KEY_AGREEMENT(alg) &&
116               !PSA_ALG_IS_RAW_KEY_AGREEMENT(alg)) {
117        core_alg = PSA_ALG_KEY_AGREEMENT_GET_KDF(alg);
118        append(&buffer, buffer_size, &required_size,
119               "PSA_ALG_KEY_AGREEMENT(", 22);
120        append_with_alg(&buffer, buffer_size, &required_size,
121                        psa_ka_algorithm_name,
122                        PSA_ALG_KEY_AGREEMENT_GET_BASE(alg));
123        append(&buffer, buffer_size, &required_size, ", ", 2);
124    }
125    switch (core_alg) {
126    %(algorithm_cases)s
127    default:
128        %(algorithm_code)s{
129            append_integer(&buffer, buffer_size, &required_size,
130                           "0x%%08lx", (unsigned long) core_alg);
131        }
132        break;
133    }
134    if (core_alg != alg) {
135        if (length_modifier != NO_LENGTH_MODIFIER) {
136            append(&buffer, buffer_size, &required_size, ", ", 2);
137            append_integer(&buffer, buffer_size, &required_size,
138                           "%%lu", length_modifier);
139        }
140        append(&buffer, buffer_size, &required_size, ")", 1);
141    }
142    buffer[0] = 0;
143    return (int) required_size;
144}
145
146static int psa_snprint_key_usage(char *buffer, size_t buffer_size,
147                                 psa_key_usage_t usage)
148{
149    size_t required_size = 0;
150    if (usage == 0) {
151        if (buffer_size > 1) {
152            buffer[0] = '0';
153            buffer[1] = 0;
154        } else if (buffer_size == 1) {
155            buffer[0] = 0;
156        }
157        return 1;
158    }
159%(key_usage_code)s
160    if (usage != 0) {
161        if (required_size != 0) {
162            append(&buffer, buffer_size, &required_size, " | ", 3);
163        }
164        append_integer(&buffer, buffer_size, &required_size,
165                       "0x%%08lx", (unsigned long) usage);
166    } else {
167        buffer[0] = 0;
168    }
169    return (int) required_size;
170}
171
172/* End of automatically generated file. */
173'''
174
175KEY_TYPE_FROM_CURVE_TEMPLATE = '''if (%(tester)s(type)) {
176            append_with_curve(&buffer, buffer_size, &required_size,
177                              "%(builder)s", %(builder_length)s,
178                              PSA_KEY_TYPE_ECC_GET_FAMILY(type));
179        } else '''
180
181KEY_TYPE_FROM_GROUP_TEMPLATE = '''if (%(tester)s(type)) {
182            append_with_group(&buffer, buffer_size, &required_size,
183                              "%(builder)s", %(builder_length)s,
184                              PSA_KEY_TYPE_DH_GET_FAMILY(type));
185        } else '''
186
187ALGORITHM_FROM_HASH_TEMPLATE = '''if (%(tester)s(core_alg)) {
188            append(&buffer, buffer_size, &required_size,
189                   "%(builder)s(", %(builder_length)s + 1);
190            append_with_alg(&buffer, buffer_size, &required_size,
191                            psa_hash_algorithm_name,
192                            PSA_ALG_GET_HASH(core_alg));
193            append(&buffer, buffer_size, &required_size, ")", 1);
194        } else '''
195
196BIT_TEST_TEMPLATE = '''\
197    if (%(var)s & %(flag)s) {
198        if (required_size != 0) {
199            append(&buffer, buffer_size, &required_size, " | ", 3);
200        }
201        append(&buffer, buffer_size, &required_size, "%(flag)s", %(length)d);
202        %(var)s ^= %(flag)s;
203    }\
204'''
205
206class CaseBuilder(macro_collector.PSAMacroCollector):
207    """Collect PSA crypto macro definitions and write value recognition functions.
208
209    1. Call `read_file` on the input header file(s).
210    2. Call `write_file` to write ``psa_constant_names_generated.c``.
211    """
212
213    def __init__(self):
214        super().__init__(include_intermediate=True)
215
216    @staticmethod
217    def _make_return_case(name):
218        return 'case %(name)s: return "%(name)s";' % {'name': name}
219
220    @staticmethod
221    def _make_append_case(name):
222        template = ('case %(name)s: '
223                    'append(&buffer, buffer_size, &required_size, "%(name)s", %(length)d); '
224                    'break;')
225        return template % {'name': name, 'length': len(name)}
226
227    @staticmethod
228    def _make_bit_test(var, flag):
229        return BIT_TEST_TEMPLATE % {'var': var,
230                                    'flag': flag,
231                                    'length': len(flag)}
232
233    def _make_status_cases(self):
234        return '\n    '.join(map(self._make_return_case,
235                                 sorted(self.statuses)))
236
237    def _make_ecc_curve_cases(self):
238        return '\n    '.join(map(self._make_return_case,
239                                 sorted(self.ecc_curves)))
240
241    def _make_dh_group_cases(self):
242        return '\n    '.join(map(self._make_return_case,
243                                 sorted(self.dh_groups)))
244
245    def _make_key_type_cases(self):
246        return '\n    '.join(map(self._make_append_case,
247                                 sorted(self.key_types)))
248
249    @staticmethod
250    def _make_key_type_from_curve_code(builder, tester):
251        return KEY_TYPE_FROM_CURVE_TEMPLATE % {'builder': builder,
252                                               'builder_length': len(builder),
253                                               'tester': tester}
254
255    @staticmethod
256    def _make_key_type_from_group_code(builder, tester):
257        return KEY_TYPE_FROM_GROUP_TEMPLATE % {'builder': builder,
258                                               'builder_length': len(builder),
259                                               'tester': tester}
260
261    def _make_ecc_key_type_code(self):
262        d = self.key_types_from_curve
263        make = self._make_key_type_from_curve_code
264        return ''.join([make(k, d[k]) for k in sorted(d.keys())])
265
266    def _make_dh_key_type_code(self):
267        d = self.key_types_from_group
268        make = self._make_key_type_from_group_code
269        return ''.join([make(k, d[k]) for k in sorted(d.keys())])
270
271    def _make_hash_algorithm_cases(self):
272        return '\n    '.join(map(self._make_return_case,
273                                 sorted(self.hash_algorithms)))
274
275    def _make_ka_algorithm_cases(self):
276        return '\n    '.join(map(self._make_return_case,
277                                 sorted(self.ka_algorithms)))
278
279    def _make_algorithm_cases(self):
280        return '\n    '.join(map(self._make_append_case,
281                                 sorted(self.algorithms)))
282
283    @staticmethod
284    def _make_algorithm_from_hash_code(builder, tester):
285        return ALGORITHM_FROM_HASH_TEMPLATE % {'builder': builder,
286                                               'builder_length': len(builder),
287                                               'tester': tester}
288
289    def _make_algorithm_code(self):
290        d = self.algorithms_from_hash
291        make = self._make_algorithm_from_hash_code
292        return ''.join([make(k, d[k]) for k in sorted(d.keys())])
293
294    def _make_key_usage_code(self):
295        return '\n'.join([self._make_bit_test('usage', bit)
296                          for bit in sorted(self.key_usage_flags)])
297
298    def write_file(self, output_file):
299        """Generate the pretty-printer function code from the gathered
300        constant definitions.
301        """
302        data = {}
303        data['status_cases'] = self._make_status_cases()
304        data['ecc_curve_cases'] = self._make_ecc_curve_cases()
305        data['dh_group_cases'] = self._make_dh_group_cases()
306        data['key_type_cases'] = self._make_key_type_cases()
307        data['key_type_code'] = (self._make_ecc_key_type_code() +
308                                 self._make_dh_key_type_code())
309        data['hash_algorithm_cases'] = self._make_hash_algorithm_cases()
310        data['ka_algorithm_cases'] = self._make_ka_algorithm_cases()
311        data['algorithm_cases'] = self._make_algorithm_cases()
312        data['algorithm_code'] = self._make_algorithm_code()
313        data['key_usage_code'] = self._make_key_usage_code()
314        output_file.write(OUTPUT_TEMPLATE % data)
315
316def generate_psa_constants(header_file_names, output_file_name):
317    collector = CaseBuilder()
318    for header_file_name in header_file_names:
319        with open(header_file_name, 'rb') as header_file:
320            collector.read_file(header_file)
321    temp_file_name = output_file_name + '.tmp'
322    with open(temp_file_name, 'w') as output_file:
323        collector.write_file(output_file)
324    os.replace(temp_file_name, output_file_name)
325
326if __name__ == '__main__':
327    build_tree.chdir_to_root()
328    # Allow to change the directory where psa_constant_names_generated.c is written to.
329    OUTPUT_FILE_DIR = sys.argv[1] if len(sys.argv) == 2 else "programs/psa"
330    generate_psa_constants(['include/psa/crypto_values.h',
331                            'include/psa/crypto_extra.h'],
332                           OUTPUT_FILE_DIR + '/psa_constant_names_generated.c')
333