1 /*
2  *  Public Key abstraction layer: wrapper functions
3  *
4  *  Copyright (C) 2006-2015, ARM Limited, All Rights Reserved
5  *  SPDX-License-Identifier: Apache-2.0
6  *
7  *  Licensed under the Apache License, Version 2.0 (the "License"); you may
8  *  not use this file except in compliance with the License.
9  *  You may obtain a copy of the License at
10  *
11  *  http://www.apache.org/licenses/LICENSE-2.0
12  *
13  *  Unless required by applicable law or agreed to in writing, software
14  *  distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
15  *  WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  *  See the License for the specific language governing permissions and
17  *  limitations under the License.
18  *
19  *  This file is part of mbed TLS (https://tls.mbed.org)
20  */
21 
22 #if !defined(MBEDTLS_CONFIG_FILE)
23 #include "mbedtls/config.h"
24 #else
25 #include MBEDTLS_CONFIG_FILE
26 #endif
27 
28 #if defined(MBEDTLS_PK_C)
29 #include "mbedtls/pk_internal.h"
30 
31 /* Even if RSA not activated, for the sake of RSA-alt */
32 #include "mbedtls/rsa.h"
33 
34 #include <string.h>
35 
36 #if defined(MBEDTLS_ECP_C)
37 #include "mbedtls/ecp.h"
38 #endif
39 
40 #if defined(MBEDTLS_ECDSA_C)
41 #include "mbedtls/ecdsa.h"
42 #endif
43 
44 #if defined(MBEDTLS_PLATFORM_C)
45 #include "mbedtls/platform.h"
46 #else
47 #include <stdlib.h>
48 #define mbedtls_calloc    calloc
49 #define mbedtls_free       free
50 #endif
51 
52 #if defined(MBEDTLS_PK_RSA_ALT_SUPPORT)
53 /* Implementation that should never be optimized out by the compiler */
mbedtls_zeroize(void * v,size_t n)54 static void mbedtls_zeroize( void *v, size_t n ) {
55     volatile unsigned char *p = v; while( n-- ) *p++ = 0;
56 }
57 #endif
58 
59 #if defined(MBEDTLS_RSA_C)
rsa_can_do(mbedtls_pk_type_t type)60 static int rsa_can_do( mbedtls_pk_type_t type )
61 {
62     return( type == MBEDTLS_PK_RSA ||
63             type == MBEDTLS_PK_RSASSA_PSS );
64 }
65 
rsa_get_bitlen(const void * ctx)66 static size_t rsa_get_bitlen( const void *ctx )
67 {
68     return( 8 * ((const mbedtls_rsa_context *) ctx)->len );
69 }
70 
rsa_verify_wrap(void * ctx,mbedtls_md_type_t md_alg,const unsigned char * hash,size_t hash_len,const unsigned char * sig,size_t sig_len)71 static int rsa_verify_wrap( void *ctx, mbedtls_md_type_t md_alg,
72                    const unsigned char *hash, size_t hash_len,
73                    const unsigned char *sig, size_t sig_len )
74 {
75     int ret;
76 
77     if( sig_len < ((mbedtls_rsa_context *) ctx)->len )
78         return( MBEDTLS_ERR_RSA_VERIFY_FAILED );
79 
80     if( ( ret = mbedtls_rsa_pkcs1_verify( (mbedtls_rsa_context *) ctx, NULL, NULL,
81                                   MBEDTLS_RSA_PUBLIC, md_alg,
82                                   (unsigned int) hash_len, hash, sig ) ) != 0 )
83         return( ret );
84 
85     if( sig_len > ((mbedtls_rsa_context *) ctx)->len )
86         return( MBEDTLS_ERR_PK_SIG_LEN_MISMATCH );
87 
88     return( 0 );
89 }
90 
rsa_sign_wrap(void * ctx,mbedtls_md_type_t md_alg,const unsigned char * hash,size_t hash_len,unsigned char * sig,size_t * sig_len,int (* f_rng)(void *,unsigned char *,size_t),void * p_rng)91 static int rsa_sign_wrap( void *ctx, mbedtls_md_type_t md_alg,
92                    const unsigned char *hash, size_t hash_len,
93                    unsigned char *sig, size_t *sig_len,
94                    int (*f_rng)(void *, unsigned char *, size_t), void *p_rng )
95 {
96     *sig_len = ((mbedtls_rsa_context *) ctx)->len;
97 
98     return( mbedtls_rsa_pkcs1_sign( (mbedtls_rsa_context *) ctx, f_rng, p_rng, MBEDTLS_RSA_PRIVATE,
99                 md_alg, (unsigned int) hash_len, hash, sig ) );
100 }
101 
rsa_decrypt_wrap(void * ctx,const unsigned char * input,size_t ilen,unsigned char * output,size_t * olen,size_t osize,int (* f_rng)(void *,unsigned char *,size_t),void * p_rng)102 static int rsa_decrypt_wrap( void *ctx,
103                     const unsigned char *input, size_t ilen,
104                     unsigned char *output, size_t *olen, size_t osize,
105                     int (*f_rng)(void *, unsigned char *, size_t), void *p_rng )
106 {
107     if( ilen != ((mbedtls_rsa_context *) ctx)->len )
108         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
109 
110     return( mbedtls_rsa_pkcs1_decrypt( (mbedtls_rsa_context *) ctx, f_rng, p_rng,
111                 MBEDTLS_RSA_PRIVATE, olen, input, output, osize ) );
112 }
113 
rsa_encrypt_wrap(void * ctx,const unsigned char * input,size_t ilen,unsigned char * output,size_t * olen,size_t osize,int (* f_rng)(void *,unsigned char *,size_t),void * p_rng)114 static int rsa_encrypt_wrap( void *ctx,
115                     const unsigned char *input, size_t ilen,
116                     unsigned char *output, size_t *olen, size_t osize,
117                     int (*f_rng)(void *, unsigned char *, size_t), void *p_rng )
118 {
119     *olen = ((mbedtls_rsa_context *) ctx)->len;
120 
121     if( *olen > osize )
122         return( MBEDTLS_ERR_RSA_OUTPUT_TOO_LARGE );
123 
124     return( mbedtls_rsa_pkcs1_encrypt( (mbedtls_rsa_context *) ctx,
125                 f_rng, p_rng, MBEDTLS_RSA_PUBLIC, ilen, input, output ) );
126 }
127 
rsa_check_pair_wrap(const void * pub,const void * prv)128 static int rsa_check_pair_wrap( const void *pub, const void *prv )
129 {
130     return( mbedtls_rsa_check_pub_priv( (const mbedtls_rsa_context *) pub,
131                                 (const mbedtls_rsa_context *) prv ) );
132 }
133 
rsa_alloc_wrap(void)134 static void *rsa_alloc_wrap( void )
135 {
136     void *ctx = mbedtls_calloc( 1, sizeof( mbedtls_rsa_context ) );
137 
138     if( ctx != NULL )
139         mbedtls_rsa_init( (mbedtls_rsa_context *) ctx, 0, 0 );
140 
141     return( ctx );
142 }
143 
rsa_free_wrap(void * ctx)144 static void rsa_free_wrap( void *ctx )
145 {
146     mbedtls_rsa_free( (mbedtls_rsa_context *) ctx );
147     mbedtls_free( ctx );
148 }
149 
rsa_debug(const void * ctx,mbedtls_pk_debug_item * items)150 static void rsa_debug( const void *ctx, mbedtls_pk_debug_item *items )
151 {
152     items->type = MBEDTLS_PK_DEBUG_MPI;
153     items->name = "rsa.N";
154     items->value = &( ((mbedtls_rsa_context *) ctx)->N );
155 
156     items++;
157 
158     items->type = MBEDTLS_PK_DEBUG_MPI;
159     items->name = "rsa.E";
160     items->value = &( ((mbedtls_rsa_context *) ctx)->E );
161 }
162 
163 const mbedtls_pk_info_t mbedtls_rsa_info = {
164     MBEDTLS_PK_RSA,
165     "RSA",
166     rsa_get_bitlen,
167     rsa_can_do,
168     rsa_verify_wrap,
169     rsa_sign_wrap,
170     rsa_decrypt_wrap,
171     rsa_encrypt_wrap,
172     rsa_check_pair_wrap,
173     rsa_alloc_wrap,
174     rsa_free_wrap,
175     rsa_debug,
176 };
177 #endif /* MBEDTLS_RSA_C */
178 
179 #if defined(MBEDTLS_ECP_C)
180 /*
181  * Generic EC key
182  */
eckey_can_do(mbedtls_pk_type_t type)183 static int eckey_can_do( mbedtls_pk_type_t type )
184 {
185     return( type == MBEDTLS_PK_ECKEY ||
186             type == MBEDTLS_PK_ECKEY_DH ||
187             type == MBEDTLS_PK_ECDSA );
188 }
189 
eckey_get_bitlen(const void * ctx)190 static size_t eckey_get_bitlen( const void *ctx )
191 {
192     return( ((mbedtls_ecp_keypair *) ctx)->grp.pbits );
193 }
194 
195 #if defined(MBEDTLS_ECDSA_C)
196 /* Forward declarations */
197 static int ecdsa_verify_wrap( void *ctx, mbedtls_md_type_t md_alg,
198                        const unsigned char *hash, size_t hash_len,
199                        const unsigned char *sig, size_t sig_len );
200 
201 static int ecdsa_sign_wrap( void *ctx, mbedtls_md_type_t md_alg,
202                    const unsigned char *hash, size_t hash_len,
203                    unsigned char *sig, size_t *sig_len,
204                    int (*f_rng)(void *, unsigned char *, size_t), void *p_rng );
205 
eckey_verify_wrap(void * ctx,mbedtls_md_type_t md_alg,const unsigned char * hash,size_t hash_len,const unsigned char * sig,size_t sig_len)206 static int eckey_verify_wrap( void *ctx, mbedtls_md_type_t md_alg,
207                        const unsigned char *hash, size_t hash_len,
208                        const unsigned char *sig, size_t sig_len )
209 {
210     int ret;
211     mbedtls_ecdsa_context ecdsa;
212 
213     mbedtls_ecdsa_init( &ecdsa );
214 
215     if( ( ret = mbedtls_ecdsa_from_keypair( &ecdsa, ctx ) ) == 0 )
216         ret = ecdsa_verify_wrap( &ecdsa, md_alg, hash, hash_len, sig, sig_len );
217 
218     mbedtls_ecdsa_free( &ecdsa );
219 
220     return( ret );
221 }
222 
eckey_sign_wrap(void * ctx,mbedtls_md_type_t md_alg,const unsigned char * hash,size_t hash_len,unsigned char * sig,size_t * sig_len,int (* f_rng)(void *,unsigned char *,size_t),void * p_rng)223 static int eckey_sign_wrap( void *ctx, mbedtls_md_type_t md_alg,
224                    const unsigned char *hash, size_t hash_len,
225                    unsigned char *sig, size_t *sig_len,
226                    int (*f_rng)(void *, unsigned char *, size_t), void *p_rng )
227 {
228     int ret;
229     mbedtls_ecdsa_context ecdsa;
230 
231     mbedtls_ecdsa_init( &ecdsa );
232 
233     if( ( ret = mbedtls_ecdsa_from_keypair( &ecdsa, ctx ) ) == 0 )
234         ret = ecdsa_sign_wrap( &ecdsa, md_alg, hash, hash_len, sig, sig_len,
235                                f_rng, p_rng );
236 
237     mbedtls_ecdsa_free( &ecdsa );
238 
239     return( ret );
240 }
241 
242 #endif /* MBEDTLS_ECDSA_C */
243 
eckey_check_pair(const void * pub,const void * prv)244 static int eckey_check_pair( const void *pub, const void *prv )
245 {
246     return( mbedtls_ecp_check_pub_priv( (const mbedtls_ecp_keypair *) pub,
247                                 (const mbedtls_ecp_keypair *) prv ) );
248 }
249 
eckey_alloc_wrap(void)250 static void *eckey_alloc_wrap( void )
251 {
252     void *ctx = mbedtls_calloc( 1, sizeof( mbedtls_ecp_keypair ) );
253 
254     if( ctx != NULL )
255         mbedtls_ecp_keypair_init( ctx );
256 
257     return( ctx );
258 }
259 
eckey_free_wrap(void * ctx)260 static void eckey_free_wrap( void *ctx )
261 {
262     mbedtls_ecp_keypair_free( (mbedtls_ecp_keypair *) ctx );
263     mbedtls_free( ctx );
264 }
265 
eckey_debug(const void * ctx,mbedtls_pk_debug_item * items)266 static void eckey_debug( const void *ctx, mbedtls_pk_debug_item *items )
267 {
268     items->type = MBEDTLS_PK_DEBUG_ECP;
269     items->name = "eckey.Q";
270     items->value = &( ((mbedtls_ecp_keypair *) ctx)->Q );
271 }
272 
273 const mbedtls_pk_info_t mbedtls_eckey_info = {
274     MBEDTLS_PK_ECKEY,
275     "EC",
276     eckey_get_bitlen,
277     eckey_can_do,
278 #if defined(MBEDTLS_ECDSA_C)
279     eckey_verify_wrap,
280     eckey_sign_wrap,
281 #else
282     NULL,
283     NULL,
284 #endif
285     NULL,
286     NULL,
287     eckey_check_pair,
288     eckey_alloc_wrap,
289     eckey_free_wrap,
290     eckey_debug,
291 };
292 
293 /*
294  * EC key restricted to ECDH
295  */
eckeydh_can_do(mbedtls_pk_type_t type)296 static int eckeydh_can_do( mbedtls_pk_type_t type )
297 {
298     return( type == MBEDTLS_PK_ECKEY ||
299             type == MBEDTLS_PK_ECKEY_DH );
300 }
301 
302 const mbedtls_pk_info_t mbedtls_eckeydh_info = {
303     MBEDTLS_PK_ECKEY_DH,
304     "EC_DH",
305     eckey_get_bitlen,         /* Same underlying key structure */
306     eckeydh_can_do,
307     NULL,
308     NULL,
309     NULL,
310     NULL,
311     eckey_check_pair,
312     eckey_alloc_wrap,       /* Same underlying key structure */
313     eckey_free_wrap,        /* Same underlying key structure */
314     eckey_debug,            /* Same underlying key structure */
315 };
316 #endif /* MBEDTLS_ECP_C */
317 
318 #if defined(MBEDTLS_ECDSA_C)
ecdsa_can_do(mbedtls_pk_type_t type)319 static int ecdsa_can_do( mbedtls_pk_type_t type )
320 {
321     return( type == MBEDTLS_PK_ECDSA );
322 }
323 
ecdsa_verify_wrap(void * ctx,mbedtls_md_type_t md_alg,const unsigned char * hash,size_t hash_len,const unsigned char * sig,size_t sig_len)324 static int ecdsa_verify_wrap( void *ctx, mbedtls_md_type_t md_alg,
325                        const unsigned char *hash, size_t hash_len,
326                        const unsigned char *sig, size_t sig_len )
327 {
328     int ret;
329     ((void) md_alg);
330 
331     ret = mbedtls_ecdsa_read_signature( (mbedtls_ecdsa_context *) ctx,
332                                 hash, hash_len, sig, sig_len );
333 
334     if( ret == MBEDTLS_ERR_ECP_SIG_LEN_MISMATCH )
335         return( MBEDTLS_ERR_PK_SIG_LEN_MISMATCH );
336 
337     return( ret );
338 }
339 
ecdsa_sign_wrap(void * ctx,mbedtls_md_type_t md_alg,const unsigned char * hash,size_t hash_len,unsigned char * sig,size_t * sig_len,int (* f_rng)(void *,unsigned char *,size_t),void * p_rng)340 static int ecdsa_sign_wrap( void *ctx, mbedtls_md_type_t md_alg,
341                    const unsigned char *hash, size_t hash_len,
342                    unsigned char *sig, size_t *sig_len,
343                    int (*f_rng)(void *, unsigned char *, size_t), void *p_rng )
344 {
345     return( mbedtls_ecdsa_write_signature( (mbedtls_ecdsa_context *) ctx,
346                 md_alg, hash, hash_len, sig, sig_len, f_rng, p_rng ) );
347 }
348 
ecdsa_alloc_wrap(void)349 static void *ecdsa_alloc_wrap( void )
350 {
351     void *ctx = mbedtls_calloc( 1, sizeof( mbedtls_ecdsa_context ) );
352 
353     if( ctx != NULL )
354         mbedtls_ecdsa_init( (mbedtls_ecdsa_context *) ctx );
355 
356     return( ctx );
357 }
358 
ecdsa_free_wrap(void * ctx)359 static void ecdsa_free_wrap( void *ctx )
360 {
361     mbedtls_ecdsa_free( (mbedtls_ecdsa_context *) ctx );
362     mbedtls_free( ctx );
363 }
364 
365 const mbedtls_pk_info_t mbedtls_ecdsa_info = {
366     MBEDTLS_PK_ECDSA,
367     "ECDSA",
368     eckey_get_bitlen,     /* Compatible key structures */
369     ecdsa_can_do,
370     ecdsa_verify_wrap,
371     ecdsa_sign_wrap,
372     NULL,
373     NULL,
374     eckey_check_pair,   /* Compatible key structures */
375     ecdsa_alloc_wrap,
376     ecdsa_free_wrap,
377     eckey_debug,        /* Compatible key structures */
378 };
379 #endif /* MBEDTLS_ECDSA_C */
380 
381 #if defined(MBEDTLS_PK_RSA_ALT_SUPPORT)
382 /*
383  * Support for alternative RSA-private implementations
384  */
385 
rsa_alt_can_do(mbedtls_pk_type_t type)386 static int rsa_alt_can_do( mbedtls_pk_type_t type )
387 {
388     return( type == MBEDTLS_PK_RSA );
389 }
390 
rsa_alt_get_bitlen(const void * ctx)391 static size_t rsa_alt_get_bitlen( const void *ctx )
392 {
393     const mbedtls_rsa_alt_context *rsa_alt = (const mbedtls_rsa_alt_context *) ctx;
394 
395     return( 8 * rsa_alt->key_len_func( rsa_alt->key ) );
396 }
397 
rsa_alt_sign_wrap(void * ctx,mbedtls_md_type_t md_alg,const unsigned char * hash,size_t hash_len,unsigned char * sig,size_t * sig_len,int (* f_rng)(void *,unsigned char *,size_t),void * p_rng)398 static int rsa_alt_sign_wrap( void *ctx, mbedtls_md_type_t md_alg,
399                    const unsigned char *hash, size_t hash_len,
400                    unsigned char *sig, size_t *sig_len,
401                    int (*f_rng)(void *, unsigned char *, size_t), void *p_rng )
402 {
403     mbedtls_rsa_alt_context *rsa_alt = (mbedtls_rsa_alt_context *) ctx;
404 
405     *sig_len = rsa_alt->key_len_func( rsa_alt->key );
406 
407     return( rsa_alt->sign_func( rsa_alt->key, f_rng, p_rng, MBEDTLS_RSA_PRIVATE,
408                 md_alg, (unsigned int) hash_len, hash, sig ) );
409 }
410 
rsa_alt_decrypt_wrap(void * ctx,const unsigned char * input,size_t ilen,unsigned char * output,size_t * olen,size_t osize,int (* f_rng)(void *,unsigned char *,size_t),void * p_rng)411 static int rsa_alt_decrypt_wrap( void *ctx,
412                     const unsigned char *input, size_t ilen,
413                     unsigned char *output, size_t *olen, size_t osize,
414                     int (*f_rng)(void *, unsigned char *, size_t), void *p_rng )
415 {
416     mbedtls_rsa_alt_context *rsa_alt = (mbedtls_rsa_alt_context *) ctx;
417 
418     ((void) f_rng);
419     ((void) p_rng);
420 
421     if( ilen != rsa_alt->key_len_func( rsa_alt->key ) )
422         return( MBEDTLS_ERR_RSA_BAD_INPUT_DATA );
423 
424     return( rsa_alt->decrypt_func( rsa_alt->key,
425                 MBEDTLS_RSA_PRIVATE, olen, input, output, osize ) );
426 }
427 
428 #if defined(MBEDTLS_RSA_C)
rsa_alt_check_pair(const void * pub,const void * prv)429 static int rsa_alt_check_pair( const void *pub, const void *prv )
430 {
431     unsigned char sig[MBEDTLS_MPI_MAX_SIZE];
432     unsigned char hash[32];
433     size_t sig_len = 0;
434     int ret;
435 
436     if( rsa_alt_get_bitlen( prv ) != rsa_get_bitlen( pub ) )
437         return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
438 
439     memset( hash, 0x2a, sizeof( hash ) );
440 
441     if( ( ret = rsa_alt_sign_wrap( (void *) prv, MBEDTLS_MD_NONE,
442                                    hash, sizeof( hash ),
443                                    sig, &sig_len, NULL, NULL ) ) != 0 )
444     {
445         return( ret );
446     }
447 
448     if( rsa_verify_wrap( (void *) pub, MBEDTLS_MD_NONE,
449                          hash, sizeof( hash ), sig, sig_len ) != 0 )
450     {
451         return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED );
452     }
453 
454     return( 0 );
455 }
456 #endif /* MBEDTLS_RSA_C */
457 
rsa_alt_alloc_wrap(void)458 static void *rsa_alt_alloc_wrap( void )
459 {
460     void *ctx = mbedtls_calloc( 1, sizeof( mbedtls_rsa_alt_context ) );
461 
462     if( ctx != NULL )
463         memset( ctx, 0, sizeof( mbedtls_rsa_alt_context ) );
464 
465     return( ctx );
466 }
467 
rsa_alt_free_wrap(void * ctx)468 static void rsa_alt_free_wrap( void *ctx )
469 {
470     mbedtls_zeroize( ctx, sizeof( mbedtls_rsa_alt_context ) );
471     mbedtls_free( ctx );
472 }
473 
474 const mbedtls_pk_info_t mbedtls_rsa_alt_info = {
475     MBEDTLS_PK_RSA_ALT,
476     "RSA-alt",
477     rsa_alt_get_bitlen,
478     rsa_alt_can_do,
479     NULL,
480     rsa_alt_sign_wrap,
481     rsa_alt_decrypt_wrap,
482     NULL,
483 #if defined(MBEDTLS_RSA_C)
484     rsa_alt_check_pair,
485 #else
486     NULL,
487 #endif
488     rsa_alt_alloc_wrap,
489     rsa_alt_free_wrap,
490     NULL,
491 };
492 
493 #endif /* MBEDTLS_PK_RSA_ALT_SUPPORT */
494 
495 #endif /* MBEDTLS_PK_C */
496