1 /*
2  *  PSA RSA layer on top of Mbed TLS crypto
3  */
4 /*
5  *  Copyright The Mbed TLS Contributors
6  *  SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
7  */
8 
9 #include "common.h"
10 
11 #if defined(MBEDTLS_PSA_CRYPTO_C) || defined(MCUBOOT_USE_PSA_CRYPTO)
12 
13 #include <psa/crypto.h>
14 #include "psa/crypto_values.h"
15 #include "psa_crypto_core.h"
16 #include "psa_crypto_random_impl.h"
17 #include "psa_crypto_rsa.h"
18 #include "psa_crypto_hash.h"
19 #include "mbedtls/psa_util.h"
20 
21 #include <stdlib.h>
22 #include <string.h>
23 #include "mbedtls/platform.h"
24 
25 #include <mbedtls/rsa.h>
26 #include <mbedtls/error.h>
27 #include "rsa_internal.h"
28 
29 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
30     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) || \
31     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) || \
32     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) || \
33     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_IMPORT) || \
34     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT) || \
35     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
36 
37 /* Mbed TLS doesn't support non-byte-aligned key sizes (i.e. key sizes
38  * that are not a multiple of 8) well. For example, there is only
39  * mbedtls_rsa_get_len(), which returns a number of bytes, and no
40  * way to return the exact bit size of a key.
41  * To keep things simple, reject non-byte-aligned key sizes. */
psa_check_rsa_key_byte_aligned(const mbedtls_rsa_context * rsa)42 static psa_status_t psa_check_rsa_key_byte_aligned(
43     const mbedtls_rsa_context *rsa)
44 {
45     mbedtls_mpi n;
46     psa_status_t status;
47     mbedtls_mpi_init(&n);
48     status = mbedtls_to_psa_error(
49         mbedtls_rsa_export(rsa, &n, NULL, NULL, NULL, NULL));
50     if (status == PSA_SUCCESS) {
51         if (mbedtls_mpi_bitlen(&n) % 8 != 0) {
52             status = PSA_ERROR_NOT_SUPPORTED;
53         }
54     }
55     mbedtls_mpi_free(&n);
56     return status;
57 }
58 
mbedtls_psa_rsa_load_representation(psa_key_type_t type,const uint8_t * data,size_t data_length,mbedtls_rsa_context ** p_rsa)59 psa_status_t mbedtls_psa_rsa_load_representation(
60     psa_key_type_t type, const uint8_t *data, size_t data_length,
61     mbedtls_rsa_context **p_rsa)
62 {
63     psa_status_t status;
64     size_t bits;
65 
66     *p_rsa = mbedtls_calloc(1, sizeof(mbedtls_rsa_context));
67     if (*p_rsa == NULL) {
68         return PSA_ERROR_INSUFFICIENT_MEMORY;
69     }
70     mbedtls_rsa_init(*p_rsa);
71 
72     /* Parse the data. */
73     if (PSA_KEY_TYPE_IS_KEY_PAIR(type)) {
74         status = mbedtls_to_psa_error(mbedtls_rsa_parse_key(*p_rsa, data, data_length));
75     } else {
76         status = mbedtls_to_psa_error(mbedtls_rsa_parse_pubkey(*p_rsa, data, data_length));
77     }
78     if (status != PSA_SUCCESS) {
79         goto exit;
80     }
81 
82     /* The size of an RSA key doesn't have to be a multiple of 8. Mbed TLS
83      * supports non-byte-aligned key sizes, but not well. For example,
84      * mbedtls_rsa_get_len() returns the key size in bytes, not in bits. */
85     bits = PSA_BYTES_TO_BITS(mbedtls_rsa_get_len(*p_rsa));
86     if (bits > PSA_VENDOR_RSA_MAX_KEY_BITS) {
87         status = PSA_ERROR_NOT_SUPPORTED;
88         goto exit;
89     }
90     status = psa_check_rsa_key_byte_aligned(*p_rsa);
91     if (status != PSA_SUCCESS) {
92         goto exit;
93     }
94 
95 exit:
96     return status;
97 }
98 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
99         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) ||
100         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) ||
101         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) ||
102         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_IMPORT) ||
103         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT) ||
104         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY) */
105 
106 #if (defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_IMPORT) && \
107     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT)) || \
108     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
mbedtls_psa_rsa_import_key(const psa_key_attributes_t * attributes,const uint8_t * data,size_t data_length,uint8_t * key_buffer,size_t key_buffer_size,size_t * key_buffer_length,size_t * bits)109 psa_status_t mbedtls_psa_rsa_import_key(
110     const psa_key_attributes_t *attributes,
111     const uint8_t *data, size_t data_length,
112     uint8_t *key_buffer, size_t key_buffer_size,
113     size_t *key_buffer_length, size_t *bits)
114 {
115     psa_status_t status;
116     mbedtls_rsa_context *rsa = NULL;
117 
118     /* Parse input */
119     status = mbedtls_psa_rsa_load_representation(attributes->type,
120                                                  data,
121                                                  data_length,
122                                                  &rsa);
123     if (status != PSA_SUCCESS) {
124         goto exit;
125     }
126 
127     *bits = (psa_key_bits_t) PSA_BYTES_TO_BITS(mbedtls_rsa_get_len(rsa));
128 
129     /* Re-export the data to PSA export format, such that we can store export
130      * representation in the key slot. Export representation in case of RSA is
131      * the smallest representation that's allowed as input, so a straight-up
132      * allocation of the same size as the input buffer will be large enough. */
133     status = mbedtls_psa_rsa_export_key(attributes->type,
134                                         rsa,
135                                         key_buffer,
136                                         key_buffer_size,
137                                         key_buffer_length);
138 exit:
139     /* Always free the RSA object */
140     mbedtls_rsa_free(rsa);
141     mbedtls_free(rsa);
142 
143     return status;
144 }
145 #endif /* (defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_IMPORT) &&
146         *  defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT)) ||
147         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY) */
148 
149 #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT) || \
150     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
mbedtls_psa_rsa_export_key(psa_key_type_t type,mbedtls_rsa_context * rsa,uint8_t * data,size_t data_size,size_t * data_length)151 psa_status_t mbedtls_psa_rsa_export_key(psa_key_type_t type,
152                                         mbedtls_rsa_context *rsa,
153                                         uint8_t *data,
154                                         size_t data_size,
155                                         size_t *data_length)
156 {
157     int ret;
158     uint8_t *end = data + data_size;
159 
160     /* PSA Crypto API defines the format of an RSA key as a DER-encoded
161      * representation of the non-encrypted PKCS#1 RSAPrivateKey for a
162      * private key and of the RFC3279 RSAPublicKey for a public key. */
163     if (PSA_KEY_TYPE_IS_KEY_PAIR(type)) {
164         ret = mbedtls_rsa_write_key(rsa, data, &end);
165     } else {
166         ret = mbedtls_rsa_write_pubkey(rsa, data, &end);
167     }
168 
169     if (ret < 0) {
170         /* Clean up in case pk_write failed halfway through. */
171         memset(data, 0, data_size);
172         return mbedtls_to_psa_error(ret);
173     }
174 
175     /* The mbedtls_pk_xxx functions write to the end of the buffer.
176      * Move the data to the beginning and erase remaining data
177      * at the original location. */
178     if (2 * (size_t) ret <= data_size) {
179         memcpy(data, data + data_size - ret, ret);
180         memset(data + data_size - ret, 0, ret);
181     } else if ((size_t) ret < data_size) {
182         memmove(data, data + data_size - ret, ret);
183         memset(data + ret, 0, data_size - ret);
184     }
185 
186     *data_length = ret;
187     return PSA_SUCCESS;
188 }
189 
mbedtls_psa_rsa_export_public_key(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,uint8_t * data,size_t data_size,size_t * data_length)190 psa_status_t mbedtls_psa_rsa_export_public_key(
191     const psa_key_attributes_t *attributes,
192     const uint8_t *key_buffer, size_t key_buffer_size,
193     uint8_t *data, size_t data_size, size_t *data_length)
194 {
195     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
196     mbedtls_rsa_context *rsa = NULL;
197 
198     status = mbedtls_psa_rsa_load_representation(
199         attributes->type, key_buffer, key_buffer_size, &rsa);
200     if (status != PSA_SUCCESS) {
201         return status;
202     }
203 
204     status = mbedtls_psa_rsa_export_key(PSA_KEY_TYPE_RSA_PUBLIC_KEY,
205                                         rsa,
206                                         data,
207                                         data_size,
208                                         data_length);
209 
210     mbedtls_rsa_free(rsa);
211     mbedtls_free(rsa);
212 
213     return status;
214 }
215 #endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT) ||
216         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY) */
217 
218 #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_GENERATE)
psa_rsa_read_exponent(const uint8_t * e_bytes,size_t e_length,int * exponent)219 static psa_status_t psa_rsa_read_exponent(const uint8_t *e_bytes,
220                                           size_t e_length,
221                                           int *exponent)
222 {
223     size_t i;
224     uint32_t acc = 0;
225 
226     /* Mbed TLS encodes the public exponent as an int. For simplicity, only
227      * support values that fit in a 32-bit integer, which is larger than
228      * int on just about every platform anyway. */
229     if (e_length > sizeof(acc)) {
230         return PSA_ERROR_NOT_SUPPORTED;
231     }
232     for (i = 0; i < e_length; i++) {
233         acc = (acc << 8) | e_bytes[i];
234     }
235     if (acc > INT_MAX) {
236         return PSA_ERROR_NOT_SUPPORTED;
237     }
238     *exponent = acc;
239     return PSA_SUCCESS;
240 }
241 
mbedtls_psa_rsa_generate_key(const psa_key_attributes_t * attributes,const psa_key_production_parameters_t * params,size_t params_data_length,uint8_t * key_buffer,size_t key_buffer_size,size_t * key_buffer_length)242 psa_status_t mbedtls_psa_rsa_generate_key(
243     const psa_key_attributes_t *attributes,
244     const psa_key_production_parameters_t *params, size_t params_data_length,
245     uint8_t *key_buffer, size_t key_buffer_size, size_t *key_buffer_length)
246 {
247     psa_status_t status;
248     mbedtls_rsa_context rsa;
249     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
250     int exponent = 65537;
251 
252     if (params_data_length != 0) {
253         status = psa_rsa_read_exponent(params->data, params_data_length,
254                                        &exponent);
255         if (status != PSA_SUCCESS) {
256             return status;
257         }
258     }
259 
260     mbedtls_rsa_init(&rsa);
261     ret = mbedtls_rsa_gen_key(&rsa,
262                               mbedtls_psa_get_random,
263                               MBEDTLS_PSA_RANDOM_STATE,
264                               (unsigned int) attributes->bits,
265                               exponent);
266     if (ret != 0) {
267         return mbedtls_to_psa_error(ret);
268     }
269 
270     status = mbedtls_psa_rsa_export_key(attributes->type,
271                                         &rsa, key_buffer, key_buffer_size,
272                                         key_buffer_length);
273     mbedtls_rsa_free(&rsa);
274 
275     return status;
276 }
277 #endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_GENERATE) */
278 
279 /****************************************************************/
280 /* Sign/verify hashes */
281 /****************************************************************/
282 
283 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) || \
284     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
285 
286 /* Decode the hash algorithm from alg and store the mbedtls encoding in
287  * md_alg. Verify that the hash length is acceptable. */
psa_rsa_decode_md_type(psa_algorithm_t alg,size_t hash_length,mbedtls_md_type_t * md_alg)288 static psa_status_t psa_rsa_decode_md_type(psa_algorithm_t alg,
289                                            size_t hash_length,
290                                            mbedtls_md_type_t *md_alg)
291 {
292     psa_algorithm_t hash_alg = PSA_ALG_SIGN_GET_HASH(alg);
293     *md_alg = mbedtls_md_type_from_psa_alg(hash_alg);
294 
295     /* The Mbed TLS RSA module uses an unsigned int for hash length
296      * parameters. Validate that it fits so that we don't risk an
297      * overflow later. */
298 #if SIZE_MAX > UINT_MAX
299     if (hash_length > UINT_MAX) {
300         return PSA_ERROR_INVALID_ARGUMENT;
301     }
302 #endif
303 
304     /* For signatures using a hash, the hash length must be correct. */
305     if (alg != PSA_ALG_RSA_PKCS1V15_SIGN_RAW) {
306         if (*md_alg == MBEDTLS_MD_NONE) {
307             return PSA_ERROR_NOT_SUPPORTED;
308         }
309         if (mbedtls_md_get_size_from_type(*md_alg) != hash_length) {
310             return PSA_ERROR_INVALID_ARGUMENT;
311         }
312     }
313 
314     return PSA_SUCCESS;
315 }
316 
mbedtls_psa_rsa_sign_hash(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,psa_algorithm_t alg,const uint8_t * hash,size_t hash_length,uint8_t * signature,size_t signature_size,size_t * signature_length)317 psa_status_t mbedtls_psa_rsa_sign_hash(
318     const psa_key_attributes_t *attributes,
319     const uint8_t *key_buffer, size_t key_buffer_size,
320     psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
321     uint8_t *signature, size_t signature_size, size_t *signature_length)
322 {
323     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
324     mbedtls_rsa_context *rsa = NULL;
325     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
326     mbedtls_md_type_t md_alg;
327 
328     status = mbedtls_psa_rsa_load_representation(attributes->type,
329                                                  key_buffer,
330                                                  key_buffer_size,
331                                                  &rsa);
332     if (status != PSA_SUCCESS) {
333         return status;
334     }
335 
336     status = psa_rsa_decode_md_type(alg, hash_length, &md_alg);
337     if (status != PSA_SUCCESS) {
338         goto exit;
339     }
340 
341     if (signature_size < mbedtls_rsa_get_len(rsa)) {
342         status = PSA_ERROR_BUFFER_TOO_SMALL;
343         goto exit;
344     }
345 
346 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN)
347     if (PSA_ALG_IS_RSA_PKCS1V15_SIGN(alg)) {
348         ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V15,
349                                       MBEDTLS_MD_NONE);
350         if (ret == 0) {
351             ret = mbedtls_rsa_pkcs1_sign(rsa,
352                                          mbedtls_psa_get_random,
353                                          MBEDTLS_PSA_RANDOM_STATE,
354                                          md_alg,
355                                          (unsigned int) hash_length,
356                                          hash,
357                                          signature);
358         }
359     } else
360 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN */
361 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
362     if (PSA_ALG_IS_RSA_PSS(alg)) {
363         ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, md_alg);
364 
365         if (ret == 0) {
366             ret = mbedtls_rsa_rsassa_pss_sign(rsa,
367                                               mbedtls_psa_get_random,
368                                               MBEDTLS_PSA_RANDOM_STATE,
369                                               MBEDTLS_MD_NONE,
370                                               (unsigned int) hash_length,
371                                               hash,
372                                               signature);
373         }
374     } else
375 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
376     {
377         status = PSA_ERROR_INVALID_ARGUMENT;
378         goto exit;
379     }
380 
381     if (ret == 0) {
382         *signature_length = mbedtls_rsa_get_len(rsa);
383     }
384     status = mbedtls_to_psa_error(ret);
385 
386 exit:
387     mbedtls_rsa_free(rsa);
388     mbedtls_free(rsa);
389 
390     return status;
391 }
392 
393 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
rsa_pss_expected_salt_len(psa_algorithm_t alg,const mbedtls_rsa_context * rsa,size_t hash_length)394 static int rsa_pss_expected_salt_len(psa_algorithm_t alg,
395                                      const mbedtls_rsa_context *rsa,
396                                      size_t hash_length)
397 {
398     if (PSA_ALG_IS_RSA_PSS_ANY_SALT(alg)) {
399         return MBEDTLS_RSA_SALT_LEN_ANY;
400     }
401     /* Otherwise: standard salt length, i.e. largest possible salt length
402      * up to the hash length. */
403     int klen = (int) mbedtls_rsa_get_len(rsa);   // known to fit
404     int hlen = (int) hash_length; // known to fit
405     int room = klen - 2 - hlen;
406     if (room < 0) {
407         return 0;  // there is no valid signature in this case anyway
408     } else if (room > hlen) {
409         return hlen;
410     } else {
411         return room;
412     }
413 }
414 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
415 
mbedtls_psa_rsa_verify_hash(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,psa_algorithm_t alg,const uint8_t * hash,size_t hash_length,const uint8_t * signature,size_t signature_length)416 psa_status_t mbedtls_psa_rsa_verify_hash(
417     const psa_key_attributes_t *attributes,
418     const uint8_t *key_buffer, size_t key_buffer_size,
419     psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
420     const uint8_t *signature, size_t signature_length)
421 {
422     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
423     mbedtls_rsa_context *rsa = NULL;
424     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
425     mbedtls_md_type_t md_alg;
426 
427     status = mbedtls_psa_rsa_load_representation(attributes->type,
428                                                  key_buffer,
429                                                  key_buffer_size,
430                                                  &rsa);
431     if (status != PSA_SUCCESS) {
432         goto exit;
433     }
434 
435     status = psa_rsa_decode_md_type(alg, hash_length, &md_alg);
436     if (status != PSA_SUCCESS) {
437         goto exit;
438     }
439 
440     if (signature_length != mbedtls_rsa_get_len(rsa)) {
441         status = PSA_ERROR_INVALID_SIGNATURE;
442         goto exit;
443     }
444 
445 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN)
446     if (PSA_ALG_IS_RSA_PKCS1V15_SIGN(alg)) {
447         ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V15,
448                                       MBEDTLS_MD_NONE);
449         if (ret == 0) {
450             ret = mbedtls_rsa_pkcs1_verify(rsa,
451                                            md_alg,
452                                            (unsigned int) hash_length,
453                                            hash,
454                                            signature);
455         }
456     } else
457 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN */
458 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
459     if (PSA_ALG_IS_RSA_PSS(alg)) {
460         ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, md_alg);
461         if (ret == 0) {
462             int slen = rsa_pss_expected_salt_len(alg, rsa, hash_length);
463             ret = mbedtls_rsa_rsassa_pss_verify_ext(rsa,
464                                                     md_alg,
465                                                     (unsigned) hash_length,
466                                                     hash,
467                                                     md_alg,
468                                                     slen,
469                                                     signature);
470         }
471     } else
472 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
473     {
474         status = PSA_ERROR_INVALID_ARGUMENT;
475         goto exit;
476     }
477 
478     /* Mbed TLS distinguishes "invalid padding" from "valid padding but
479      * the rest of the signature is invalid". This has little use in
480      * practice and PSA doesn't report this distinction. */
481     status = (ret == MBEDTLS_ERR_RSA_INVALID_PADDING) ?
482              PSA_ERROR_INVALID_SIGNATURE :
483              mbedtls_to_psa_error(ret);
484 
485 exit:
486     mbedtls_rsa_free(rsa);
487     mbedtls_free(rsa);
488 
489     return status;
490 }
491 
492 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) ||
493         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) */
494 
495 /****************************************************************/
496 /* Asymmetric cryptography */
497 /****************************************************************/
498 
499 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
psa_rsa_oaep_set_padding_mode(psa_algorithm_t alg,mbedtls_rsa_context * rsa)500 static int psa_rsa_oaep_set_padding_mode(psa_algorithm_t alg,
501                                          mbedtls_rsa_context *rsa)
502 {
503     psa_algorithm_t hash_alg = PSA_ALG_RSA_OAEP_GET_HASH(alg);
504     mbedtls_md_type_t md_alg = mbedtls_md_type_from_psa_alg(hash_alg);
505 
506     /* Just to get the error status right, as rsa_set_padding() doesn't
507      * distinguish between "bad RSA algorithm" and "unknown hash". */
508     if (mbedtls_md_info_from_type(md_alg) == NULL) {
509         return PSA_ERROR_NOT_SUPPORTED;
510     }
511 
512     return mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, md_alg);
513 }
514 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
515 
mbedtls_psa_asymmetric_encrypt(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,psa_algorithm_t alg,const uint8_t * input,size_t input_length,const uint8_t * salt,size_t salt_length,uint8_t * output,size_t output_size,size_t * output_length)516 psa_status_t mbedtls_psa_asymmetric_encrypt(const psa_key_attributes_t *attributes,
517                                             const uint8_t *key_buffer,
518                                             size_t key_buffer_size,
519                                             psa_algorithm_t alg,
520                                             const uint8_t *input,
521                                             size_t input_length,
522                                             const uint8_t *salt,
523                                             size_t salt_length,
524                                             uint8_t *output,
525                                             size_t output_size,
526                                             size_t *output_length)
527 {
528     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
529     (void) key_buffer;
530     (void) key_buffer_size;
531     (void) input;
532     (void) input_length;
533     (void) salt;
534     (void) salt_length;
535     (void) output;
536     (void) output_size;
537     (void) output_length;
538 
539     if (PSA_KEY_TYPE_IS_RSA(attributes->type)) {
540 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
541         defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
542         mbedtls_rsa_context *rsa = NULL;
543         status = mbedtls_psa_rsa_load_representation(attributes->type,
544                                                      key_buffer,
545                                                      key_buffer_size,
546                                                      &rsa);
547         if (status != PSA_SUCCESS) {
548             goto rsa_exit;
549         }
550 
551         if (output_size < mbedtls_rsa_get_len(rsa)) {
552             status = PSA_ERROR_BUFFER_TOO_SMALL;
553             goto rsa_exit;
554         }
555 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
556         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
557         if (alg == PSA_ALG_RSA_PKCS1V15_CRYPT) {
558 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT)
559             status = mbedtls_to_psa_error(
560                 mbedtls_rsa_pkcs1_encrypt(rsa,
561                                           mbedtls_psa_get_random,
562                                           MBEDTLS_PSA_RANDOM_STATE,
563                                           input_length,
564                                           input,
565                                           output));
566 #else
567             status = PSA_ERROR_NOT_SUPPORTED;
568 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT */
569         } else
570         if (PSA_ALG_IS_RSA_OAEP(alg)) {
571 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
572             status = mbedtls_to_psa_error(
573                 psa_rsa_oaep_set_padding_mode(alg, rsa));
574             if (status != PSA_SUCCESS) {
575                 goto rsa_exit;
576             }
577 
578             status = mbedtls_to_psa_error(
579                 mbedtls_rsa_rsaes_oaep_encrypt(rsa,
580                                                mbedtls_psa_get_random,
581                                                MBEDTLS_PSA_RANDOM_STATE,
582                                                salt, salt_length,
583                                                input_length,
584                                                input,
585                                                output));
586 #else
587             status = PSA_ERROR_NOT_SUPPORTED;
588 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP */
589         } else {
590             status = PSA_ERROR_INVALID_ARGUMENT;
591         }
592 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
593         defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
594 rsa_exit:
595         if (status == PSA_SUCCESS) {
596             *output_length = mbedtls_rsa_get_len(rsa);
597         }
598 
599         mbedtls_rsa_free(rsa);
600         mbedtls_free(rsa);
601 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
602         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
603     } else {
604         status = PSA_ERROR_NOT_SUPPORTED;
605     }
606 
607     return status;
608 }
609 
mbedtls_psa_asymmetric_decrypt(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,psa_algorithm_t alg,const uint8_t * input,size_t input_length,const uint8_t * salt,size_t salt_length,uint8_t * output,size_t output_size,size_t * output_length)610 psa_status_t mbedtls_psa_asymmetric_decrypt(const psa_key_attributes_t *attributes,
611                                             const uint8_t *key_buffer,
612                                             size_t key_buffer_size,
613                                             psa_algorithm_t alg,
614                                             const uint8_t *input,
615                                             size_t input_length,
616                                             const uint8_t *salt,
617                                             size_t salt_length,
618                                             uint8_t *output,
619                                             size_t output_size,
620                                             size_t *output_length)
621 {
622     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
623     (void) key_buffer;
624     (void) key_buffer_size;
625     (void) input;
626     (void) input_length;
627     (void) salt;
628     (void) salt_length;
629     (void) output;
630     (void) output_size;
631     (void) output_length;
632 
633     *output_length = 0;
634 
635     if (attributes->type == PSA_KEY_TYPE_RSA_KEY_PAIR) {
636 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
637         defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
638         mbedtls_rsa_context *rsa = NULL;
639         status = mbedtls_psa_rsa_load_representation(attributes->type,
640                                                      key_buffer,
641                                                      key_buffer_size,
642                                                      &rsa);
643         if (status != PSA_SUCCESS) {
644             goto rsa_exit;
645         }
646 
647         if (input_length != mbedtls_rsa_get_len(rsa)) {
648             status = PSA_ERROR_INVALID_ARGUMENT;
649             goto rsa_exit;
650         }
651 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
652         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
653 
654         if (alg == PSA_ALG_RSA_PKCS1V15_CRYPT) {
655 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT)
656             status = mbedtls_to_psa_error(
657                 mbedtls_rsa_pkcs1_decrypt(rsa,
658                                           mbedtls_psa_get_random,
659                                           MBEDTLS_PSA_RANDOM_STATE,
660                                           output_length,
661                                           input,
662                                           output,
663                                           output_size));
664 #else
665             status = PSA_ERROR_NOT_SUPPORTED;
666 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT */
667         } else
668         if (PSA_ALG_IS_RSA_OAEP(alg)) {
669 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
670             status = mbedtls_to_psa_error(
671                 psa_rsa_oaep_set_padding_mode(alg, rsa));
672             if (status != PSA_SUCCESS) {
673                 goto rsa_exit;
674             }
675 
676             status = mbedtls_to_psa_error(
677                 mbedtls_rsa_rsaes_oaep_decrypt(rsa,
678                                                mbedtls_psa_get_random,
679                                                MBEDTLS_PSA_RANDOM_STATE,
680                                                salt, salt_length,
681                                                output_length,
682                                                input,
683                                                output,
684                                                output_size));
685 #else
686             status = PSA_ERROR_NOT_SUPPORTED;
687 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP */
688         } else {
689             status = PSA_ERROR_INVALID_ARGUMENT;
690         }
691 
692 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
693         defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
694 rsa_exit:
695         mbedtls_rsa_free(rsa);
696         mbedtls_free(rsa);
697 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
698         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
699     } else {
700         status = PSA_ERROR_NOT_SUPPORTED;
701     }
702 
703     return status;
704 }
705 
706 #endif /* MBEDTLS_PSA_CRYPTO_C */
707