1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3 * common code for virtio vsock
4 *
5 * Copyright (C) 2013-2015 Red Hat, Inc.
6 * Author: Asias He <asias@redhat.com>
7 * Stefan Hajnoczi <stefanha@redhat.com>
8 */
9 #include <linux/spinlock.h>
10 #include <linux/module.h>
11 #include <linux/sched/signal.h>
12 #include <linux/ctype.h>
13 #include <linux/list.h>
14 #include <linux/virtio_vsock.h>
15 #include <uapi/linux/vsockmon.h>
16
17 #include <net/sock.h>
18 #include <net/af_vsock.h>
19
20 #define CREATE_TRACE_POINTS
21 #include <trace/events/vsock_virtio_transport_common.h>
22
23 /* How long to wait for graceful shutdown of a connection */
24 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
25
26 /* Threshold for detecting small packets to copy */
27 #define GOOD_COPY_LEN 128
28
29 static const struct virtio_transport *
virtio_transport_get_ops(struct vsock_sock * vsk)30 virtio_transport_get_ops(struct vsock_sock *vsk)
31 {
32 const struct vsock_transport *t = vsock_core_get_transport(vsk);
33
34 if (WARN_ON(!t))
35 return NULL;
36
37 return container_of(t, struct virtio_transport, transport);
38 }
39
40 /* Returns a new packet on success, otherwise returns NULL.
41 *
42 * If NULL is returned, errp is set to a negative errno.
43 */
44 static struct sk_buff *
virtio_transport_alloc_skb(struct virtio_vsock_pkt_info * info,size_t len,u32 src_cid,u32 src_port,u32 dst_cid,u32 dst_port)45 virtio_transport_alloc_skb(struct virtio_vsock_pkt_info *info,
46 size_t len,
47 u32 src_cid,
48 u32 src_port,
49 u32 dst_cid,
50 u32 dst_port)
51 {
52 const size_t skb_len = VIRTIO_VSOCK_SKB_HEADROOM + len;
53 struct virtio_vsock_hdr *hdr;
54 struct sk_buff *skb;
55 void *payload;
56 int err;
57
58 skb = virtio_vsock_alloc_skb(skb_len, GFP_KERNEL);
59 if (!skb)
60 return NULL;
61
62 hdr = virtio_vsock_hdr(skb);
63 hdr->type = cpu_to_le16(info->type);
64 hdr->op = cpu_to_le16(info->op);
65 hdr->src_cid = cpu_to_le64(src_cid);
66 hdr->dst_cid = cpu_to_le64(dst_cid);
67 hdr->src_port = cpu_to_le32(src_port);
68 hdr->dst_port = cpu_to_le32(dst_port);
69 hdr->flags = cpu_to_le32(info->flags);
70 hdr->len = cpu_to_le32(len);
71
72 if (info->msg && len > 0) {
73 payload = skb_put(skb, len);
74 err = memcpy_from_msg(payload, info->msg, len);
75 if (err)
76 goto out;
77
78 if (msg_data_left(info->msg) == 0 &&
79 info->type == VIRTIO_VSOCK_TYPE_SEQPACKET) {
80 hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM);
81
82 if (info->msg->msg_flags & MSG_EOR)
83 hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
84 }
85 }
86
87 if (info->reply)
88 virtio_vsock_skb_set_reply(skb);
89
90 trace_virtio_transport_alloc_pkt(src_cid, src_port,
91 dst_cid, dst_port,
92 len,
93 info->type,
94 info->op,
95 info->flags);
96
97 if (info->vsk && !skb_set_owner_sk_safe(skb, sk_vsock(info->vsk))) {
98 WARN_ONCE(1, "failed to allocate skb on vsock socket with sk_refcnt == 0\n");
99 goto out;
100 }
101
102 return skb;
103
104 out:
105 kfree_skb(skb);
106 return NULL;
107 }
108
109 /* Packet capture */
virtio_transport_build_skb(void * opaque)110 static struct sk_buff *virtio_transport_build_skb(void *opaque)
111 {
112 struct virtio_vsock_hdr *pkt_hdr;
113 struct sk_buff *pkt = opaque;
114 struct af_vsockmon_hdr *hdr;
115 struct sk_buff *skb;
116 size_t payload_len;
117 void *payload_buf;
118
119 /* A packet could be split to fit the RX buffer, so we can retrieve
120 * the payload length from the header and the buffer pointer taking
121 * care of the offset in the original packet.
122 */
123 pkt_hdr = virtio_vsock_hdr(pkt);
124 payload_len = pkt->len;
125 payload_buf = pkt->data;
126
127 skb = alloc_skb(sizeof(*hdr) + sizeof(*pkt_hdr) + payload_len,
128 GFP_ATOMIC);
129 if (!skb)
130 return NULL;
131
132 hdr = skb_put(skb, sizeof(*hdr));
133
134 /* pkt->hdr is little-endian so no need to byteswap here */
135 hdr->src_cid = pkt_hdr->src_cid;
136 hdr->src_port = pkt_hdr->src_port;
137 hdr->dst_cid = pkt_hdr->dst_cid;
138 hdr->dst_port = pkt_hdr->dst_port;
139
140 hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
141 hdr->len = cpu_to_le16(sizeof(*pkt_hdr));
142 memset(hdr->reserved, 0, sizeof(hdr->reserved));
143
144 switch (le16_to_cpu(pkt_hdr->op)) {
145 case VIRTIO_VSOCK_OP_REQUEST:
146 case VIRTIO_VSOCK_OP_RESPONSE:
147 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
148 break;
149 case VIRTIO_VSOCK_OP_RST:
150 case VIRTIO_VSOCK_OP_SHUTDOWN:
151 hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
152 break;
153 case VIRTIO_VSOCK_OP_RW:
154 hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
155 break;
156 case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
157 case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
158 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
159 break;
160 default:
161 hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
162 break;
163 }
164
165 skb_put_data(skb, pkt_hdr, sizeof(*pkt_hdr));
166
167 if (payload_len) {
168 skb_put_data(skb, payload_buf, payload_len);
169 }
170
171 return skb;
172 }
173
virtio_transport_deliver_tap_pkt(struct sk_buff * skb)174 void virtio_transport_deliver_tap_pkt(struct sk_buff *skb)
175 {
176 if (virtio_vsock_skb_tap_delivered(skb))
177 return;
178
179 vsock_deliver_tap(virtio_transport_build_skb, skb);
180 virtio_vsock_skb_set_tap_delivered(skb);
181 }
182 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
183
virtio_transport_get_type(struct sock * sk)184 static u16 virtio_transport_get_type(struct sock *sk)
185 {
186 if (sk->sk_type == SOCK_STREAM)
187 return VIRTIO_VSOCK_TYPE_STREAM;
188 else
189 return VIRTIO_VSOCK_TYPE_SEQPACKET;
190 }
191
192 /* This function can only be used on connecting/connected sockets,
193 * since a socket assigned to a transport is required.
194 *
195 * Do not use on listener sockets!
196 */
virtio_transport_send_pkt_info(struct vsock_sock * vsk,struct virtio_vsock_pkt_info * info)197 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
198 struct virtio_vsock_pkt_info *info)
199 {
200 u32 src_cid, src_port, dst_cid, dst_port;
201 const struct virtio_transport *t_ops;
202 struct virtio_vsock_sock *vvs;
203 u32 pkt_len = info->pkt_len;
204 u32 rest_len;
205 int ret;
206
207 info->type = virtio_transport_get_type(sk_vsock(vsk));
208
209 t_ops = virtio_transport_get_ops(vsk);
210 if (unlikely(!t_ops))
211 return -EFAULT;
212
213 src_cid = t_ops->transport.get_local_cid();
214 src_port = vsk->local_addr.svm_port;
215 if (!info->remote_cid) {
216 dst_cid = vsk->remote_addr.svm_cid;
217 dst_port = vsk->remote_addr.svm_port;
218 } else {
219 dst_cid = info->remote_cid;
220 dst_port = info->remote_port;
221 }
222
223 vvs = vsk->trans;
224
225 /* virtio_transport_get_credit might return less than pkt_len credit */
226 pkt_len = virtio_transport_get_credit(vvs, pkt_len);
227
228 /* Do not send zero length OP_RW pkt */
229 if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
230 return pkt_len;
231
232 rest_len = pkt_len;
233
234 do {
235 struct sk_buff *skb;
236 size_t skb_len;
237
238 skb_len = min_t(u32, VIRTIO_VSOCK_MAX_PKT_BUF_SIZE, rest_len);
239
240 skb = virtio_transport_alloc_skb(info, skb_len,
241 src_cid, src_port,
242 dst_cid, dst_port);
243 if (!skb) {
244 ret = -ENOMEM;
245 break;
246 }
247
248 virtio_transport_inc_tx_pkt(vvs, skb);
249
250 ret = t_ops->send_pkt(skb);
251 if (ret < 0)
252 break;
253
254 /* Both virtio and vhost 'send_pkt()' returns 'skb_len',
255 * but for reliability use 'ret' instead of 'skb_len'.
256 * Also if partial send happens (e.g. 'ret' != 'skb_len')
257 * somehow, we break this loop, but account such returned
258 * value in 'virtio_transport_put_credit()'.
259 */
260 rest_len -= ret;
261
262 if (WARN_ONCE(ret != skb_len,
263 "'send_pkt()' returns %i, but %zu expected\n",
264 ret, skb_len))
265 break;
266 } while (rest_len);
267
268 virtio_transport_put_credit(vvs, rest_len);
269
270 /* Return number of bytes, if any data has been sent. */
271 if (rest_len != pkt_len)
272 ret = pkt_len - rest_len;
273
274 return ret;
275 }
276
virtio_transport_inc_rx_pkt(struct virtio_vsock_sock * vvs,u32 len)277 static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
278 u32 len)
279 {
280 if (vvs->rx_bytes + len > vvs->buf_alloc)
281 return false;
282
283 vvs->rx_bytes += len;
284 return true;
285 }
286
virtio_transport_dec_rx_pkt(struct virtio_vsock_sock * vvs,u32 len)287 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
288 u32 len)
289 {
290 vvs->rx_bytes -= len;
291 vvs->fwd_cnt += len;
292 }
293
virtio_transport_inc_tx_pkt(struct virtio_vsock_sock * vvs,struct sk_buff * skb)294 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct sk_buff *skb)
295 {
296 struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
297
298 spin_lock_bh(&vvs->rx_lock);
299 vvs->last_fwd_cnt = vvs->fwd_cnt;
300 hdr->fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
301 hdr->buf_alloc = cpu_to_le32(vvs->buf_alloc);
302 spin_unlock_bh(&vvs->rx_lock);
303 }
304 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
305
virtio_transport_get_credit(struct virtio_vsock_sock * vvs,u32 credit)306 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
307 {
308 u32 ret;
309
310 if (!credit)
311 return 0;
312
313 spin_lock_bh(&vvs->tx_lock);
314 ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
315 if (ret > credit)
316 ret = credit;
317 vvs->tx_cnt += ret;
318 spin_unlock_bh(&vvs->tx_lock);
319
320 return ret;
321 }
322 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
323
virtio_transport_put_credit(struct virtio_vsock_sock * vvs,u32 credit)324 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
325 {
326 if (!credit)
327 return;
328
329 spin_lock_bh(&vvs->tx_lock);
330 vvs->tx_cnt -= credit;
331 spin_unlock_bh(&vvs->tx_lock);
332 }
333 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
334
virtio_transport_send_credit_update(struct vsock_sock * vsk)335 static int virtio_transport_send_credit_update(struct vsock_sock *vsk)
336 {
337 struct virtio_vsock_pkt_info info = {
338 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
339 .vsk = vsk,
340 };
341
342 return virtio_transport_send_pkt_info(vsk, &info);
343 }
344
345 static ssize_t
virtio_transport_stream_do_peek(struct vsock_sock * vsk,struct msghdr * msg,size_t len)346 virtio_transport_stream_do_peek(struct vsock_sock *vsk,
347 struct msghdr *msg,
348 size_t len)
349 {
350 struct virtio_vsock_sock *vvs = vsk->trans;
351 struct sk_buff *skb;
352 size_t total = 0;
353 int err;
354
355 spin_lock_bh(&vvs->rx_lock);
356
357 skb_queue_walk(&vvs->rx_queue, skb) {
358 size_t bytes;
359
360 bytes = len - total;
361 if (bytes > skb->len)
362 bytes = skb->len;
363
364 spin_unlock_bh(&vvs->rx_lock);
365
366 /* sk_lock is held by caller so no one else can dequeue.
367 * Unlock rx_lock since memcpy_to_msg() may sleep.
368 */
369 err = memcpy_to_msg(msg, skb->data, bytes);
370 if (err)
371 goto out;
372
373 total += bytes;
374
375 spin_lock_bh(&vvs->rx_lock);
376
377 if (total == len)
378 break;
379 }
380
381 spin_unlock_bh(&vvs->rx_lock);
382
383 return total;
384
385 out:
386 if (total)
387 err = total;
388 return err;
389 }
390
391 static ssize_t
virtio_transport_stream_do_dequeue(struct vsock_sock * vsk,struct msghdr * msg,size_t len)392 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
393 struct msghdr *msg,
394 size_t len)
395 {
396 struct virtio_vsock_sock *vvs = vsk->trans;
397 size_t bytes, total = 0;
398 struct sk_buff *skb;
399 int err = -EFAULT;
400 u32 free_space;
401
402 spin_lock_bh(&vvs->rx_lock);
403
404 if (WARN_ONCE(skb_queue_empty(&vvs->rx_queue) && vvs->rx_bytes,
405 "rx_queue is empty, but rx_bytes is non-zero\n")) {
406 spin_unlock_bh(&vvs->rx_lock);
407 return err;
408 }
409
410 while (total < len && !skb_queue_empty(&vvs->rx_queue)) {
411 skb = skb_peek(&vvs->rx_queue);
412
413 bytes = len - total;
414 if (bytes > skb->len)
415 bytes = skb->len;
416
417 /* sk_lock is held by caller so no one else can dequeue.
418 * Unlock rx_lock since memcpy_to_msg() may sleep.
419 */
420 spin_unlock_bh(&vvs->rx_lock);
421
422 err = memcpy_to_msg(msg, skb->data, bytes);
423 if (err)
424 goto out;
425
426 spin_lock_bh(&vvs->rx_lock);
427
428 total += bytes;
429 skb_pull(skb, bytes);
430
431 if (skb->len == 0) {
432 u32 pkt_len = le32_to_cpu(virtio_vsock_hdr(skb)->len);
433
434 virtio_transport_dec_rx_pkt(vvs, pkt_len);
435 __skb_unlink(skb, &vvs->rx_queue);
436 consume_skb(skb);
437 }
438 }
439
440 free_space = vvs->buf_alloc - (vvs->fwd_cnt - vvs->last_fwd_cnt);
441
442 spin_unlock_bh(&vvs->rx_lock);
443
444 /* To reduce the number of credit update messages,
445 * don't update credits as long as lots of space is available.
446 * Note: the limit chosen here is arbitrary. Setting the limit
447 * too high causes extra messages. Too low causes transmitter
448 * stalls. As stalls are in theory more expensive than extra
449 * messages, we set the limit to a high value. TODO: experiment
450 * with different values.
451 */
452 if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
453 virtio_transport_send_credit_update(vsk);
454
455 return total;
456
457 out:
458 if (total)
459 err = total;
460 return err;
461 }
462
463 static ssize_t
virtio_transport_seqpacket_do_peek(struct vsock_sock * vsk,struct msghdr * msg)464 virtio_transport_seqpacket_do_peek(struct vsock_sock *vsk,
465 struct msghdr *msg)
466 {
467 struct virtio_vsock_sock *vvs = vsk->trans;
468 struct sk_buff *skb;
469 size_t total, len;
470
471 spin_lock_bh(&vvs->rx_lock);
472
473 if (!vvs->msg_count) {
474 spin_unlock_bh(&vvs->rx_lock);
475 return 0;
476 }
477
478 total = 0;
479 len = msg_data_left(msg);
480
481 skb_queue_walk(&vvs->rx_queue, skb) {
482 struct virtio_vsock_hdr *hdr;
483
484 if (total < len) {
485 size_t bytes;
486 int err;
487
488 bytes = len - total;
489 if (bytes > skb->len)
490 bytes = skb->len;
491
492 spin_unlock_bh(&vvs->rx_lock);
493
494 /* sk_lock is held by caller so no one else can dequeue.
495 * Unlock rx_lock since memcpy_to_msg() may sleep.
496 */
497 err = memcpy_to_msg(msg, skb->data, bytes);
498 if (err)
499 return err;
500
501 spin_lock_bh(&vvs->rx_lock);
502 }
503
504 total += skb->len;
505 hdr = virtio_vsock_hdr(skb);
506
507 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) {
508 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOR)
509 msg->msg_flags |= MSG_EOR;
510
511 break;
512 }
513 }
514
515 spin_unlock_bh(&vvs->rx_lock);
516
517 return total;
518 }
519
virtio_transport_seqpacket_do_dequeue(struct vsock_sock * vsk,struct msghdr * msg,int flags)520 static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,
521 struct msghdr *msg,
522 int flags)
523 {
524 struct virtio_vsock_sock *vvs = vsk->trans;
525 int dequeued_len = 0;
526 size_t user_buf_len = msg_data_left(msg);
527 bool msg_ready = false;
528 struct sk_buff *skb;
529
530 spin_lock_bh(&vvs->rx_lock);
531
532 if (vvs->msg_count == 0) {
533 spin_unlock_bh(&vvs->rx_lock);
534 return 0;
535 }
536
537 while (!msg_ready) {
538 struct virtio_vsock_hdr *hdr;
539 size_t pkt_len;
540
541 skb = __skb_dequeue(&vvs->rx_queue);
542 if (!skb)
543 break;
544 hdr = virtio_vsock_hdr(skb);
545 pkt_len = (size_t)le32_to_cpu(hdr->len);
546
547 if (dequeued_len >= 0) {
548 size_t bytes_to_copy;
549
550 bytes_to_copy = min(user_buf_len, pkt_len);
551
552 if (bytes_to_copy) {
553 int err;
554
555 /* sk_lock is held by caller so no one else can dequeue.
556 * Unlock rx_lock since memcpy_to_msg() may sleep.
557 */
558 spin_unlock_bh(&vvs->rx_lock);
559
560 err = memcpy_to_msg(msg, skb->data, bytes_to_copy);
561 if (err) {
562 /* Copy of message failed. Rest of
563 * fragments will be freed without copy.
564 */
565 dequeued_len = err;
566 } else {
567 user_buf_len -= bytes_to_copy;
568 }
569
570 spin_lock_bh(&vvs->rx_lock);
571 }
572
573 if (dequeued_len >= 0)
574 dequeued_len += pkt_len;
575 }
576
577 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) {
578 msg_ready = true;
579 vvs->msg_count--;
580
581 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOR)
582 msg->msg_flags |= MSG_EOR;
583 }
584
585 virtio_transport_dec_rx_pkt(vvs, pkt_len);
586 kfree_skb(skb);
587 }
588
589 spin_unlock_bh(&vvs->rx_lock);
590
591 virtio_transport_send_credit_update(vsk);
592
593 return dequeued_len;
594 }
595
596 ssize_t
virtio_transport_stream_dequeue(struct vsock_sock * vsk,struct msghdr * msg,size_t len,int flags)597 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
598 struct msghdr *msg,
599 size_t len, int flags)
600 {
601 if (flags & MSG_PEEK)
602 return virtio_transport_stream_do_peek(vsk, msg, len);
603 else
604 return virtio_transport_stream_do_dequeue(vsk, msg, len);
605 }
606 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
607
608 ssize_t
virtio_transport_seqpacket_dequeue(struct vsock_sock * vsk,struct msghdr * msg,int flags)609 virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
610 struct msghdr *msg,
611 int flags)
612 {
613 if (flags & MSG_PEEK)
614 return virtio_transport_seqpacket_do_peek(vsk, msg);
615 else
616 return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags);
617 }
618 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue);
619
620 int
virtio_transport_seqpacket_enqueue(struct vsock_sock * vsk,struct msghdr * msg,size_t len)621 virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk,
622 struct msghdr *msg,
623 size_t len)
624 {
625 struct virtio_vsock_sock *vvs = vsk->trans;
626
627 spin_lock_bh(&vvs->tx_lock);
628
629 if (len > vvs->peer_buf_alloc) {
630 spin_unlock_bh(&vvs->tx_lock);
631 return -EMSGSIZE;
632 }
633
634 spin_unlock_bh(&vvs->tx_lock);
635
636 return virtio_transport_stream_enqueue(vsk, msg, len);
637 }
638 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_enqueue);
639
640 int
virtio_transport_dgram_dequeue(struct vsock_sock * vsk,struct msghdr * msg,size_t len,int flags)641 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
642 struct msghdr *msg,
643 size_t len, int flags)
644 {
645 return -EOPNOTSUPP;
646 }
647 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
648
virtio_transport_stream_has_data(struct vsock_sock * vsk)649 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
650 {
651 struct virtio_vsock_sock *vvs = vsk->trans;
652 s64 bytes;
653
654 spin_lock_bh(&vvs->rx_lock);
655 bytes = vvs->rx_bytes;
656 spin_unlock_bh(&vvs->rx_lock);
657
658 return bytes;
659 }
660 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
661
virtio_transport_seqpacket_has_data(struct vsock_sock * vsk)662 u32 virtio_transport_seqpacket_has_data(struct vsock_sock *vsk)
663 {
664 struct virtio_vsock_sock *vvs = vsk->trans;
665 u32 msg_count;
666
667 spin_lock_bh(&vvs->rx_lock);
668 msg_count = vvs->msg_count;
669 spin_unlock_bh(&vvs->rx_lock);
670
671 return msg_count;
672 }
673 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_has_data);
674
virtio_transport_has_space(struct vsock_sock * vsk)675 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
676 {
677 struct virtio_vsock_sock *vvs = vsk->trans;
678 s64 bytes;
679
680 bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
681 if (bytes < 0)
682 bytes = 0;
683
684 return bytes;
685 }
686
virtio_transport_stream_has_space(struct vsock_sock * vsk)687 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
688 {
689 struct virtio_vsock_sock *vvs = vsk->trans;
690 s64 bytes;
691
692 spin_lock_bh(&vvs->tx_lock);
693 bytes = virtio_transport_has_space(vsk);
694 spin_unlock_bh(&vvs->tx_lock);
695
696 return bytes;
697 }
698 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
699
virtio_transport_do_socket_init(struct vsock_sock * vsk,struct vsock_sock * psk)700 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
701 struct vsock_sock *psk)
702 {
703 struct virtio_vsock_sock *vvs;
704
705 vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
706 if (!vvs)
707 return -ENOMEM;
708
709 vsk->trans = vvs;
710 vvs->vsk = vsk;
711 if (psk && psk->trans) {
712 struct virtio_vsock_sock *ptrans = psk->trans;
713
714 vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
715 }
716
717 if (vsk->buffer_size > VIRTIO_VSOCK_MAX_BUF_SIZE)
718 vsk->buffer_size = VIRTIO_VSOCK_MAX_BUF_SIZE;
719
720 vvs->buf_alloc = vsk->buffer_size;
721
722 spin_lock_init(&vvs->rx_lock);
723 spin_lock_init(&vvs->tx_lock);
724 skb_queue_head_init(&vvs->rx_queue);
725
726 return 0;
727 }
728 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
729
730 /* sk_lock held by the caller */
virtio_transport_notify_buffer_size(struct vsock_sock * vsk,u64 * val)731 void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val)
732 {
733 struct virtio_vsock_sock *vvs = vsk->trans;
734
735 if (*val > VIRTIO_VSOCK_MAX_BUF_SIZE)
736 *val = VIRTIO_VSOCK_MAX_BUF_SIZE;
737
738 vvs->buf_alloc = *val;
739
740 virtio_transport_send_credit_update(vsk);
741 }
742 EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size);
743
744 int
virtio_transport_notify_poll_in(struct vsock_sock * vsk,size_t target,bool * data_ready_now)745 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
746 size_t target,
747 bool *data_ready_now)
748 {
749 *data_ready_now = vsock_stream_has_data(vsk) >= target;
750
751 return 0;
752 }
753 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
754
755 int
virtio_transport_notify_poll_out(struct vsock_sock * vsk,size_t target,bool * space_avail_now)756 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
757 size_t target,
758 bool *space_avail_now)
759 {
760 s64 free_space;
761
762 free_space = vsock_stream_has_space(vsk);
763 if (free_space > 0)
764 *space_avail_now = true;
765 else if (free_space == 0)
766 *space_avail_now = false;
767
768 return 0;
769 }
770 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
771
virtio_transport_notify_recv_init(struct vsock_sock * vsk,size_t target,struct vsock_transport_recv_notify_data * data)772 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
773 size_t target, struct vsock_transport_recv_notify_data *data)
774 {
775 return 0;
776 }
777 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
778
virtio_transport_notify_recv_pre_block(struct vsock_sock * vsk,size_t target,struct vsock_transport_recv_notify_data * data)779 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
780 size_t target, struct vsock_transport_recv_notify_data *data)
781 {
782 return 0;
783 }
784 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
785
virtio_transport_notify_recv_pre_dequeue(struct vsock_sock * vsk,size_t target,struct vsock_transport_recv_notify_data * data)786 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
787 size_t target, struct vsock_transport_recv_notify_data *data)
788 {
789 return 0;
790 }
791 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
792
virtio_transport_notify_recv_post_dequeue(struct vsock_sock * vsk,size_t target,ssize_t copied,bool data_read,struct vsock_transport_recv_notify_data * data)793 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
794 size_t target, ssize_t copied, bool data_read,
795 struct vsock_transport_recv_notify_data *data)
796 {
797 return 0;
798 }
799 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
800
virtio_transport_notify_send_init(struct vsock_sock * vsk,struct vsock_transport_send_notify_data * data)801 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
802 struct vsock_transport_send_notify_data *data)
803 {
804 return 0;
805 }
806 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
807
virtio_transport_notify_send_pre_block(struct vsock_sock * vsk,struct vsock_transport_send_notify_data * data)808 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
809 struct vsock_transport_send_notify_data *data)
810 {
811 return 0;
812 }
813 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
814
virtio_transport_notify_send_pre_enqueue(struct vsock_sock * vsk,struct vsock_transport_send_notify_data * data)815 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
816 struct vsock_transport_send_notify_data *data)
817 {
818 return 0;
819 }
820 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
821
virtio_transport_notify_send_post_enqueue(struct vsock_sock * vsk,ssize_t written,struct vsock_transport_send_notify_data * data)822 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
823 ssize_t written, struct vsock_transport_send_notify_data *data)
824 {
825 return 0;
826 }
827 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
828
virtio_transport_stream_rcvhiwat(struct vsock_sock * vsk)829 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
830 {
831 return vsk->buffer_size;
832 }
833 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
834
virtio_transport_stream_is_active(struct vsock_sock * vsk)835 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
836 {
837 return true;
838 }
839 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
840
virtio_transport_stream_allow(u32 cid,u32 port)841 bool virtio_transport_stream_allow(u32 cid, u32 port)
842 {
843 return true;
844 }
845 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
846
virtio_transport_dgram_bind(struct vsock_sock * vsk,struct sockaddr_vm * addr)847 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
848 struct sockaddr_vm *addr)
849 {
850 return -EOPNOTSUPP;
851 }
852 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
853
virtio_transport_dgram_allow(u32 cid,u32 port)854 bool virtio_transport_dgram_allow(u32 cid, u32 port)
855 {
856 return false;
857 }
858 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
859
virtio_transport_connect(struct vsock_sock * vsk)860 int virtio_transport_connect(struct vsock_sock *vsk)
861 {
862 struct virtio_vsock_pkt_info info = {
863 .op = VIRTIO_VSOCK_OP_REQUEST,
864 .vsk = vsk,
865 };
866
867 return virtio_transport_send_pkt_info(vsk, &info);
868 }
869 EXPORT_SYMBOL_GPL(virtio_transport_connect);
870
virtio_transport_shutdown(struct vsock_sock * vsk,int mode)871 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
872 {
873 struct virtio_vsock_pkt_info info = {
874 .op = VIRTIO_VSOCK_OP_SHUTDOWN,
875 .flags = (mode & RCV_SHUTDOWN ?
876 VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
877 (mode & SEND_SHUTDOWN ?
878 VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
879 .vsk = vsk,
880 };
881
882 return virtio_transport_send_pkt_info(vsk, &info);
883 }
884 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
885
886 int
virtio_transport_dgram_enqueue(struct vsock_sock * vsk,struct sockaddr_vm * remote_addr,struct msghdr * msg,size_t dgram_len)887 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
888 struct sockaddr_vm *remote_addr,
889 struct msghdr *msg,
890 size_t dgram_len)
891 {
892 return -EOPNOTSUPP;
893 }
894 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
895
896 ssize_t
virtio_transport_stream_enqueue(struct vsock_sock * vsk,struct msghdr * msg,size_t len)897 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
898 struct msghdr *msg,
899 size_t len)
900 {
901 struct virtio_vsock_pkt_info info = {
902 .op = VIRTIO_VSOCK_OP_RW,
903 .msg = msg,
904 .pkt_len = len,
905 .vsk = vsk,
906 };
907
908 return virtio_transport_send_pkt_info(vsk, &info);
909 }
910 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
911
virtio_transport_destruct(struct vsock_sock * vsk)912 void virtio_transport_destruct(struct vsock_sock *vsk)
913 {
914 struct virtio_vsock_sock *vvs = vsk->trans;
915
916 kfree(vvs);
917 }
918 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
919
virtio_transport_reset(struct vsock_sock * vsk,struct sk_buff * skb)920 static int virtio_transport_reset(struct vsock_sock *vsk,
921 struct sk_buff *skb)
922 {
923 struct virtio_vsock_pkt_info info = {
924 .op = VIRTIO_VSOCK_OP_RST,
925 .reply = !!skb,
926 .vsk = vsk,
927 };
928
929 /* Send RST only if the original pkt is not a RST pkt */
930 if (skb && le16_to_cpu(virtio_vsock_hdr(skb)->op) == VIRTIO_VSOCK_OP_RST)
931 return 0;
932
933 return virtio_transport_send_pkt_info(vsk, &info);
934 }
935
936 /* Normally packets are associated with a socket. There may be no socket if an
937 * attempt was made to connect to a socket that does not exist.
938 */
virtio_transport_reset_no_sock(const struct virtio_transport * t,struct sk_buff * skb)939 static int virtio_transport_reset_no_sock(const struct virtio_transport *t,
940 struct sk_buff *skb)
941 {
942 struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
943 struct virtio_vsock_pkt_info info = {
944 .op = VIRTIO_VSOCK_OP_RST,
945 .type = le16_to_cpu(hdr->type),
946 .reply = true,
947 };
948 struct sk_buff *reply;
949
950 /* Send RST only if the original pkt is not a RST pkt */
951 if (le16_to_cpu(hdr->op) == VIRTIO_VSOCK_OP_RST)
952 return 0;
953
954 if (!t)
955 return -ENOTCONN;
956
957 reply = virtio_transport_alloc_skb(&info, 0,
958 le64_to_cpu(hdr->dst_cid),
959 le32_to_cpu(hdr->dst_port),
960 le64_to_cpu(hdr->src_cid),
961 le32_to_cpu(hdr->src_port));
962 if (!reply)
963 return -ENOMEM;
964
965 return t->send_pkt(reply);
966 }
967
968 /* This function should be called with sk_lock held and SOCK_DONE set */
virtio_transport_remove_sock(struct vsock_sock * vsk)969 static void virtio_transport_remove_sock(struct vsock_sock *vsk)
970 {
971 struct virtio_vsock_sock *vvs = vsk->trans;
972
973 /* We don't need to take rx_lock, as the socket is closing and we are
974 * removing it.
975 */
976 __skb_queue_purge(&vvs->rx_queue);
977 vsock_remove_sock(vsk);
978 }
979
virtio_transport_wait_close(struct sock * sk,long timeout)980 static void virtio_transport_wait_close(struct sock *sk, long timeout)
981 {
982 if (timeout) {
983 DEFINE_WAIT_FUNC(wait, woken_wake_function);
984
985 add_wait_queue(sk_sleep(sk), &wait);
986
987 do {
988 if (sk_wait_event(sk, &timeout,
989 sock_flag(sk, SOCK_DONE), &wait))
990 break;
991 } while (!signal_pending(current) && timeout);
992
993 remove_wait_queue(sk_sleep(sk), &wait);
994 }
995 }
996
virtio_transport_do_close(struct vsock_sock * vsk,bool cancel_timeout)997 static void virtio_transport_do_close(struct vsock_sock *vsk,
998 bool cancel_timeout)
999 {
1000 struct sock *sk = sk_vsock(vsk);
1001
1002 sock_set_flag(sk, SOCK_DONE);
1003 vsk->peer_shutdown = SHUTDOWN_MASK;
1004 if (vsock_stream_has_data(vsk) <= 0)
1005 sk->sk_state = TCP_CLOSING;
1006 sk->sk_state_change(sk);
1007
1008 if (vsk->close_work_scheduled &&
1009 (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
1010 vsk->close_work_scheduled = false;
1011
1012 virtio_transport_remove_sock(vsk);
1013
1014 /* Release refcnt obtained when we scheduled the timeout */
1015 sock_put(sk);
1016 }
1017 }
1018
virtio_transport_close_timeout(struct work_struct * work)1019 static void virtio_transport_close_timeout(struct work_struct *work)
1020 {
1021 struct vsock_sock *vsk =
1022 container_of(work, struct vsock_sock, close_work.work);
1023 struct sock *sk = sk_vsock(vsk);
1024
1025 sock_hold(sk);
1026 lock_sock(sk);
1027
1028 if (!sock_flag(sk, SOCK_DONE)) {
1029 (void)virtio_transport_reset(vsk, NULL);
1030
1031 virtio_transport_do_close(vsk, false);
1032 }
1033
1034 vsk->close_work_scheduled = false;
1035
1036 release_sock(sk);
1037 sock_put(sk);
1038 }
1039
1040 /* User context, vsk->sk is locked */
virtio_transport_close(struct vsock_sock * vsk)1041 static bool virtio_transport_close(struct vsock_sock *vsk)
1042 {
1043 struct sock *sk = &vsk->sk;
1044
1045 if (!(sk->sk_state == TCP_ESTABLISHED ||
1046 sk->sk_state == TCP_CLOSING))
1047 return true;
1048
1049 /* Already received SHUTDOWN from peer, reply with RST */
1050 if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
1051 (void)virtio_transport_reset(vsk, NULL);
1052 return true;
1053 }
1054
1055 if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
1056 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
1057
1058 if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
1059 virtio_transport_wait_close(sk, sk->sk_lingertime);
1060
1061 if (sock_flag(sk, SOCK_DONE)) {
1062 return true;
1063 }
1064
1065 sock_hold(sk);
1066 INIT_DELAYED_WORK(&vsk->close_work,
1067 virtio_transport_close_timeout);
1068 vsk->close_work_scheduled = true;
1069 schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
1070 return false;
1071 }
1072
virtio_transport_release(struct vsock_sock * vsk)1073 void virtio_transport_release(struct vsock_sock *vsk)
1074 {
1075 struct sock *sk = &vsk->sk;
1076 bool remove_sock = true;
1077
1078 if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)
1079 remove_sock = virtio_transport_close(vsk);
1080
1081 if (remove_sock) {
1082 sock_set_flag(sk, SOCK_DONE);
1083 virtio_transport_remove_sock(vsk);
1084 }
1085 }
1086 EXPORT_SYMBOL_GPL(virtio_transport_release);
1087
1088 static int
virtio_transport_recv_connecting(struct sock * sk,struct sk_buff * skb)1089 virtio_transport_recv_connecting(struct sock *sk,
1090 struct sk_buff *skb)
1091 {
1092 struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1093 struct vsock_sock *vsk = vsock_sk(sk);
1094 int skerr;
1095 int err;
1096
1097 switch (le16_to_cpu(hdr->op)) {
1098 case VIRTIO_VSOCK_OP_RESPONSE:
1099 sk->sk_state = TCP_ESTABLISHED;
1100 sk->sk_socket->state = SS_CONNECTED;
1101 vsock_insert_connected(vsk);
1102 sk->sk_state_change(sk);
1103 break;
1104 case VIRTIO_VSOCK_OP_INVALID:
1105 break;
1106 case VIRTIO_VSOCK_OP_RST:
1107 skerr = ECONNRESET;
1108 err = 0;
1109 goto destroy;
1110 default:
1111 skerr = EPROTO;
1112 err = -EINVAL;
1113 goto destroy;
1114 }
1115 return 0;
1116
1117 destroy:
1118 virtio_transport_reset(vsk, skb);
1119 sk->sk_state = TCP_CLOSE;
1120 sk->sk_err = skerr;
1121 sk_error_report(sk);
1122 return err;
1123 }
1124
1125 static void
virtio_transport_recv_enqueue(struct vsock_sock * vsk,struct sk_buff * skb)1126 virtio_transport_recv_enqueue(struct vsock_sock *vsk,
1127 struct sk_buff *skb)
1128 {
1129 struct virtio_vsock_sock *vvs = vsk->trans;
1130 bool can_enqueue, free_pkt = false;
1131 struct virtio_vsock_hdr *hdr;
1132 u32 len;
1133
1134 hdr = virtio_vsock_hdr(skb);
1135 len = le32_to_cpu(hdr->len);
1136
1137 spin_lock_bh(&vvs->rx_lock);
1138
1139 can_enqueue = virtio_transport_inc_rx_pkt(vvs, len);
1140 if (!can_enqueue) {
1141 free_pkt = true;
1142 goto out;
1143 }
1144
1145 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM)
1146 vvs->msg_count++;
1147
1148 /* Try to copy small packets into the buffer of last packet queued,
1149 * to avoid wasting memory queueing the entire buffer with a small
1150 * payload.
1151 */
1152 if (len <= GOOD_COPY_LEN && !skb_queue_empty(&vvs->rx_queue)) {
1153 struct virtio_vsock_hdr *last_hdr;
1154 struct sk_buff *last_skb;
1155
1156 last_skb = skb_peek_tail(&vvs->rx_queue);
1157 last_hdr = virtio_vsock_hdr(last_skb);
1158
1159 /* If there is space in the last packet queued, we copy the
1160 * new packet in its buffer. We avoid this if the last packet
1161 * queued has VIRTIO_VSOCK_SEQ_EOM set, because this is
1162 * delimiter of SEQPACKET message, so 'pkt' is the first packet
1163 * of a new message.
1164 */
1165 if (skb->len < skb_tailroom(last_skb) &&
1166 !(le32_to_cpu(last_hdr->flags) & VIRTIO_VSOCK_SEQ_EOM)) {
1167 memcpy(skb_put(last_skb, skb->len), skb->data, skb->len);
1168 free_pkt = true;
1169 last_hdr->flags |= hdr->flags;
1170 le32_add_cpu(&last_hdr->len, len);
1171 goto out;
1172 }
1173 }
1174
1175 __skb_queue_tail(&vvs->rx_queue, skb);
1176
1177 out:
1178 spin_unlock_bh(&vvs->rx_lock);
1179 if (free_pkt)
1180 kfree_skb(skb);
1181 }
1182
1183 static int
virtio_transport_recv_connected(struct sock * sk,struct sk_buff * skb)1184 virtio_transport_recv_connected(struct sock *sk,
1185 struct sk_buff *skb)
1186 {
1187 struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1188 struct vsock_sock *vsk = vsock_sk(sk);
1189 int err = 0;
1190
1191 switch (le16_to_cpu(hdr->op)) {
1192 case VIRTIO_VSOCK_OP_RW:
1193 virtio_transport_recv_enqueue(vsk, skb);
1194 vsock_data_ready(sk);
1195 return err;
1196 case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
1197 virtio_transport_send_credit_update(vsk);
1198 break;
1199 case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
1200 sk->sk_write_space(sk);
1201 break;
1202 case VIRTIO_VSOCK_OP_SHUTDOWN:
1203 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
1204 vsk->peer_shutdown |= RCV_SHUTDOWN;
1205 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
1206 vsk->peer_shutdown |= SEND_SHUTDOWN;
1207 if (vsk->peer_shutdown == SHUTDOWN_MASK &&
1208 vsock_stream_has_data(vsk) <= 0 &&
1209 !sock_flag(sk, SOCK_DONE)) {
1210 (void)virtio_transport_reset(vsk, NULL);
1211 virtio_transport_do_close(vsk, true);
1212 }
1213 if (le32_to_cpu(virtio_vsock_hdr(skb)->flags))
1214 sk->sk_state_change(sk);
1215 break;
1216 case VIRTIO_VSOCK_OP_RST:
1217 virtio_transport_do_close(vsk, true);
1218 break;
1219 default:
1220 err = -EINVAL;
1221 break;
1222 }
1223
1224 kfree_skb(skb);
1225 return err;
1226 }
1227
1228 static void
virtio_transport_recv_disconnecting(struct sock * sk,struct sk_buff * skb)1229 virtio_transport_recv_disconnecting(struct sock *sk,
1230 struct sk_buff *skb)
1231 {
1232 struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1233 struct vsock_sock *vsk = vsock_sk(sk);
1234
1235 if (le16_to_cpu(hdr->op) == VIRTIO_VSOCK_OP_RST)
1236 virtio_transport_do_close(vsk, true);
1237 }
1238
1239 static int
virtio_transport_send_response(struct vsock_sock * vsk,struct sk_buff * skb)1240 virtio_transport_send_response(struct vsock_sock *vsk,
1241 struct sk_buff *skb)
1242 {
1243 struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1244 struct virtio_vsock_pkt_info info = {
1245 .op = VIRTIO_VSOCK_OP_RESPONSE,
1246 .remote_cid = le64_to_cpu(hdr->src_cid),
1247 .remote_port = le32_to_cpu(hdr->src_port),
1248 .reply = true,
1249 .vsk = vsk,
1250 };
1251
1252 return virtio_transport_send_pkt_info(vsk, &info);
1253 }
1254
virtio_transport_space_update(struct sock * sk,struct sk_buff * skb)1255 static bool virtio_transport_space_update(struct sock *sk,
1256 struct sk_buff *skb)
1257 {
1258 struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1259 struct vsock_sock *vsk = vsock_sk(sk);
1260 struct virtio_vsock_sock *vvs = vsk->trans;
1261 bool space_available;
1262
1263 /* Listener sockets are not associated with any transport, so we are
1264 * not able to take the state to see if there is space available in the
1265 * remote peer, but since they are only used to receive requests, we
1266 * can assume that there is always space available in the other peer.
1267 */
1268 if (!vvs)
1269 return true;
1270
1271 /* buf_alloc and fwd_cnt is always included in the hdr */
1272 spin_lock_bh(&vvs->tx_lock);
1273 vvs->peer_buf_alloc = le32_to_cpu(hdr->buf_alloc);
1274 vvs->peer_fwd_cnt = le32_to_cpu(hdr->fwd_cnt);
1275 space_available = virtio_transport_has_space(vsk);
1276 spin_unlock_bh(&vvs->tx_lock);
1277 return space_available;
1278 }
1279
1280 /* Handle server socket */
1281 static int
virtio_transport_recv_listen(struct sock * sk,struct sk_buff * skb,struct virtio_transport * t)1282 virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
1283 struct virtio_transport *t)
1284 {
1285 struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1286 struct vsock_sock *vsk = vsock_sk(sk);
1287 struct vsock_sock *vchild;
1288 struct sock *child;
1289 int ret;
1290
1291 if (le16_to_cpu(hdr->op) != VIRTIO_VSOCK_OP_REQUEST) {
1292 virtio_transport_reset_no_sock(t, skb);
1293 return -EINVAL;
1294 }
1295
1296 if (sk_acceptq_is_full(sk)) {
1297 virtio_transport_reset_no_sock(t, skb);
1298 return -ENOMEM;
1299 }
1300
1301 child = vsock_create_connected(sk);
1302 if (!child) {
1303 virtio_transport_reset_no_sock(t, skb);
1304 return -ENOMEM;
1305 }
1306
1307 sk_acceptq_added(sk);
1308
1309 lock_sock_nested(child, SINGLE_DEPTH_NESTING);
1310
1311 child->sk_state = TCP_ESTABLISHED;
1312
1313 vchild = vsock_sk(child);
1314 vsock_addr_init(&vchild->local_addr, le64_to_cpu(hdr->dst_cid),
1315 le32_to_cpu(hdr->dst_port));
1316 vsock_addr_init(&vchild->remote_addr, le64_to_cpu(hdr->src_cid),
1317 le32_to_cpu(hdr->src_port));
1318
1319 ret = vsock_assign_transport(vchild, vsk);
1320 /* Transport assigned (looking at remote_addr) must be the same
1321 * where we received the request.
1322 */
1323 if (ret || vchild->transport != &t->transport) {
1324 release_sock(child);
1325 virtio_transport_reset_no_sock(t, skb);
1326 sock_put(child);
1327 return ret;
1328 }
1329
1330 if (virtio_transport_space_update(child, skb))
1331 child->sk_write_space(child);
1332
1333 vsock_insert_connected(vchild);
1334 vsock_enqueue_accept(sk, child);
1335 virtio_transport_send_response(vchild, skb);
1336
1337 release_sock(child);
1338
1339 sk->sk_data_ready(sk);
1340 return 0;
1341 }
1342
virtio_transport_valid_type(u16 type)1343 static bool virtio_transport_valid_type(u16 type)
1344 {
1345 return (type == VIRTIO_VSOCK_TYPE_STREAM) ||
1346 (type == VIRTIO_VSOCK_TYPE_SEQPACKET);
1347 }
1348
1349 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
1350 * lock.
1351 */
virtio_transport_recv_pkt(struct virtio_transport * t,struct sk_buff * skb)1352 void virtio_transport_recv_pkt(struct virtio_transport *t,
1353 struct sk_buff *skb)
1354 {
1355 struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1356 struct sockaddr_vm src, dst;
1357 struct vsock_sock *vsk;
1358 struct sock *sk;
1359 bool space_available;
1360
1361 vsock_addr_init(&src, le64_to_cpu(hdr->src_cid),
1362 le32_to_cpu(hdr->src_port));
1363 vsock_addr_init(&dst, le64_to_cpu(hdr->dst_cid),
1364 le32_to_cpu(hdr->dst_port));
1365
1366 trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
1367 dst.svm_cid, dst.svm_port,
1368 le32_to_cpu(hdr->len),
1369 le16_to_cpu(hdr->type),
1370 le16_to_cpu(hdr->op),
1371 le32_to_cpu(hdr->flags),
1372 le32_to_cpu(hdr->buf_alloc),
1373 le32_to_cpu(hdr->fwd_cnt));
1374
1375 if (!virtio_transport_valid_type(le16_to_cpu(hdr->type))) {
1376 (void)virtio_transport_reset_no_sock(t, skb);
1377 goto free_pkt;
1378 }
1379
1380 /* The socket must be in connected or bound table
1381 * otherwise send reset back
1382 */
1383 sk = vsock_find_connected_socket(&src, &dst);
1384 if (!sk) {
1385 sk = vsock_find_bound_socket(&dst);
1386 if (!sk) {
1387 (void)virtio_transport_reset_no_sock(t, skb);
1388 goto free_pkt;
1389 }
1390 }
1391
1392 if (virtio_transport_get_type(sk) != le16_to_cpu(hdr->type)) {
1393 (void)virtio_transport_reset_no_sock(t, skb);
1394 sock_put(sk);
1395 goto free_pkt;
1396 }
1397
1398 if (!skb_set_owner_sk_safe(skb, sk)) {
1399 WARN_ONCE(1, "receiving vsock socket has sk_refcnt == 0\n");
1400 goto free_pkt;
1401 }
1402
1403 vsk = vsock_sk(sk);
1404
1405 lock_sock(sk);
1406
1407 /* Check if sk has been closed before lock_sock */
1408 if (sock_flag(sk, SOCK_DONE)) {
1409 (void)virtio_transport_reset_no_sock(t, skb);
1410 release_sock(sk);
1411 sock_put(sk);
1412 goto free_pkt;
1413 }
1414
1415 space_available = virtio_transport_space_update(sk, skb);
1416
1417 /* Update CID in case it has changed after a transport reset event */
1418 if (vsk->local_addr.svm_cid != VMADDR_CID_ANY)
1419 vsk->local_addr.svm_cid = dst.svm_cid;
1420
1421 if (space_available)
1422 sk->sk_write_space(sk);
1423
1424 switch (sk->sk_state) {
1425 case TCP_LISTEN:
1426 virtio_transport_recv_listen(sk, skb, t);
1427 kfree_skb(skb);
1428 break;
1429 case TCP_SYN_SENT:
1430 virtio_transport_recv_connecting(sk, skb);
1431 kfree_skb(skb);
1432 break;
1433 case TCP_ESTABLISHED:
1434 virtio_transport_recv_connected(sk, skb);
1435 break;
1436 case TCP_CLOSING:
1437 virtio_transport_recv_disconnecting(sk, skb);
1438 kfree_skb(skb);
1439 break;
1440 default:
1441 (void)virtio_transport_reset_no_sock(t, skb);
1442 kfree_skb(skb);
1443 break;
1444 }
1445
1446 release_sock(sk);
1447
1448 /* Release refcnt obtained when we fetched this socket out of the
1449 * bound or connected list.
1450 */
1451 sock_put(sk);
1452 return;
1453
1454 free_pkt:
1455 kfree_skb(skb);
1456 }
1457 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
1458
1459 /* Remove skbs found in a queue that have a vsk that matches.
1460 *
1461 * Each skb is freed.
1462 *
1463 * Returns the count of skbs that were reply packets.
1464 */
virtio_transport_purge_skbs(void * vsk,struct sk_buff_head * queue)1465 int virtio_transport_purge_skbs(void *vsk, struct sk_buff_head *queue)
1466 {
1467 struct sk_buff_head freeme;
1468 struct sk_buff *skb, *tmp;
1469 int cnt = 0;
1470
1471 skb_queue_head_init(&freeme);
1472
1473 spin_lock_bh(&queue->lock);
1474 skb_queue_walk_safe(queue, skb, tmp) {
1475 if (vsock_sk(skb->sk) != vsk)
1476 continue;
1477
1478 __skb_unlink(skb, queue);
1479 __skb_queue_tail(&freeme, skb);
1480
1481 if (virtio_vsock_skb_reply(skb))
1482 cnt++;
1483 }
1484 spin_unlock_bh(&queue->lock);
1485
1486 __skb_queue_purge(&freeme);
1487
1488 return cnt;
1489 }
1490 EXPORT_SYMBOL_GPL(virtio_transport_purge_skbs);
1491
virtio_transport_read_skb(struct vsock_sock * vsk,skb_read_actor_t recv_actor)1492 int virtio_transport_read_skb(struct vsock_sock *vsk, skb_read_actor_t recv_actor)
1493 {
1494 struct virtio_vsock_sock *vvs = vsk->trans;
1495 struct sock *sk = sk_vsock(vsk);
1496 struct sk_buff *skb;
1497 int off = 0;
1498 int err;
1499
1500 spin_lock_bh(&vvs->rx_lock);
1501 /* Use __skb_recv_datagram() for race-free handling of the receive. It
1502 * works for types other than dgrams.
1503 */
1504 skb = __skb_recv_datagram(sk, &vvs->rx_queue, MSG_DONTWAIT, &off, &err);
1505 spin_unlock_bh(&vvs->rx_lock);
1506
1507 if (!skb)
1508 return err;
1509
1510 return recv_actor(sk, skb);
1511 }
1512 EXPORT_SYMBOL_GPL(virtio_transport_read_skb);
1513
1514 MODULE_LICENSE("GPL v2");
1515 MODULE_AUTHOR("Asias He");
1516 MODULE_DESCRIPTION("common code for virtio vsock");
1517