1 /*
2  *  PSA RSA layer on top of Mbed TLS crypto
3  */
4 /*
5  *  Copyright The Mbed TLS Contributors
6  *  SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
7  */
8 
9 #include "common.h"
10 
11 #if defined(MBEDTLS_PSA_CRYPTO_C) || defined(MCUBOOT_USE_PSA_CRYPTO)
12 
13 #include <psa/crypto.h>
14 #include "psa/crypto_values.h"
15 #include "psa_crypto_core.h"
16 #include "psa_crypto_random_impl.h"
17 #include "psa_crypto_rsa.h"
18 #include "psa_crypto_hash.h"
19 #include "mbedtls/psa_util.h"
20 
21 #include <stdlib.h>
22 #include <string.h>
23 #include "mbedtls/platform.h"
24 
25 #include <mbedtls/rsa.h>
26 #include <mbedtls/error.h>
27 #include "rsa_internal.h"
28 
29 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
30     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) || \
31     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) || \
32     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) || \
33     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_IMPORT) || \
34     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT) || \
35     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
36 
37 /* Mbed TLS doesn't support non-byte-aligned key sizes (i.e. key sizes
38  * that are not a multiple of 8) well. For example, there is only
39  * mbedtls_rsa_get_len(), which returns a number of bytes, and no
40  * way to return the exact bit size of a key.
41  * To keep things simple, reject non-byte-aligned key sizes. */
psa_check_rsa_key_byte_aligned(const mbedtls_rsa_context * rsa)42 static psa_status_t psa_check_rsa_key_byte_aligned(
43     const mbedtls_rsa_context *rsa)
44 {
45     mbedtls_mpi n;
46     psa_status_t status;
47     mbedtls_mpi_init(&n);
48     status = mbedtls_to_psa_error(
49         mbedtls_rsa_export(rsa, &n, NULL, NULL, NULL, NULL));
50     if (status == PSA_SUCCESS) {
51         if (mbedtls_mpi_bitlen(&n) % 8 != 0) {
52             status = PSA_ERROR_NOT_SUPPORTED;
53         }
54     }
55     mbedtls_mpi_free(&n);
56     return status;
57 }
58 
mbedtls_psa_rsa_load_representation(psa_key_type_t type,const uint8_t * data,size_t data_length,mbedtls_rsa_context ** p_rsa)59 psa_status_t mbedtls_psa_rsa_load_representation(
60     psa_key_type_t type, const uint8_t *data, size_t data_length,
61     mbedtls_rsa_context **p_rsa)
62 {
63     psa_status_t status;
64     size_t bits;
65 
66     *p_rsa = mbedtls_calloc(1, sizeof(mbedtls_rsa_context));
67     if (*p_rsa == NULL) {
68         return PSA_ERROR_INSUFFICIENT_MEMORY;
69     }
70     mbedtls_rsa_init(*p_rsa);
71 
72     /* Parse the data. */
73     if (PSA_KEY_TYPE_IS_KEY_PAIR(type)) {
74         status = mbedtls_to_psa_error(mbedtls_rsa_parse_key(*p_rsa, data, data_length));
75     } else {
76         status = mbedtls_to_psa_error(mbedtls_rsa_parse_pubkey(*p_rsa, data, data_length));
77     }
78     if (status != PSA_SUCCESS) {
79         goto exit;
80     }
81 
82     /* The size of an RSA key doesn't have to be a multiple of 8. Mbed TLS
83      * supports non-byte-aligned key sizes, but not well. For example,
84      * mbedtls_rsa_get_len() returns the key size in bytes, not in bits. */
85     bits = PSA_BYTES_TO_BITS(mbedtls_rsa_get_len(*p_rsa));
86     if (bits > PSA_VENDOR_RSA_MAX_KEY_BITS) {
87         status = PSA_ERROR_NOT_SUPPORTED;
88         goto exit;
89     }
90     status = psa_check_rsa_key_byte_aligned(*p_rsa);
91     if (status != PSA_SUCCESS) {
92         goto exit;
93     }
94 
95 exit:
96     return status;
97 }
98 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
99         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) ||
100         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) ||
101         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) ||
102         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_IMPORT) ||
103         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT) ||
104         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY) */
105 
106 #if (defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_IMPORT) && \
107     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT)) || \
108     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
mbedtls_psa_rsa_import_key(const psa_key_attributes_t * attributes,const uint8_t * data,size_t data_length,uint8_t * key_buffer,size_t key_buffer_size,size_t * key_buffer_length,size_t * bits)109 psa_status_t mbedtls_psa_rsa_import_key(
110     const psa_key_attributes_t *attributes,
111     const uint8_t *data, size_t data_length,
112     uint8_t *key_buffer, size_t key_buffer_size,
113     size_t *key_buffer_length, size_t *bits)
114 {
115     psa_status_t status;
116     mbedtls_rsa_context *rsa = NULL;
117 
118     /* Parse input */
119     status = mbedtls_psa_rsa_load_representation(attributes->type,
120                                                  data,
121                                                  data_length,
122                                                  &rsa);
123     if (status != PSA_SUCCESS) {
124         goto exit;
125     }
126 
127     *bits = (psa_key_bits_t) PSA_BYTES_TO_BITS(mbedtls_rsa_get_len(rsa));
128 
129     /* Re-export the data to PSA export format, such that we can store export
130      * representation in the key slot. Export representation in case of RSA is
131      * the smallest representation that's allowed as input, so a straight-up
132      * allocation of the same size as the input buffer will be large enough. */
133     status = mbedtls_psa_rsa_export_key(attributes->type,
134                                         rsa,
135                                         key_buffer,
136                                         key_buffer_size,
137                                         key_buffer_length);
138 exit:
139     /* Always free the RSA object */
140     mbedtls_rsa_free(rsa);
141     mbedtls_free(rsa);
142 
143     return status;
144 }
145 #endif /* (defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_IMPORT) &&
146         *  defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT)) ||
147         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY) */
148 
149 #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT) || \
150     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
mbedtls_psa_rsa_export_key(psa_key_type_t type,mbedtls_rsa_context * rsa,uint8_t * data,size_t data_size,size_t * data_length)151 psa_status_t mbedtls_psa_rsa_export_key(psa_key_type_t type,
152                                         mbedtls_rsa_context *rsa,
153                                         uint8_t *data,
154                                         size_t data_size,
155                                         size_t *data_length)
156 {
157     int ret;
158     uint8_t *end = data + data_size;
159 
160     /* PSA Crypto API defines the format of an RSA key as a DER-encoded
161      * representation of the non-encrypted PKCS#1 RSAPrivateKey for a
162      * private key and of the RFC3279 RSAPublicKey for a public key. */
163     if (PSA_KEY_TYPE_IS_KEY_PAIR(type)) {
164         ret = mbedtls_rsa_write_key(rsa, data, &end);
165     } else {
166         ret = mbedtls_rsa_write_pubkey(rsa, data, &end);
167     }
168 
169     if (ret < 0) {
170         /* Clean up in case pk_write failed halfway through. */
171         memset(data, 0, data_size);
172         return mbedtls_to_psa_error(ret);
173     }
174 
175     /* The mbedtls_pk_xxx functions write to the end of the buffer.
176      * Move the data to the beginning and erase remaining data
177      * at the original location. */
178     if (2 * (size_t) ret <= data_size) {
179         memcpy(data, data + data_size - ret, ret);
180         memset(data + data_size - ret, 0, ret);
181     } else if ((size_t) ret < data_size) {
182         memmove(data, data + data_size - ret, ret);
183         memset(data + ret, 0, data_size - ret);
184     }
185 
186     *data_length = ret;
187     return PSA_SUCCESS;
188 }
189 
mbedtls_psa_rsa_export_public_key(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,uint8_t * data,size_t data_size,size_t * data_length)190 psa_status_t mbedtls_psa_rsa_export_public_key(
191     const psa_key_attributes_t *attributes,
192     const uint8_t *key_buffer, size_t key_buffer_size,
193     uint8_t *data, size_t data_size, size_t *data_length)
194 {
195     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
196     mbedtls_rsa_context *rsa = NULL;
197 
198     status = mbedtls_psa_rsa_load_representation(
199         attributes->type, key_buffer, key_buffer_size, &rsa);
200     if (status == PSA_SUCCESS) {
201         status = mbedtls_psa_rsa_export_key(PSA_KEY_TYPE_RSA_PUBLIC_KEY,
202                                             rsa,
203                                             data,
204                                             data_size,
205                                             data_length);
206     }
207 
208     mbedtls_rsa_free(rsa);
209     mbedtls_free(rsa);
210 
211     return status;
212 }
213 #endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT) ||
214         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY) */
215 
216 #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_GENERATE)
psa_rsa_read_exponent(const uint8_t * e_bytes,size_t e_length,int * exponent)217 static psa_status_t psa_rsa_read_exponent(const uint8_t *e_bytes,
218                                           size_t e_length,
219                                           int *exponent)
220 {
221     size_t i;
222     uint32_t acc = 0;
223 
224     /* Mbed TLS encodes the public exponent as an int. For simplicity, only
225      * support values that fit in a 32-bit integer, which is larger than
226      * int on just about every platform anyway. */
227     if (e_length > sizeof(acc)) {
228         return PSA_ERROR_NOT_SUPPORTED;
229     }
230     for (i = 0; i < e_length; i++) {
231         acc = (acc << 8) | e_bytes[i];
232     }
233     if (acc > INT_MAX) {
234         return PSA_ERROR_NOT_SUPPORTED;
235     }
236     *exponent = acc;
237     return PSA_SUCCESS;
238 }
239 
mbedtls_psa_rsa_generate_key(const psa_key_attributes_t * attributes,const uint8_t * custom_data,size_t custom_data_length,uint8_t * key_buffer,size_t key_buffer_size,size_t * key_buffer_length)240 psa_status_t mbedtls_psa_rsa_generate_key(
241     const psa_key_attributes_t *attributes,
242     const uint8_t *custom_data, size_t custom_data_length,
243     uint8_t *key_buffer, size_t key_buffer_size, size_t *key_buffer_length)
244 {
245     psa_status_t status;
246     mbedtls_rsa_context rsa;
247     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
248     int exponent = 65537;
249 
250     if (custom_data_length != 0) {
251         status = psa_rsa_read_exponent(custom_data, custom_data_length,
252                                        &exponent);
253         if (status != PSA_SUCCESS) {
254             return status;
255         }
256     }
257 
258     mbedtls_rsa_init(&rsa);
259     ret = mbedtls_rsa_gen_key(&rsa,
260                               mbedtls_psa_get_random,
261                               MBEDTLS_PSA_RANDOM_STATE,
262                               (unsigned int) attributes->bits,
263                               exponent);
264     if (ret != 0) {
265         mbedtls_rsa_free(&rsa);
266         return mbedtls_to_psa_error(ret);
267     }
268 
269     status = mbedtls_psa_rsa_export_key(attributes->type,
270                                         &rsa, key_buffer, key_buffer_size,
271                                         key_buffer_length);
272     mbedtls_rsa_free(&rsa);
273 
274     return status;
275 }
276 #endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_GENERATE) */
277 
278 /****************************************************************/
279 /* Sign/verify hashes */
280 /****************************************************************/
281 
282 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) || \
283     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
284 
285 /* Decode the hash algorithm from alg and store the mbedtls encoding in
286  * md_alg. Verify that the hash length is acceptable. */
psa_rsa_decode_md_type(psa_algorithm_t alg,size_t hash_length,mbedtls_md_type_t * md_alg)287 static psa_status_t psa_rsa_decode_md_type(psa_algorithm_t alg,
288                                            size_t hash_length,
289                                            mbedtls_md_type_t *md_alg)
290 {
291     psa_algorithm_t hash_alg = PSA_ALG_SIGN_GET_HASH(alg);
292     *md_alg = mbedtls_md_type_from_psa_alg(hash_alg);
293 
294     /* The Mbed TLS RSA module uses an unsigned int for hash length
295      * parameters. Validate that it fits so that we don't risk an
296      * overflow later. */
297 #if SIZE_MAX > UINT_MAX
298     if (hash_length > UINT_MAX) {
299         return PSA_ERROR_INVALID_ARGUMENT;
300     }
301 #endif
302 
303     /* For signatures using a hash, the hash length must be correct. */
304     if (alg != PSA_ALG_RSA_PKCS1V15_SIGN_RAW) {
305         if (*md_alg == MBEDTLS_MD_NONE) {
306             return PSA_ERROR_NOT_SUPPORTED;
307         }
308         if (mbedtls_md_get_size_from_type(*md_alg) != hash_length) {
309             return PSA_ERROR_INVALID_ARGUMENT;
310         }
311     }
312 
313     return PSA_SUCCESS;
314 }
315 
mbedtls_psa_rsa_sign_hash(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,psa_algorithm_t alg,const uint8_t * hash,size_t hash_length,uint8_t * signature,size_t signature_size,size_t * signature_length)316 psa_status_t mbedtls_psa_rsa_sign_hash(
317     const psa_key_attributes_t *attributes,
318     const uint8_t *key_buffer, size_t key_buffer_size,
319     psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
320     uint8_t *signature, size_t signature_size, size_t *signature_length)
321 {
322     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
323     mbedtls_rsa_context *rsa = NULL;
324     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
325     mbedtls_md_type_t md_alg;
326 
327     status = mbedtls_psa_rsa_load_representation(attributes->type,
328                                                  key_buffer,
329                                                  key_buffer_size,
330                                                  &rsa);
331     if (status != PSA_SUCCESS) {
332         goto exit;
333     }
334 
335     status = psa_rsa_decode_md_type(alg, hash_length, &md_alg);
336     if (status != PSA_SUCCESS) {
337         goto exit;
338     }
339 
340     if (signature_size < mbedtls_rsa_get_len(rsa)) {
341         status = PSA_ERROR_BUFFER_TOO_SMALL;
342         goto exit;
343     }
344 
345 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN)
346     if (PSA_ALG_IS_RSA_PKCS1V15_SIGN(alg)) {
347         ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V15,
348                                       MBEDTLS_MD_NONE);
349         if (ret == 0) {
350             ret = mbedtls_rsa_pkcs1_sign(rsa,
351                                          mbedtls_psa_get_random,
352                                          MBEDTLS_PSA_RANDOM_STATE,
353                                          md_alg,
354                                          (unsigned int) hash_length,
355                                          hash,
356                                          signature);
357         }
358     } else
359 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN */
360 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
361     if (PSA_ALG_IS_RSA_PSS(alg)) {
362         ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, md_alg);
363 
364         if (ret == 0) {
365             ret = mbedtls_rsa_rsassa_pss_sign(rsa,
366                                               mbedtls_psa_get_random,
367                                               MBEDTLS_PSA_RANDOM_STATE,
368                                               MBEDTLS_MD_NONE,
369                                               (unsigned int) hash_length,
370                                               hash,
371                                               signature);
372         }
373     } else
374 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
375     {
376         status = PSA_ERROR_INVALID_ARGUMENT;
377         goto exit;
378     }
379 
380     if (ret == 0) {
381         *signature_length = mbedtls_rsa_get_len(rsa);
382     }
383     status = mbedtls_to_psa_error(ret);
384 
385 exit:
386     mbedtls_rsa_free(rsa);
387     mbedtls_free(rsa);
388 
389     return status;
390 }
391 
392 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
rsa_pss_expected_salt_len(psa_algorithm_t alg,const mbedtls_rsa_context * rsa,size_t hash_length)393 static int rsa_pss_expected_salt_len(psa_algorithm_t alg,
394                                      const mbedtls_rsa_context *rsa,
395                                      size_t hash_length)
396 {
397     if (PSA_ALG_IS_RSA_PSS_ANY_SALT(alg)) {
398         return MBEDTLS_RSA_SALT_LEN_ANY;
399     }
400     /* Otherwise: standard salt length, i.e. largest possible salt length
401      * up to the hash length. */
402     int klen = (int) mbedtls_rsa_get_len(rsa);   // known to fit
403     int hlen = (int) hash_length; // known to fit
404     int room = klen - 2 - hlen;
405     if (room < 0) {
406         return 0;  // there is no valid signature in this case anyway
407     } else if (room > hlen) {
408         return hlen;
409     } else {
410         return room;
411     }
412 }
413 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
414 
mbedtls_psa_rsa_verify_hash(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,psa_algorithm_t alg,const uint8_t * hash,size_t hash_length,const uint8_t * signature,size_t signature_length)415 psa_status_t mbedtls_psa_rsa_verify_hash(
416     const psa_key_attributes_t *attributes,
417     const uint8_t *key_buffer, size_t key_buffer_size,
418     psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
419     const uint8_t *signature, size_t signature_length)
420 {
421     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
422     mbedtls_rsa_context *rsa = NULL;
423     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
424     mbedtls_md_type_t md_alg;
425 
426     status = mbedtls_psa_rsa_load_representation(attributes->type,
427                                                  key_buffer,
428                                                  key_buffer_size,
429                                                  &rsa);
430     if (status != PSA_SUCCESS) {
431         goto exit;
432     }
433 
434     status = psa_rsa_decode_md_type(alg, hash_length, &md_alg);
435     if (status != PSA_SUCCESS) {
436         goto exit;
437     }
438 
439     if (signature_length != mbedtls_rsa_get_len(rsa)) {
440         status = PSA_ERROR_INVALID_SIGNATURE;
441         goto exit;
442     }
443 
444 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN)
445     if (PSA_ALG_IS_RSA_PKCS1V15_SIGN(alg)) {
446         ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V15,
447                                       MBEDTLS_MD_NONE);
448         if (ret == 0) {
449             ret = mbedtls_rsa_pkcs1_verify(rsa,
450                                            md_alg,
451                                            (unsigned int) hash_length,
452                                            hash,
453                                            signature);
454         }
455     } else
456 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN */
457 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
458     if (PSA_ALG_IS_RSA_PSS(alg)) {
459         ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, md_alg);
460         if (ret == 0) {
461             int slen = rsa_pss_expected_salt_len(alg, rsa, hash_length);
462             ret = mbedtls_rsa_rsassa_pss_verify_ext(rsa,
463                                                     md_alg,
464                                                     (unsigned) hash_length,
465                                                     hash,
466                                                     md_alg,
467                                                     slen,
468                                                     signature);
469         }
470     } else
471 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
472     {
473         status = PSA_ERROR_INVALID_ARGUMENT;
474         goto exit;
475     }
476 
477     /* Mbed TLS distinguishes "invalid padding" from "valid padding but
478      * the rest of the signature is invalid". This has little use in
479      * practice and PSA doesn't report this distinction. */
480     status = (ret == MBEDTLS_ERR_RSA_INVALID_PADDING) ?
481              PSA_ERROR_INVALID_SIGNATURE :
482              mbedtls_to_psa_error(ret);
483 
484 exit:
485     mbedtls_rsa_free(rsa);
486     mbedtls_free(rsa);
487 
488     return status;
489 }
490 
491 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) ||
492         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) */
493 
494 /****************************************************************/
495 /* Asymmetric cryptography */
496 /****************************************************************/
497 
498 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
psa_rsa_oaep_set_padding_mode(psa_algorithm_t alg,mbedtls_rsa_context * rsa)499 static int psa_rsa_oaep_set_padding_mode(psa_algorithm_t alg,
500                                          mbedtls_rsa_context *rsa)
501 {
502     psa_algorithm_t hash_alg = PSA_ALG_RSA_OAEP_GET_HASH(alg);
503     mbedtls_md_type_t md_alg = mbedtls_md_type_from_psa_alg(hash_alg);
504 
505     /* Just to get the error status right, as rsa_set_padding() doesn't
506      * distinguish between "bad RSA algorithm" and "unknown hash". */
507     if (mbedtls_md_info_from_type(md_alg) == NULL) {
508         return PSA_ERROR_NOT_SUPPORTED;
509     }
510 
511     return mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, md_alg);
512 }
513 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
514 
mbedtls_psa_asymmetric_encrypt(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,psa_algorithm_t alg,const uint8_t * input,size_t input_length,const uint8_t * salt,size_t salt_length,uint8_t * output,size_t output_size,size_t * output_length)515 psa_status_t mbedtls_psa_asymmetric_encrypt(const psa_key_attributes_t *attributes,
516                                             const uint8_t *key_buffer,
517                                             size_t key_buffer_size,
518                                             psa_algorithm_t alg,
519                                             const uint8_t *input,
520                                             size_t input_length,
521                                             const uint8_t *salt,
522                                             size_t salt_length,
523                                             uint8_t *output,
524                                             size_t output_size,
525                                             size_t *output_length)
526 {
527     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
528     (void) key_buffer;
529     (void) key_buffer_size;
530     (void) input;
531     (void) input_length;
532     (void) salt;
533     (void) salt_length;
534     (void) output;
535     (void) output_size;
536     (void) output_length;
537 
538     if (PSA_KEY_TYPE_IS_RSA(attributes->type)) {
539 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
540         defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
541         mbedtls_rsa_context *rsa = NULL;
542         status = mbedtls_psa_rsa_load_representation(attributes->type,
543                                                      key_buffer,
544                                                      key_buffer_size,
545                                                      &rsa);
546         if (status != PSA_SUCCESS) {
547             goto rsa_exit;
548         }
549 
550         if (output_size < mbedtls_rsa_get_len(rsa)) {
551             status = PSA_ERROR_BUFFER_TOO_SMALL;
552             goto rsa_exit;
553         }
554 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
555         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
556         if (alg == PSA_ALG_RSA_PKCS1V15_CRYPT) {
557 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT)
558             status = mbedtls_to_psa_error(
559                 mbedtls_rsa_pkcs1_encrypt(rsa,
560                                           mbedtls_psa_get_random,
561                                           MBEDTLS_PSA_RANDOM_STATE,
562                                           input_length,
563                                           input,
564                                           output));
565 #else
566             status = PSA_ERROR_NOT_SUPPORTED;
567 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT */
568         } else
569         if (PSA_ALG_IS_RSA_OAEP(alg)) {
570 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
571             status = mbedtls_to_psa_error(
572                 psa_rsa_oaep_set_padding_mode(alg, rsa));
573             if (status != PSA_SUCCESS) {
574                 goto rsa_exit;
575             }
576 
577             status = mbedtls_to_psa_error(
578                 mbedtls_rsa_rsaes_oaep_encrypt(rsa,
579                                                mbedtls_psa_get_random,
580                                                MBEDTLS_PSA_RANDOM_STATE,
581                                                salt, salt_length,
582                                                input_length,
583                                                input,
584                                                output));
585 #else
586             status = PSA_ERROR_NOT_SUPPORTED;
587 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP */
588         } else {
589             status = PSA_ERROR_INVALID_ARGUMENT;
590         }
591 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
592         defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
593 rsa_exit:
594         if (status == PSA_SUCCESS) {
595             *output_length = mbedtls_rsa_get_len(rsa);
596         }
597 
598         mbedtls_rsa_free(rsa);
599         mbedtls_free(rsa);
600 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
601         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
602     } else {
603         status = PSA_ERROR_NOT_SUPPORTED;
604     }
605 
606     return status;
607 }
608 
mbedtls_psa_asymmetric_decrypt(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,psa_algorithm_t alg,const uint8_t * input,size_t input_length,const uint8_t * salt,size_t salt_length,uint8_t * output,size_t output_size,size_t * output_length)609 psa_status_t mbedtls_psa_asymmetric_decrypt(const psa_key_attributes_t *attributes,
610                                             const uint8_t *key_buffer,
611                                             size_t key_buffer_size,
612                                             psa_algorithm_t alg,
613                                             const uint8_t *input,
614                                             size_t input_length,
615                                             const uint8_t *salt,
616                                             size_t salt_length,
617                                             uint8_t *output,
618                                             size_t output_size,
619                                             size_t *output_length)
620 {
621     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
622     (void) key_buffer;
623     (void) key_buffer_size;
624     (void) input;
625     (void) input_length;
626     (void) salt;
627     (void) salt_length;
628     (void) output;
629     (void) output_size;
630     (void) output_length;
631 
632     *output_length = 0;
633 
634     if (attributes->type == PSA_KEY_TYPE_RSA_KEY_PAIR) {
635 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
636         defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
637         mbedtls_rsa_context *rsa = NULL;
638         status = mbedtls_psa_rsa_load_representation(attributes->type,
639                                                      key_buffer,
640                                                      key_buffer_size,
641                                                      &rsa);
642         if (status != PSA_SUCCESS) {
643             goto rsa_exit;
644         }
645 
646         if (input_length != mbedtls_rsa_get_len(rsa)) {
647             status = PSA_ERROR_INVALID_ARGUMENT;
648             goto rsa_exit;
649         }
650 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
651         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
652 
653         if (alg == PSA_ALG_RSA_PKCS1V15_CRYPT) {
654 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT)
655             status = mbedtls_to_psa_error(
656                 mbedtls_rsa_pkcs1_decrypt(rsa,
657                                           mbedtls_psa_get_random,
658                                           MBEDTLS_PSA_RANDOM_STATE,
659                                           output_length,
660                                           input,
661                                           output,
662                                           output_size));
663 #else
664             status = PSA_ERROR_NOT_SUPPORTED;
665 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT */
666         } else
667         if (PSA_ALG_IS_RSA_OAEP(alg)) {
668 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
669             status = mbedtls_to_psa_error(
670                 psa_rsa_oaep_set_padding_mode(alg, rsa));
671             if (status != PSA_SUCCESS) {
672                 goto rsa_exit;
673             }
674 
675             status = mbedtls_to_psa_error(
676                 mbedtls_rsa_rsaes_oaep_decrypt(rsa,
677                                                mbedtls_psa_get_random,
678                                                MBEDTLS_PSA_RANDOM_STATE,
679                                                salt, salt_length,
680                                                output_length,
681                                                input,
682                                                output,
683                                                output_size));
684 #else
685             status = PSA_ERROR_NOT_SUPPORTED;
686 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP */
687         } else {
688             status = PSA_ERROR_INVALID_ARGUMENT;
689         }
690 
691 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
692         defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
693 rsa_exit:
694         mbedtls_rsa_free(rsa);
695         mbedtls_free(rsa);
696 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
697         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
698     } else {
699         status = PSA_ERROR_NOT_SUPPORTED;
700     }
701 
702     return status;
703 }
704 
705 #endif /* MBEDTLS_PSA_CRYPTO_C */
706