1 /*
2  *  PSA RSA layer on top of Mbed TLS crypto
3  */
4 /*
5  *  Copyright The Mbed TLS Contributors
6  *  SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
7  */
8 
9 #include "common.h"
10 
11 #if defined(MBEDTLS_PSA_CRYPTO_C)
12 
13 #include <psa/crypto.h>
14 #include "psa/crypto_values.h"
15 #include "psa_crypto_core.h"
16 #include "psa_crypto_random_impl.h"
17 #include "psa_crypto_rsa.h"
18 #include "psa_crypto_hash.h"
19 #include "md_psa.h"
20 
21 #include <stdlib.h>
22 #include <string.h>
23 #include "mbedtls/platform.h"
24 
25 #include <mbedtls/rsa.h>
26 #include <mbedtls/error.h>
27 #include <mbedtls/pk.h>
28 #include "pk_wrap.h"
29 
30 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
31     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) || \
32     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) || \
33     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) || \
34     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_IMPORT) || \
35     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT) || \
36     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
37 
38 /* Mbed TLS doesn't support non-byte-aligned key sizes (i.e. key sizes
39  * that are not a multiple of 8) well. For example, there is only
40  * mbedtls_rsa_get_len(), which returns a number of bytes, and no
41  * way to return the exact bit size of a key.
42  * To keep things simple, reject non-byte-aligned key sizes. */
psa_check_rsa_key_byte_aligned(const mbedtls_rsa_context * rsa)43 static psa_status_t psa_check_rsa_key_byte_aligned(
44     const mbedtls_rsa_context *rsa)
45 {
46     mbedtls_mpi n;
47     psa_status_t status;
48     mbedtls_mpi_init(&n);
49     status = mbedtls_to_psa_error(
50         mbedtls_rsa_export(rsa, &n, NULL, NULL, NULL, NULL));
51     if (status == PSA_SUCCESS) {
52         if (mbedtls_mpi_bitlen(&n) % 8 != 0) {
53             status = PSA_ERROR_NOT_SUPPORTED;
54         }
55     }
56     mbedtls_mpi_free(&n);
57     return status;
58 }
59 
mbedtls_psa_rsa_load_representation(psa_key_type_t type,const uint8_t * data,size_t data_length,mbedtls_rsa_context ** p_rsa)60 psa_status_t mbedtls_psa_rsa_load_representation(
61     psa_key_type_t type, const uint8_t *data, size_t data_length,
62     mbedtls_rsa_context **p_rsa)
63 {
64     psa_status_t status;
65     mbedtls_pk_context ctx;
66     size_t bits;
67     mbedtls_pk_init(&ctx);
68 
69     /* Parse the data. */
70     if (PSA_KEY_TYPE_IS_KEY_PAIR(type)) {
71         status = mbedtls_to_psa_error(
72             mbedtls_pk_parse_key(&ctx, data, data_length, NULL, 0,
73                                  mbedtls_psa_get_random, MBEDTLS_PSA_RANDOM_STATE));
74     } else {
75         status = mbedtls_to_psa_error(
76             mbedtls_pk_parse_public_key(&ctx, data, data_length));
77     }
78     if (status != PSA_SUCCESS) {
79         goto exit;
80     }
81 
82     /* We have something that the pkparse module recognizes. If it is a
83      * valid RSA key, store it. */
84     if (mbedtls_pk_get_type(&ctx) != MBEDTLS_PK_RSA) {
85         status = PSA_ERROR_INVALID_ARGUMENT;
86         goto exit;
87     }
88 
89     /* The size of an RSA key doesn't have to be a multiple of 8. Mbed TLS
90      * supports non-byte-aligned key sizes, but not well. For example,
91      * mbedtls_rsa_get_len() returns the key size in bytes, not in bits. */
92     bits = PSA_BYTES_TO_BITS(mbedtls_rsa_get_len(mbedtls_pk_rsa(ctx)));
93     if (bits > PSA_VENDOR_RSA_MAX_KEY_BITS) {
94         status = PSA_ERROR_NOT_SUPPORTED;
95         goto exit;
96     }
97     status = psa_check_rsa_key_byte_aligned(mbedtls_pk_rsa(ctx));
98     if (status != PSA_SUCCESS) {
99         goto exit;
100     }
101 
102     /* Copy out the pointer to the RSA context, and reset the PK context
103      * such that pk_free doesn't free the RSA context we just grabbed. */
104     *p_rsa = mbedtls_pk_rsa(ctx);
105     ctx.pk_info = NULL;
106 
107 exit:
108     mbedtls_pk_free(&ctx);
109     return status;
110 }
111 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
112         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) ||
113         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) ||
114         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) ||
115         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_IMPORT) ||
116         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT) ||
117         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY) */
118 
119 #if (defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_IMPORT) && \
120     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT)) || \
121     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
mbedtls_psa_rsa_import_key(const psa_key_attributes_t * attributes,const uint8_t * data,size_t data_length,uint8_t * key_buffer,size_t key_buffer_size,size_t * key_buffer_length,size_t * bits)122 psa_status_t mbedtls_psa_rsa_import_key(
123     const psa_key_attributes_t *attributes,
124     const uint8_t *data, size_t data_length,
125     uint8_t *key_buffer, size_t key_buffer_size,
126     size_t *key_buffer_length, size_t *bits)
127 {
128     psa_status_t status;
129     mbedtls_rsa_context *rsa = NULL;
130 
131     /* Parse input */
132     status = mbedtls_psa_rsa_load_representation(attributes->core.type,
133                                                  data,
134                                                  data_length,
135                                                  &rsa);
136     if (status != PSA_SUCCESS) {
137         goto exit;
138     }
139 
140     *bits = (psa_key_bits_t) PSA_BYTES_TO_BITS(mbedtls_rsa_get_len(rsa));
141 
142     /* Re-export the data to PSA export format, such that we can store export
143      * representation in the key slot. Export representation in case of RSA is
144      * the smallest representation that's allowed as input, so a straight-up
145      * allocation of the same size as the input buffer will be large enough. */
146     status = mbedtls_psa_rsa_export_key(attributes->core.type,
147                                         rsa,
148                                         key_buffer,
149                                         key_buffer_size,
150                                         key_buffer_length);
151 exit:
152     /* Always free the RSA object */
153     mbedtls_rsa_free(rsa);
154     mbedtls_free(rsa);
155 
156     return status;
157 }
158 #endif /* (defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_IMPORT) &&
159         *  defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT)) ||
160         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY) */
161 
162 #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT) || \
163     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
mbedtls_psa_rsa_export_key(psa_key_type_t type,mbedtls_rsa_context * rsa,uint8_t * data,size_t data_size,size_t * data_length)164 psa_status_t mbedtls_psa_rsa_export_key(psa_key_type_t type,
165                                         mbedtls_rsa_context *rsa,
166                                         uint8_t *data,
167                                         size_t data_size,
168                                         size_t *data_length)
169 {
170     int ret;
171     mbedtls_pk_context pk;
172     uint8_t *pos = data + data_size;
173 
174     mbedtls_pk_init(&pk);
175     pk.pk_info = &mbedtls_rsa_info;
176     pk.pk_ctx = rsa;
177 
178     /* PSA Crypto API defines the format of an RSA key as a DER-encoded
179      * representation of the non-encrypted PKCS#1 RSAPrivateKey for a
180      * private key and of the RFC3279 RSAPublicKey for a public key. */
181     if (PSA_KEY_TYPE_IS_KEY_PAIR(type)) {
182         ret = mbedtls_pk_write_key_der(&pk, data, data_size);
183     } else {
184         ret = mbedtls_pk_write_pubkey(&pos, data, &pk);
185     }
186 
187     if (ret < 0) {
188         /* Clean up in case pk_write failed halfway through. */
189         memset(data, 0, data_size);
190         return mbedtls_to_psa_error(ret);
191     }
192 
193     /* The mbedtls_pk_xxx functions write to the end of the buffer.
194      * Move the data to the beginning and erase remaining data
195      * at the original location. */
196     if (2 * (size_t) ret <= data_size) {
197         memcpy(data, data + data_size - ret, ret);
198         memset(data + data_size - ret, 0, ret);
199     } else if ((size_t) ret < data_size) {
200         memmove(data, data + data_size - ret, ret);
201         memset(data + ret, 0, data_size - ret);
202     }
203 
204     *data_length = ret;
205     return PSA_SUCCESS;
206 }
207 
mbedtls_psa_rsa_export_public_key(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,uint8_t * data,size_t data_size,size_t * data_length)208 psa_status_t mbedtls_psa_rsa_export_public_key(
209     const psa_key_attributes_t *attributes,
210     const uint8_t *key_buffer, size_t key_buffer_size,
211     uint8_t *data, size_t data_size, size_t *data_length)
212 {
213     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
214     mbedtls_rsa_context *rsa = NULL;
215 
216     status = mbedtls_psa_rsa_load_representation(
217         attributes->core.type, key_buffer, key_buffer_size, &rsa);
218     if (status != PSA_SUCCESS) {
219         return status;
220     }
221 
222     status = mbedtls_psa_rsa_export_key(PSA_KEY_TYPE_RSA_PUBLIC_KEY,
223                                         rsa,
224                                         data,
225                                         data_size,
226                                         data_length);
227 
228     mbedtls_rsa_free(rsa);
229     mbedtls_free(rsa);
230 
231     return status;
232 }
233 #endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT) ||
234         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY) */
235 
236 #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_GENERATE)
psa_rsa_read_exponent(const uint8_t * domain_parameters,size_t domain_parameters_size,int * exponent)237 static psa_status_t psa_rsa_read_exponent(const uint8_t *domain_parameters,
238                                           size_t domain_parameters_size,
239                                           int *exponent)
240 {
241     size_t i;
242     uint32_t acc = 0;
243 
244     if (domain_parameters_size == 0) {
245         *exponent = 65537;
246         return PSA_SUCCESS;
247     }
248 
249     /* Mbed TLS encodes the public exponent as an int. For simplicity, only
250      * support values that fit in a 32-bit integer, which is larger than
251      * int on just about every platform anyway. */
252     if (domain_parameters_size > sizeof(acc)) {
253         return PSA_ERROR_NOT_SUPPORTED;
254     }
255     for (i = 0; i < domain_parameters_size; i++) {
256         acc = (acc << 8) | domain_parameters[i];
257     }
258     if (acc > INT_MAX) {
259         return PSA_ERROR_NOT_SUPPORTED;
260     }
261     *exponent = acc;
262     return PSA_SUCCESS;
263 }
264 
mbedtls_psa_rsa_generate_key(const psa_key_attributes_t * attributes,uint8_t * key_buffer,size_t key_buffer_size,size_t * key_buffer_length)265 psa_status_t mbedtls_psa_rsa_generate_key(
266     const psa_key_attributes_t *attributes,
267     uint8_t *key_buffer, size_t key_buffer_size, size_t *key_buffer_length)
268 {
269     psa_status_t status;
270     mbedtls_rsa_context rsa;
271     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
272     int exponent;
273 
274     status = psa_rsa_read_exponent(attributes->domain_parameters,
275                                    attributes->domain_parameters_size,
276                                    &exponent);
277     if (status != PSA_SUCCESS) {
278         return status;
279     }
280 
281     mbedtls_rsa_init(&rsa);
282     ret = mbedtls_rsa_gen_key(&rsa,
283                               mbedtls_psa_get_random,
284                               MBEDTLS_PSA_RANDOM_STATE,
285                               (unsigned int) attributes->core.bits,
286                               exponent);
287     if (ret != 0) {
288         return mbedtls_to_psa_error(ret);
289     }
290 
291     status = mbedtls_psa_rsa_export_key(attributes->core.type,
292                                         &rsa, key_buffer, key_buffer_size,
293                                         key_buffer_length);
294     mbedtls_rsa_free(&rsa);
295 
296     return status;
297 }
298 #endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_GENERATE) */
299 
300 /****************************************************************/
301 /* Sign/verify hashes */
302 /****************************************************************/
303 
304 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) || \
305     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
306 
307 /* Decode the hash algorithm from alg and store the mbedtls encoding in
308  * md_alg. Verify that the hash length is acceptable. */
psa_rsa_decode_md_type(psa_algorithm_t alg,size_t hash_length,mbedtls_md_type_t * md_alg)309 static psa_status_t psa_rsa_decode_md_type(psa_algorithm_t alg,
310                                            size_t hash_length,
311                                            mbedtls_md_type_t *md_alg)
312 {
313     psa_algorithm_t hash_alg = PSA_ALG_SIGN_GET_HASH(alg);
314     *md_alg = mbedtls_md_type_from_psa_alg(hash_alg);
315 
316     /* The Mbed TLS RSA module uses an unsigned int for hash length
317      * parameters. Validate that it fits so that we don't risk an
318      * overflow later. */
319 #if SIZE_MAX > UINT_MAX
320     if (hash_length > UINT_MAX) {
321         return PSA_ERROR_INVALID_ARGUMENT;
322     }
323 #endif
324 
325     /* For signatures using a hash, the hash length must be correct. */
326     if (alg != PSA_ALG_RSA_PKCS1V15_SIGN_RAW) {
327         if (*md_alg == MBEDTLS_MD_NONE) {
328             return PSA_ERROR_NOT_SUPPORTED;
329         }
330         if (mbedtls_md_get_size_from_type(*md_alg) != hash_length) {
331             return PSA_ERROR_INVALID_ARGUMENT;
332         }
333     }
334 
335     return PSA_SUCCESS;
336 }
337 
mbedtls_psa_rsa_sign_hash(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,psa_algorithm_t alg,const uint8_t * hash,size_t hash_length,uint8_t * signature,size_t signature_size,size_t * signature_length)338 psa_status_t mbedtls_psa_rsa_sign_hash(
339     const psa_key_attributes_t *attributes,
340     const uint8_t *key_buffer, size_t key_buffer_size,
341     psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
342     uint8_t *signature, size_t signature_size, size_t *signature_length)
343 {
344     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
345     mbedtls_rsa_context *rsa = NULL;
346     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
347     mbedtls_md_type_t md_alg;
348 
349     status = mbedtls_psa_rsa_load_representation(attributes->core.type,
350                                                  key_buffer,
351                                                  key_buffer_size,
352                                                  &rsa);
353     if (status != PSA_SUCCESS) {
354         return status;
355     }
356 
357     status = psa_rsa_decode_md_type(alg, hash_length, &md_alg);
358     if (status != PSA_SUCCESS) {
359         goto exit;
360     }
361 
362     if (signature_size < mbedtls_rsa_get_len(rsa)) {
363         status = PSA_ERROR_BUFFER_TOO_SMALL;
364         goto exit;
365     }
366 
367 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN)
368     if (PSA_ALG_IS_RSA_PKCS1V15_SIGN(alg)) {
369         ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V15,
370                                       MBEDTLS_MD_NONE);
371         if (ret == 0) {
372             ret = mbedtls_rsa_pkcs1_sign(rsa,
373                                          mbedtls_psa_get_random,
374                                          MBEDTLS_PSA_RANDOM_STATE,
375                                          md_alg,
376                                          (unsigned int) hash_length,
377                                          hash,
378                                          signature);
379         }
380     } else
381 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN */
382 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
383     if (PSA_ALG_IS_RSA_PSS(alg)) {
384         ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, md_alg);
385 
386         if (ret == 0) {
387             ret = mbedtls_rsa_rsassa_pss_sign(rsa,
388                                               mbedtls_psa_get_random,
389                                               MBEDTLS_PSA_RANDOM_STATE,
390                                               MBEDTLS_MD_NONE,
391                                               (unsigned int) hash_length,
392                                               hash,
393                                               signature);
394         }
395     } else
396 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
397     {
398         status = PSA_ERROR_INVALID_ARGUMENT;
399         goto exit;
400     }
401 
402     if (ret == 0) {
403         *signature_length = mbedtls_rsa_get_len(rsa);
404     }
405     status = mbedtls_to_psa_error(ret);
406 
407 exit:
408     mbedtls_rsa_free(rsa);
409     mbedtls_free(rsa);
410 
411     return status;
412 }
413 
414 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
rsa_pss_expected_salt_len(psa_algorithm_t alg,const mbedtls_rsa_context * rsa,size_t hash_length)415 static int rsa_pss_expected_salt_len(psa_algorithm_t alg,
416                                      const mbedtls_rsa_context *rsa,
417                                      size_t hash_length)
418 {
419     if (PSA_ALG_IS_RSA_PSS_ANY_SALT(alg)) {
420         return MBEDTLS_RSA_SALT_LEN_ANY;
421     }
422     /* Otherwise: standard salt length, i.e. largest possible salt length
423      * up to the hash length. */
424     int klen = (int) mbedtls_rsa_get_len(rsa);   // known to fit
425     int hlen = (int) hash_length; // known to fit
426     int room = klen - 2 - hlen;
427     if (room < 0) {
428         return 0;  // there is no valid signature in this case anyway
429     } else if (room > hlen) {
430         return hlen;
431     } else {
432         return room;
433     }
434 }
435 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
436 
mbedtls_psa_rsa_verify_hash(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,psa_algorithm_t alg,const uint8_t * hash,size_t hash_length,const uint8_t * signature,size_t signature_length)437 psa_status_t mbedtls_psa_rsa_verify_hash(
438     const psa_key_attributes_t *attributes,
439     const uint8_t *key_buffer, size_t key_buffer_size,
440     psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
441     const uint8_t *signature, size_t signature_length)
442 {
443     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
444     mbedtls_rsa_context *rsa = NULL;
445     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
446     mbedtls_md_type_t md_alg;
447 
448     status = mbedtls_psa_rsa_load_representation(attributes->core.type,
449                                                  key_buffer,
450                                                  key_buffer_size,
451                                                  &rsa);
452     if (status != PSA_SUCCESS) {
453         goto exit;
454     }
455 
456     status = psa_rsa_decode_md_type(alg, hash_length, &md_alg);
457     if (status != PSA_SUCCESS) {
458         goto exit;
459     }
460 
461     if (signature_length != mbedtls_rsa_get_len(rsa)) {
462         status = PSA_ERROR_INVALID_SIGNATURE;
463         goto exit;
464     }
465 
466 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN)
467     if (PSA_ALG_IS_RSA_PKCS1V15_SIGN(alg)) {
468         ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V15,
469                                       MBEDTLS_MD_NONE);
470         if (ret == 0) {
471             ret = mbedtls_rsa_pkcs1_verify(rsa,
472                                            md_alg,
473                                            (unsigned int) hash_length,
474                                            hash,
475                                            signature);
476         }
477     } else
478 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN */
479 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
480     if (PSA_ALG_IS_RSA_PSS(alg)) {
481         ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, md_alg);
482         if (ret == 0) {
483             int slen = rsa_pss_expected_salt_len(alg, rsa, hash_length);
484             ret = mbedtls_rsa_rsassa_pss_verify_ext(rsa,
485                                                     md_alg,
486                                                     (unsigned) hash_length,
487                                                     hash,
488                                                     md_alg,
489                                                     slen,
490                                                     signature);
491         }
492     } else
493 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
494     {
495         status = PSA_ERROR_INVALID_ARGUMENT;
496         goto exit;
497     }
498 
499     /* Mbed TLS distinguishes "invalid padding" from "valid padding but
500      * the rest of the signature is invalid". This has little use in
501      * practice and PSA doesn't report this distinction. */
502     status = (ret == MBEDTLS_ERR_RSA_INVALID_PADDING) ?
503              PSA_ERROR_INVALID_SIGNATURE :
504              mbedtls_to_psa_error(ret);
505 
506 exit:
507     mbedtls_rsa_free(rsa);
508     mbedtls_free(rsa);
509 
510     return status;
511 }
512 
513 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) ||
514         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) */
515 
516 /****************************************************************/
517 /* Asymmetric cryptography */
518 /****************************************************************/
519 
520 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
psa_rsa_oaep_set_padding_mode(psa_algorithm_t alg,mbedtls_rsa_context * rsa)521 static int psa_rsa_oaep_set_padding_mode(psa_algorithm_t alg,
522                                          mbedtls_rsa_context *rsa)
523 {
524     psa_algorithm_t hash_alg = PSA_ALG_RSA_OAEP_GET_HASH(alg);
525     mbedtls_md_type_t md_alg = mbedtls_md_type_from_psa_alg(hash_alg);
526 
527     /* Just to get the error status right, as rsa_set_padding() doesn't
528      * distinguish between "bad RSA algorithm" and "unknown hash". */
529     if (mbedtls_md_info_from_type(md_alg) == NULL) {
530         return PSA_ERROR_NOT_SUPPORTED;
531     }
532 
533     return mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, md_alg);
534 }
535 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
536 
mbedtls_psa_asymmetric_encrypt(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,psa_algorithm_t alg,const uint8_t * input,size_t input_length,const uint8_t * salt,size_t salt_length,uint8_t * output,size_t output_size,size_t * output_length)537 psa_status_t mbedtls_psa_asymmetric_encrypt(const psa_key_attributes_t *attributes,
538                                             const uint8_t *key_buffer,
539                                             size_t key_buffer_size,
540                                             psa_algorithm_t alg,
541                                             const uint8_t *input,
542                                             size_t input_length,
543                                             const uint8_t *salt,
544                                             size_t salt_length,
545                                             uint8_t *output,
546                                             size_t output_size,
547                                             size_t *output_length)
548 {
549     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
550     (void) key_buffer;
551     (void) key_buffer_size;
552     (void) input;
553     (void) input_length;
554     (void) salt;
555     (void) salt_length;
556     (void) output;
557     (void) output_size;
558     (void) output_length;
559 
560     if (PSA_KEY_TYPE_IS_RSA(attributes->core.type)) {
561 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
562         defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
563         mbedtls_rsa_context *rsa = NULL;
564         status = mbedtls_psa_rsa_load_representation(attributes->core.type,
565                                                      key_buffer,
566                                                      key_buffer_size,
567                                                      &rsa);
568         if (status != PSA_SUCCESS) {
569             goto rsa_exit;
570         }
571 
572         if (output_size < mbedtls_rsa_get_len(rsa)) {
573             status = PSA_ERROR_BUFFER_TOO_SMALL;
574             goto rsa_exit;
575         }
576 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
577         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
578         if (alg == PSA_ALG_RSA_PKCS1V15_CRYPT) {
579 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT)
580             status = mbedtls_to_psa_error(
581                 mbedtls_rsa_pkcs1_encrypt(rsa,
582                                           mbedtls_psa_get_random,
583                                           MBEDTLS_PSA_RANDOM_STATE,
584                                           input_length,
585                                           input,
586                                           output));
587 #else
588             status = PSA_ERROR_NOT_SUPPORTED;
589 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT */
590         } else
591         if (PSA_ALG_IS_RSA_OAEP(alg)) {
592 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
593             status = mbedtls_to_psa_error(
594                 psa_rsa_oaep_set_padding_mode(alg, rsa));
595             if (status != PSA_SUCCESS) {
596                 goto rsa_exit;
597             }
598 
599             status = mbedtls_to_psa_error(
600                 mbedtls_rsa_rsaes_oaep_encrypt(rsa,
601                                                mbedtls_psa_get_random,
602                                                MBEDTLS_PSA_RANDOM_STATE,
603                                                salt, salt_length,
604                                                input_length,
605                                                input,
606                                                output));
607 #else
608             status = PSA_ERROR_NOT_SUPPORTED;
609 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP */
610         } else {
611             status = PSA_ERROR_INVALID_ARGUMENT;
612         }
613 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
614         defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
615 rsa_exit:
616         if (status == PSA_SUCCESS) {
617             *output_length = mbedtls_rsa_get_len(rsa);
618         }
619 
620         mbedtls_rsa_free(rsa);
621         mbedtls_free(rsa);
622 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
623         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
624     } else {
625         status = PSA_ERROR_NOT_SUPPORTED;
626     }
627 
628     return status;
629 }
630 
mbedtls_psa_asymmetric_decrypt(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,psa_algorithm_t alg,const uint8_t * input,size_t input_length,const uint8_t * salt,size_t salt_length,uint8_t * output,size_t output_size,size_t * output_length)631 psa_status_t mbedtls_psa_asymmetric_decrypt(const psa_key_attributes_t *attributes,
632                                             const uint8_t *key_buffer,
633                                             size_t key_buffer_size,
634                                             psa_algorithm_t alg,
635                                             const uint8_t *input,
636                                             size_t input_length,
637                                             const uint8_t *salt,
638                                             size_t salt_length,
639                                             uint8_t *output,
640                                             size_t output_size,
641                                             size_t *output_length)
642 {
643     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
644     (void) key_buffer;
645     (void) key_buffer_size;
646     (void) input;
647     (void) input_length;
648     (void) salt;
649     (void) salt_length;
650     (void) output;
651     (void) output_size;
652     (void) output_length;
653 
654     *output_length = 0;
655 
656     if (attributes->core.type == PSA_KEY_TYPE_RSA_KEY_PAIR) {
657 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
658         defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
659         mbedtls_rsa_context *rsa = NULL;
660         status = mbedtls_psa_rsa_load_representation(attributes->core.type,
661                                                      key_buffer,
662                                                      key_buffer_size,
663                                                      &rsa);
664         if (status != PSA_SUCCESS) {
665             goto rsa_exit;
666         }
667 
668         if (input_length != mbedtls_rsa_get_len(rsa)) {
669             status = PSA_ERROR_INVALID_ARGUMENT;
670             goto rsa_exit;
671         }
672 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
673         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
674 
675         if (alg == PSA_ALG_RSA_PKCS1V15_CRYPT) {
676 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT)
677             status = mbedtls_to_psa_error(
678                 mbedtls_rsa_pkcs1_decrypt(rsa,
679                                           mbedtls_psa_get_random,
680                                           MBEDTLS_PSA_RANDOM_STATE,
681                                           output_length,
682                                           input,
683                                           output,
684                                           output_size));
685 #else
686             status = PSA_ERROR_NOT_SUPPORTED;
687 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT */
688         } else
689         if (PSA_ALG_IS_RSA_OAEP(alg)) {
690 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
691             status = mbedtls_to_psa_error(
692                 psa_rsa_oaep_set_padding_mode(alg, rsa));
693             if (status != PSA_SUCCESS) {
694                 goto rsa_exit;
695             }
696 
697             status = mbedtls_to_psa_error(
698                 mbedtls_rsa_rsaes_oaep_decrypt(rsa,
699                                                mbedtls_psa_get_random,
700                                                MBEDTLS_PSA_RANDOM_STATE,
701                                                salt, salt_length,
702                                                output_length,
703                                                input,
704                                                output,
705                                                output_size));
706 #else
707             status = PSA_ERROR_NOT_SUPPORTED;
708 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP */
709         } else {
710             status = PSA_ERROR_INVALID_ARGUMENT;
711         }
712 
713 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
714         defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
715 rsa_exit:
716         mbedtls_rsa_free(rsa);
717         mbedtls_free(rsa);
718 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
719         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
720     } else {
721         status = PSA_ERROR_NOT_SUPPORTED;
722     }
723 
724     return status;
725 }
726 
727 #endif /* MBEDTLS_PSA_CRYPTO_C */
728