1 /*
2  *  PSA RSA layer on top of Mbed TLS crypto
3  */
4 /*
5  *  Copyright The Mbed TLS Contributors
6  *  SPDX-License-Identifier: Apache-2.0
7  *
8  *  Licensed under the Apache License, Version 2.0 (the "License"); you may
9  *  not use this file except in compliance with the License.
10  *  You may obtain a copy of the License at
11  *
12  *  http://www.apache.org/licenses/LICENSE-2.0
13  *
14  *  Unless required by applicable law or agreed to in writing, software
15  *  distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16  *  WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  *  See the License for the specific language governing permissions and
18  *  limitations under the License.
19  */
20 
21 #include "common.h"
22 
23 #if defined(MBEDTLS_PSA_CRYPTO_C)
24 
25 #include <psa/crypto.h>
26 #include "psa/crypto_values.h"
27 #include "psa_crypto_core.h"
28 #include "psa_crypto_random_impl.h"
29 #include "psa_crypto_rsa.h"
30 #include "psa_crypto_hash.h"
31 
32 #include <stdlib.h>
33 #include <string.h>
34 #include "mbedtls/platform.h"
35 
36 #include <mbedtls/rsa.h>
37 #include <mbedtls/error.h>
38 #include <mbedtls/pk.h>
39 #include "pk_wrap.h"
40 #include "hash_info.h"
41 
42 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
43     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) || \
44     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) || \
45     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) || \
46     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR) || \
47     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
48 
49 /* Mbed TLS doesn't support non-byte-aligned key sizes (i.e. key sizes
50  * that are not a multiple of 8) well. For example, there is only
51  * mbedtls_rsa_get_len(), which returns a number of bytes, and no
52  * way to return the exact bit size of a key.
53  * To keep things simple, reject non-byte-aligned key sizes. */
psa_check_rsa_key_byte_aligned(const mbedtls_rsa_context * rsa)54 static psa_status_t psa_check_rsa_key_byte_aligned(
55     const mbedtls_rsa_context *rsa )
56 {
57     mbedtls_mpi n;
58     psa_status_t status;
59     mbedtls_mpi_init( &n );
60     status = mbedtls_to_psa_error(
61         mbedtls_rsa_export( rsa, &n, NULL, NULL, NULL, NULL ) );
62     if( status == PSA_SUCCESS )
63     {
64         if( mbedtls_mpi_bitlen( &n ) % 8 != 0 )
65             status = PSA_ERROR_NOT_SUPPORTED;
66     }
67     mbedtls_mpi_free( &n );
68     return( status );
69 }
70 
mbedtls_psa_rsa_load_representation(psa_key_type_t type,const uint8_t * data,size_t data_length,mbedtls_rsa_context ** p_rsa)71 psa_status_t mbedtls_psa_rsa_load_representation(
72     psa_key_type_t type, const uint8_t *data, size_t data_length,
73     mbedtls_rsa_context **p_rsa )
74 {
75     psa_status_t status;
76     mbedtls_pk_context ctx;
77     size_t bits;
78     mbedtls_pk_init( &ctx );
79 
80     /* Parse the data. */
81     if( PSA_KEY_TYPE_IS_KEY_PAIR( type ) )
82         status = mbedtls_to_psa_error(
83             mbedtls_pk_parse_key( &ctx, data, data_length, NULL, 0,
84                 mbedtls_psa_get_random, MBEDTLS_PSA_RANDOM_STATE ) );
85     else
86         status = mbedtls_to_psa_error(
87             mbedtls_pk_parse_public_key( &ctx, data, data_length ) );
88     if( status != PSA_SUCCESS )
89         goto exit;
90 
91     /* We have something that the pkparse module recognizes. If it is a
92      * valid RSA key, store it. */
93     if( mbedtls_pk_get_type( &ctx ) != MBEDTLS_PK_RSA )
94     {
95         status = PSA_ERROR_INVALID_ARGUMENT;
96         goto exit;
97     }
98 
99     /* The size of an RSA key doesn't have to be a multiple of 8. Mbed TLS
100      * supports non-byte-aligned key sizes, but not well. For example,
101      * mbedtls_rsa_get_len() returns the key size in bytes, not in bits. */
102     bits = PSA_BYTES_TO_BITS( mbedtls_rsa_get_len( mbedtls_pk_rsa( ctx ) ) );
103     if( bits > PSA_VENDOR_RSA_MAX_KEY_BITS )
104     {
105         status = PSA_ERROR_NOT_SUPPORTED;
106         goto exit;
107     }
108     status = psa_check_rsa_key_byte_aligned( mbedtls_pk_rsa( ctx ) );
109     if( status != PSA_SUCCESS )
110         goto exit;
111 
112     /* Copy out the pointer to the RSA context, and reset the PK context
113      * such that pk_free doesn't free the RSA context we just grabbed. */
114     *p_rsa = mbedtls_pk_rsa( ctx );
115     ctx.pk_info = NULL;
116 
117 exit:
118     mbedtls_pk_free( &ctx );
119     return( status );
120 }
121 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
122         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) ||
123         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) ||
124         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) ||
125         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR) ||
126         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY) */
127 
128 #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR) || \
129     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
130 
mbedtls_psa_rsa_import_key(const psa_key_attributes_t * attributes,const uint8_t * data,size_t data_length,uint8_t * key_buffer,size_t key_buffer_size,size_t * key_buffer_length,size_t * bits)131 psa_status_t mbedtls_psa_rsa_import_key(
132     const psa_key_attributes_t *attributes,
133     const uint8_t *data, size_t data_length,
134     uint8_t *key_buffer, size_t key_buffer_size,
135     size_t *key_buffer_length, size_t *bits )
136 {
137     psa_status_t status;
138     mbedtls_rsa_context *rsa = NULL;
139 
140     /* Parse input */
141     status = mbedtls_psa_rsa_load_representation( attributes->core.type,
142                                                   data,
143                                                   data_length,
144                                                   &rsa );
145     if( status != PSA_SUCCESS )
146         goto exit;
147 
148     *bits = (psa_key_bits_t) PSA_BYTES_TO_BITS( mbedtls_rsa_get_len( rsa ) );
149 
150     /* Re-export the data to PSA export format, such that we can store export
151      * representation in the key slot. Export representation in case of RSA is
152      * the smallest representation that's allowed as input, so a straight-up
153      * allocation of the same size as the input buffer will be large enough. */
154     status = mbedtls_psa_rsa_export_key( attributes->core.type,
155                                          rsa,
156                                          key_buffer,
157                                          key_buffer_size,
158                                          key_buffer_length );
159 exit:
160     /* Always free the RSA object */
161     mbedtls_rsa_free( rsa );
162     mbedtls_free( rsa );
163 
164     return( status );
165 }
166 
mbedtls_psa_rsa_export_key(psa_key_type_t type,mbedtls_rsa_context * rsa,uint8_t * data,size_t data_size,size_t * data_length)167 psa_status_t mbedtls_psa_rsa_export_key( psa_key_type_t type,
168                                          mbedtls_rsa_context *rsa,
169                                          uint8_t *data,
170                                          size_t data_size,
171                                          size_t *data_length )
172 {
173 #if defined(MBEDTLS_PK_WRITE_C)
174     int ret;
175     mbedtls_pk_context pk;
176     uint8_t *pos = data + data_size;
177 
178     mbedtls_pk_init( &pk );
179     pk.pk_info = &mbedtls_rsa_info;
180     pk.pk_ctx = rsa;
181 
182     /* PSA Crypto API defines the format of an RSA key as a DER-encoded
183      * representation of the non-encrypted PKCS#1 RSAPrivateKey for a
184      * private key and of the RFC3279 RSAPublicKey for a public key. */
185     if( PSA_KEY_TYPE_IS_KEY_PAIR( type ) )
186         ret = mbedtls_pk_write_key_der( &pk, data, data_size );
187     else
188         ret = mbedtls_pk_write_pubkey( &pos, data, &pk );
189 
190     if( ret < 0 )
191     {
192         /* Clean up in case pk_write failed halfway through. */
193         memset( data, 0, data_size );
194         return( mbedtls_to_psa_error( ret ) );
195     }
196 
197     /* The mbedtls_pk_xxx functions write to the end of the buffer.
198      * Move the data to the beginning and erase remaining data
199      * at the original location. */
200     if( 2 * (size_t) ret <= data_size )
201     {
202         memcpy( data, data + data_size - ret, ret );
203         memset( data + data_size - ret, 0, ret );
204     }
205     else if( (size_t) ret < data_size )
206     {
207         memmove( data, data + data_size - ret, ret );
208         memset( data + ret, 0, data_size - ret );
209     }
210 
211     *data_length = ret;
212     return( PSA_SUCCESS );
213 #else
214     (void) type;
215     (void) rsa;
216     (void) data;
217     (void) data_size;
218     (void) data_length;
219     return( PSA_ERROR_NOT_SUPPORTED );
220 #endif /* MBEDTLS_PK_WRITE_C */
221 }
222 
mbedtls_psa_rsa_export_public_key(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,uint8_t * data,size_t data_size,size_t * data_length)223 psa_status_t mbedtls_psa_rsa_export_public_key(
224     const psa_key_attributes_t *attributes,
225     const uint8_t *key_buffer, size_t key_buffer_size,
226     uint8_t *data, size_t data_size, size_t *data_length )
227 {
228     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
229     mbedtls_rsa_context *rsa = NULL;
230 
231     status = mbedtls_psa_rsa_load_representation(
232                  attributes->core.type, key_buffer, key_buffer_size, &rsa );
233     if( status != PSA_SUCCESS )
234         return( status );
235 
236     status = mbedtls_psa_rsa_export_key( PSA_KEY_TYPE_RSA_PUBLIC_KEY,
237                                          rsa,
238                                          data,
239                                          data_size,
240                                          data_length );
241 
242     mbedtls_rsa_free( rsa );
243     mbedtls_free( rsa );
244 
245     return( status );
246 }
247 #endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR) ||
248         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY) */
249 
250 #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR) && \
251     defined(MBEDTLS_GENPRIME)
psa_rsa_read_exponent(const uint8_t * domain_parameters,size_t domain_parameters_size,int * exponent)252 static psa_status_t psa_rsa_read_exponent( const uint8_t *domain_parameters,
253                                            size_t domain_parameters_size,
254                                            int *exponent )
255 {
256     size_t i;
257     uint32_t acc = 0;
258 
259     if( domain_parameters_size == 0 )
260     {
261         *exponent = 65537;
262         return( PSA_SUCCESS );
263     }
264 
265     /* Mbed TLS encodes the public exponent as an int. For simplicity, only
266      * support values that fit in a 32-bit integer, which is larger than
267      * int on just about every platform anyway. */
268     if( domain_parameters_size > sizeof( acc ) )
269         return( PSA_ERROR_NOT_SUPPORTED );
270     for( i = 0; i < domain_parameters_size; i++ )
271         acc = ( acc << 8 ) | domain_parameters[i];
272     if( acc > INT_MAX )
273         return( PSA_ERROR_NOT_SUPPORTED );
274     *exponent = acc;
275     return( PSA_SUCCESS );
276 }
277 
mbedtls_psa_rsa_generate_key(const psa_key_attributes_t * attributes,uint8_t * key_buffer,size_t key_buffer_size,size_t * key_buffer_length)278 psa_status_t mbedtls_psa_rsa_generate_key(
279     const psa_key_attributes_t *attributes,
280     uint8_t *key_buffer, size_t key_buffer_size, size_t *key_buffer_length )
281 {
282     psa_status_t status;
283     mbedtls_rsa_context rsa;
284     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
285     int exponent;
286 
287     status = psa_rsa_read_exponent( attributes->domain_parameters,
288                                     attributes->domain_parameters_size,
289                                     &exponent );
290     if( status != PSA_SUCCESS )
291         return( status );
292 
293     mbedtls_rsa_init( &rsa );
294     ret = mbedtls_rsa_gen_key( &rsa,
295                                mbedtls_psa_get_random,
296                                MBEDTLS_PSA_RANDOM_STATE,
297                                (unsigned int)attributes->core.bits,
298                                exponent );
299     if( ret != 0 )
300         return( mbedtls_to_psa_error( ret ) );
301 
302     status = mbedtls_psa_rsa_export_key( attributes->core.type,
303                                          &rsa, key_buffer, key_buffer_size,
304                                          key_buffer_length );
305     mbedtls_rsa_free( &rsa );
306 
307     return( status );
308 }
309 #endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR)
310         * defined(MBEDTLS_GENPRIME) */
311 
312 /****************************************************************/
313 /* Sign/verify hashes */
314 /****************************************************************/
315 
316 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) || \
317     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
318 
319 /* Decode the hash algorithm from alg and store the mbedtls encoding in
320  * md_alg. Verify that the hash length is acceptable. */
psa_rsa_decode_md_type(psa_algorithm_t alg,size_t hash_length,mbedtls_md_type_t * md_alg)321 static psa_status_t psa_rsa_decode_md_type( psa_algorithm_t alg,
322                                             size_t hash_length,
323                                             mbedtls_md_type_t *md_alg )
324 {
325     psa_algorithm_t hash_alg = PSA_ALG_SIGN_GET_HASH( alg );
326     *md_alg = mbedtls_hash_info_md_from_psa( hash_alg );
327 
328     /* The Mbed TLS RSA module uses an unsigned int for hash length
329      * parameters. Validate that it fits so that we don't risk an
330      * overflow later. */
331 #if SIZE_MAX > UINT_MAX
332     if( hash_length > UINT_MAX )
333         return( PSA_ERROR_INVALID_ARGUMENT );
334 #endif
335 
336     /* For signatures using a hash, the hash length must be correct. */
337     if( alg != PSA_ALG_RSA_PKCS1V15_SIGN_RAW )
338     {
339         if( *md_alg == MBEDTLS_MD_NONE )
340             return( PSA_ERROR_NOT_SUPPORTED );
341         if( mbedtls_hash_info_get_size( *md_alg ) != hash_length )
342             return( PSA_ERROR_INVALID_ARGUMENT );
343     }
344 
345     return( PSA_SUCCESS );
346 }
347 
mbedtls_psa_rsa_sign_hash(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,psa_algorithm_t alg,const uint8_t * hash,size_t hash_length,uint8_t * signature,size_t signature_size,size_t * signature_length)348 psa_status_t mbedtls_psa_rsa_sign_hash(
349     const psa_key_attributes_t *attributes,
350     const uint8_t *key_buffer, size_t key_buffer_size,
351     psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
352     uint8_t *signature, size_t signature_size, size_t *signature_length )
353 {
354     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
355     mbedtls_rsa_context *rsa = NULL;
356     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
357     mbedtls_md_type_t md_alg;
358 
359     status = mbedtls_psa_rsa_load_representation( attributes->core.type,
360                                                   key_buffer,
361                                                   key_buffer_size,
362                                                   &rsa );
363     if( status != PSA_SUCCESS )
364         return( status );
365 
366     status = psa_rsa_decode_md_type( alg, hash_length, &md_alg );
367     if( status != PSA_SUCCESS )
368         goto exit;
369 
370     if( signature_size < mbedtls_rsa_get_len( rsa ) )
371     {
372         status = PSA_ERROR_BUFFER_TOO_SMALL;
373         goto exit;
374     }
375 
376 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN)
377     if( PSA_ALG_IS_RSA_PKCS1V15_SIGN( alg ) )
378     {
379         ret = mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V15,
380                                        MBEDTLS_MD_NONE );
381         if( ret == 0 )
382         {
383             ret = mbedtls_rsa_pkcs1_sign( rsa,
384                                           mbedtls_psa_get_random,
385                                           MBEDTLS_PSA_RANDOM_STATE,
386                                           md_alg,
387                                           (unsigned int) hash_length,
388                                           hash,
389                                           signature );
390         }
391     }
392     else
393 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN */
394 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
395     if( PSA_ALG_IS_RSA_PSS( alg ) )
396     {
397         ret = mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V21, md_alg );
398 
399         if( ret == 0 )
400         {
401             ret = mbedtls_rsa_rsassa_pss_sign( rsa,
402                                                mbedtls_psa_get_random,
403                                                MBEDTLS_PSA_RANDOM_STATE,
404                                                MBEDTLS_MD_NONE,
405                                                (unsigned int) hash_length,
406                                                hash,
407                                                signature );
408         }
409     }
410     else
411 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
412     {
413         status = PSA_ERROR_INVALID_ARGUMENT;
414         goto exit;
415     }
416 
417     if( ret == 0 )
418         *signature_length = mbedtls_rsa_get_len( rsa );
419     status = mbedtls_to_psa_error( ret );
420 
421 exit:
422     mbedtls_rsa_free( rsa );
423     mbedtls_free( rsa );
424 
425     return( status );
426 }
427 
428 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
rsa_pss_expected_salt_len(psa_algorithm_t alg,const mbedtls_rsa_context * rsa,size_t hash_length)429 static int rsa_pss_expected_salt_len( psa_algorithm_t alg,
430                                       const mbedtls_rsa_context *rsa,
431                                       size_t hash_length )
432 {
433     if( PSA_ALG_IS_RSA_PSS_ANY_SALT( alg ) )
434         return( MBEDTLS_RSA_SALT_LEN_ANY );
435     /* Otherwise: standard salt length, i.e. largest possible salt length
436      * up to the hash length. */
437     int klen = (int) mbedtls_rsa_get_len( rsa ); // known to fit
438     int hlen = (int) hash_length; // known to fit
439     int room = klen - 2 - hlen;
440     if( room < 0 )
441         return( 0 ); // there is no valid signature in this case anyway
442     else if( room > hlen )
443         return( hlen );
444     else
445         return( room );
446 }
447 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
448 
mbedtls_psa_rsa_verify_hash(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,psa_algorithm_t alg,const uint8_t * hash,size_t hash_length,const uint8_t * signature,size_t signature_length)449 psa_status_t mbedtls_psa_rsa_verify_hash(
450     const psa_key_attributes_t *attributes,
451     const uint8_t *key_buffer, size_t key_buffer_size,
452     psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
453     const uint8_t *signature, size_t signature_length )
454 {
455     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
456     mbedtls_rsa_context *rsa = NULL;
457     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
458     mbedtls_md_type_t md_alg;
459 
460     status = mbedtls_psa_rsa_load_representation( attributes->core.type,
461                                                   key_buffer,
462                                                   key_buffer_size,
463                                                   &rsa );
464     if( status != PSA_SUCCESS )
465         goto exit;
466 
467     status = psa_rsa_decode_md_type( alg, hash_length, &md_alg );
468     if( status != PSA_SUCCESS )
469         goto exit;
470 
471     if( signature_length != mbedtls_rsa_get_len( rsa ) )
472     {
473         status = PSA_ERROR_INVALID_SIGNATURE;
474         goto exit;
475     }
476 
477 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN)
478     if( PSA_ALG_IS_RSA_PKCS1V15_SIGN( alg ) )
479     {
480         ret = mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V15,
481                                        MBEDTLS_MD_NONE );
482         if( ret == 0 )
483         {
484             ret = mbedtls_rsa_pkcs1_verify( rsa,
485                                             md_alg,
486                                             (unsigned int) hash_length,
487                                             hash,
488                                             signature );
489         }
490     }
491     else
492 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN */
493 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
494     if( PSA_ALG_IS_RSA_PSS( alg ) )
495     {
496         ret = mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V21, md_alg );
497         if( ret == 0 )
498         {
499             int slen = rsa_pss_expected_salt_len( alg, rsa, hash_length );
500             ret = mbedtls_rsa_rsassa_pss_verify_ext( rsa,
501                                                      md_alg,
502                                                      (unsigned) hash_length,
503                                                      hash,
504                                                      md_alg,
505                                                      slen,
506                                                      signature );
507         }
508     }
509     else
510 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
511     {
512         status = PSA_ERROR_INVALID_ARGUMENT;
513         goto exit;
514     }
515 
516     /* Mbed TLS distinguishes "invalid padding" from "valid padding but
517      * the rest of the signature is invalid". This has little use in
518      * practice and PSA doesn't report this distinction. */
519     status = ( ret == MBEDTLS_ERR_RSA_INVALID_PADDING ) ?
520              PSA_ERROR_INVALID_SIGNATURE :
521              mbedtls_to_psa_error( ret );
522 
523 exit:
524     mbedtls_rsa_free( rsa );
525     mbedtls_free( rsa );
526 
527     return( status );
528 }
529 
530 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) ||
531         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) */
532 
533 /****************************************************************/
534 /* Asymmetric cryptography */
535 /****************************************************************/
536 
537 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
psa_rsa_oaep_set_padding_mode(psa_algorithm_t alg,mbedtls_rsa_context * rsa)538 static int psa_rsa_oaep_set_padding_mode( psa_algorithm_t alg,
539                                           mbedtls_rsa_context *rsa )
540 {
541     psa_algorithm_t hash_alg = PSA_ALG_RSA_OAEP_GET_HASH( alg );
542     mbedtls_md_type_t md_alg = mbedtls_hash_info_md_from_psa( hash_alg );
543 
544     return( mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V21, md_alg ) );
545 }
546 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
547 
mbedtls_psa_asymmetric_encrypt(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,psa_algorithm_t alg,const uint8_t * input,size_t input_length,const uint8_t * salt,size_t salt_length,uint8_t * output,size_t output_size,size_t * output_length)548 psa_status_t mbedtls_psa_asymmetric_encrypt( const psa_key_attributes_t *attributes,
549                                              const uint8_t *key_buffer,
550                                              size_t key_buffer_size,
551                                              psa_algorithm_t alg,
552                                              const uint8_t *input,
553                                              size_t input_length,
554                                              const uint8_t *salt,
555                                              size_t salt_length,
556                                              uint8_t *output,
557                                              size_t output_size,
558                                              size_t *output_length )
559 {
560     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
561     (void) key_buffer;
562     (void) key_buffer_size;
563     (void) input;
564     (void) input_length;
565     (void) salt;
566     (void) salt_length;
567     (void) output;
568     (void) output_size;
569     (void) output_length;
570 
571     if( PSA_KEY_TYPE_IS_RSA( attributes->core.type ) )
572     {
573 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
574     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
575         mbedtls_rsa_context *rsa = NULL;
576         status = mbedtls_psa_rsa_load_representation( attributes->core.type,
577                                                       key_buffer,
578                                                       key_buffer_size,
579                                                       &rsa );
580         if( status != PSA_SUCCESS )
581             goto rsa_exit;
582 
583         if( output_size < mbedtls_rsa_get_len( rsa ) )
584         {
585             status = PSA_ERROR_BUFFER_TOO_SMALL;
586             goto rsa_exit;
587         }
588 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
589         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
590         if( alg == PSA_ALG_RSA_PKCS1V15_CRYPT )
591         {
592 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT)
593             status = mbedtls_to_psa_error(
594                     mbedtls_rsa_pkcs1_encrypt( rsa,
595                                                mbedtls_psa_get_random,
596                                                MBEDTLS_PSA_RANDOM_STATE,
597                                                input_length,
598                                                input,
599                                                output ) );
600 #else
601             status = PSA_ERROR_NOT_SUPPORTED;
602 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT */
603         }
604         else
605         if( PSA_ALG_IS_RSA_OAEP( alg ) )
606         {
607 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
608             status = mbedtls_to_psa_error(
609                          psa_rsa_oaep_set_padding_mode( alg, rsa ) );
610             if( status != PSA_SUCCESS )
611                 goto rsa_exit;
612 
613             status = mbedtls_to_psa_error(
614                 mbedtls_rsa_rsaes_oaep_encrypt( rsa,
615                                                 mbedtls_psa_get_random,
616                                                 MBEDTLS_PSA_RANDOM_STATE,
617                                                 salt, salt_length,
618                                                 input_length,
619                                                 input,
620                                                 output ) );
621 #else
622             status = PSA_ERROR_NOT_SUPPORTED;
623 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP */
624         }
625         else
626         {
627             status = PSA_ERROR_INVALID_ARGUMENT;
628         }
629 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
630     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
631 rsa_exit:
632         if( status == PSA_SUCCESS )
633             *output_length = mbedtls_rsa_get_len( rsa );
634 
635         mbedtls_rsa_free( rsa );
636         mbedtls_free( rsa );
637 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
638         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
639     }
640     else
641     {
642         status = PSA_ERROR_NOT_SUPPORTED;
643     }
644 
645     return status;
646 }
647 
mbedtls_psa_asymmetric_decrypt(const psa_key_attributes_t * attributes,const uint8_t * key_buffer,size_t key_buffer_size,psa_algorithm_t alg,const uint8_t * input,size_t input_length,const uint8_t * salt,size_t salt_length,uint8_t * output,size_t output_size,size_t * output_length)648 psa_status_t mbedtls_psa_asymmetric_decrypt( const psa_key_attributes_t *attributes,
649                                              const uint8_t *key_buffer,
650                                              size_t key_buffer_size,
651                                              psa_algorithm_t alg,
652                                              const uint8_t *input,
653                                              size_t input_length,
654                                              const uint8_t *salt,
655                                              size_t salt_length,
656                                              uint8_t *output,
657                                              size_t output_size,
658                                              size_t *output_length )
659 {
660     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
661     (void) key_buffer;
662     (void) key_buffer_size;
663     (void) input;
664     (void) input_length;
665     (void) salt;
666     (void) salt_length;
667     (void) output;
668     (void) output_size;
669     (void) output_length;
670 
671     *output_length = 0;
672 
673     if( attributes->core.type == PSA_KEY_TYPE_RSA_KEY_PAIR )
674     {
675 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
676     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
677         mbedtls_rsa_context *rsa = NULL;
678         status = mbedtls_psa_rsa_load_representation( attributes->core.type,
679                                                       key_buffer,
680                                                       key_buffer_size,
681                                                       &rsa );
682         if( status != PSA_SUCCESS )
683             goto rsa_exit;
684 
685         if( input_length != mbedtls_rsa_get_len( rsa ) )
686         {
687             status = PSA_ERROR_INVALID_ARGUMENT;
688             goto rsa_exit;
689         }
690 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
691         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
692 
693         if( alg == PSA_ALG_RSA_PKCS1V15_CRYPT )
694         {
695 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT)
696             status = mbedtls_to_psa_error(
697                 mbedtls_rsa_pkcs1_decrypt( rsa,
698                                            mbedtls_psa_get_random,
699                                            MBEDTLS_PSA_RANDOM_STATE,
700                                            output_length,
701                                            input,
702                                            output,
703                                            output_size ) );
704 #else
705             status = PSA_ERROR_NOT_SUPPORTED;
706 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT */
707         }
708         else
709         if( PSA_ALG_IS_RSA_OAEP( alg ) )
710         {
711 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
712             status = mbedtls_to_psa_error(
713                          psa_rsa_oaep_set_padding_mode( alg, rsa ) );
714             if( status != PSA_SUCCESS )
715                 goto rsa_exit;
716 
717             status = mbedtls_to_psa_error(
718                 mbedtls_rsa_rsaes_oaep_decrypt( rsa,
719                                                 mbedtls_psa_get_random,
720                                                 MBEDTLS_PSA_RANDOM_STATE,
721                                                 salt, salt_length,
722                                                 output_length,
723                                                 input,
724                                                 output,
725                                                 output_size ) );
726 #else
727             status = PSA_ERROR_NOT_SUPPORTED;
728 #endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP */
729         }
730         else
731         {
732             status = PSA_ERROR_INVALID_ARGUMENT;
733         }
734 
735 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
736     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
737 rsa_exit:
738         mbedtls_rsa_free( rsa );
739         mbedtls_free( rsa );
740 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
741         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
742     }
743     else
744     {
745         status = PSA_ERROR_NOT_SUPPORTED;
746     }
747 
748     return status;
749 }
750 
751 #endif /* MBEDTLS_PSA_CRYPTO_C */
752