1 /*
2  *  PSA RSA layer on top of Mbed TLS crypto
3  */
4 /*
5  *  Copyright The Mbed TLS Contributors
6  *  SPDX-License-Identifier: Apache-2.0
7  *
8  *  Licensed under the Apache License, Version 2.0 (the "License"); you may
9  *  not use this file except in compliance with the License.
10  *  You may obtain a copy of the License at
11  *
12  *  http://www.apache.org/licenses/LICENSE-2.0
13  *
14  *  Unless required by applicable law or agreed to in writing, software
15  *  distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16  *  WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  *  See the License for the specific language governing permissions and
18  *  limitations under the License.
19  */
20 
21 #include "common.h"
22 
23 #if defined(MBEDTLS_PSA_CRYPTO_C)
24 
25 #include <psa/crypto.h>
26 #include "psa/crypto_values.h"
27 #include "psa_crypto_core.h"
28 #include "psa_crypto_random_impl.h"
29 #include "psa_crypto_rsa.h"
30 #include "psa_crypto_hash.h"
31 
32 #include <stdlib.h>
33 #include <string.h>
34 #include "mbedtls/platform.h"
35 
36 #include <mbedtls/rsa.h>
37 #include <mbedtls/error.h>
38 #include <mbedtls/pk.h>
39 #include "pk_wrap.h"
40 #include "hash_info.h"
41 
42 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
43     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) || \
44     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) || \
45     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) || \
46     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR) || \
47     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
48 
49 /* Mbed TLS doesn't support non-byte-aligned key sizes (i.e. key sizes
50  * that are not a multiple of 8) well. For example, there is only
51  * mbedtls_rsa_get_len(), which returns a number of bytes, and no
52  * way to return the exact bit size of a key.
53  * To keep things simple, reject non-byte-aligned key sizes. */
psa_check_rsa_key_byte_aligned(const mbedtls_rsa_context * rsa)54 static psa_status_t psa_check_rsa_key_byte_aligned(
55     const mbedtls_rsa_context *rsa)
56 {
57     mbedtls_mpi n;
58     psa_status_t status;
59     mbedtls_mpi_init(&n);
60     status = mbedtls_to_psa_error(
61         mbedtls_rsa_export(rsa, &n, NULL, NULL, NULL, NULL));
62     if (status == PSA_SUCCESS) {
63         if (mbedtls_mpi_bitlen(&n) % 8 != 0) {
64             status = PSA_ERROR_NOT_SUPPORTED;
65         }
66     }
67     mbedtls_mpi_free(&n);
68     return status;
69 }
70 
mbedtls_psa_rsa_load_representation(psa_key_type_t type,const uint8_t * data,size_t data_length,mbedtls_rsa_context ** p_rsa)71 psa_status_t mbedtls_psa_rsa_load_representation(
72     psa_key_type_t type, const uint8_t *data, size_t data_length,
73     mbedtls_rsa_context **p_rsa)
74 {
75     psa_status_t status;
76     mbedtls_pk_context ctx;
77     size_t bits;
78     mbedtls_pk_init(&ctx);
79 
80     /* Parse the data. */
81     if (PSA_KEY_TYPE_IS_KEY_PAIR(type)) {
82         status = mbedtls_to_psa_error(
83             mbedtls_pk_parse_key(&ctx, data, data_length, NULL, 0,
84                                  mbedtls_psa_get_random, MBEDTLS_PSA_RANDOM_STATE));
85     } else {
86         status = mbedtls_to_psa_error(
87             mbedtls_pk_parse_public_key(&ctx, data, data_length));
88     }
89     if (status != PSA_SUCCESS) {
90         goto exit;
91     }
92 
93     /* We have something that the pkparse module recognizes. If it is a
94      * valid RSA key, store it. */
95     if (mbedtls_pk_get_type(&ctx) != MBEDTLS_PK_RSA) {
96         status = PSA_ERROR_INVALID_ARGUMENT;
97         goto exit;
98     }
99 
100     /* The size of an RSA key doesn't have to be a multiple of 8. Mbed TLS
101      * supports non-byte-aligned key sizes, but not well. For example,
102      * mbedtls_rsa_get_len() returns the key size in bytes, not in bits. */
103     bits = PSA_BYTES_TO_BITS(mbedtls_rsa_get_len(mbedtls_pk_rsa(ctx)));
104     if (bits > PSA_VENDOR_RSA_MAX_KEY_BITS) {
105         status = PSA_ERROR_NOT_SUPPORTED;
106         goto exit;
107     }
108     status = psa_check_rsa_key_byte_aligned(mbedtls_pk_rsa(ctx));
109     if (status != PSA_SUCCESS) {
110         goto exit;
111     }
112 
113     /* Copy out the pointer to the RSA context, and reset the PK context
114      * such that pk_free doesn't free the RSA context we just grabbed. */
115     *p_rsa = mbedtls_pk_rsa(ctx);
116     ctx.pk_info = NULL;
117 
118 exit:
119     mbedtls_pk_free(&ctx);
120     return status;
121 }
122 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
123         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) ||
124         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) ||
125         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) ||
126         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR) ||
127         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY) */
128 
129 #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR) || \
130     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
131 
mbedtls_psa_rsa_import_key(const psa_key_attributes_t * attributes,const uint8_t * data,size_t data_length,uint8_t * key_buffer,size_t key_buffer_size,size_t * key_buffer_length,size_t * bits)132 psa_status_t mbedtls_psa_rsa_import_key(
133     const psa_key_attributes_t *attributes,
134     const uint8_t *data, size_t data_length,
135     uint8_t *key_buffer, size_t key_buffer_size,
136     size_t *key_buffer_length, size_t *bits)
137 {
138     psa_status_t status;
139     mbedtls_rsa_context *rsa = NULL;
140 
141     /* Parse input */
142     status = mbedtls_psa_rsa_load_representation(attributes->core.type,
143                                                  data,
144                                                  data_length,
145                                                  &rsa);
146     if (status != PSA_SUCCESS) {
147         goto exit;
148     }
149 
150     *bits = (psa_key_bits_t) PSA_BYTES_TO_BITS(mbedtls_rsa_get_len(rsa));
151 
152     /* Re-export the data to PSA export format, such that we can store export
153      * representation in the key slot. Export representation in case of RSA is
154      * the smallest representation that's allowed as input, so a straight-up
155      * allocation of the same size as the input buffer will be large enough. */
156     status = mbedtls_psa_rsa_export_key(attributes->core.type,
157                                         rsa,
158                                         key_buffer,
159                                         key_buffer_size,
160                                         key_buffer_length);
161 exit:
162     /* Always free the RSA object */
163     mbedtls_rsa_free(rsa);
164     mbedtls_free(rsa);
165 
166     return status;
167 }
168 
mbedtls_psa_rsa_export_key(psa_key_type_t type,mbedtls_rsa_context * rsa,uint8_t * data,size_t data_size,size_t * data_length)169 psa_status_t mbedtls_psa_rsa_export_key(psa_key_type_t type,
170                                         mbedtls_rsa_context *rsa,
171                                         uint8_t *data,
172                                         size_t data_size,
173                                         size_t *data_length)
174 {
175     int ret;
176     mbedtls_pk_context pk;
177     uint8_t *pos = data + data_size;
178 
179     mbedtls_pk_init(&pk);
180     pk.pk_info = &mbedtls_rsa_info;
181     pk.pk_ctx = rsa;
182 
183     /* PSA Crypto API defines the format of an RSA key as a DER-encoded
184      * representation of the non-encrypted PKCS#1 RSAPrivateKey for a
185      * private key and of the RFC3279 RSAPublicKey for a public key. */
186     if (PSA_KEY_TYPE_IS_KEY_PAIR(type)) {
187         ret = mbedtls_pk_write_key_der(&pk, data, data_size);
188     } else {
189         ret = mbedtls_pk_write_pubkey(&pos, data, &pk);
190     }
191 
192     if (ret < 0) {
193         /* Clean up in case pk_write failed halfway through. */
194         memset(data, 0, data_size);
195         return mbedtls_to_psa_error(ret);
196     }
197 
198     /* The mbedtls_pk_xxx functions write to the end of the buffer.
199      * Move the data to the beginning and erase remaining data
200      * at the original location. */
201     if (2 * (size_t) ret <= data_size) {
202         memcpy(data, data + data_size - ret, ret);
203         memset(data + data_size - ret, 0, ret);
204     } else if ((size_t) ret < data_size) {
205         memmove(data, data + data_size - ret, ret);
206         memset(data + ret, 0, data_size - ret);
207     }
208 
209     *data_length = ret;
210     return PSA_SUCCESS;
211 }
212 
mbedtls_psa_rsa_export_public_key(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,uint8_t * data,size_t data_size,size_t * data_length)213 psa_status_t mbedtls_psa_rsa_export_public_key(
214     const psa_key_attributes_t *attributes,
215     const uint8_t *key_buffer, size_t key_buffer_size,
216     uint8_t *data, size_t data_size, size_t *data_length)
217 {
218     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
219     mbedtls_rsa_context *rsa = NULL;
220 
221     status = mbedtls_psa_rsa_load_representation(
222         attributes->core.type, key_buffer, key_buffer_size, &rsa);
223     if (status != PSA_SUCCESS) {
224         return status;
225     }
226 
227     status = mbedtls_psa_rsa_export_key(PSA_KEY_TYPE_RSA_PUBLIC_KEY,
228                                         rsa,
229                                         data,
230                                         data_size,
231                                         data_length);
232 
233     mbedtls_rsa_free(rsa);
234     mbedtls_free(rsa);
235 
236     return status;
237 }
238 #endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR) ||
239         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY) */
240 
241 #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR) && \
242     defined(MBEDTLS_GENPRIME)
psa_rsa_read_exponent(const uint8_t * domain_parameters,size_t domain_parameters_size,int * exponent)243 static psa_status_t psa_rsa_read_exponent(const uint8_t *domain_parameters,
244                                           size_t domain_parameters_size,
245                                           int *exponent)
246 {
247     size_t i;
248     uint32_t acc = 0;
249 
250     if (domain_parameters_size == 0) {
251         *exponent = 65537;
252         return PSA_SUCCESS;
253     }
254 
255     /* Mbed TLS encodes the public exponent as an int. For simplicity, only
256      * support values that fit in a 32-bit integer, which is larger than
257      * int on just about every platform anyway. */
258     if (domain_parameters_size > sizeof(acc)) {
259         return PSA_ERROR_NOT_SUPPORTED;
260     }
261     for (i = 0; i < domain_parameters_size; i++) {
262         acc = (acc << 8) | domain_parameters[i];
263     }
264     if (acc > INT_MAX) {
265         return PSA_ERROR_NOT_SUPPORTED;
266     }
267     *exponent = acc;
268     return PSA_SUCCESS;
269 }
270 
mbedtls_psa_rsa_generate_key(const psa_key_attributes_t * attributes,uint8_t * key_buffer,size_t key_buffer_size,size_t * key_buffer_length)271 psa_status_t mbedtls_psa_rsa_generate_key(
272     const psa_key_attributes_t *attributes,
273     uint8_t *key_buffer, size_t key_buffer_size, size_t *key_buffer_length)
274 {
275     psa_status_t status;
276     mbedtls_rsa_context rsa;
277     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
278     int exponent;
279 
280     status = psa_rsa_read_exponent(attributes->domain_parameters,
281                                    attributes->domain_parameters_size,
282                                    &exponent);
283     if (status != PSA_SUCCESS) {
284         return status;
285     }
286 
287     mbedtls_rsa_init(&rsa);
288     ret = mbedtls_rsa_gen_key(&rsa,
289                               mbedtls_psa_get_random,
290                               MBEDTLS_PSA_RANDOM_STATE,
291                               (unsigned int) attributes->core.bits,
292                               exponent);
293     if (ret != 0) {
294         return mbedtls_to_psa_error(ret);
295     }
296 
297     status = mbedtls_psa_rsa_export_key(attributes->core.type,
298                                         &rsa, key_buffer, key_buffer_size,
299                                         key_buffer_length);
300     mbedtls_rsa_free(&rsa);
301 
302     return status;
303 }
304 #endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR)
305         * defined(MBEDTLS_GENPRIME) */
306 
307 /****************************************************************/
308 /* Sign/verify hashes */
309 /****************************************************************/
310 
311 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) || \
312     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
313 
314 /* Decode the hash algorithm from alg and store the mbedtls encoding in
315  * md_alg. Verify that the hash length is acceptable. */
psa_rsa_decode_md_type(psa_algorithm_t alg,size_t hash_length,mbedtls_md_type_t * md_alg)316 static psa_status_t psa_rsa_decode_md_type(psa_algorithm_t alg,
317                                            size_t hash_length,
318                                            mbedtls_md_type_t *md_alg)
319 {
320     psa_algorithm_t hash_alg = PSA_ALG_SIGN_GET_HASH(alg);
321     *md_alg = mbedtls_hash_info_md_from_psa(hash_alg);
322 
323     /* The Mbed TLS RSA module uses an unsigned int for hash length
324      * parameters. Validate that it fits so that we don't risk an
325      * overflow later. */
326     if (hash_length > UINT_MAX) {
327         return PSA_ERROR_INVALID_ARGUMENT;
328     }
329 
330     /* For signatures using a hash, the hash length must be correct. */
331     if (alg != PSA_ALG_RSA_PKCS1V15_SIGN_RAW) {
332         if (*md_alg == MBEDTLS_MD_NONE) {
333             return PSA_ERROR_NOT_SUPPORTED;
334         }
335         if (mbedtls_hash_info_get_size(*md_alg) != hash_length) {
336             return PSA_ERROR_INVALID_ARGUMENT;
337         }
338     }
339 
340     return PSA_SUCCESS;
341 }
342 
mbedtls_psa_rsa_sign_hash(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,psa_algorithm_t alg,const uint8_t * hash,size_t hash_length,uint8_t * signature,size_t signature_size,size_t * signature_length)343 psa_status_t mbedtls_psa_rsa_sign_hash(
344     const psa_key_attributes_t *attributes,
345     const uint8_t *key_buffer, size_t key_buffer_size,
346     psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
347     uint8_t *signature, size_t signature_size, size_t *signature_length)
348 {
349     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
350     mbedtls_rsa_context *rsa = NULL;
351     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
352     mbedtls_md_type_t md_alg;
353 
354     status = mbedtls_psa_rsa_load_representation(attributes->core.type,
355                                                  key_buffer,
356                                                  key_buffer_size,
357                                                  &rsa);
358     if (status != PSA_SUCCESS) {
359         return status;
360     }
361 
362     status = psa_rsa_decode_md_type(alg, hash_length, &md_alg);
363     if (status != PSA_SUCCESS) {
364         goto exit;
365     }
366 
367     if (signature_size < mbedtls_rsa_get_len(rsa)) {
368         status = PSA_ERROR_BUFFER_TOO_SMALL;
369         goto exit;
370     }
371 
372 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN)
373     if (PSA_ALG_IS_RSA_PKCS1V15_SIGN(alg)) {
374         ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V15,
375                                       MBEDTLS_MD_NONE);
376         if (ret == 0) {
377             ret = mbedtls_rsa_pkcs1_sign(rsa,
378                                          mbedtls_psa_get_random,
379                                          MBEDTLS_PSA_RANDOM_STATE,
380                                          md_alg,
381                                          (unsigned int) hash_length,
382                                          hash,
383                                          signature);
384         }
385     } else
386 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN */
387 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
388     if (PSA_ALG_IS_RSA_PSS(alg)) {
389         ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, md_alg);
390 
391         if (ret == 0) {
392             ret = mbedtls_rsa_rsassa_pss_sign(rsa,
393                                               mbedtls_psa_get_random,
394                                               MBEDTLS_PSA_RANDOM_STATE,
395                                               MBEDTLS_MD_NONE,
396                                               (unsigned int) hash_length,
397                                               hash,
398                                               signature);
399         }
400     } else
401 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
402     {
403         status = PSA_ERROR_INVALID_ARGUMENT;
404         goto exit;
405     }
406 
407     if (ret == 0) {
408         *signature_length = mbedtls_rsa_get_len(rsa);
409     }
410     status = mbedtls_to_psa_error(ret);
411 
412 exit:
413     mbedtls_rsa_free(rsa);
414     mbedtls_free(rsa);
415 
416     return status;
417 }
418 
419 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
rsa_pss_expected_salt_len(psa_algorithm_t alg,const mbedtls_rsa_context * rsa,size_t hash_length)420 static int rsa_pss_expected_salt_len(psa_algorithm_t alg,
421                                      const mbedtls_rsa_context *rsa,
422                                      size_t hash_length)
423 {
424     if (PSA_ALG_IS_RSA_PSS_ANY_SALT(alg)) {
425         return MBEDTLS_RSA_SALT_LEN_ANY;
426     }
427     /* Otherwise: standard salt length, i.e. largest possible salt length
428      * up to the hash length. */
429     int klen = (int) mbedtls_rsa_get_len(rsa);   // known to fit
430     int hlen = (int) hash_length; // known to fit
431     int room = klen - 2 - hlen;
432     if (room < 0) {
433         return 0;  // there is no valid signature in this case anyway
434     } else if (room > hlen) {
435         return hlen;
436     } else {
437         return room;
438     }
439 }
440 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
441 
mbedtls_psa_rsa_verify_hash(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,psa_algorithm_t alg,const uint8_t * hash,size_t hash_length,const uint8_t * signature,size_t signature_length)442 psa_status_t mbedtls_psa_rsa_verify_hash(
443     const psa_key_attributes_t *attributes,
444     const uint8_t *key_buffer, size_t key_buffer_size,
445     psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
446     const uint8_t *signature, size_t signature_length)
447 {
448     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
449     mbedtls_rsa_context *rsa = NULL;
450     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
451     mbedtls_md_type_t md_alg;
452 
453     status = mbedtls_psa_rsa_load_representation(attributes->core.type,
454                                                  key_buffer,
455                                                  key_buffer_size,
456                                                  &rsa);
457     if (status != PSA_SUCCESS) {
458         goto exit;
459     }
460 
461     status = psa_rsa_decode_md_type(alg, hash_length, &md_alg);
462     if (status != PSA_SUCCESS) {
463         goto exit;
464     }
465 
466     if (signature_length != mbedtls_rsa_get_len(rsa)) {
467         status = PSA_ERROR_INVALID_SIGNATURE;
468         goto exit;
469     }
470 
471 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN)
472     if (PSA_ALG_IS_RSA_PKCS1V15_SIGN(alg)) {
473         ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V15,
474                                       MBEDTLS_MD_NONE);
475         if (ret == 0) {
476             ret = mbedtls_rsa_pkcs1_verify(rsa,
477                                            md_alg,
478                                            (unsigned int) hash_length,
479                                            hash,
480                                            signature);
481         }
482     } else
483 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN */
484 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
485     if (PSA_ALG_IS_RSA_PSS(alg)) {
486         ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, md_alg);
487         if (ret == 0) {
488             int slen = rsa_pss_expected_salt_len(alg, rsa, hash_length);
489             ret = mbedtls_rsa_rsassa_pss_verify_ext(rsa,
490                                                     md_alg,
491                                                     (unsigned) hash_length,
492                                                     hash,
493                                                     md_alg,
494                                                     slen,
495                                                     signature);
496         }
497     } else
498 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
499     {
500         status = PSA_ERROR_INVALID_ARGUMENT;
501         goto exit;
502     }
503 
504     /* Mbed TLS distinguishes "invalid padding" from "valid padding but
505      * the rest of the signature is invalid". This has little use in
506      * practice and PSA doesn't report this distinction. */
507     status = (ret == MBEDTLS_ERR_RSA_INVALID_PADDING) ?
508              PSA_ERROR_INVALID_SIGNATURE :
509              mbedtls_to_psa_error(ret);
510 
511 exit:
512     mbedtls_rsa_free(rsa);
513     mbedtls_free(rsa);
514 
515     return status;
516 }
517 
518 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) ||
519         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) */
520 
521 /****************************************************************/
522 /* Asymmetric cryptography */
523 /****************************************************************/
524 
525 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
psa_rsa_oaep_set_padding_mode(psa_algorithm_t alg,mbedtls_rsa_context * rsa)526 static int psa_rsa_oaep_set_padding_mode(psa_algorithm_t alg,
527                                          mbedtls_rsa_context *rsa)
528 {
529     psa_algorithm_t hash_alg = PSA_ALG_RSA_OAEP_GET_HASH(alg);
530     mbedtls_md_type_t md_alg = mbedtls_hash_info_md_from_psa(hash_alg);
531 
532     return mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, md_alg);
533 }
534 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
535 
mbedtls_psa_asymmetric_encrypt(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,psa_algorithm_t alg,const uint8_t * input,size_t input_length,const uint8_t * salt,size_t salt_length,uint8_t * output,size_t output_size,size_t * output_length)536 psa_status_t mbedtls_psa_asymmetric_encrypt(const psa_key_attributes_t *attributes,
537                                             const uint8_t *key_buffer,
538                                             size_t key_buffer_size,
539                                             psa_algorithm_t alg,
540                                             const uint8_t *input,
541                                             size_t input_length,
542                                             const uint8_t *salt,
543                                             size_t salt_length,
544                                             uint8_t *output,
545                                             size_t output_size,
546                                             size_t *output_length)
547 {
548     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
549     (void) key_buffer;
550     (void) key_buffer_size;
551     (void) input;
552     (void) input_length;
553     (void) salt;
554     (void) salt_length;
555     (void) output;
556     (void) output_size;
557     (void) output_length;
558 
559     if (PSA_KEY_TYPE_IS_RSA(attributes->core.type)) {
560 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
561         defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
562         mbedtls_rsa_context *rsa = NULL;
563         status = mbedtls_psa_rsa_load_representation(attributes->core.type,
564                                                      key_buffer,
565                                                      key_buffer_size,
566                                                      &rsa);
567         if (status != PSA_SUCCESS) {
568             goto rsa_exit;
569         }
570 
571         if (output_size < mbedtls_rsa_get_len(rsa)) {
572             status = PSA_ERROR_BUFFER_TOO_SMALL;
573             goto rsa_exit;
574         }
575 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
576         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
577         if (alg == PSA_ALG_RSA_PKCS1V15_CRYPT) {
578 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT)
579             status = mbedtls_to_psa_error(
580                 mbedtls_rsa_pkcs1_encrypt(rsa,
581                                           mbedtls_psa_get_random,
582                                           MBEDTLS_PSA_RANDOM_STATE,
583                                           input_length,
584                                           input,
585                                           output));
586 #else
587             status = PSA_ERROR_NOT_SUPPORTED;
588 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT */
589         } else
590         if (PSA_ALG_IS_RSA_OAEP(alg)) {
591 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
592             status = mbedtls_to_psa_error(
593                 psa_rsa_oaep_set_padding_mode(alg, rsa));
594             if (status != PSA_SUCCESS) {
595                 goto rsa_exit;
596             }
597 
598             status = mbedtls_to_psa_error(
599                 mbedtls_rsa_rsaes_oaep_encrypt(rsa,
600                                                mbedtls_psa_get_random,
601                                                MBEDTLS_PSA_RANDOM_STATE,
602                                                salt, salt_length,
603                                                input_length,
604                                                input,
605                                                output));
606 #else
607             status = PSA_ERROR_NOT_SUPPORTED;
608 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP */
609         } else {
610             status = PSA_ERROR_INVALID_ARGUMENT;
611         }
612 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
613         defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
614 rsa_exit:
615         if (status == PSA_SUCCESS) {
616             *output_length = mbedtls_rsa_get_len(rsa);
617         }
618 
619         mbedtls_rsa_free(rsa);
620         mbedtls_free(rsa);
621 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
622         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
623     } else {
624         status = PSA_ERROR_NOT_SUPPORTED;
625     }
626 
627     return status;
628 }
629 
mbedtls_psa_asymmetric_decrypt(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,psa_algorithm_t alg,const uint8_t * input,size_t input_length,const uint8_t * salt,size_t salt_length,uint8_t * output,size_t output_size,size_t * output_length)630 psa_status_t mbedtls_psa_asymmetric_decrypt(const psa_key_attributes_t *attributes,
631                                             const uint8_t *key_buffer,
632                                             size_t key_buffer_size,
633                                             psa_algorithm_t alg,
634                                             const uint8_t *input,
635                                             size_t input_length,
636                                             const uint8_t *salt,
637                                             size_t salt_length,
638                                             uint8_t *output,
639                                             size_t output_size,
640                                             size_t *output_length)
641 {
642     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
643     (void) key_buffer;
644     (void) key_buffer_size;
645     (void) input;
646     (void) input_length;
647     (void) salt;
648     (void) salt_length;
649     (void) output;
650     (void) output_size;
651     (void) output_length;
652 
653     *output_length = 0;
654 
655     if (attributes->core.type == PSA_KEY_TYPE_RSA_KEY_PAIR) {
656 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
657         defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
658         mbedtls_rsa_context *rsa = NULL;
659         status = mbedtls_psa_rsa_load_representation(attributes->core.type,
660                                                      key_buffer,
661                                                      key_buffer_size,
662                                                      &rsa);
663         if (status != PSA_SUCCESS) {
664             goto rsa_exit;
665         }
666 
667         if (input_length != mbedtls_rsa_get_len(rsa)) {
668             status = PSA_ERROR_INVALID_ARGUMENT;
669             goto rsa_exit;
670         }
671 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
672         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
673 
674         if (alg == PSA_ALG_RSA_PKCS1V15_CRYPT) {
675 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT)
676             status = mbedtls_to_psa_error(
677                 mbedtls_rsa_pkcs1_decrypt(rsa,
678                                           mbedtls_psa_get_random,
679                                           MBEDTLS_PSA_RANDOM_STATE,
680                                           output_length,
681                                           input,
682                                           output,
683                                           output_size));
684 #else
685             status = PSA_ERROR_NOT_SUPPORTED;
686 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT */
687         } else
688         if (PSA_ALG_IS_RSA_OAEP(alg)) {
689 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
690             status = mbedtls_to_psa_error(
691                 psa_rsa_oaep_set_padding_mode(alg, rsa));
692             if (status != PSA_SUCCESS) {
693                 goto rsa_exit;
694             }
695 
696             status = mbedtls_to_psa_error(
697                 mbedtls_rsa_rsaes_oaep_decrypt(rsa,
698                                                mbedtls_psa_get_random,
699                                                MBEDTLS_PSA_RANDOM_STATE,
700                                                salt, salt_length,
701                                                output_length,
702                                                input,
703                                                output,
704                                                output_size));
705 #else
706             status = PSA_ERROR_NOT_SUPPORTED;
707 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP */
708         } else {
709             status = PSA_ERROR_INVALID_ARGUMENT;
710         }
711 
712 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
713         defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
714 rsa_exit:
715         mbedtls_rsa_free(rsa);
716         mbedtls_free(rsa);
717 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
718         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
719     } else {
720         status = PSA_ERROR_NOT_SUPPORTED;
721     }
722 
723     return status;
724 }
725 
726 #endif /* MBEDTLS_PSA_CRYPTO_C */
727