1 /*
2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4 * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
5 * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
6 * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
7 *
8 * This software is available to you under a choice of one of two
9 * licenses. You may choose to be licensed under the terms of the GNU
10 * General Public License (GPL) Version 2, available from the file
11 * COPYING in the main directory of this source tree, or the
12 * OpenIB.org BSD license below:
13 *
14 * Redistribution and use in source and binary forms, with or
15 * without modification, are permitted provided that the following
16 * conditions are met:
17 *
18 * - Redistributions of source code must retain the above
19 * copyright notice, this list of conditions and the following
20 * disclaimer.
21 *
22 * - Redistributions in binary form must reproduce the above
23 * copyright notice, this list of conditions and the following
24 * disclaimer in the documentation and/or other materials
25 * provided with the distribution.
26 *
27 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
28 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
29 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
30 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
31 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
32 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
33 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
34 * SOFTWARE.
35 */
36
37 #include <linux/sched/signal.h>
38 #include <linux/module.h>
39 #include <crypto/aead.h>
40
41 #include <net/strparser.h>
42 #include <net/tls.h>
43
44 #define MAX_IV_SIZE TLS_CIPHER_AES_GCM_128_IV_SIZE
45
tls_do_decryption(struct sock * sk,struct scatterlist * sgin,struct scatterlist * sgout,char * iv_recv,size_t data_len,struct aead_request * aead_req)46 static int tls_do_decryption(struct sock *sk,
47 struct scatterlist *sgin,
48 struct scatterlist *sgout,
49 char *iv_recv,
50 size_t data_len,
51 struct aead_request *aead_req)
52 {
53 struct tls_context *tls_ctx = tls_get_ctx(sk);
54 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
55 int ret;
56
57 aead_request_set_tfm(aead_req, ctx->aead_recv);
58 aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
59 aead_request_set_crypt(aead_req, sgin, sgout,
60 data_len + tls_ctx->rx.tag_size,
61 (u8 *)iv_recv);
62 aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
63 crypto_req_done, &ctx->async_wait);
64
65 ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);
66 return ret;
67 }
68
trim_sg(struct sock * sk,struct scatterlist * sg,int * sg_num_elem,unsigned int * sg_size,int target_size)69 static void trim_sg(struct sock *sk, struct scatterlist *sg,
70 int *sg_num_elem, unsigned int *sg_size, int target_size)
71 {
72 int i = *sg_num_elem - 1;
73 int trim = *sg_size - target_size;
74
75 if (trim <= 0) {
76 WARN_ON(trim < 0);
77 return;
78 }
79
80 *sg_size = target_size;
81 while (trim >= sg[i].length) {
82 trim -= sg[i].length;
83 sk_mem_uncharge(sk, sg[i].length);
84 put_page(sg_page(&sg[i]));
85 i--;
86
87 if (i < 0)
88 goto out;
89 }
90
91 sg[i].length -= trim;
92 sk_mem_uncharge(sk, trim);
93
94 out:
95 *sg_num_elem = i + 1;
96 }
97
trim_both_sgl(struct sock * sk,int target_size)98 static void trim_both_sgl(struct sock *sk, int target_size)
99 {
100 struct tls_context *tls_ctx = tls_get_ctx(sk);
101 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
102
103 trim_sg(sk, ctx->sg_plaintext_data,
104 &ctx->sg_plaintext_num_elem,
105 &ctx->sg_plaintext_size,
106 target_size);
107
108 if (target_size > 0)
109 target_size += tls_ctx->tx.overhead_size;
110
111 trim_sg(sk, ctx->sg_encrypted_data,
112 &ctx->sg_encrypted_num_elem,
113 &ctx->sg_encrypted_size,
114 target_size);
115 }
116
alloc_encrypted_sg(struct sock * sk,int len)117 static int alloc_encrypted_sg(struct sock *sk, int len)
118 {
119 struct tls_context *tls_ctx = tls_get_ctx(sk);
120 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
121 int rc = 0;
122
123 rc = sk_alloc_sg(sk, len,
124 ctx->sg_encrypted_data, 0,
125 &ctx->sg_encrypted_num_elem,
126 &ctx->sg_encrypted_size, 0);
127
128 if (rc == -ENOSPC)
129 ctx->sg_encrypted_num_elem = ARRAY_SIZE(ctx->sg_encrypted_data);
130
131 return rc;
132 }
133
alloc_plaintext_sg(struct sock * sk,int len)134 static int alloc_plaintext_sg(struct sock *sk, int len)
135 {
136 struct tls_context *tls_ctx = tls_get_ctx(sk);
137 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
138 int rc = 0;
139
140 rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0,
141 &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size,
142 tls_ctx->pending_open_record_frags);
143
144 if (rc == -ENOSPC)
145 ctx->sg_plaintext_num_elem = ARRAY_SIZE(ctx->sg_plaintext_data);
146
147 return rc;
148 }
149
free_sg(struct sock * sk,struct scatterlist * sg,int * sg_num_elem,unsigned int * sg_size)150 static void free_sg(struct sock *sk, struct scatterlist *sg,
151 int *sg_num_elem, unsigned int *sg_size)
152 {
153 int i, n = *sg_num_elem;
154
155 for (i = 0; i < n; ++i) {
156 sk_mem_uncharge(sk, sg[i].length);
157 put_page(sg_page(&sg[i]));
158 }
159 *sg_num_elem = 0;
160 *sg_size = 0;
161 }
162
tls_free_both_sg(struct sock * sk)163 static void tls_free_both_sg(struct sock *sk)
164 {
165 struct tls_context *tls_ctx = tls_get_ctx(sk);
166 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
167
168 free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem,
169 &ctx->sg_encrypted_size);
170
171 free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
172 &ctx->sg_plaintext_size);
173 }
174
tls_do_encryption(struct tls_context * tls_ctx,struct tls_sw_context_tx * ctx,struct aead_request * aead_req,size_t data_len)175 static int tls_do_encryption(struct tls_context *tls_ctx,
176 struct tls_sw_context_tx *ctx,
177 struct aead_request *aead_req,
178 size_t data_len)
179 {
180 int rc;
181
182 ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size;
183 ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size;
184
185 aead_request_set_tfm(aead_req, ctx->aead_send);
186 aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
187 aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out,
188 data_len, tls_ctx->tx.iv);
189
190 aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
191 crypto_req_done, &ctx->async_wait);
192
193 rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait);
194
195 ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
196 ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;
197
198 return rc;
199 }
200
tls_push_record(struct sock * sk,int flags,unsigned char record_type)201 static int tls_push_record(struct sock *sk, int flags,
202 unsigned char record_type)
203 {
204 struct tls_context *tls_ctx = tls_get_ctx(sk);
205 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
206 struct aead_request *req;
207 int rc;
208
209 req = aead_request_alloc(ctx->aead_send, sk->sk_allocation);
210 if (!req)
211 return -ENOMEM;
212
213 sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1);
214 sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1);
215
216 tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size,
217 tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
218 record_type);
219
220 tls_fill_prepend(tls_ctx,
221 page_address(sg_page(&ctx->sg_encrypted_data[0])) +
222 ctx->sg_encrypted_data[0].offset,
223 ctx->sg_plaintext_size, record_type);
224
225 tls_ctx->pending_open_record_frags = 0;
226 set_bit(TLS_PENDING_CLOSED_RECORD, &tls_ctx->flags);
227
228 rc = tls_do_encryption(tls_ctx, ctx, req, ctx->sg_plaintext_size);
229 if (rc < 0) {
230 /* If we are called from write_space and
231 * we fail, we need to set this SOCK_NOSPACE
232 * to trigger another write_space in the future.
233 */
234 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
235 goto out_req;
236 }
237
238 free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
239 &ctx->sg_plaintext_size);
240
241 ctx->sg_encrypted_num_elem = 0;
242 ctx->sg_encrypted_size = 0;
243
244 /* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */
245 rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags);
246 if (rc < 0 && rc != -EAGAIN)
247 tls_err_abort(sk, EBADMSG);
248
249 tls_advance_record_sn(sk, &tls_ctx->tx);
250 out_req:
251 aead_request_free(req);
252 return rc;
253 }
254
tls_sw_push_pending_record(struct sock * sk,int flags)255 static int tls_sw_push_pending_record(struct sock *sk, int flags)
256 {
257 return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
258 }
259
zerocopy_from_iter(struct sock * sk,struct iov_iter * from,int length,int * pages_used,unsigned int * size_used,struct scatterlist * to,int to_max_pages,bool charge)260 static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
261 int length, int *pages_used,
262 unsigned int *size_used,
263 struct scatterlist *to, int to_max_pages,
264 bool charge)
265 {
266 struct page *pages[MAX_SKB_FRAGS];
267
268 size_t offset;
269 ssize_t copied, use;
270 int i = 0;
271 unsigned int size = *size_used;
272 int num_elem = *pages_used;
273 int rc = 0;
274 int maxpages;
275
276 while (length > 0) {
277 i = 0;
278 maxpages = to_max_pages - num_elem;
279 if (maxpages == 0) {
280 rc = -EFAULT;
281 goto out;
282 }
283 copied = iov_iter_get_pages(from, pages,
284 length,
285 maxpages, &offset);
286 if (copied <= 0) {
287 rc = -EFAULT;
288 goto out;
289 }
290
291 iov_iter_advance(from, copied);
292
293 length -= copied;
294 size += copied;
295 while (copied) {
296 use = min_t(int, copied, PAGE_SIZE - offset);
297
298 sg_set_page(&to[num_elem],
299 pages[i], use, offset);
300 sg_unmark_end(&to[num_elem]);
301 if (charge)
302 sk_mem_charge(sk, use);
303
304 offset = 0;
305 copied -= use;
306
307 ++i;
308 ++num_elem;
309 }
310 }
311
312 /* Mark the end in the last sg entry if newly added */
313 if (num_elem > *pages_used)
314 sg_mark_end(&to[num_elem - 1]);
315 out:
316 if (rc)
317 iov_iter_revert(from, size - *size_used);
318 *size_used = size;
319 *pages_used = num_elem;
320
321 return rc;
322 }
323
memcopy_from_iter(struct sock * sk,struct iov_iter * from,int bytes)324 static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,
325 int bytes)
326 {
327 struct tls_context *tls_ctx = tls_get_ctx(sk);
328 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
329 struct scatterlist *sg = ctx->sg_plaintext_data;
330 int copy, i, rc = 0;
331
332 for (i = tls_ctx->pending_open_record_frags;
333 i < ctx->sg_plaintext_num_elem; ++i) {
334 copy = sg[i].length;
335 if (copy_from_iter(
336 page_address(sg_page(&sg[i])) + sg[i].offset,
337 copy, from) != copy) {
338 rc = -EFAULT;
339 goto out;
340 }
341 bytes -= copy;
342
343 ++tls_ctx->pending_open_record_frags;
344
345 if (!bytes)
346 break;
347 }
348
349 out:
350 return rc;
351 }
352
tls_sw_sendmsg(struct sock * sk,struct msghdr * msg,size_t size)353 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
354 {
355 struct tls_context *tls_ctx = tls_get_ctx(sk);
356 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
357 int ret = 0;
358 int required_size;
359 long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
360 bool eor = !(msg->msg_flags & MSG_MORE);
361 size_t try_to_copy, copied = 0;
362 unsigned char record_type = TLS_RECORD_TYPE_DATA;
363 int record_room;
364 bool full_record;
365 int orig_size;
366 bool is_kvec = msg->msg_iter.type & ITER_KVEC;
367
368 if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
369 return -ENOTSUPP;
370
371 lock_sock(sk);
372
373 if (tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo))
374 goto send_end;
375
376 if (unlikely(msg->msg_controllen)) {
377 ret = tls_proccess_cmsg(sk, msg, &record_type);
378 if (ret)
379 goto send_end;
380 }
381
382 while (msg_data_left(msg)) {
383 if (sk->sk_err) {
384 ret = -sk->sk_err;
385 goto send_end;
386 }
387
388 orig_size = ctx->sg_plaintext_size;
389 full_record = false;
390 try_to_copy = msg_data_left(msg);
391 record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
392 if (try_to_copy >= record_room) {
393 try_to_copy = record_room;
394 full_record = true;
395 }
396
397 required_size = ctx->sg_plaintext_size + try_to_copy +
398 tls_ctx->tx.overhead_size;
399
400 if (!sk_stream_memory_free(sk))
401 goto wait_for_sndbuf;
402 alloc_encrypted:
403 ret = alloc_encrypted_sg(sk, required_size);
404 if (ret) {
405 if (ret != -ENOSPC)
406 goto wait_for_memory;
407
408 /* Adjust try_to_copy according to the amount that was
409 * actually allocated. The difference is due
410 * to max sg elements limit
411 */
412 try_to_copy -= required_size - ctx->sg_encrypted_size;
413 full_record = true;
414 }
415 if (!is_kvec && (full_record || eor)) {
416 ret = zerocopy_from_iter(sk, &msg->msg_iter,
417 try_to_copy, &ctx->sg_plaintext_num_elem,
418 &ctx->sg_plaintext_size,
419 ctx->sg_plaintext_data,
420 ARRAY_SIZE(ctx->sg_plaintext_data),
421 true);
422 if (ret)
423 goto fallback_to_reg_send;
424
425 copied += try_to_copy;
426 ret = tls_push_record(sk, msg->msg_flags, record_type);
427 if (ret)
428 goto send_end;
429 continue;
430
431 fallback_to_reg_send:
432 trim_sg(sk, ctx->sg_plaintext_data,
433 &ctx->sg_plaintext_num_elem,
434 &ctx->sg_plaintext_size,
435 orig_size);
436 }
437
438 required_size = ctx->sg_plaintext_size + try_to_copy;
439 alloc_plaintext:
440 ret = alloc_plaintext_sg(sk, required_size);
441 if (ret) {
442 if (ret != -ENOSPC)
443 goto wait_for_memory;
444
445 /* Adjust try_to_copy according to the amount that was
446 * actually allocated. The difference is due
447 * to max sg elements limit
448 */
449 try_to_copy -= required_size - ctx->sg_plaintext_size;
450 full_record = true;
451
452 trim_sg(sk, ctx->sg_encrypted_data,
453 &ctx->sg_encrypted_num_elem,
454 &ctx->sg_encrypted_size,
455 ctx->sg_plaintext_size +
456 tls_ctx->tx.overhead_size);
457 }
458
459 ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
460 if (ret)
461 goto trim_sgl;
462
463 copied += try_to_copy;
464 if (full_record || eor) {
465 push_record:
466 ret = tls_push_record(sk, msg->msg_flags, record_type);
467 if (ret) {
468 if (ret == -ENOMEM)
469 goto wait_for_memory;
470
471 goto send_end;
472 }
473 }
474
475 continue;
476
477 wait_for_sndbuf:
478 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
479 wait_for_memory:
480 ret = sk_stream_wait_memory(sk, &timeo);
481 if (ret) {
482 trim_sgl:
483 trim_both_sgl(sk, orig_size);
484 goto send_end;
485 }
486
487 if (tls_is_pending_closed_record(tls_ctx))
488 goto push_record;
489
490 if (ctx->sg_encrypted_size < required_size)
491 goto alloc_encrypted;
492
493 goto alloc_plaintext;
494 }
495
496 send_end:
497 ret = sk_stream_error(sk, msg->msg_flags, ret);
498
499 release_sock(sk);
500 return copied ? copied : ret;
501 }
502
tls_sw_sendpage(struct sock * sk,struct page * page,int offset,size_t size,int flags)503 int tls_sw_sendpage(struct sock *sk, struct page *page,
504 int offset, size_t size, int flags)
505 {
506 struct tls_context *tls_ctx = tls_get_ctx(sk);
507 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
508 int ret = 0;
509 long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
510 bool eor;
511 size_t orig_size = size;
512 unsigned char record_type = TLS_RECORD_TYPE_DATA;
513 struct scatterlist *sg;
514 bool full_record;
515 int record_room;
516
517 if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
518 MSG_SENDPAGE_NOTLAST))
519 return -ENOTSUPP;
520
521 /* No MSG_EOR from splice, only look at MSG_MORE */
522 eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
523
524 lock_sock(sk);
525
526 sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
527
528 if (tls_complete_pending_work(sk, tls_ctx, flags, &timeo))
529 goto sendpage_end;
530
531 /* Call the sk_stream functions to manage the sndbuf mem. */
532 while (size > 0) {
533 size_t copy, required_size;
534
535 if (sk->sk_err) {
536 ret = -sk->sk_err;
537 goto sendpage_end;
538 }
539
540 full_record = false;
541 record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
542 copy = size;
543 if (copy >= record_room) {
544 copy = record_room;
545 full_record = true;
546 }
547 required_size = ctx->sg_plaintext_size + copy +
548 tls_ctx->tx.overhead_size;
549
550 if (!sk_stream_memory_free(sk))
551 goto wait_for_sndbuf;
552 alloc_payload:
553 ret = alloc_encrypted_sg(sk, required_size);
554 if (ret) {
555 if (ret != -ENOSPC)
556 goto wait_for_memory;
557
558 /* Adjust copy according to the amount that was
559 * actually allocated. The difference is due
560 * to max sg elements limit
561 */
562 copy -= required_size - ctx->sg_plaintext_size;
563 full_record = true;
564 }
565
566 get_page(page);
567 sg = ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem;
568 sg_set_page(sg, page, copy, offset);
569 sg_unmark_end(sg);
570
571 ctx->sg_plaintext_num_elem++;
572
573 sk_mem_charge(sk, copy);
574 offset += copy;
575 size -= copy;
576 ctx->sg_plaintext_size += copy;
577 tls_ctx->pending_open_record_frags = ctx->sg_plaintext_num_elem;
578
579 if (full_record || eor ||
580 ctx->sg_plaintext_num_elem ==
581 ARRAY_SIZE(ctx->sg_plaintext_data)) {
582 push_record:
583 ret = tls_push_record(sk, flags, record_type);
584 if (ret) {
585 if (ret == -ENOMEM)
586 goto wait_for_memory;
587
588 goto sendpage_end;
589 }
590 }
591 continue;
592 wait_for_sndbuf:
593 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
594 wait_for_memory:
595 ret = sk_stream_wait_memory(sk, &timeo);
596 if (ret) {
597 trim_both_sgl(sk, ctx->sg_plaintext_size);
598 goto sendpage_end;
599 }
600
601 if (tls_is_pending_closed_record(tls_ctx))
602 goto push_record;
603
604 goto alloc_payload;
605 }
606
607 sendpage_end:
608 if (orig_size > size)
609 ret = orig_size - size;
610 else
611 ret = sk_stream_error(sk, flags, ret);
612
613 release_sock(sk);
614 return ret;
615 }
616
tls_wait_data(struct sock * sk,int flags,long timeo,int * err)617 static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
618 long timeo, int *err)
619 {
620 struct tls_context *tls_ctx = tls_get_ctx(sk);
621 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
622 struct sk_buff *skb;
623 DEFINE_WAIT_FUNC(wait, woken_wake_function);
624
625 while (!(skb = ctx->recv_pkt)) {
626 if (sk->sk_err) {
627 *err = sock_error(sk);
628 return NULL;
629 }
630
631 if (sk->sk_shutdown & RCV_SHUTDOWN)
632 return NULL;
633
634 if (sock_flag(sk, SOCK_DONE))
635 return NULL;
636
637 if ((flags & MSG_DONTWAIT) || !timeo) {
638 *err = -EAGAIN;
639 return NULL;
640 }
641
642 add_wait_queue(sk_sleep(sk), &wait);
643 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
644 sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
645 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
646 remove_wait_queue(sk_sleep(sk), &wait);
647
648 /* Handle signals */
649 if (signal_pending(current)) {
650 *err = sock_intr_errno(timeo);
651 return NULL;
652 }
653 }
654
655 return skb;
656 }
657
658 /* This function decrypts the input skb into either out_iov or in out_sg
659 * or in skb buffers itself. The input parameter 'zc' indicates if
660 * zero-copy mode needs to be tried or not. With zero-copy mode, either
661 * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
662 * NULL, then the decryption happens inside skb buffers itself, i.e.
663 * zero-copy gets disabled and 'zc' is updated.
664 */
665
decrypt_internal(struct sock * sk,struct sk_buff * skb,struct iov_iter * out_iov,struct scatterlist * out_sg,int * chunk,bool * zc)666 static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
667 struct iov_iter *out_iov,
668 struct scatterlist *out_sg,
669 int *chunk, bool *zc)
670 {
671 struct tls_context *tls_ctx = tls_get_ctx(sk);
672 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
673 struct strp_msg *rxm = strp_msg(skb);
674 int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
675 struct aead_request *aead_req;
676 struct sk_buff *unused;
677 u8 *aad, *iv, *mem = NULL;
678 struct scatterlist *sgin = NULL;
679 struct scatterlist *sgout = NULL;
680 const int data_len = rxm->full_len - tls_ctx->rx.overhead_size;
681
682 if (*zc && (out_iov || out_sg)) {
683 if (out_iov)
684 n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
685 else
686 n_sgout = sg_nents(out_sg);
687 } else {
688 n_sgout = 0;
689 *zc = false;
690 }
691
692 n_sgin = skb_cow_data(skb, 0, &unused);
693 if (n_sgin < 1)
694 return -EBADMSG;
695
696 /* Increment to accommodate AAD */
697 n_sgin = n_sgin + 1;
698
699 nsg = n_sgin + n_sgout;
700
701 aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
702 mem_size = aead_size + (nsg * sizeof(struct scatterlist));
703 mem_size = mem_size + TLS_AAD_SPACE_SIZE;
704 mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
705
706 /* Allocate a single block of memory which contains
707 * aead_req || sgin[] || sgout[] || aad || iv.
708 * This order achieves correct alignment for aead_req, sgin, sgout.
709 */
710 mem = kmalloc(mem_size, sk->sk_allocation);
711 if (!mem)
712 return -ENOMEM;
713
714 /* Segment the allocated memory */
715 aead_req = (struct aead_request *)mem;
716 sgin = (struct scatterlist *)(mem + aead_size);
717 sgout = sgin + n_sgin;
718 aad = (u8 *)(sgout + n_sgout);
719 iv = aad + TLS_AAD_SPACE_SIZE;
720
721 /* Prepare IV */
722 err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
723 iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
724 tls_ctx->rx.iv_size);
725 if (err < 0) {
726 kfree(mem);
727 return err;
728 }
729 memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
730
731 /* Prepare AAD */
732 tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size,
733 tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size,
734 ctx->control);
735
736 /* Prepare sgin */
737 sg_init_table(sgin, n_sgin);
738 sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE);
739 err = skb_to_sgvec(skb, &sgin[1],
740 rxm->offset + tls_ctx->rx.prepend_size,
741 rxm->full_len - tls_ctx->rx.prepend_size);
742 if (err < 0) {
743 kfree(mem);
744 return err;
745 }
746
747 if (n_sgout) {
748 if (out_iov) {
749 sg_init_table(sgout, n_sgout);
750 sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE);
751
752 *chunk = 0;
753 err = zerocopy_from_iter(sk, out_iov, data_len, &pages,
754 chunk, &sgout[1],
755 (n_sgout - 1), false);
756 if (err < 0)
757 goto fallback_to_reg_recv;
758 } else if (out_sg) {
759 memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
760 } else {
761 goto fallback_to_reg_recv;
762 }
763 } else {
764 fallback_to_reg_recv:
765 sgout = sgin;
766 pages = 0;
767 *chunk = 0;
768 *zc = false;
769 }
770
771 /* Prepare and submit AEAD request */
772 err = tls_do_decryption(sk, sgin, sgout, iv, data_len, aead_req);
773
774 /* Release the pages in case iov was mapped to pages */
775 for (; pages > 0; pages--)
776 put_page(sg_page(&sgout[pages]));
777
778 kfree(mem);
779 return err;
780 }
781
decrypt_skb_update(struct sock * sk,struct sk_buff * skb,struct iov_iter * dest,int * chunk,bool * zc)782 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
783 struct iov_iter *dest, int *chunk, bool *zc)
784 {
785 struct tls_context *tls_ctx = tls_get_ctx(sk);
786 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
787 struct strp_msg *rxm = strp_msg(skb);
788 int err = 0;
789
790 #ifdef CONFIG_TLS_DEVICE
791 err = tls_device_decrypted(sk, skb);
792 if (err < 0)
793 return err;
794 #endif
795 if (!ctx->decrypted) {
796 err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
797 if (err < 0)
798 return err;
799 } else {
800 *zc = false;
801 }
802
803 rxm->offset += tls_ctx->rx.prepend_size;
804 rxm->full_len -= tls_ctx->rx.overhead_size;
805 tls_advance_record_sn(sk, &tls_ctx->rx);
806 ctx->decrypted = true;
807 ctx->saved_data_ready(sk);
808
809 return err;
810 }
811
decrypt_skb(struct sock * sk,struct sk_buff * skb,struct scatterlist * sgout)812 int decrypt_skb(struct sock *sk, struct sk_buff *skb,
813 struct scatterlist *sgout)
814 {
815 bool zc = true;
816 int chunk;
817
818 return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
819 }
820
tls_sw_advance_skb(struct sock * sk,struct sk_buff * skb,unsigned int len)821 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
822 unsigned int len)
823 {
824 struct tls_context *tls_ctx = tls_get_ctx(sk);
825 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
826 struct strp_msg *rxm = strp_msg(skb);
827
828 if (len < rxm->full_len) {
829 rxm->offset += len;
830 rxm->full_len -= len;
831
832 return false;
833 }
834
835 /* Finished with message */
836 ctx->recv_pkt = NULL;
837 kfree_skb(skb);
838 __strp_unpause(&ctx->strp);
839
840 return true;
841 }
842
tls_sw_recvmsg(struct sock * sk,struct msghdr * msg,size_t len,int nonblock,int flags,int * addr_len)843 int tls_sw_recvmsg(struct sock *sk,
844 struct msghdr *msg,
845 size_t len,
846 int nonblock,
847 int flags,
848 int *addr_len)
849 {
850 struct tls_context *tls_ctx = tls_get_ctx(sk);
851 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
852 unsigned char control;
853 struct strp_msg *rxm;
854 struct sk_buff *skb;
855 ssize_t copied = 0;
856 bool cmsg = false;
857 int target, err = 0;
858 long timeo;
859 bool is_kvec = msg->msg_iter.type & ITER_KVEC;
860
861 flags |= nonblock;
862
863 if (unlikely(flags & MSG_ERRQUEUE))
864 return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
865
866 lock_sock(sk);
867
868 target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
869 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
870 do {
871 bool zc = false;
872 int chunk = 0;
873
874 skb = tls_wait_data(sk, flags, timeo, &err);
875 if (!skb)
876 goto recv_end;
877
878 rxm = strp_msg(skb);
879 if (!cmsg) {
880 int cerr;
881
882 cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
883 sizeof(ctx->control), &ctx->control);
884 cmsg = true;
885 control = ctx->control;
886 if (ctx->control != TLS_RECORD_TYPE_DATA) {
887 if (cerr || msg->msg_flags & MSG_CTRUNC) {
888 err = -EIO;
889 goto recv_end;
890 }
891 }
892 } else if (control != ctx->control) {
893 goto recv_end;
894 }
895
896 if (!ctx->decrypted) {
897 int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
898
899 if (!is_kvec && to_copy <= len &&
900 likely(!(flags & MSG_PEEK)))
901 zc = true;
902
903 err = decrypt_skb_update(sk, skb, &msg->msg_iter,
904 &chunk, &zc);
905 if (err < 0) {
906 tls_err_abort(sk, EBADMSG);
907 goto recv_end;
908 }
909 ctx->decrypted = true;
910 }
911
912 if (!zc) {
913 chunk = min_t(unsigned int, rxm->full_len, len);
914 err = skb_copy_datagram_msg(skb, rxm->offset, msg,
915 chunk);
916 if (err < 0)
917 goto recv_end;
918 }
919
920 copied += chunk;
921 len -= chunk;
922 if (likely(!(flags & MSG_PEEK))) {
923 u8 control = ctx->control;
924
925 if (tls_sw_advance_skb(sk, skb, chunk)) {
926 /* Return full control message to
927 * userspace before trying to parse
928 * another message type
929 */
930 msg->msg_flags |= MSG_EOR;
931 if (control != TLS_RECORD_TYPE_DATA)
932 goto recv_end;
933 }
934 } else {
935 /* MSG_PEEK right now cannot look beyond current skb
936 * from strparser, meaning we cannot advance skb here
937 * and thus unpause strparser since we'd loose original
938 * one.
939 */
940 break;
941 }
942
943 /* If we have a new message from strparser, continue now. */
944 if (copied >= target && !ctx->recv_pkt)
945 break;
946 } while (len);
947
948 recv_end:
949 release_sock(sk);
950 return copied ? : err;
951 }
952
tls_sw_splice_read(struct socket * sock,loff_t * ppos,struct pipe_inode_info * pipe,size_t len,unsigned int flags)953 ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
954 struct pipe_inode_info *pipe,
955 size_t len, unsigned int flags)
956 {
957 struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
958 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
959 struct strp_msg *rxm = NULL;
960 struct sock *sk = sock->sk;
961 struct sk_buff *skb;
962 ssize_t copied = 0;
963 int err = 0;
964 long timeo;
965 int chunk;
966 bool zc = false;
967
968 lock_sock(sk);
969
970 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
971
972 skb = tls_wait_data(sk, flags, timeo, &err);
973 if (!skb)
974 goto splice_read_end;
975
976 /* splice does not support reading control messages */
977 if (ctx->control != TLS_RECORD_TYPE_DATA) {
978 err = -ENOTSUPP;
979 goto splice_read_end;
980 }
981
982 if (!ctx->decrypted) {
983 err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
984
985 if (err < 0) {
986 tls_err_abort(sk, EBADMSG);
987 goto splice_read_end;
988 }
989 ctx->decrypted = true;
990 }
991 rxm = strp_msg(skb);
992
993 chunk = min_t(unsigned int, rxm->full_len, len);
994 copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
995 if (copied < 0)
996 goto splice_read_end;
997
998 if (likely(!(flags & MSG_PEEK)))
999 tls_sw_advance_skb(sk, skb, copied);
1000
1001 splice_read_end:
1002 release_sock(sk);
1003 return copied ? : err;
1004 }
1005
tls_sw_poll(struct file * file,struct socket * sock,struct poll_table_struct * wait)1006 unsigned int tls_sw_poll(struct file *file, struct socket *sock,
1007 struct poll_table_struct *wait)
1008 {
1009 unsigned int ret;
1010 struct sock *sk = sock->sk;
1011 struct tls_context *tls_ctx = tls_get_ctx(sk);
1012 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1013
1014 /* Grab POLLOUT and POLLHUP from the underlying socket */
1015 ret = ctx->sk_poll(file, sock, wait);
1016
1017 /* Clear POLLIN bits, and set based on recv_pkt */
1018 ret &= ~(POLLIN | POLLRDNORM);
1019 if (ctx->recv_pkt)
1020 ret |= POLLIN | POLLRDNORM;
1021
1022 return ret;
1023 }
1024
tls_read_size(struct strparser * strp,struct sk_buff * skb)1025 static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
1026 {
1027 struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
1028 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1029 char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
1030 struct strp_msg *rxm = strp_msg(skb);
1031 size_t cipher_overhead;
1032 size_t data_len = 0;
1033 int ret;
1034
1035 /* Verify that we have a full TLS header, or wait for more data */
1036 if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
1037 return 0;
1038
1039 /* Sanity-check size of on-stack buffer. */
1040 if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) {
1041 ret = -EINVAL;
1042 goto read_failure;
1043 }
1044
1045 /* Linearize header to local buffer */
1046 ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);
1047
1048 if (ret < 0)
1049 goto read_failure;
1050
1051 ctx->control = header[0];
1052
1053 data_len = ((header[4] & 0xFF) | (header[3] << 8));
1054
1055 cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size;
1056
1057 if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) {
1058 ret = -EMSGSIZE;
1059 goto read_failure;
1060 }
1061 if (data_len < cipher_overhead) {
1062 ret = -EBADMSG;
1063 goto read_failure;
1064 }
1065
1066 if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.info.version) ||
1067 header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.info.version)) {
1068 ret = -EINVAL;
1069 goto read_failure;
1070 }
1071
1072 #ifdef CONFIG_TLS_DEVICE
1073 handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset,
1074 *(u64*)tls_ctx->rx.rec_seq);
1075 #endif
1076 return data_len + TLS_HEADER_SIZE;
1077
1078 read_failure:
1079 tls_err_abort(strp->sk, ret);
1080
1081 return ret;
1082 }
1083
tls_queue(struct strparser * strp,struct sk_buff * skb)1084 static void tls_queue(struct strparser *strp, struct sk_buff *skb)
1085 {
1086 struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
1087 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1088
1089 ctx->decrypted = false;
1090
1091 ctx->recv_pkt = skb;
1092 strp_pause(strp);
1093
1094 ctx->saved_data_ready(strp->sk);
1095 }
1096
tls_data_ready(struct sock * sk)1097 static void tls_data_ready(struct sock *sk)
1098 {
1099 struct tls_context *tls_ctx = tls_get_ctx(sk);
1100 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1101
1102 strp_data_ready(&ctx->strp);
1103 }
1104
tls_sw_free_resources_tx(struct sock * sk)1105 void tls_sw_free_resources_tx(struct sock *sk)
1106 {
1107 struct tls_context *tls_ctx = tls_get_ctx(sk);
1108 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1109
1110 crypto_free_aead(ctx->aead_send);
1111 tls_free_both_sg(sk);
1112
1113 kfree(ctx);
1114 }
1115
tls_sw_release_resources_rx(struct sock * sk)1116 void tls_sw_release_resources_rx(struct sock *sk)
1117 {
1118 struct tls_context *tls_ctx = tls_get_ctx(sk);
1119 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1120
1121 if (ctx->aead_recv) {
1122 kfree_skb(ctx->recv_pkt);
1123 ctx->recv_pkt = NULL;
1124 crypto_free_aead(ctx->aead_recv);
1125 strp_stop(&ctx->strp);
1126 write_lock_bh(&sk->sk_callback_lock);
1127 sk->sk_data_ready = ctx->saved_data_ready;
1128 write_unlock_bh(&sk->sk_callback_lock);
1129 release_sock(sk);
1130 strp_done(&ctx->strp);
1131 lock_sock(sk);
1132 }
1133 }
1134
tls_sw_free_resources_rx(struct sock * sk)1135 void tls_sw_free_resources_rx(struct sock *sk)
1136 {
1137 struct tls_context *tls_ctx = tls_get_ctx(sk);
1138 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1139
1140 tls_sw_release_resources_rx(sk);
1141
1142 kfree(ctx);
1143 }
1144
tls_set_sw_offload(struct sock * sk,struct tls_context * ctx,int tx)1145 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
1146 {
1147 struct tls_crypto_info *crypto_info;
1148 struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
1149 struct tls_sw_context_tx *sw_ctx_tx = NULL;
1150 struct tls_sw_context_rx *sw_ctx_rx = NULL;
1151 struct cipher_context *cctx;
1152 struct crypto_aead **aead;
1153 struct strp_callbacks cb;
1154 u16 nonce_size, tag_size, iv_size, rec_seq_size;
1155 char *iv, *rec_seq;
1156 int rc = 0;
1157
1158 if (!ctx) {
1159 rc = -EINVAL;
1160 goto out;
1161 }
1162
1163 if (tx) {
1164 if (!ctx->priv_ctx_tx) {
1165 sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
1166 if (!sw_ctx_tx) {
1167 rc = -ENOMEM;
1168 goto out;
1169 }
1170 ctx->priv_ctx_tx = sw_ctx_tx;
1171 } else {
1172 sw_ctx_tx =
1173 (struct tls_sw_context_tx *)ctx->priv_ctx_tx;
1174 }
1175 } else {
1176 if (!ctx->priv_ctx_rx) {
1177 sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
1178 if (!sw_ctx_rx) {
1179 rc = -ENOMEM;
1180 goto out;
1181 }
1182 ctx->priv_ctx_rx = sw_ctx_rx;
1183 } else {
1184 sw_ctx_rx =
1185 (struct tls_sw_context_rx *)ctx->priv_ctx_rx;
1186 }
1187 }
1188
1189 if (tx) {
1190 crypto_init_wait(&sw_ctx_tx->async_wait);
1191 crypto_info = &ctx->crypto_send.info;
1192 cctx = &ctx->tx;
1193 aead = &sw_ctx_tx->aead_send;
1194 } else {
1195 crypto_init_wait(&sw_ctx_rx->async_wait);
1196 crypto_info = &ctx->crypto_recv.info;
1197 cctx = &ctx->rx;
1198 aead = &sw_ctx_rx->aead_recv;
1199 }
1200
1201 switch (crypto_info->cipher_type) {
1202 case TLS_CIPHER_AES_GCM_128: {
1203 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1204 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
1205 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1206 iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
1207 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
1208 rec_seq =
1209 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
1210 gcm_128_info =
1211 (struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
1212 break;
1213 }
1214 default:
1215 rc = -EINVAL;
1216 goto free_priv;
1217 }
1218
1219 /* Sanity-check the IV size for stack allocations. */
1220 if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) {
1221 rc = -EINVAL;
1222 goto free_priv;
1223 }
1224
1225 cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
1226 cctx->tag_size = tag_size;
1227 cctx->overhead_size = cctx->prepend_size + cctx->tag_size;
1228 cctx->iv_size = iv_size;
1229 cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
1230 GFP_KERNEL);
1231 if (!cctx->iv) {
1232 rc = -ENOMEM;
1233 goto free_priv;
1234 }
1235 memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
1236 memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
1237 cctx->rec_seq_size = rec_seq_size;
1238 cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
1239 if (!cctx->rec_seq) {
1240 rc = -ENOMEM;
1241 goto free_iv;
1242 }
1243
1244 if (sw_ctx_tx) {
1245 sg_init_table(sw_ctx_tx->sg_encrypted_data,
1246 ARRAY_SIZE(sw_ctx_tx->sg_encrypted_data));
1247 sg_init_table(sw_ctx_tx->sg_plaintext_data,
1248 ARRAY_SIZE(sw_ctx_tx->sg_plaintext_data));
1249
1250 sg_init_table(sw_ctx_tx->sg_aead_in, 2);
1251 sg_set_buf(&sw_ctx_tx->sg_aead_in[0], sw_ctx_tx->aad_space,
1252 sizeof(sw_ctx_tx->aad_space));
1253 sg_unmark_end(&sw_ctx_tx->sg_aead_in[1]);
1254 sg_chain(sw_ctx_tx->sg_aead_in, 2,
1255 sw_ctx_tx->sg_plaintext_data);
1256 sg_init_table(sw_ctx_tx->sg_aead_out, 2);
1257 sg_set_buf(&sw_ctx_tx->sg_aead_out[0], sw_ctx_tx->aad_space,
1258 sizeof(sw_ctx_tx->aad_space));
1259 sg_unmark_end(&sw_ctx_tx->sg_aead_out[1]);
1260 sg_chain(sw_ctx_tx->sg_aead_out, 2,
1261 sw_ctx_tx->sg_encrypted_data);
1262 }
1263
1264 if (!*aead) {
1265 *aead = crypto_alloc_aead("gcm(aes)", 0, 0);
1266 if (IS_ERR(*aead)) {
1267 rc = PTR_ERR(*aead);
1268 *aead = NULL;
1269 goto free_rec_seq;
1270 }
1271 }
1272
1273 ctx->push_pending_record = tls_sw_push_pending_record;
1274
1275 rc = crypto_aead_setkey(*aead, gcm_128_info->key,
1276 TLS_CIPHER_AES_GCM_128_KEY_SIZE);
1277 if (rc)
1278 goto free_aead;
1279
1280 rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
1281 if (rc)
1282 goto free_aead;
1283
1284 if (sw_ctx_rx) {
1285 /* Set up strparser */
1286 memset(&cb, 0, sizeof(cb));
1287 cb.rcv_msg = tls_queue;
1288 cb.parse_msg = tls_read_size;
1289
1290 strp_init(&sw_ctx_rx->strp, sk, &cb);
1291
1292 write_lock_bh(&sk->sk_callback_lock);
1293 sw_ctx_rx->saved_data_ready = sk->sk_data_ready;
1294 sk->sk_data_ready = tls_data_ready;
1295 write_unlock_bh(&sk->sk_callback_lock);
1296
1297 sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll;
1298
1299 strp_check_rcv(&sw_ctx_rx->strp);
1300 }
1301
1302 goto out;
1303
1304 free_aead:
1305 crypto_free_aead(*aead);
1306 *aead = NULL;
1307 free_rec_seq:
1308 kfree(cctx->rec_seq);
1309 cctx->rec_seq = NULL;
1310 free_iv:
1311 kfree(cctx->iv);
1312 cctx->iv = NULL;
1313 free_priv:
1314 if (tx) {
1315 kfree(ctx->priv_ctx_tx);
1316 ctx->priv_ctx_tx = NULL;
1317 } else {
1318 kfree(ctx->priv_ctx_rx);
1319 ctx->priv_ctx_rx = NULL;
1320 }
1321 out:
1322 return rc;
1323 }
1324