1 /*
2  * SPDX-FileCopyrightText: 2020-2021 Espressif Systems (Shanghai) CO LTD
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #ifdef ESP_PLATFORM
8 #include "esp_system.h"
9 #endif
10 
11 #include "utils/includes.h"
12 #include "utils/common.h"
13 #include "crypto.h"
14 #include "random.h"
15 #include "sha256.h"
16 
17 #include "mbedtls/ecp.h"
18 #include "mbedtls/entropy.h"
19 #include "mbedtls/ctr_drbg.h"
20 #include "mbedtls/md.h"
21 #include "mbedtls/aes.h"
22 #include "mbedtls/bignum.h"
23 #include "mbedtls/pkcs5.h"
24 #include "mbedtls/cmac.h"
25 #include "mbedtls/nist_kw.h"
26 #include "mbedtls/des.h"
27 #include "mbedtls/ccm.h"
28 #include "mbedtls/arc4.h"
29 
30 #include "common.h"
31 #include "utils/wpabuf.h"
32 #include "dh_group5.h"
33 #include "sha1.h"
34 #include "sha256.h"
35 #include "md5.h"
36 #include "aes_wrap.h"
37 #include "crypto.h"
38 #include "mbedtls/esp_config.h"
39 
digest_vector(mbedtls_md_type_t md_type,size_t num_elem,const u8 * addr[],const size_t * len,u8 * mac)40 static int digest_vector(mbedtls_md_type_t md_type, size_t num_elem,
41 			 const u8 *addr[], const size_t *len, u8 *mac)
42 {
43 	size_t i;
44 	const mbedtls_md_info_t *md_info;
45 	mbedtls_md_context_t md_ctx;
46 	int ret;
47 
48 	mbedtls_md_init(&md_ctx);
49 
50 	md_info = mbedtls_md_info_from_type(md_type);
51 	if (!md_info) {
52 		wpa_printf(MSG_ERROR, "mbedtls_md_info_from_type() failed");
53 		return -1;
54 	}
55 
56 	ret = mbedtls_md_setup(&md_ctx, md_info, 0);
57 	if (ret != 0) {
58 		wpa_printf(MSG_ERROR, "mbedtls_md_setup() returned error");
59 		goto cleanup;
60 	}
61 
62 	ret = mbedtls_md_starts(&md_ctx);
63 	if (ret != 0) {
64 		wpa_printf(MSG_ERROR, "mbedtls_md_starts returned error");
65 		goto cleanup;
66 	}
67 
68 	for (i = 0; i < num_elem; i++) {
69 		ret = mbedtls_md_update(&md_ctx, addr[i], len[i]);
70 		if (ret != 0) {
71 			wpa_printf(MSG_ERROR, "mbedtls_md_update ret=%d", ret);
72 			goto cleanup;
73 		}
74 	}
75 
76 	ret = mbedtls_md_finish(&md_ctx, mac);
77 cleanup:
78 	mbedtls_md_free(&md_ctx);
79 
80 	return ret;
81 
82 }
83 
sha256_vector(size_t num_elem,const u8 * addr[],const size_t * len,u8 * mac)84 int sha256_vector(size_t num_elem, const u8 *addr[], const size_t *len,
85 		  u8 *mac)
86 {
87 	return digest_vector(MBEDTLS_MD_SHA256, num_elem, addr, len, mac);
88 }
89 
sha384_vector(size_t num_elem,const u8 * addr[],const size_t * len,u8 * mac)90 int sha384_vector(size_t num_elem, const u8 *addr[], const size_t *len,
91 		  u8 *mac)
92 {
93 	return digest_vector(MBEDTLS_MD_SHA384, num_elem, addr, len, mac);
94 }
95 
sha1_vector(size_t num_elem,const u8 * addr[],const size_t * len,u8 * mac)96 int sha1_vector(size_t num_elem, const u8 *addr[], const size_t *len, u8 *mac)
97 {
98 	return digest_vector(MBEDTLS_MD_SHA1, num_elem, addr, len, mac);
99 }
100 
md5_vector(size_t num_elem,const u8 * addr[],const size_t * len,u8 * mac)101 int md5_vector(size_t num_elem, const u8 *addr[], const size_t *len, u8 *mac)
102 {
103 	return digest_vector(MBEDTLS_MD_MD5, num_elem, addr, len, mac);
104 }
105 
106 #ifdef MBEDTLS_MD4_C
md4_vector(size_t num_elem,const u8 * addr[],const size_t * len,u8 * mac)107 int md4_vector(size_t num_elem, const u8 *addr[], const size_t *len, u8 *mac)
108 {
109 	return digest_vector(MBEDTLS_MD_MD4, num_elem, addr, len, mac);
110 }
111 #endif
112 
113 struct crypto_hash {
114 	mbedtls_md_context_t ctx;
115 };
116 
crypto_hash_init(enum crypto_hash_alg alg,const u8 * key,size_t key_len)117 struct crypto_hash * crypto_hash_init(enum crypto_hash_alg alg, const u8 *key,
118 				      size_t key_len)
119 {
120 	struct crypto_hash *ctx;
121 	mbedtls_md_type_t md_type;
122 	const mbedtls_md_info_t *md_info;
123 	int ret;
124 
125 	switch (alg) {
126 		case CRYPTO_HASH_ALG_HMAC_MD5:
127 			md_type = MBEDTLS_MD_MD5;
128 			break;
129 		case CRYPTO_HASH_ALG_HMAC_SHA1:
130 			md_type = MBEDTLS_MD_SHA1;
131 			break;
132 		case CRYPTO_HASH_ALG_HMAC_SHA256:
133 			md_type = MBEDTLS_MD_SHA256;
134 			break;
135 		default:
136 			return NULL;
137 	}
138 
139 	ctx = os_zalloc(sizeof(*ctx));
140 	if (ctx == NULL) {
141 		return NULL;
142 	}
143 
144 	mbedtls_md_init(&ctx->ctx);
145 	md_info = mbedtls_md_info_from_type(md_type);
146 	if (!md_info) {
147 		os_free(ctx);
148 		return NULL;
149 	}
150 	ret = mbedtls_md_setup(&ctx->ctx, md_info, 1);
151 	if (ret != 0) {
152 		os_free(ctx);
153 		return NULL;
154 	}
155 	mbedtls_md_hmac_starts(&ctx->ctx, key, key_len);
156 
157 	return ctx;
158 }
159 
crypto_hash_update(struct crypto_hash * ctx,const u8 * data,size_t len)160 void crypto_hash_update(struct crypto_hash *ctx, const u8 *data, size_t len)
161 {
162 	if (ctx == NULL) {
163 		return;
164 	}
165 	mbedtls_md_hmac_update(&ctx->ctx, data, len);
166 }
167 
crypto_hash_finish(struct crypto_hash * ctx,u8 * mac,size_t * len)168 int crypto_hash_finish(struct crypto_hash *ctx, u8 *mac, size_t *len)
169 {
170 	if (ctx == NULL) {
171 		return -2;
172 	}
173 
174 	if (mac == NULL || len == NULL) {
175 		mbedtls_md_free(&ctx->ctx);
176 		bin_clear_free(ctx, sizeof(*ctx));
177 		return 0;
178 	}
179 	mbedtls_md_hmac_finish(&ctx->ctx, mac);
180 	mbedtls_md_free(&ctx->ctx);
181 	bin_clear_free(ctx, sizeof(*ctx));
182 
183 	return 0;
184 }
185 
hmac_vector(mbedtls_md_type_t md_type,const u8 * key,size_t key_len,size_t num_elem,const u8 * addr[],const size_t * len,u8 * mac)186 static int hmac_vector(mbedtls_md_type_t md_type,
187 		       const u8 *key, size_t key_len,
188 		       size_t num_elem, const u8 *addr[],
189 		       const size_t *len, u8 *mac)
190 {
191 	size_t i;
192 	const mbedtls_md_info_t *md_info;
193 	mbedtls_md_context_t md_ctx;
194 	int ret;
195 
196 	mbedtls_md_init(&md_ctx);
197 
198 	md_info = mbedtls_md_info_from_type(md_type);
199 	if (!md_info) {
200 		return -1;
201 	}
202 
203 	ret = mbedtls_md_setup(&md_ctx, md_info, 1);
204 	if (ret != 0) {
205 		return(ret);
206 	}
207 
208 	mbedtls_md_hmac_starts(&md_ctx, key, key_len);
209 
210 	for (i = 0; i < num_elem; i++) {
211 		mbedtls_md_hmac_update(&md_ctx, addr[i], len[i]);
212 	}
213 
214 	mbedtls_md_hmac_finish(&md_ctx, mac);
215 
216 	mbedtls_md_free(&md_ctx);
217 
218 	return 0;
219 }
220 
hmac_sha384_vector(const u8 * key,size_t key_len,size_t num_elem,const u8 * addr[],const size_t * len,u8 * mac)221 int hmac_sha384_vector(const u8 *key, size_t key_len, size_t num_elem,
222 		const u8 *addr[], const size_t *len, u8 *mac)
223 {
224 	return hmac_vector(MBEDTLS_MD_SHA384, key, key_len, num_elem, addr,
225 			   len, mac);
226 }
227 
228 
hmac_sha384(const u8 * key,size_t key_len,const u8 * data,size_t data_len,u8 * mac)229 int hmac_sha384(const u8 *key, size_t key_len, const u8 *data,
230 		size_t data_len, u8 *mac)
231 {
232 	return hmac_sha384_vector(key, key_len, 1, &data, &data_len, mac);
233 }
234 
hmac_sha256_vector(const u8 * key,size_t key_len,size_t num_elem,const u8 * addr[],const size_t * len,u8 * mac)235 int hmac_sha256_vector(const u8 *key, size_t key_len, size_t num_elem,
236 		       const u8 *addr[], const size_t *len, u8 *mac)
237 {
238 	return hmac_vector(MBEDTLS_MD_SHA256, key, key_len, num_elem, addr,
239 			   len, mac);
240 }
241 
hmac_sha256(const u8 * key,size_t key_len,const u8 * data,size_t data_len,u8 * mac)242 int hmac_sha256(const u8 *key, size_t key_len, const u8 *data,
243 		size_t data_len, u8 *mac)
244 {
245 	return hmac_sha256_vector(key, key_len, 1, &data, &data_len, mac);
246 }
247 
hmac_md5_vector(const u8 * key,size_t key_len,size_t num_elem,const u8 * addr[],const size_t * len,u8 * mac)248 int hmac_md5_vector(const u8 *key, size_t key_len, size_t num_elem,
249 		    const u8 *addr[], const size_t *len, u8 *mac)
250 {
251 	return hmac_vector(MBEDTLS_MD_MD5, key, key_len,
252 			   num_elem, addr, len, mac);
253 }
254 
hmac_md5(const u8 * key,size_t key_len,const u8 * data,size_t data_len,u8 * mac)255 int hmac_md5(const u8 *key, size_t key_len, const u8 *data, size_t data_len,
256 	     u8 *mac)
257 {
258 	return hmac_md5_vector(key, key_len, 1, &data, &data_len, mac);
259 }
260 
hmac_sha1_vector(const u8 * key,size_t key_len,size_t num_elem,const u8 * addr[],const size_t * len,u8 * mac)261 int hmac_sha1_vector(const u8 *key, size_t key_len, size_t num_elem,
262 		     const u8 *addr[], const size_t *len, u8 *mac)
263 {
264 	return hmac_vector(MBEDTLS_MD_SHA1, key, key_len, num_elem, addr,
265 			   len, mac);
266 }
267 
hmac_sha1(const u8 * key,size_t key_len,const u8 * data,size_t data_len,u8 * mac)268 int hmac_sha1(const u8 *key, size_t key_len, const u8 *data, size_t data_len,
269 	      u8 *mac)
270 {
271 	return hmac_sha1_vector(key, key_len, 1, &data, &data_len, mac);
272 }
273 
aes_crypt_init(int mode,const u8 * key,size_t len)274 static void *aes_crypt_init(int mode, const u8 *key, size_t len)
275 {
276 	int ret = -1;
277 	mbedtls_aes_context *aes = os_malloc(sizeof(*aes));
278 	if (!aes) {
279 		return NULL;
280 	}
281 	mbedtls_aes_init(aes);
282 
283 	if (mode == MBEDTLS_AES_ENCRYPT) {
284 		ret = mbedtls_aes_setkey_enc(aes, key, len * 8);
285 	} else if (mode == MBEDTLS_AES_DECRYPT){
286 		ret = mbedtls_aes_setkey_dec(aes, key, len * 8);
287 	}
288 	if (ret < 0) {
289 		mbedtls_aes_free(aes);
290 		os_free(aes);
291 		wpa_printf(MSG_ERROR, "%s: mbedtls_aes_setkey_enc/mbedtls_aes_setkey_dec failed", __func__);
292 		return NULL;
293 	}
294 
295 	return (void *) aes;
296 }
297 
aes_crypt(void * ctx,int mode,const u8 * in,u8 * out)298 static int aes_crypt(void *ctx, int mode, const u8 *in, u8 *out)
299 {
300 	return mbedtls_aes_crypt_ecb((mbedtls_aes_context *)ctx,
301 				     mode, in, out);
302 }
303 
aes_crypt_deinit(void * ctx)304 static void aes_crypt_deinit(void *ctx)
305 {
306 	mbedtls_aes_free((mbedtls_aes_context *)ctx);
307 	os_free(ctx);
308 }
309 
aes_encrypt_init(const u8 * key,size_t len)310 void *aes_encrypt_init(const u8 *key, size_t len)
311 {
312 	return aes_crypt_init(MBEDTLS_AES_ENCRYPT, key, len);
313 }
314 
aes_encrypt(void * ctx,const u8 * plain,u8 * crypt)315 int aes_encrypt(void *ctx, const u8 *plain, u8 *crypt)
316 {
317 	return aes_crypt(ctx, MBEDTLS_AES_ENCRYPT, plain, crypt);
318 }
319 
aes_encrypt_deinit(void * ctx)320 void aes_encrypt_deinit(void *ctx)
321 {
322 	return aes_crypt_deinit(ctx);
323 }
324 
aes_decrypt_init(const u8 * key,size_t len)325 void * aes_decrypt_init(const u8 *key, size_t len)
326 {
327 	return aes_crypt_init(MBEDTLS_AES_DECRYPT, key, len);
328 }
329 
aes_decrypt(void * ctx,const u8 * crypt,u8 * plain)330 int aes_decrypt(void *ctx, const u8 *crypt, u8 *plain)
331 {
332 	return aes_crypt(ctx, MBEDTLS_AES_DECRYPT, crypt, plain);
333 }
334 
aes_decrypt_deinit(void * ctx)335 void aes_decrypt_deinit(void *ctx)
336 {
337 	return aes_crypt_deinit(ctx);
338 }
339 
aes_128_cbc_encrypt(const u8 * key,const u8 * iv,u8 * data,size_t data_len)340 int aes_128_cbc_encrypt(const u8 *key, const u8 *iv, u8 *data, size_t data_len)
341 {
342 	int ret = 0;
343 	mbedtls_aes_context ctx;
344 	u8 cbc[MBEDTLS_AES_BLOCK_SIZE];
345 
346 	mbedtls_aes_init(&ctx);
347 
348 	ret = mbedtls_aes_setkey_enc(&ctx, key, 128);
349 	if (ret < 0) {
350 		mbedtls_aes_free(&ctx);
351 		return ret;
352 	}
353 
354 	os_memcpy(cbc, iv, MBEDTLS_AES_BLOCK_SIZE);
355 	ret = mbedtls_aes_crypt_cbc(&ctx, MBEDTLS_AES_ENCRYPT,
356 				    data_len, cbc, data, data);
357 	mbedtls_aes_free(&ctx);
358 
359 	return ret;
360 }
361 
aes_128_cbc_decrypt(const u8 * key,const u8 * iv,u8 * data,size_t data_len)362 int aes_128_cbc_decrypt(const u8 *key, const u8 *iv, u8 *data, size_t data_len)
363 {
364 	int ret = 0;
365 	mbedtls_aes_context ctx;
366 	u8 cbc[MBEDTLS_AES_BLOCK_SIZE];
367 
368 	mbedtls_aes_init(&ctx);
369 
370 	ret = mbedtls_aes_setkey_dec(&ctx, key, 128);
371 	if (ret < 0) {
372 		mbedtls_aes_free(&ctx);
373 		return ret;
374 	}
375 
376 	os_memcpy(cbc, iv, MBEDTLS_AES_BLOCK_SIZE);
377 	ret = mbedtls_aes_crypt_cbc(&ctx, MBEDTLS_AES_DECRYPT,
378 				    data_len, cbc, data, data);
379 	mbedtls_aes_free(&ctx);
380 
381 	return ret;
382 
383 }
384 
385 struct crypto_cipher {
386 	mbedtls_cipher_context_t ctx_enc;
387 	mbedtls_cipher_context_t ctx_dec;
388 };
389 
crypto_init_cipher_ctx(mbedtls_cipher_context_t * ctx,const mbedtls_cipher_info_t * cipher_info,const u8 * iv,const u8 * key,mbedtls_operation_t operation)390 static int crypto_init_cipher_ctx(mbedtls_cipher_context_t *ctx,
391 				  const mbedtls_cipher_info_t *cipher_info,
392 				  const u8 *iv, const u8 *key,
393 				  mbedtls_operation_t operation)
394 {
395 	mbedtls_cipher_init(ctx);
396 	int ret;
397 
398 	ret = mbedtls_cipher_setup(ctx, cipher_info);
399 	if (ret != 0) {
400 		return -1;
401 	}
402 
403 	if (mbedtls_cipher_setkey(ctx, key, cipher_info->key_bitlen,
404 				 operation) != 0) {
405 		wpa_printf(MSG_ERROR, "mbedtls_cipher_setkey returned error");
406 		return -1;
407 	}
408 	if (mbedtls_cipher_set_iv(ctx, iv, cipher_info->iv_size) != 0) {
409 		wpa_printf(MSG_ERROR, "mbedtls_cipher_set_iv returned error");
410 		return -1;
411 	}
412 	if (mbedtls_cipher_reset(ctx) != 0) {
413 		wpa_printf(MSG_ERROR, "mbedtls_cipher_reset() returned error");
414 		return -1;
415 	}
416 
417 	return 0;
418 }
419 
alg_to_mbedtls_cipher(enum crypto_cipher_alg alg,size_t key_len)420 static mbedtls_cipher_type_t alg_to_mbedtls_cipher(enum crypto_cipher_alg alg,
421 						   size_t key_len)
422 {
423 	switch (alg) {
424 #ifdef MBEDTLS_ARC4_C
425 	case CRYPTO_CIPHER_ALG_RC4:
426 		return MBEDTLS_CIPHER_ARC4_128;
427 #endif
428 	case CRYPTO_CIPHER_ALG_AES:
429 		if (key_len == 16) {
430 			return MBEDTLS_CIPHER_AES_128_CBC;
431 		}
432 		if (key_len == 24) {
433 			return MBEDTLS_CIPHER_AES_192_CBC;
434 		}
435 		if (key_len == 32) {
436 			return MBEDTLS_CIPHER_AES_256_CBC;
437 		}
438 		break;
439 #ifdef MBEDTLS_DES_C
440 	case CRYPTO_CIPHER_ALG_3DES:
441 		return MBEDTLS_CIPHER_DES_EDE3_CBC;
442 	case CRYPTO_CIPHER_ALG_DES:
443 		return MBEDTLS_CIPHER_DES_CBC;
444 #endif
445 	default:
446 		break;
447 	}
448 
449 	return MBEDTLS_CIPHER_NONE;
450 }
451 
crypto_cipher_init(enum crypto_cipher_alg alg,const u8 * iv,const u8 * key,size_t key_len)452 struct crypto_cipher *crypto_cipher_init(enum crypto_cipher_alg alg,
453 					 const u8 *iv, const u8 *key,
454 					 size_t key_len)
455 {
456 	struct crypto_cipher *ctx;
457 	mbedtls_cipher_type_t cipher_type;
458 	const mbedtls_cipher_info_t *cipher_info;
459 
460 	ctx = (struct crypto_cipher *)os_zalloc(sizeof(*ctx));
461 	if (!ctx) {
462 		return NULL;
463 	}
464 
465 	cipher_type = alg_to_mbedtls_cipher(alg, key_len);
466 	if (cipher_type == MBEDTLS_CIPHER_NONE) {
467 		goto cleanup;
468 	}
469 
470 	cipher_info = mbedtls_cipher_info_from_type(cipher_type);
471 	if (cipher_info == NULL) {
472 		goto cleanup;
473 	}
474 
475 	/* Init both ctx encryption/decryption */
476 	if (crypto_init_cipher_ctx(&ctx->ctx_enc, cipher_info, iv, key,
477 				   MBEDTLS_ENCRYPT) < 0) {
478 		goto cleanup;
479 	}
480 
481 	if (crypto_init_cipher_ctx(&ctx->ctx_dec, cipher_info, iv, key,
482 				   MBEDTLS_DECRYPT) < 0) {
483 		goto cleanup;
484 	}
485 
486 	return ctx;
487 
488 cleanup:
489 	os_free(ctx);
490 	return NULL;
491 }
492 
493 #if 0
494 int crypto_cipher_encrypt(struct crypto_cipher *ctx, const u8 *plain,
495 			  u8 *crypt, size_t len)
496 {
497 	int ret;
498 	size_t olen = 1200;
499 
500 	ret = mbedtls_cipher_update(&ctx->ctx_enc, plain, len, crypt, &olen);
501 	if (ret != 0) {
502 		return -1;
503 	}
504 
505 	ret = mbedtls_cipher_finish(&ctx->ctx_enc, crypt + olen, &olen);
506 	if (ret != 0) {
507 		return -1;
508 	}
509 
510 	return 0;
511 }
512 
513 int crypto_cipher_decrypt(struct crypto_cipher *ctx, const u8 *crypt,
514 			  u8 *plain, size_t len)
515 {
516 	int ret;
517 	size_t olen = 1200;
518 
519 	ret = mbedtls_cipher_update(&ctx->ctx_dec, crypt, len, plain, &olen);
520 	if (ret != 0) {
521 		return -1;
522 	}
523 
524 	ret = mbedtls_cipher_finish(&ctx->ctx_dec, plain + olen, &olen);
525 	if (ret != 0) {
526 		return -1;
527 	}
528 
529 	return 0;
530 }
531 #endif
532 
crypto_cipher_deinit(struct crypto_cipher * ctx)533 void crypto_cipher_deinit(struct crypto_cipher *ctx)
534 {
535 	mbedtls_cipher_free(&ctx->ctx_enc);
536 	mbedtls_cipher_free(&ctx->ctx_dec);
537 	os_free(ctx);
538 }
539 
aes_ctr_encrypt(const u8 * key,size_t key_len,const u8 * nonce,u8 * data,size_t data_len)540 int aes_ctr_encrypt(const u8 *key, size_t key_len, const u8 *nonce,
541 		    u8 *data, size_t data_len)
542 {
543 	int ret;
544 	mbedtls_aes_context ctx;
545 	uint8_t stream_block[MBEDTLS_AES_BLOCK_SIZE];
546 	size_t offset = 0;
547 
548 	mbedtls_aes_init(&ctx);
549 	ret = mbedtls_aes_setkey_enc(&ctx, key, key_len * 8);
550 	if (ret < 0) {
551 		goto cleanup;
552 	}
553 	ret = mbedtls_aes_crypt_ctr(&ctx, data_len, &offset, (u8 *)nonce,
554 				    stream_block, data, data);
555 cleanup:
556 	mbedtls_aes_free(&ctx);
557 	return ret;
558 }
559 
aes_128_ctr_encrypt(const u8 * key,const u8 * nonce,u8 * data,size_t data_len)560 int aes_128_ctr_encrypt(const u8 *key, const u8 *nonce,
561 			u8 *data, size_t data_len)
562 {
563 	return aes_ctr_encrypt(key, 16, nonce, data, data_len);
564 }
565 
566 
567 #ifdef MBEDTLS_NIST_KW_C
aes_wrap(const u8 * kek,size_t kek_len,int n,const u8 * plain,u8 * cipher)568 int aes_wrap(const u8 *kek, size_t kek_len, int n, const u8 *plain, u8 *cipher)
569 {
570 	mbedtls_nist_kw_context ctx;
571 	size_t olen;
572 	int ret = 0;
573 	mbedtls_nist_kw_init(&ctx);
574 
575 	ret = mbedtls_nist_kw_setkey(&ctx, MBEDTLS_CIPHER_ID_AES,
576 			kek, kek_len * 8, 1);
577 	if (ret != 0) {
578 		return ret;
579 	}
580 
581 	ret = mbedtls_nist_kw_wrap(&ctx, MBEDTLS_KW_MODE_KW, plain,
582 			n * 8, cipher, &olen, (n + 1) * 8);
583 
584 	mbedtls_nist_kw_free(&ctx);
585 	return ret;
586 }
587 
aes_unwrap(const u8 * kek,size_t kek_len,int n,const u8 * cipher,u8 * plain)588 int aes_unwrap(const u8 *kek, size_t kek_len, int n, const u8 *cipher,
589 	       u8 *plain)
590 {
591 	mbedtls_nist_kw_context ctx;
592 	size_t olen;
593 	int ret = 0;
594 	mbedtls_nist_kw_init(&ctx);
595 
596 	ret = mbedtls_nist_kw_setkey(&ctx, MBEDTLS_CIPHER_ID_AES,
597 			kek, kek_len * 8, 0);
598 	if (ret != 0) {
599 		return ret;
600 	}
601 
602 	ret = mbedtls_nist_kw_unwrap(&ctx, MBEDTLS_KW_MODE_KW, cipher,
603 			(n + 1) * 8, plain, &olen, (n * 8));
604 
605 	mbedtls_nist_kw_free(&ctx);
606 	return ret;
607 }
608 #endif
609 
crypto_mod_exp(const uint8_t * base,size_t base_len,const uint8_t * power,size_t power_len,const uint8_t * modulus,size_t modulus_len,uint8_t * result,size_t * result_len)610 int crypto_mod_exp(const uint8_t *base, size_t base_len,
611 		   const uint8_t *power, size_t power_len,
612 		   const uint8_t *modulus, size_t modulus_len,
613 		   uint8_t *result, size_t *result_len)
614 {
615 	mbedtls_mpi bn_base, bn_exp, bn_modulus, bn_result, bn_rinv;
616 	int ret = 0;
617 
618 	mbedtls_mpi_init(&bn_base);
619 	mbedtls_mpi_init(&bn_exp);
620 	mbedtls_mpi_init(&bn_modulus);
621 	mbedtls_mpi_init(&bn_result);
622 	mbedtls_mpi_init(&bn_rinv);
623 
624 	MBEDTLS_MPI_CHK(mbedtls_mpi_read_binary(&bn_base, base, base_len));
625 	MBEDTLS_MPI_CHK(mbedtls_mpi_read_binary(&bn_exp, power, power_len));
626 	MBEDTLS_MPI_CHK(mbedtls_mpi_read_binary(&bn_modulus, modulus, modulus_len));
627 
628 	MBEDTLS_MPI_CHK(mbedtls_mpi_exp_mod(&bn_result, &bn_base, &bn_exp, &bn_modulus,
629 					    &bn_rinv));
630 
631 	ret = mbedtls_mpi_write_binary(&bn_result, result, *result_len);
632 
633 cleanup:
634 	mbedtls_mpi_free(&bn_base);
635 	mbedtls_mpi_free(&bn_exp);
636 	mbedtls_mpi_free(&bn_modulus);
637 	mbedtls_mpi_free(&bn_result);
638 	mbedtls_mpi_free(&bn_rinv);
639 
640 	return ret;
641 }
642 
pbkdf2_sha1(const char * passphrase,const u8 * ssid,size_t ssid_len,int iterations,u8 * buf,size_t buflen)643 int pbkdf2_sha1(const char *passphrase, const u8 *ssid, size_t ssid_len,
644 		int iterations, u8 *buf, size_t buflen)
645 {
646 
647 	mbedtls_md_context_t sha1_ctx;
648 	const mbedtls_md_info_t *info_sha1;
649 	int ret;
650 
651 	mbedtls_md_init(&sha1_ctx);
652 
653 	info_sha1 = mbedtls_md_info_from_type(MBEDTLS_MD_SHA1);
654 	if (info_sha1 == NULL) {
655 		ret = -1;
656 		goto cleanup;
657 	}
658 
659 	if ((ret = mbedtls_md_setup(&sha1_ctx, info_sha1, 1)) != 0) {
660 		ret = -1;
661 		goto cleanup;
662 	}
663 
664 	ret = mbedtls_pkcs5_pbkdf2_hmac(&sha1_ctx, (const u8 *) passphrase,
665 					os_strlen(passphrase) , ssid,
666 					ssid_len, iterations, 32, buf);
667 	if (ret != 0) {
668 		ret = -1;
669 		goto cleanup;
670 	}
671 
672 cleanup:
673 	mbedtls_md_free(&sha1_ctx);
674 	return ret;
675 }
676 
677 #ifdef MBEDTLS_DES_C
des_encrypt(const u8 * clear,const u8 * key,u8 * cypher)678 int des_encrypt(const u8 *clear, const u8 *key, u8 *cypher)
679 {
680 	int ret;
681 	mbedtls_des_context des;
682 	u8 pkey[8], next, tmp;
683 	int i;
684 
685 	/* Add parity bits to the key */
686 	next = 0;
687 	for (i = 0; i < 7; i++) {
688 		tmp = key[i];
689 		pkey[i] = (tmp >> i) | next | 1;
690 		next = tmp << (7 - i);
691 	}
692 	pkey[i] = next | 1;
693 
694 	mbedtls_des_init(&des);
695 	ret = mbedtls_des_setkey_enc(&des, pkey);
696 	if (ret < 0) {
697 		return ret;
698 	}
699 	ret = mbedtls_des_crypt_ecb(&des, clear, cypher);
700 	mbedtls_des_free(&des);
701 
702 	return ret;
703 }
704 #endif
705 
706 /* Only enable this if all other ciphers are using MbedTLS implementation */
707 #if defined(MBEDTLS_CCM_C) && defined(MBEDTLS_CMAC_C) && defined(MBEDTLS_NIST_KW_C)
aes_ccm_ae(const u8 * key,size_t key_len,const u8 * nonce,size_t M,const u8 * plain,size_t plain_len,const u8 * aad,size_t aad_len,u8 * crypt,u8 * auth)708 int aes_ccm_ae(const u8 *key, size_t key_len, const u8 *nonce,
709 	       size_t M, const u8 *plain, size_t plain_len,
710 	       const u8 *aad, size_t aad_len, u8 *crypt, u8 *auth)
711 {
712 	int ret;
713 	mbedtls_ccm_context ccm;
714 
715 	mbedtls_ccm_init(&ccm);
716 
717 	ret = mbedtls_ccm_setkey(&ccm, MBEDTLS_CIPHER_ID_AES,
718 				 key, key_len * 8);
719 	if (ret < 0) {
720 		wpa_printf(MSG_ERROR, "mbedtls_ccm_setkey failed");
721 		goto cleanup;
722 	}
723 
724 	ret = mbedtls_ccm_encrypt_and_tag(&ccm, plain_len, nonce, 13, aad,
725 					  aad_len, plain, crypt, auth, M);
726 
727 cleanup:
728 	mbedtls_ccm_free(&ccm);
729 
730 	return ret;
731 }
732 
aes_ccm_ad(const u8 * key,size_t key_len,const u8 * nonce,size_t M,const u8 * crypt,size_t crypt_len,const u8 * aad,size_t aad_len,const u8 * auth,u8 * plain)733 int aes_ccm_ad(const u8 *key, size_t key_len, const u8 *nonce,
734 	       size_t M, const u8 *crypt, size_t crypt_len,
735 	       const u8 *aad, size_t aad_len, const u8 *auth,
736 	       u8 *plain)
737 {
738 	int ret;
739 	mbedtls_ccm_context ccm;
740 
741 	mbedtls_ccm_init(&ccm);
742 
743 	ret = mbedtls_ccm_setkey(&ccm, MBEDTLS_CIPHER_ID_AES,
744 				 key, key_len * 8);
745 	if (ret < 0) {
746 		goto cleanup;;
747 	}
748 
749 	ret = mbedtls_ccm_star_auth_decrypt(&ccm, crypt_len,
750 					    nonce, 13, aad, aad_len,
751 					    crypt, plain, auth, M);
752 
753 cleanup:
754 	mbedtls_ccm_free(&ccm);
755 
756 	return ret;
757 }
758 #endif
759 
760 #ifdef MBEDTLS_ARC4_C
rc4_skip(const u8 * key,size_t keylen,size_t skip,u8 * data,size_t data_len)761 int rc4_skip(const u8 *key, size_t keylen, size_t skip,
762              u8 *data, size_t data_len)
763 {
764 	int ret;
765 	unsigned char skip_buf_in[16];
766 	unsigned char skip_buf_out[16];
767 	mbedtls_arc4_context ctx;
768 	unsigned char *obuf = os_malloc(data_len);
769 
770 	if (!obuf) {
771 		wpa_printf(MSG_ERROR, "%s:memory allocation failed", __func__);
772 		return -1;
773 	}
774 	mbedtls_arc4_init(&ctx);
775 	mbedtls_arc4_setup(&ctx, key, keylen);
776 	while (skip >= sizeof(skip_buf_in)) {
777 		size_t len = skip;
778 		if (len > sizeof(skip_buf_in)) {
779 			len = sizeof(skip_buf_in);
780 		}
781 		if ((ret = mbedtls_arc4_crypt(&ctx, len, skip_buf_in,
782 					      skip_buf_out)) != 0) {
783 			wpa_printf(MSG_ERROR, "rc4 encryption failed");
784 			return -1;
785 		}
786 		os_memcpy(skip_buf_in, skip_buf_out, 16);
787 		skip -= len;
788 	}
789 
790 	mbedtls_arc4_crypt(&ctx, data_len, data, obuf);
791 
792 	memcpy(data, obuf, data_len);
793 	os_free(obuf);
794 
795 	return 0;
796 }
797 #endif
798 
799 #ifdef MBEDTLS_CMAC_C
omac1_aes_vector(const u8 * key,size_t key_len,size_t num_elem,const u8 * addr[],const size_t * len,u8 * mac)800 int omac1_aes_vector(const u8 *key, size_t key_len, size_t num_elem,
801 		     const u8 *addr[], const size_t *len, u8 *mac)
802 {
803 	const mbedtls_cipher_info_t *cipher_info;
804 	int i, ret = 0;
805 	mbedtls_cipher_type_t cipher_type;
806 	mbedtls_cipher_context_t ctx;
807 
808 	switch (key_len) {
809 	case 16:
810 		cipher_type = MBEDTLS_CIPHER_AES_128_ECB;
811 		break;
812 	case 24:
813 		cipher_type = MBEDTLS_CIPHER_AES_192_ECB;
814 		break;
815 	case 32:
816 		cipher_type = MBEDTLS_CIPHER_AES_256_ECB;
817 		break;
818 	default:
819 		cipher_type = MBEDTLS_CIPHER_NONE;
820 		break;
821 	}
822 	cipher_info = mbedtls_cipher_info_from_type(cipher_type);
823 	if (cipher_info == NULL) {
824 		/* Failing at this point must be due to a build issue */
825 		ret = MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE;
826 		goto cleanup;
827 	}
828 
829 	if (key == NULL ||  mac == NULL) {
830 		return -1;
831 	}
832 
833 	mbedtls_cipher_init(&ctx);
834 
835 	ret = mbedtls_cipher_setup(&ctx, cipher_info);
836 	if (ret != 0) {
837 		goto cleanup;
838 	}
839 
840 	ret = mbedtls_cipher_cmac_starts(&ctx, key, key_len * 8);
841 	if (ret != 0) {
842 		goto cleanup;
843 	}
844 
845 	for (i = 0 ; i < num_elem; i++) {
846 		ret = mbedtls_cipher_cmac_update(&ctx, addr[i], len[i]);
847 		if (ret != 0) {
848 			goto cleanup;
849 		}
850 	}
851 
852 	ret = mbedtls_cipher_cmac_finish(&ctx, mac);
853 cleanup:
854 	mbedtls_cipher_free(&ctx);
855 	return(ret);
856 }
857 
omac1_aes_128_vector(const u8 * key,size_t num_elem,const u8 * addr[],const size_t * len,u8 * mac)858 int omac1_aes_128_vector(const u8 *key, size_t num_elem,
859 			 const u8 *addr[], const size_t *len, u8 *mac)
860 {
861 	return omac1_aes_vector(key, 16, num_elem, addr, len, mac);
862 }
863 
omac1_aes_128(const u8 * key,const u8 * data,size_t data_len,u8 * mac)864 int omac1_aes_128(const u8 *key, const u8 *data, size_t data_len, u8 *mac)
865 {
866 	return omac1_aes_128_vector(key, 1, &data, &data_len, mac);
867 }
868 #endif
869 
crypto_dh_init(u8 generator,const u8 * prime,size_t prime_len,u8 * privkey,u8 * pubkey)870 int crypto_dh_init(u8 generator, const u8 *prime, size_t prime_len, u8 *privkey,
871 		   u8 *pubkey)
872 {
873 	size_t pubkey_len, pad;
874 
875 	if (os_get_random(privkey, prime_len) < 0) {
876 		return -1;
877 	}
878 	if (os_memcmp(privkey, prime, prime_len) > 0) {
879 		/* Make sure private value is smaller than prime */
880 		privkey[0] = 0;
881 	}
882 
883 	pubkey_len = prime_len;
884 	if (crypto_mod_exp(&generator, 1, privkey, prime_len, prime, prime_len,
885 				pubkey, &pubkey_len) < 0) {
886 		return -1;
887 	}
888 	if (pubkey_len < prime_len) {
889 		pad = prime_len - pubkey_len;
890 		os_memmove(pubkey + pad, pubkey, pubkey_len);
891 		os_memset(pubkey, 0, pad);
892 	}
893 
894 	return 0;
895 }
896