1 /*
2  * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3  * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4  * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
5  * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
6  * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
7  *
8  * This software is available to you under a choice of one of two
9  * licenses.  You may choose to be licensed under the terms of the GNU
10  * General Public License (GPL) Version 2, available from the file
11  * COPYING in the main directory of this source tree, or the
12  * OpenIB.org BSD license below:
13  *
14  *     Redistribution and use in source and binary forms, with or
15  *     without modification, are permitted provided that the following
16  *     conditions are met:
17  *
18  *      - Redistributions of source code must retain the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer.
21  *
22  *      - Redistributions in binary form must reproduce the above
23  *        copyright notice, this list of conditions and the following
24  *        disclaimer in the documentation and/or other materials
25  *        provided with the distribution.
26  *
27  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
28  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
29  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
30  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
31  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
32  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
33  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
34  * SOFTWARE.
35  */
36 
37 #include <linux/sched/signal.h>
38 #include <linux/module.h>
39 #include <crypto/aead.h>
40 
41 #include <net/strparser.h>
42 #include <net/tls.h>
43 
44 #define MAX_IV_SIZE	TLS_CIPHER_AES_GCM_128_IV_SIZE
45 
tls_do_decryption(struct sock * sk,struct scatterlist * sgin,struct scatterlist * sgout,char * iv_recv,size_t data_len,struct aead_request * aead_req)46 static int tls_do_decryption(struct sock *sk,
47 			     struct scatterlist *sgin,
48 			     struct scatterlist *sgout,
49 			     char *iv_recv,
50 			     size_t data_len,
51 			     struct aead_request *aead_req)
52 {
53 	struct tls_context *tls_ctx = tls_get_ctx(sk);
54 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
55 	int ret;
56 
57 	aead_request_set_tfm(aead_req, ctx->aead_recv);
58 	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
59 	aead_request_set_crypt(aead_req, sgin, sgout,
60 			       data_len + tls_ctx->rx.tag_size,
61 			       (u8 *)iv_recv);
62 	aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
63 				  crypto_req_done, &ctx->async_wait);
64 
65 	ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);
66 	return ret;
67 }
68 
trim_sg(struct sock * sk,struct scatterlist * sg,int * sg_num_elem,unsigned int * sg_size,int target_size)69 static void trim_sg(struct sock *sk, struct scatterlist *sg,
70 		    int *sg_num_elem, unsigned int *sg_size, int target_size)
71 {
72 	int i = *sg_num_elem - 1;
73 	int trim = *sg_size - target_size;
74 
75 	if (trim <= 0) {
76 		WARN_ON(trim < 0);
77 		return;
78 	}
79 
80 	*sg_size = target_size;
81 	while (trim >= sg[i].length) {
82 		trim -= sg[i].length;
83 		sk_mem_uncharge(sk, sg[i].length);
84 		put_page(sg_page(&sg[i]));
85 		i--;
86 
87 		if (i < 0)
88 			goto out;
89 	}
90 
91 	sg[i].length -= trim;
92 	sk_mem_uncharge(sk, trim);
93 
94 out:
95 	*sg_num_elem = i + 1;
96 }
97 
trim_both_sgl(struct sock * sk,int target_size)98 static void trim_both_sgl(struct sock *sk, int target_size)
99 {
100 	struct tls_context *tls_ctx = tls_get_ctx(sk);
101 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
102 
103 	trim_sg(sk, ctx->sg_plaintext_data,
104 		&ctx->sg_plaintext_num_elem,
105 		&ctx->sg_plaintext_size,
106 		target_size);
107 
108 	if (target_size > 0)
109 		target_size += tls_ctx->tx.overhead_size;
110 
111 	trim_sg(sk, ctx->sg_encrypted_data,
112 		&ctx->sg_encrypted_num_elem,
113 		&ctx->sg_encrypted_size,
114 		target_size);
115 }
116 
alloc_encrypted_sg(struct sock * sk,int len)117 static int alloc_encrypted_sg(struct sock *sk, int len)
118 {
119 	struct tls_context *tls_ctx = tls_get_ctx(sk);
120 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
121 	int rc = 0;
122 
123 	rc = sk_alloc_sg(sk, len,
124 			 ctx->sg_encrypted_data, 0,
125 			 &ctx->sg_encrypted_num_elem,
126 			 &ctx->sg_encrypted_size, 0);
127 
128 	if (rc == -ENOSPC)
129 		ctx->sg_encrypted_num_elem = ARRAY_SIZE(ctx->sg_encrypted_data);
130 
131 	return rc;
132 }
133 
alloc_plaintext_sg(struct sock * sk,int len)134 static int alloc_plaintext_sg(struct sock *sk, int len)
135 {
136 	struct tls_context *tls_ctx = tls_get_ctx(sk);
137 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
138 	int rc = 0;
139 
140 	rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0,
141 			 &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size,
142 			 tls_ctx->pending_open_record_frags);
143 
144 	if (rc == -ENOSPC)
145 		ctx->sg_plaintext_num_elem = ARRAY_SIZE(ctx->sg_plaintext_data);
146 
147 	return rc;
148 }
149 
free_sg(struct sock * sk,struct scatterlist * sg,int * sg_num_elem,unsigned int * sg_size)150 static void free_sg(struct sock *sk, struct scatterlist *sg,
151 		    int *sg_num_elem, unsigned int *sg_size)
152 {
153 	int i, n = *sg_num_elem;
154 
155 	for (i = 0; i < n; ++i) {
156 		sk_mem_uncharge(sk, sg[i].length);
157 		put_page(sg_page(&sg[i]));
158 	}
159 	*sg_num_elem = 0;
160 	*sg_size = 0;
161 }
162 
tls_free_both_sg(struct sock * sk)163 static void tls_free_both_sg(struct sock *sk)
164 {
165 	struct tls_context *tls_ctx = tls_get_ctx(sk);
166 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
167 
168 	free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem,
169 		&ctx->sg_encrypted_size);
170 
171 	free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
172 		&ctx->sg_plaintext_size);
173 }
174 
tls_do_encryption(struct tls_context * tls_ctx,struct tls_sw_context_tx * ctx,struct aead_request * aead_req,size_t data_len)175 static int tls_do_encryption(struct tls_context *tls_ctx,
176 			     struct tls_sw_context_tx *ctx,
177 			     struct aead_request *aead_req,
178 			     size_t data_len)
179 {
180 	int rc;
181 
182 	ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size;
183 	ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size;
184 
185 	aead_request_set_tfm(aead_req, ctx->aead_send);
186 	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
187 	aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out,
188 			       data_len, tls_ctx->tx.iv);
189 
190 	aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
191 				  crypto_req_done, &ctx->async_wait);
192 
193 	rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait);
194 
195 	ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
196 	ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;
197 
198 	return rc;
199 }
200 
tls_push_record(struct sock * sk,int flags,unsigned char record_type)201 static int tls_push_record(struct sock *sk, int flags,
202 			   unsigned char record_type)
203 {
204 	struct tls_context *tls_ctx = tls_get_ctx(sk);
205 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
206 	struct aead_request *req;
207 	int rc;
208 
209 	req = aead_request_alloc(ctx->aead_send, sk->sk_allocation);
210 	if (!req)
211 		return -ENOMEM;
212 
213 	sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1);
214 	sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1);
215 
216 	tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size,
217 		     tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
218 		     record_type);
219 
220 	tls_fill_prepend(tls_ctx,
221 			 page_address(sg_page(&ctx->sg_encrypted_data[0])) +
222 			 ctx->sg_encrypted_data[0].offset,
223 			 ctx->sg_plaintext_size, record_type);
224 
225 	tls_ctx->pending_open_record_frags = 0;
226 	set_bit(TLS_PENDING_CLOSED_RECORD, &tls_ctx->flags);
227 
228 	rc = tls_do_encryption(tls_ctx, ctx, req, ctx->sg_plaintext_size);
229 	if (rc < 0) {
230 		/* If we are called from write_space and
231 		 * we fail, we need to set this SOCK_NOSPACE
232 		 * to trigger another write_space in the future.
233 		 */
234 		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
235 		goto out_req;
236 	}
237 
238 	free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
239 		&ctx->sg_plaintext_size);
240 
241 	ctx->sg_encrypted_num_elem = 0;
242 	ctx->sg_encrypted_size = 0;
243 
244 	/* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */
245 	rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags);
246 	if (rc < 0 && rc != -EAGAIN)
247 		tls_err_abort(sk, EBADMSG);
248 
249 	tls_advance_record_sn(sk, &tls_ctx->tx);
250 out_req:
251 	aead_request_free(req);
252 	return rc;
253 }
254 
tls_sw_push_pending_record(struct sock * sk,int flags)255 static int tls_sw_push_pending_record(struct sock *sk, int flags)
256 {
257 	return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
258 }
259 
zerocopy_from_iter(struct sock * sk,struct iov_iter * from,int length,int * pages_used,unsigned int * size_used,struct scatterlist * to,int to_max_pages,bool charge)260 static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
261 			      int length, int *pages_used,
262 			      unsigned int *size_used,
263 			      struct scatterlist *to, int to_max_pages,
264 			      bool charge)
265 {
266 	struct page *pages[MAX_SKB_FRAGS];
267 
268 	size_t offset;
269 	ssize_t copied, use;
270 	int i = 0;
271 	unsigned int size = *size_used;
272 	int num_elem = *pages_used;
273 	int rc = 0;
274 	int maxpages;
275 
276 	while (length > 0) {
277 		i = 0;
278 		maxpages = to_max_pages - num_elem;
279 		if (maxpages == 0) {
280 			rc = -EFAULT;
281 			goto out;
282 		}
283 		copied = iov_iter_get_pages(from, pages,
284 					    length,
285 					    maxpages, &offset);
286 		if (copied <= 0) {
287 			rc = -EFAULT;
288 			goto out;
289 		}
290 
291 		iov_iter_advance(from, copied);
292 
293 		length -= copied;
294 		size += copied;
295 		while (copied) {
296 			use = min_t(int, copied, PAGE_SIZE - offset);
297 
298 			sg_set_page(&to[num_elem],
299 				    pages[i], use, offset);
300 			sg_unmark_end(&to[num_elem]);
301 			if (charge)
302 				sk_mem_charge(sk, use);
303 
304 			offset = 0;
305 			copied -= use;
306 
307 			++i;
308 			++num_elem;
309 		}
310 	}
311 
312 	/* Mark the end in the last sg entry if newly added */
313 	if (num_elem > *pages_used)
314 		sg_mark_end(&to[num_elem - 1]);
315 out:
316 	if (rc)
317 		iov_iter_revert(from, size - *size_used);
318 	*size_used = size;
319 	*pages_used = num_elem;
320 
321 	return rc;
322 }
323 
memcopy_from_iter(struct sock * sk,struct iov_iter * from,int bytes)324 static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,
325 			     int bytes)
326 {
327 	struct tls_context *tls_ctx = tls_get_ctx(sk);
328 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
329 	struct scatterlist *sg = ctx->sg_plaintext_data;
330 	int copy, i, rc = 0;
331 
332 	for (i = tls_ctx->pending_open_record_frags;
333 	     i < ctx->sg_plaintext_num_elem; ++i) {
334 		copy = sg[i].length;
335 		if (copy_from_iter(
336 				page_address(sg_page(&sg[i])) + sg[i].offset,
337 				copy, from) != copy) {
338 			rc = -EFAULT;
339 			goto out;
340 		}
341 		bytes -= copy;
342 
343 		++tls_ctx->pending_open_record_frags;
344 
345 		if (!bytes)
346 			break;
347 	}
348 
349 out:
350 	return rc;
351 }
352 
tls_sw_sendmsg(struct sock * sk,struct msghdr * msg,size_t size)353 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
354 {
355 	struct tls_context *tls_ctx = tls_get_ctx(sk);
356 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
357 	int ret = 0;
358 	int required_size;
359 	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
360 	bool eor = !(msg->msg_flags & MSG_MORE);
361 	size_t try_to_copy, copied = 0;
362 	unsigned char record_type = TLS_RECORD_TYPE_DATA;
363 	int record_room;
364 	bool full_record;
365 	int orig_size;
366 	bool is_kvec = msg->msg_iter.type & ITER_KVEC;
367 
368 	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
369 		return -ENOTSUPP;
370 
371 	lock_sock(sk);
372 
373 	if (tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo))
374 		goto send_end;
375 
376 	if (unlikely(msg->msg_controllen)) {
377 		ret = tls_proccess_cmsg(sk, msg, &record_type);
378 		if (ret)
379 			goto send_end;
380 	}
381 
382 	while (msg_data_left(msg)) {
383 		if (sk->sk_err) {
384 			ret = -sk->sk_err;
385 			goto send_end;
386 		}
387 
388 		orig_size = ctx->sg_plaintext_size;
389 		full_record = false;
390 		try_to_copy = msg_data_left(msg);
391 		record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
392 		if (try_to_copy >= record_room) {
393 			try_to_copy = record_room;
394 			full_record = true;
395 		}
396 
397 		required_size = ctx->sg_plaintext_size + try_to_copy +
398 				tls_ctx->tx.overhead_size;
399 
400 		if (!sk_stream_memory_free(sk))
401 			goto wait_for_sndbuf;
402 alloc_encrypted:
403 		ret = alloc_encrypted_sg(sk, required_size);
404 		if (ret) {
405 			if (ret != -ENOSPC)
406 				goto wait_for_memory;
407 
408 			/* Adjust try_to_copy according to the amount that was
409 			 * actually allocated. The difference is due
410 			 * to max sg elements limit
411 			 */
412 			try_to_copy -= required_size - ctx->sg_encrypted_size;
413 			full_record = true;
414 		}
415 		if (!is_kvec && (full_record || eor)) {
416 			ret = zerocopy_from_iter(sk, &msg->msg_iter,
417 				try_to_copy, &ctx->sg_plaintext_num_elem,
418 				&ctx->sg_plaintext_size,
419 				ctx->sg_plaintext_data,
420 				ARRAY_SIZE(ctx->sg_plaintext_data),
421 				true);
422 			if (ret)
423 				goto fallback_to_reg_send;
424 
425 			copied += try_to_copy;
426 			ret = tls_push_record(sk, msg->msg_flags, record_type);
427 			if (ret)
428 				goto send_end;
429 			continue;
430 
431 fallback_to_reg_send:
432 			trim_sg(sk, ctx->sg_plaintext_data,
433 				&ctx->sg_plaintext_num_elem,
434 				&ctx->sg_plaintext_size,
435 				orig_size);
436 		}
437 
438 		required_size = ctx->sg_plaintext_size + try_to_copy;
439 alloc_plaintext:
440 		ret = alloc_plaintext_sg(sk, required_size);
441 		if (ret) {
442 			if (ret != -ENOSPC)
443 				goto wait_for_memory;
444 
445 			/* Adjust try_to_copy according to the amount that was
446 			 * actually allocated. The difference is due
447 			 * to max sg elements limit
448 			 */
449 			try_to_copy -= required_size - ctx->sg_plaintext_size;
450 			full_record = true;
451 
452 			trim_sg(sk, ctx->sg_encrypted_data,
453 				&ctx->sg_encrypted_num_elem,
454 				&ctx->sg_encrypted_size,
455 				ctx->sg_plaintext_size +
456 				tls_ctx->tx.overhead_size);
457 		}
458 
459 		ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
460 		if (ret)
461 			goto trim_sgl;
462 
463 		copied += try_to_copy;
464 		if (full_record || eor) {
465 push_record:
466 			ret = tls_push_record(sk, msg->msg_flags, record_type);
467 			if (ret) {
468 				if (ret == -ENOMEM)
469 					goto wait_for_memory;
470 
471 				goto send_end;
472 			}
473 		}
474 
475 		continue;
476 
477 wait_for_sndbuf:
478 		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
479 wait_for_memory:
480 		ret = sk_stream_wait_memory(sk, &timeo);
481 		if (ret) {
482 trim_sgl:
483 			trim_both_sgl(sk, orig_size);
484 			goto send_end;
485 		}
486 
487 		if (tls_is_pending_closed_record(tls_ctx))
488 			goto push_record;
489 
490 		if (ctx->sg_encrypted_size < required_size)
491 			goto alloc_encrypted;
492 
493 		goto alloc_plaintext;
494 	}
495 
496 send_end:
497 	ret = sk_stream_error(sk, msg->msg_flags, ret);
498 
499 	release_sock(sk);
500 	return copied ? copied : ret;
501 }
502 
tls_sw_sendpage(struct sock * sk,struct page * page,int offset,size_t size,int flags)503 int tls_sw_sendpage(struct sock *sk, struct page *page,
504 		    int offset, size_t size, int flags)
505 {
506 	struct tls_context *tls_ctx = tls_get_ctx(sk);
507 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
508 	int ret = 0;
509 	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
510 	bool eor;
511 	size_t orig_size = size;
512 	unsigned char record_type = TLS_RECORD_TYPE_DATA;
513 	struct scatterlist *sg;
514 	bool full_record;
515 	int record_room;
516 
517 	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
518 		      MSG_SENDPAGE_NOTLAST))
519 		return -ENOTSUPP;
520 
521 	/* No MSG_EOR from splice, only look at MSG_MORE */
522 	eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
523 
524 	lock_sock(sk);
525 
526 	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
527 
528 	if (tls_complete_pending_work(sk, tls_ctx, flags, &timeo))
529 		goto sendpage_end;
530 
531 	/* Call the sk_stream functions to manage the sndbuf mem. */
532 	while (size > 0) {
533 		size_t copy, required_size;
534 
535 		if (sk->sk_err) {
536 			ret = -sk->sk_err;
537 			goto sendpage_end;
538 		}
539 
540 		full_record = false;
541 		record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
542 		copy = size;
543 		if (copy >= record_room) {
544 			copy = record_room;
545 			full_record = true;
546 		}
547 		required_size = ctx->sg_plaintext_size + copy +
548 			      tls_ctx->tx.overhead_size;
549 
550 		if (!sk_stream_memory_free(sk))
551 			goto wait_for_sndbuf;
552 alloc_payload:
553 		ret = alloc_encrypted_sg(sk, required_size);
554 		if (ret) {
555 			if (ret != -ENOSPC)
556 				goto wait_for_memory;
557 
558 			/* Adjust copy according to the amount that was
559 			 * actually allocated. The difference is due
560 			 * to max sg elements limit
561 			 */
562 			copy -= required_size - ctx->sg_plaintext_size;
563 			full_record = true;
564 		}
565 
566 		get_page(page);
567 		sg = ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem;
568 		sg_set_page(sg, page, copy, offset);
569 		sg_unmark_end(sg);
570 
571 		ctx->sg_plaintext_num_elem++;
572 
573 		sk_mem_charge(sk, copy);
574 		offset += copy;
575 		size -= copy;
576 		ctx->sg_plaintext_size += copy;
577 		tls_ctx->pending_open_record_frags = ctx->sg_plaintext_num_elem;
578 
579 		if (full_record || eor ||
580 		    ctx->sg_plaintext_num_elem ==
581 		    ARRAY_SIZE(ctx->sg_plaintext_data)) {
582 push_record:
583 			ret = tls_push_record(sk, flags, record_type);
584 			if (ret) {
585 				if (ret == -ENOMEM)
586 					goto wait_for_memory;
587 
588 				goto sendpage_end;
589 			}
590 		}
591 		continue;
592 wait_for_sndbuf:
593 		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
594 wait_for_memory:
595 		ret = sk_stream_wait_memory(sk, &timeo);
596 		if (ret) {
597 			trim_both_sgl(sk, ctx->sg_plaintext_size);
598 			goto sendpage_end;
599 		}
600 
601 		if (tls_is_pending_closed_record(tls_ctx))
602 			goto push_record;
603 
604 		goto alloc_payload;
605 	}
606 
607 sendpage_end:
608 	if (orig_size > size)
609 		ret = orig_size - size;
610 	else
611 		ret = sk_stream_error(sk, flags, ret);
612 
613 	release_sock(sk);
614 	return ret;
615 }
616 
tls_wait_data(struct sock * sk,int flags,long timeo,int * err)617 static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
618 				     long timeo, int *err)
619 {
620 	struct tls_context *tls_ctx = tls_get_ctx(sk);
621 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
622 	struct sk_buff *skb;
623 	DEFINE_WAIT_FUNC(wait, woken_wake_function);
624 
625 	while (!(skb = ctx->recv_pkt)) {
626 		if (sk->sk_err) {
627 			*err = sock_error(sk);
628 			return NULL;
629 		}
630 
631 		if (sk->sk_shutdown & RCV_SHUTDOWN)
632 			return NULL;
633 
634 		if (sock_flag(sk, SOCK_DONE))
635 			return NULL;
636 
637 		if ((flags & MSG_DONTWAIT) || !timeo) {
638 			*err = -EAGAIN;
639 			return NULL;
640 		}
641 
642 		add_wait_queue(sk_sleep(sk), &wait);
643 		sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
644 		sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
645 		sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
646 		remove_wait_queue(sk_sleep(sk), &wait);
647 
648 		/* Handle signals */
649 		if (signal_pending(current)) {
650 			*err = sock_intr_errno(timeo);
651 			return NULL;
652 		}
653 	}
654 
655 	return skb;
656 }
657 
658 /* This function decrypts the input skb into either out_iov or in out_sg
659  * or in skb buffers itself. The input parameter 'zc' indicates if
660  * zero-copy mode needs to be tried or not. With zero-copy mode, either
661  * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
662  * NULL, then the decryption happens inside skb buffers itself, i.e.
663  * zero-copy gets disabled and 'zc' is updated.
664  */
665 
decrypt_internal(struct sock * sk,struct sk_buff * skb,struct iov_iter * out_iov,struct scatterlist * out_sg,int * chunk,bool * zc)666 static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
667 			    struct iov_iter *out_iov,
668 			    struct scatterlist *out_sg,
669 			    int *chunk, bool *zc)
670 {
671 	struct tls_context *tls_ctx = tls_get_ctx(sk);
672 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
673 	struct strp_msg *rxm = strp_msg(skb);
674 	int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
675 	struct aead_request *aead_req;
676 	struct sk_buff *unused;
677 	u8 *aad, *iv, *mem = NULL;
678 	struct scatterlist *sgin = NULL;
679 	struct scatterlist *sgout = NULL;
680 	const int data_len = rxm->full_len - tls_ctx->rx.overhead_size;
681 
682 	if (*zc && (out_iov || out_sg)) {
683 		if (out_iov)
684 			n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
685 		else
686 			n_sgout = sg_nents(out_sg);
687 	} else {
688 		n_sgout = 0;
689 		*zc = false;
690 	}
691 
692 	n_sgin = skb_cow_data(skb, 0, &unused);
693 	if (n_sgin < 1)
694 		return -EBADMSG;
695 
696 	/* Increment to accommodate AAD */
697 	n_sgin = n_sgin + 1;
698 
699 	nsg = n_sgin + n_sgout;
700 
701 	aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
702 	mem_size = aead_size + (nsg * sizeof(struct scatterlist));
703 	mem_size = mem_size + TLS_AAD_SPACE_SIZE;
704 	mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
705 
706 	/* Allocate a single block of memory which contains
707 	 * aead_req || sgin[] || sgout[] || aad || iv.
708 	 * This order achieves correct alignment for aead_req, sgin, sgout.
709 	 */
710 	mem = kmalloc(mem_size, sk->sk_allocation);
711 	if (!mem)
712 		return -ENOMEM;
713 
714 	/* Segment the allocated memory */
715 	aead_req = (struct aead_request *)mem;
716 	sgin = (struct scatterlist *)(mem + aead_size);
717 	sgout = sgin + n_sgin;
718 	aad = (u8 *)(sgout + n_sgout);
719 	iv = aad + TLS_AAD_SPACE_SIZE;
720 
721 	/* Prepare IV */
722 	err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
723 			    iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
724 			    tls_ctx->rx.iv_size);
725 	if (err < 0) {
726 		kfree(mem);
727 		return err;
728 	}
729 	memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
730 
731 	/* Prepare AAD */
732 	tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size,
733 		     tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size,
734 		     ctx->control);
735 
736 	/* Prepare sgin */
737 	sg_init_table(sgin, n_sgin);
738 	sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE);
739 	err = skb_to_sgvec(skb, &sgin[1],
740 			   rxm->offset + tls_ctx->rx.prepend_size,
741 			   rxm->full_len - tls_ctx->rx.prepend_size);
742 	if (err < 0) {
743 		kfree(mem);
744 		return err;
745 	}
746 
747 	if (n_sgout) {
748 		if (out_iov) {
749 			sg_init_table(sgout, n_sgout);
750 			sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE);
751 
752 			*chunk = 0;
753 			err = zerocopy_from_iter(sk, out_iov, data_len, &pages,
754 						 chunk, &sgout[1],
755 						 (n_sgout - 1), false);
756 			if (err < 0)
757 				goto fallback_to_reg_recv;
758 		} else if (out_sg) {
759 			memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
760 		} else {
761 			goto fallback_to_reg_recv;
762 		}
763 	} else {
764 fallback_to_reg_recv:
765 		sgout = sgin;
766 		pages = 0;
767 		*chunk = 0;
768 		*zc = false;
769 	}
770 
771 	/* Prepare and submit AEAD request */
772 	err = tls_do_decryption(sk, sgin, sgout, iv, data_len, aead_req);
773 
774 	/* Release the pages in case iov was mapped to pages */
775 	for (; pages > 0; pages--)
776 		put_page(sg_page(&sgout[pages]));
777 
778 	kfree(mem);
779 	return err;
780 }
781 
decrypt_skb_update(struct sock * sk,struct sk_buff * skb,struct iov_iter * dest,int * chunk,bool * zc)782 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
783 			      struct iov_iter *dest, int *chunk, bool *zc)
784 {
785 	struct tls_context *tls_ctx = tls_get_ctx(sk);
786 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
787 	struct strp_msg *rxm = strp_msg(skb);
788 	int err = 0;
789 
790 #ifdef CONFIG_TLS_DEVICE
791 	err = tls_device_decrypted(sk, skb);
792 	if (err < 0)
793 		return err;
794 #endif
795 	if (!ctx->decrypted) {
796 		err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
797 		if (err < 0)
798 			return err;
799 	} else {
800 		*zc = false;
801 	}
802 
803 	rxm->offset += tls_ctx->rx.prepend_size;
804 	rxm->full_len -= tls_ctx->rx.overhead_size;
805 	tls_advance_record_sn(sk, &tls_ctx->rx);
806 	ctx->decrypted = true;
807 	ctx->saved_data_ready(sk);
808 
809 	return err;
810 }
811 
decrypt_skb(struct sock * sk,struct sk_buff * skb,struct scatterlist * sgout)812 int decrypt_skb(struct sock *sk, struct sk_buff *skb,
813 		struct scatterlist *sgout)
814 {
815 	bool zc = true;
816 	int chunk;
817 
818 	return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
819 }
820 
tls_sw_advance_skb(struct sock * sk,struct sk_buff * skb,unsigned int len)821 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
822 			       unsigned int len)
823 {
824 	struct tls_context *tls_ctx = tls_get_ctx(sk);
825 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
826 	struct strp_msg *rxm = strp_msg(skb);
827 
828 	if (len < rxm->full_len) {
829 		rxm->offset += len;
830 		rxm->full_len -= len;
831 
832 		return false;
833 	}
834 
835 	/* Finished with message */
836 	ctx->recv_pkt = NULL;
837 	kfree_skb(skb);
838 	__strp_unpause(&ctx->strp);
839 
840 	return true;
841 }
842 
tls_sw_recvmsg(struct sock * sk,struct msghdr * msg,size_t len,int nonblock,int flags,int * addr_len)843 int tls_sw_recvmsg(struct sock *sk,
844 		   struct msghdr *msg,
845 		   size_t len,
846 		   int nonblock,
847 		   int flags,
848 		   int *addr_len)
849 {
850 	struct tls_context *tls_ctx = tls_get_ctx(sk);
851 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
852 	unsigned char control;
853 	struct strp_msg *rxm;
854 	struct sk_buff *skb;
855 	ssize_t copied = 0;
856 	bool cmsg = false;
857 	int target, err = 0;
858 	long timeo;
859 	bool is_kvec = msg->msg_iter.type & ITER_KVEC;
860 
861 	flags |= nonblock;
862 
863 	if (unlikely(flags & MSG_ERRQUEUE))
864 		return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
865 
866 	lock_sock(sk);
867 
868 	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
869 	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
870 	do {
871 		bool zc = false;
872 		int chunk = 0;
873 
874 		skb = tls_wait_data(sk, flags, timeo, &err);
875 		if (!skb)
876 			goto recv_end;
877 
878 		rxm = strp_msg(skb);
879 		if (!cmsg) {
880 			int cerr;
881 
882 			cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
883 					sizeof(ctx->control), &ctx->control);
884 			cmsg = true;
885 			control = ctx->control;
886 			if (ctx->control != TLS_RECORD_TYPE_DATA) {
887 				if (cerr || msg->msg_flags & MSG_CTRUNC) {
888 					err = -EIO;
889 					goto recv_end;
890 				}
891 			}
892 		} else if (control != ctx->control) {
893 			goto recv_end;
894 		}
895 
896 		if (!ctx->decrypted) {
897 			int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
898 
899 			if (!is_kvec && to_copy <= len &&
900 			    likely(!(flags & MSG_PEEK)))
901 				zc = true;
902 
903 			err = decrypt_skb_update(sk, skb, &msg->msg_iter,
904 						 &chunk, &zc);
905 			if (err < 0) {
906 				tls_err_abort(sk, EBADMSG);
907 				goto recv_end;
908 			}
909 			ctx->decrypted = true;
910 		}
911 
912 		if (!zc) {
913 			chunk = min_t(unsigned int, rxm->full_len, len);
914 			err = skb_copy_datagram_msg(skb, rxm->offset, msg,
915 						    chunk);
916 			if (err < 0)
917 				goto recv_end;
918 		}
919 
920 		copied += chunk;
921 		len -= chunk;
922 		if (likely(!(flags & MSG_PEEK))) {
923 			u8 control = ctx->control;
924 
925 			if (tls_sw_advance_skb(sk, skb, chunk)) {
926 				/* Return full control message to
927 				 * userspace before trying to parse
928 				 * another message type
929 				 */
930 				msg->msg_flags |= MSG_EOR;
931 				if (control != TLS_RECORD_TYPE_DATA)
932 					goto recv_end;
933 			}
934 		} else {
935 			/* MSG_PEEK right now cannot look beyond current skb
936 			 * from strparser, meaning we cannot advance skb here
937 			 * and thus unpause strparser since we'd loose original
938 			 * one.
939 			 */
940 			break;
941 		}
942 
943 		/* If we have a new message from strparser, continue now. */
944 		if (copied >= target && !ctx->recv_pkt)
945 			break;
946 	} while (len);
947 
948 recv_end:
949 	release_sock(sk);
950 	return copied ? : err;
951 }
952 
tls_sw_splice_read(struct socket * sock,loff_t * ppos,struct pipe_inode_info * pipe,size_t len,unsigned int flags)953 ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
954 			   struct pipe_inode_info *pipe,
955 			   size_t len, unsigned int flags)
956 {
957 	struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
958 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
959 	struct strp_msg *rxm = NULL;
960 	struct sock *sk = sock->sk;
961 	struct sk_buff *skb;
962 	ssize_t copied = 0;
963 	int err = 0;
964 	long timeo;
965 	int chunk;
966 	bool zc = false;
967 
968 	lock_sock(sk);
969 
970 	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
971 
972 	skb = tls_wait_data(sk, flags, timeo, &err);
973 	if (!skb)
974 		goto splice_read_end;
975 
976 	/* splice does not support reading control messages */
977 	if (ctx->control != TLS_RECORD_TYPE_DATA) {
978 		err = -ENOTSUPP;
979 		goto splice_read_end;
980 	}
981 
982 	if (!ctx->decrypted) {
983 		err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
984 
985 		if (err < 0) {
986 			tls_err_abort(sk, EBADMSG);
987 			goto splice_read_end;
988 		}
989 		ctx->decrypted = true;
990 	}
991 	rxm = strp_msg(skb);
992 
993 	chunk = min_t(unsigned int, rxm->full_len, len);
994 	copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
995 	if (copied < 0)
996 		goto splice_read_end;
997 
998 	if (likely(!(flags & MSG_PEEK)))
999 		tls_sw_advance_skb(sk, skb, copied);
1000 
1001 splice_read_end:
1002 	release_sock(sk);
1003 	return copied ? : err;
1004 }
1005 
tls_sw_poll(struct file * file,struct socket * sock,struct poll_table_struct * wait)1006 unsigned int tls_sw_poll(struct file *file, struct socket *sock,
1007 			 struct poll_table_struct *wait)
1008 {
1009 	unsigned int ret;
1010 	struct sock *sk = sock->sk;
1011 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1012 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1013 
1014 	/* Grab POLLOUT and POLLHUP from the underlying socket */
1015 	ret = ctx->sk_poll(file, sock, wait);
1016 
1017 	/* Clear POLLIN bits, and set based on recv_pkt */
1018 	ret &= ~(POLLIN | POLLRDNORM);
1019 	if (ctx->recv_pkt)
1020 		ret |= POLLIN | POLLRDNORM;
1021 
1022 	return ret;
1023 }
1024 
tls_read_size(struct strparser * strp,struct sk_buff * skb)1025 static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
1026 {
1027 	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
1028 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1029 	char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
1030 	struct strp_msg *rxm = strp_msg(skb);
1031 	size_t cipher_overhead;
1032 	size_t data_len = 0;
1033 	int ret;
1034 
1035 	/* Verify that we have a full TLS header, or wait for more data */
1036 	if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
1037 		return 0;
1038 
1039 	/* Sanity-check size of on-stack buffer. */
1040 	if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) {
1041 		ret = -EINVAL;
1042 		goto read_failure;
1043 	}
1044 
1045 	/* Linearize header to local buffer */
1046 	ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);
1047 
1048 	if (ret < 0)
1049 		goto read_failure;
1050 
1051 	ctx->control = header[0];
1052 
1053 	data_len = ((header[4] & 0xFF) | (header[3] << 8));
1054 
1055 	cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size;
1056 
1057 	if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) {
1058 		ret = -EMSGSIZE;
1059 		goto read_failure;
1060 	}
1061 	if (data_len < cipher_overhead) {
1062 		ret = -EBADMSG;
1063 		goto read_failure;
1064 	}
1065 
1066 	if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.info.version) ||
1067 	    header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.info.version)) {
1068 		ret = -EINVAL;
1069 		goto read_failure;
1070 	}
1071 
1072 #ifdef CONFIG_TLS_DEVICE
1073 	handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset,
1074 			     *(u64*)tls_ctx->rx.rec_seq);
1075 #endif
1076 	return data_len + TLS_HEADER_SIZE;
1077 
1078 read_failure:
1079 	tls_err_abort(strp->sk, ret);
1080 
1081 	return ret;
1082 }
1083 
tls_queue(struct strparser * strp,struct sk_buff * skb)1084 static void tls_queue(struct strparser *strp, struct sk_buff *skb)
1085 {
1086 	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
1087 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1088 
1089 	ctx->decrypted = false;
1090 
1091 	ctx->recv_pkt = skb;
1092 	strp_pause(strp);
1093 
1094 	ctx->saved_data_ready(strp->sk);
1095 }
1096 
tls_data_ready(struct sock * sk)1097 static void tls_data_ready(struct sock *sk)
1098 {
1099 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1100 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1101 
1102 	strp_data_ready(&ctx->strp);
1103 }
1104 
tls_sw_free_resources_tx(struct sock * sk)1105 void tls_sw_free_resources_tx(struct sock *sk)
1106 {
1107 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1108 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1109 
1110 	crypto_free_aead(ctx->aead_send);
1111 	tls_free_both_sg(sk);
1112 
1113 	kfree(ctx);
1114 }
1115 
tls_sw_release_resources_rx(struct sock * sk)1116 void tls_sw_release_resources_rx(struct sock *sk)
1117 {
1118 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1119 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1120 
1121 	if (ctx->aead_recv) {
1122 		kfree_skb(ctx->recv_pkt);
1123 		ctx->recv_pkt = NULL;
1124 		crypto_free_aead(ctx->aead_recv);
1125 		strp_stop(&ctx->strp);
1126 		write_lock_bh(&sk->sk_callback_lock);
1127 		sk->sk_data_ready = ctx->saved_data_ready;
1128 		write_unlock_bh(&sk->sk_callback_lock);
1129 		release_sock(sk);
1130 		strp_done(&ctx->strp);
1131 		lock_sock(sk);
1132 	}
1133 }
1134 
tls_sw_free_resources_rx(struct sock * sk)1135 void tls_sw_free_resources_rx(struct sock *sk)
1136 {
1137 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1138 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1139 
1140 	tls_sw_release_resources_rx(sk);
1141 
1142 	kfree(ctx);
1143 }
1144 
tls_set_sw_offload(struct sock * sk,struct tls_context * ctx,int tx)1145 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
1146 {
1147 	struct tls_crypto_info *crypto_info;
1148 	struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
1149 	struct tls_sw_context_tx *sw_ctx_tx = NULL;
1150 	struct tls_sw_context_rx *sw_ctx_rx = NULL;
1151 	struct cipher_context *cctx;
1152 	struct crypto_aead **aead;
1153 	struct strp_callbacks cb;
1154 	u16 nonce_size, tag_size, iv_size, rec_seq_size;
1155 	char *iv, *rec_seq;
1156 	int rc = 0;
1157 
1158 	if (!ctx) {
1159 		rc = -EINVAL;
1160 		goto out;
1161 	}
1162 
1163 	if (tx) {
1164 		if (!ctx->priv_ctx_tx) {
1165 			sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
1166 			if (!sw_ctx_tx) {
1167 				rc = -ENOMEM;
1168 				goto out;
1169 			}
1170 			ctx->priv_ctx_tx = sw_ctx_tx;
1171 		} else {
1172 			sw_ctx_tx =
1173 				(struct tls_sw_context_tx *)ctx->priv_ctx_tx;
1174 		}
1175 	} else {
1176 		if (!ctx->priv_ctx_rx) {
1177 			sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
1178 			if (!sw_ctx_rx) {
1179 				rc = -ENOMEM;
1180 				goto out;
1181 			}
1182 			ctx->priv_ctx_rx = sw_ctx_rx;
1183 		} else {
1184 			sw_ctx_rx =
1185 				(struct tls_sw_context_rx *)ctx->priv_ctx_rx;
1186 		}
1187 	}
1188 
1189 	if (tx) {
1190 		crypto_init_wait(&sw_ctx_tx->async_wait);
1191 		crypto_info = &ctx->crypto_send.info;
1192 		cctx = &ctx->tx;
1193 		aead = &sw_ctx_tx->aead_send;
1194 	} else {
1195 		crypto_init_wait(&sw_ctx_rx->async_wait);
1196 		crypto_info = &ctx->crypto_recv.info;
1197 		cctx = &ctx->rx;
1198 		aead = &sw_ctx_rx->aead_recv;
1199 	}
1200 
1201 	switch (crypto_info->cipher_type) {
1202 	case TLS_CIPHER_AES_GCM_128: {
1203 		nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1204 		tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
1205 		iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1206 		iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
1207 		rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
1208 		rec_seq =
1209 		 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
1210 		gcm_128_info =
1211 			(struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
1212 		break;
1213 	}
1214 	default:
1215 		rc = -EINVAL;
1216 		goto free_priv;
1217 	}
1218 
1219 	/* Sanity-check the IV size for stack allocations. */
1220 	if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) {
1221 		rc = -EINVAL;
1222 		goto free_priv;
1223 	}
1224 
1225 	cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
1226 	cctx->tag_size = tag_size;
1227 	cctx->overhead_size = cctx->prepend_size + cctx->tag_size;
1228 	cctx->iv_size = iv_size;
1229 	cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
1230 			   GFP_KERNEL);
1231 	if (!cctx->iv) {
1232 		rc = -ENOMEM;
1233 		goto free_priv;
1234 	}
1235 	memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
1236 	memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
1237 	cctx->rec_seq_size = rec_seq_size;
1238 	cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
1239 	if (!cctx->rec_seq) {
1240 		rc = -ENOMEM;
1241 		goto free_iv;
1242 	}
1243 
1244 	if (sw_ctx_tx) {
1245 		sg_init_table(sw_ctx_tx->sg_encrypted_data,
1246 			      ARRAY_SIZE(sw_ctx_tx->sg_encrypted_data));
1247 		sg_init_table(sw_ctx_tx->sg_plaintext_data,
1248 			      ARRAY_SIZE(sw_ctx_tx->sg_plaintext_data));
1249 
1250 		sg_init_table(sw_ctx_tx->sg_aead_in, 2);
1251 		sg_set_buf(&sw_ctx_tx->sg_aead_in[0], sw_ctx_tx->aad_space,
1252 			   sizeof(sw_ctx_tx->aad_space));
1253 		sg_unmark_end(&sw_ctx_tx->sg_aead_in[1]);
1254 		sg_chain(sw_ctx_tx->sg_aead_in, 2,
1255 			 sw_ctx_tx->sg_plaintext_data);
1256 		sg_init_table(sw_ctx_tx->sg_aead_out, 2);
1257 		sg_set_buf(&sw_ctx_tx->sg_aead_out[0], sw_ctx_tx->aad_space,
1258 			   sizeof(sw_ctx_tx->aad_space));
1259 		sg_unmark_end(&sw_ctx_tx->sg_aead_out[1]);
1260 		sg_chain(sw_ctx_tx->sg_aead_out, 2,
1261 			 sw_ctx_tx->sg_encrypted_data);
1262 	}
1263 
1264 	if (!*aead) {
1265 		*aead = crypto_alloc_aead("gcm(aes)", 0, 0);
1266 		if (IS_ERR(*aead)) {
1267 			rc = PTR_ERR(*aead);
1268 			*aead = NULL;
1269 			goto free_rec_seq;
1270 		}
1271 	}
1272 
1273 	ctx->push_pending_record = tls_sw_push_pending_record;
1274 
1275 	rc = crypto_aead_setkey(*aead, gcm_128_info->key,
1276 				TLS_CIPHER_AES_GCM_128_KEY_SIZE);
1277 	if (rc)
1278 		goto free_aead;
1279 
1280 	rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
1281 	if (rc)
1282 		goto free_aead;
1283 
1284 	if (sw_ctx_rx) {
1285 		/* Set up strparser */
1286 		memset(&cb, 0, sizeof(cb));
1287 		cb.rcv_msg = tls_queue;
1288 		cb.parse_msg = tls_read_size;
1289 
1290 		strp_init(&sw_ctx_rx->strp, sk, &cb);
1291 
1292 		write_lock_bh(&sk->sk_callback_lock);
1293 		sw_ctx_rx->saved_data_ready = sk->sk_data_ready;
1294 		sk->sk_data_ready = tls_data_ready;
1295 		write_unlock_bh(&sk->sk_callback_lock);
1296 
1297 		sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll;
1298 
1299 		strp_check_rcv(&sw_ctx_rx->strp);
1300 	}
1301 
1302 	goto out;
1303 
1304 free_aead:
1305 	crypto_free_aead(*aead);
1306 	*aead = NULL;
1307 free_rec_seq:
1308 	kfree(cctx->rec_seq);
1309 	cctx->rec_seq = NULL;
1310 free_iv:
1311 	kfree(cctx->iv);
1312 	cctx->iv = NULL;
1313 free_priv:
1314 	if (tx) {
1315 		kfree(ctx->priv_ctx_tx);
1316 		ctx->priv_ctx_tx = NULL;
1317 	} else {
1318 		kfree(ctx->priv_ctx_rx);
1319 		ctx->priv_ctx_rx = NULL;
1320 	}
1321 out:
1322 	return rc;
1323 }
1324