1  // SPDX-License-Identifier: GPL-2.0-only
2  /*
3   * Stream Parser
4   *
5   * Copyright (c) 2016 Tom Herbert <tom@herbertland.com>
6   */
7  
8  #include <linux/bpf.h>
9  #include <linux/errno.h>
10  #include <linux/errqueue.h>
11  #include <linux/file.h>
12  #include <linux/in.h>
13  #include <linux/kernel.h>
14  #include <linux/export.h>
15  #include <linux/init.h>
16  #include <linux/net.h>
17  #include <linux/netdevice.h>
18  #include <linux/poll.h>
19  #include <linux/rculist.h>
20  #include <linux/skbuff.h>
21  #include <linux/socket.h>
22  #include <linux/uaccess.h>
23  #include <linux/workqueue.h>
24  #include <net/strparser.h>
25  #include <net/netns/generic.h>
26  #include <net/sock.h>
27  
28  static struct workqueue_struct *strp_wq;
29  
30  struct _strp_msg {
31  	/* Internal cb structure. struct strp_msg must be first for passing
32  	 * to upper layer.
33  	 */
34  	struct strp_msg strp;
35  	int accum_len;
36  };
37  
_strp_msg(struct sk_buff * skb)38  static inline struct _strp_msg *_strp_msg(struct sk_buff *skb)
39  {
40  	return (struct _strp_msg *)((void *)skb->cb +
41  		offsetof(struct qdisc_skb_cb, data));
42  }
43  
44  /* Lower lock held */
strp_abort_strp(struct strparser * strp,int err)45  static void strp_abort_strp(struct strparser *strp, int err)
46  {
47  	/* Unrecoverable error in receive */
48  
49  	cancel_delayed_work(&strp->msg_timer_work);
50  
51  	if (strp->stopped)
52  		return;
53  
54  	strp->stopped = 1;
55  
56  	if (strp->sk) {
57  		struct sock *sk = strp->sk;
58  
59  		/* Report an error on the lower socket */
60  		sk->sk_err = -err;
61  		sk->sk_error_report(sk);
62  	}
63  }
64  
strp_start_timer(struct strparser * strp,long timeo)65  static void strp_start_timer(struct strparser *strp, long timeo)
66  {
67  	if (timeo && timeo != LONG_MAX)
68  		mod_delayed_work(strp_wq, &strp->msg_timer_work, timeo);
69  }
70  
71  /* Lower lock held */
strp_parser_err(struct strparser * strp,int err,read_descriptor_t * desc)72  static void strp_parser_err(struct strparser *strp, int err,
73  			    read_descriptor_t *desc)
74  {
75  	desc->error = err;
76  	kfree_skb(strp->skb_head);
77  	strp->skb_head = NULL;
78  	strp->cb.abort_parser(strp, err);
79  }
80  
strp_peek_len(struct strparser * strp)81  static inline int strp_peek_len(struct strparser *strp)
82  {
83  	if (strp->sk) {
84  		struct socket *sock = strp->sk->sk_socket;
85  
86  		return sock->ops->peek_len(sock);
87  	}
88  
89  	/* If we don't have an associated socket there's nothing to peek.
90  	 * Return int max to avoid stopping the strparser.
91  	 */
92  
93  	return INT_MAX;
94  }
95  
96  /* Lower socket lock held */
__strp_recv(read_descriptor_t * desc,struct sk_buff * orig_skb,unsigned int orig_offset,size_t orig_len,size_t max_msg_size,long timeo)97  static int __strp_recv(read_descriptor_t *desc, struct sk_buff *orig_skb,
98  		       unsigned int orig_offset, size_t orig_len,
99  		       size_t max_msg_size, long timeo)
100  {
101  	struct strparser *strp = (struct strparser *)desc->arg.data;
102  	struct _strp_msg *stm;
103  	struct sk_buff *head, *skb;
104  	size_t eaten = 0, cand_len;
105  	ssize_t extra;
106  	int err;
107  	bool cloned_orig = false;
108  
109  	if (strp->paused)
110  		return 0;
111  
112  	head = strp->skb_head;
113  	if (head) {
114  		/* Message already in progress */
115  		if (unlikely(orig_offset)) {
116  			/* Getting data with a non-zero offset when a message is
117  			 * in progress is not expected. If it does happen, we
118  			 * need to clone and pull since we can't deal with
119  			 * offsets in the skbs for a message expect in the head.
120  			 */
121  			orig_skb = skb_clone(orig_skb, GFP_ATOMIC);
122  			if (!orig_skb) {
123  				STRP_STATS_INCR(strp->stats.mem_fail);
124  				desc->error = -ENOMEM;
125  				return 0;
126  			}
127  			if (!pskb_pull(orig_skb, orig_offset)) {
128  				STRP_STATS_INCR(strp->stats.mem_fail);
129  				kfree_skb(orig_skb);
130  				desc->error = -ENOMEM;
131  				return 0;
132  			}
133  			cloned_orig = true;
134  			orig_offset = 0;
135  		}
136  
137  		if (!strp->skb_nextp) {
138  			/* We are going to append to the frags_list of head.
139  			 * Need to unshare the frag_list.
140  			 */
141  			err = skb_unclone(head, GFP_ATOMIC);
142  			if (err) {
143  				STRP_STATS_INCR(strp->stats.mem_fail);
144  				desc->error = err;
145  				return 0;
146  			}
147  
148  			if (unlikely(skb_shinfo(head)->frag_list)) {
149  				/* We can't append to an sk_buff that already
150  				 * has a frag_list. We create a new head, point
151  				 * the frag_list of that to the old head, and
152  				 * then are able to use the old head->next for
153  				 * appending to the message.
154  				 */
155  				if (WARN_ON(head->next)) {
156  					desc->error = -EINVAL;
157  					return 0;
158  				}
159  
160  				skb = alloc_skb_for_msg(head);
161  				if (!skb) {
162  					STRP_STATS_INCR(strp->stats.mem_fail);
163  					desc->error = -ENOMEM;
164  					return 0;
165  				}
166  
167  				strp->skb_nextp = &head->next;
168  				strp->skb_head = skb;
169  				head = skb;
170  			} else {
171  				strp->skb_nextp =
172  				    &skb_shinfo(head)->frag_list;
173  			}
174  		}
175  	}
176  
177  	while (eaten < orig_len) {
178  		/* Always clone since we will consume something */
179  		skb = skb_clone(orig_skb, GFP_ATOMIC);
180  		if (!skb) {
181  			STRP_STATS_INCR(strp->stats.mem_fail);
182  			desc->error = -ENOMEM;
183  			break;
184  		}
185  
186  		cand_len = orig_len - eaten;
187  
188  		head = strp->skb_head;
189  		if (!head) {
190  			head = skb;
191  			strp->skb_head = head;
192  			/* Will set skb_nextp on next packet if needed */
193  			strp->skb_nextp = NULL;
194  			stm = _strp_msg(head);
195  			memset(stm, 0, sizeof(*stm));
196  			stm->strp.offset = orig_offset + eaten;
197  		} else {
198  			/* Unclone if we are appending to an skb that we
199  			 * already share a frag_list with.
200  			 */
201  			if (skb_has_frag_list(skb)) {
202  				err = skb_unclone(skb, GFP_ATOMIC);
203  				if (err) {
204  					STRP_STATS_INCR(strp->stats.mem_fail);
205  					desc->error = err;
206  					break;
207  				}
208  			}
209  
210  			stm = _strp_msg(head);
211  			*strp->skb_nextp = skb;
212  			strp->skb_nextp = &skb->next;
213  			head->data_len += skb->len;
214  			head->len += skb->len;
215  			head->truesize += skb->truesize;
216  		}
217  
218  		if (!stm->strp.full_len) {
219  			ssize_t len;
220  
221  			len = (*strp->cb.parse_msg)(strp, head);
222  
223  			if (!len) {
224  				/* Need more header to determine length */
225  				if (!stm->accum_len) {
226  					/* Start RX timer for new message */
227  					strp_start_timer(strp, timeo);
228  				}
229  				stm->accum_len += cand_len;
230  				eaten += cand_len;
231  				STRP_STATS_INCR(strp->stats.need_more_hdr);
232  				WARN_ON(eaten != orig_len);
233  				break;
234  			} else if (len < 0) {
235  				if (len == -ESTRPIPE && stm->accum_len) {
236  					len = -ENODATA;
237  					strp->unrecov_intr = 1;
238  				} else {
239  					strp->interrupted = 1;
240  				}
241  				strp_parser_err(strp, len, desc);
242  				break;
243  			} else if (len > max_msg_size) {
244  				/* Message length exceeds maximum allowed */
245  				STRP_STATS_INCR(strp->stats.msg_too_big);
246  				strp_parser_err(strp, -EMSGSIZE, desc);
247  				break;
248  			} else if (len <= (ssize_t)head->len -
249  					  skb->len - stm->strp.offset) {
250  				/* Length must be into new skb (and also
251  				 * greater than zero)
252  				 */
253  				STRP_STATS_INCR(strp->stats.bad_hdr_len);
254  				strp_parser_err(strp, -EPROTO, desc);
255  				break;
256  			}
257  
258  			stm->strp.full_len = len;
259  		}
260  
261  		extra = (ssize_t)(stm->accum_len + cand_len) -
262  			stm->strp.full_len;
263  
264  		if (extra < 0) {
265  			/* Message not complete yet. */
266  			if (stm->strp.full_len - stm->accum_len >
267  			    strp_peek_len(strp)) {
268  				/* Don't have the whole message in the socket
269  				 * buffer. Set strp->need_bytes to wait for
270  				 * the rest of the message. Also, set "early
271  				 * eaten" since we've already buffered the skb
272  				 * but don't consume yet per strp_read_sock.
273  				 */
274  
275  				if (!stm->accum_len) {
276  					/* Start RX timer for new message */
277  					strp_start_timer(strp, timeo);
278  				}
279  
280  				stm->accum_len += cand_len;
281  				eaten += cand_len;
282  				strp->need_bytes = stm->strp.full_len -
283  						       stm->accum_len;
284  				STRP_STATS_ADD(strp->stats.bytes, cand_len);
285  				desc->count = 0; /* Stop reading socket */
286  				break;
287  			}
288  			stm->accum_len += cand_len;
289  			eaten += cand_len;
290  			WARN_ON(eaten != orig_len);
291  			break;
292  		}
293  
294  		/* Positive extra indicates more bytes than needed for the
295  		 * message
296  		 */
297  
298  		WARN_ON(extra > cand_len);
299  
300  		eaten += (cand_len - extra);
301  
302  		/* Hurray, we have a new message! */
303  		cancel_delayed_work(&strp->msg_timer_work);
304  		strp->skb_head = NULL;
305  		strp->need_bytes = 0;
306  		STRP_STATS_INCR(strp->stats.msgs);
307  
308  		/* Give skb to upper layer */
309  		strp->cb.rcv_msg(strp, head);
310  
311  		if (unlikely(strp->paused)) {
312  			/* Upper layer paused strp */
313  			break;
314  		}
315  	}
316  
317  	if (cloned_orig)
318  		kfree_skb(orig_skb);
319  
320  	STRP_STATS_ADD(strp->stats.bytes, eaten);
321  
322  	return eaten;
323  }
324  
strp_process(struct strparser * strp,struct sk_buff * orig_skb,unsigned int orig_offset,size_t orig_len,size_t max_msg_size,long timeo)325  int strp_process(struct strparser *strp, struct sk_buff *orig_skb,
326  		 unsigned int orig_offset, size_t orig_len,
327  		 size_t max_msg_size, long timeo)
328  {
329  	read_descriptor_t desc; /* Dummy arg to strp_recv */
330  
331  	desc.arg.data = strp;
332  
333  	return __strp_recv(&desc, orig_skb, orig_offset, orig_len,
334  			   max_msg_size, timeo);
335  }
336  EXPORT_SYMBOL_GPL(strp_process);
337  
strp_recv(read_descriptor_t * desc,struct sk_buff * orig_skb,unsigned int orig_offset,size_t orig_len)338  static int strp_recv(read_descriptor_t *desc, struct sk_buff *orig_skb,
339  		     unsigned int orig_offset, size_t orig_len)
340  {
341  	struct strparser *strp = (struct strparser *)desc->arg.data;
342  
343  	return __strp_recv(desc, orig_skb, orig_offset, orig_len,
344  			   strp->sk->sk_rcvbuf, strp->sk->sk_rcvtimeo);
345  }
346  
default_read_sock_done(struct strparser * strp,int err)347  static int default_read_sock_done(struct strparser *strp, int err)
348  {
349  	return err;
350  }
351  
352  /* Called with lock held on lower socket */
strp_read_sock(struct strparser * strp)353  static int strp_read_sock(struct strparser *strp)
354  {
355  	struct socket *sock = strp->sk->sk_socket;
356  	read_descriptor_t desc;
357  
358  	if (unlikely(!sock || !sock->ops || !sock->ops->read_sock))
359  		return -EBUSY;
360  
361  	desc.arg.data = strp;
362  	desc.error = 0;
363  	desc.count = 1; /* give more than one skb per call */
364  
365  	/* sk should be locked here, so okay to do read_sock */
366  	sock->ops->read_sock(strp->sk, &desc, strp_recv);
367  
368  	desc.error = strp->cb.read_sock_done(strp, desc.error);
369  
370  	return desc.error;
371  }
372  
373  /* Lower sock lock held */
strp_data_ready(struct strparser * strp)374  void strp_data_ready(struct strparser *strp)
375  {
376  	if (unlikely(strp->stopped) || strp->paused)
377  		return;
378  
379  	/* This check is needed to synchronize with do_strp_work.
380  	 * do_strp_work acquires a process lock (lock_sock) whereas
381  	 * the lock held here is bh_lock_sock. The two locks can be
382  	 * held by different threads at the same time, but bh_lock_sock
383  	 * allows a thread in BH context to safely check if the process
384  	 * lock is held. In this case, if the lock is held, queue work.
385  	 */
386  	if (sock_owned_by_user_nocheck(strp->sk)) {
387  		queue_work(strp_wq, &strp->work);
388  		return;
389  	}
390  
391  	if (strp->need_bytes) {
392  		if (strp_peek_len(strp) < strp->need_bytes)
393  			return;
394  	}
395  
396  	if (strp_read_sock(strp) == -ENOMEM)
397  		queue_work(strp_wq, &strp->work);
398  }
399  EXPORT_SYMBOL_GPL(strp_data_ready);
400  
do_strp_work(struct strparser * strp)401  static void do_strp_work(struct strparser *strp)
402  {
403  	/* We need the read lock to synchronize with strp_data_ready. We
404  	 * need the socket lock for calling strp_read_sock.
405  	 */
406  	strp->cb.lock(strp);
407  
408  	if (unlikely(strp->stopped))
409  		goto out;
410  
411  	if (strp->paused)
412  		goto out;
413  
414  	if (strp_read_sock(strp) == -ENOMEM)
415  		queue_work(strp_wq, &strp->work);
416  
417  out:
418  	strp->cb.unlock(strp);
419  }
420  
strp_work(struct work_struct * w)421  static void strp_work(struct work_struct *w)
422  {
423  	do_strp_work(container_of(w, struct strparser, work));
424  }
425  
strp_msg_timeout(struct work_struct * w)426  static void strp_msg_timeout(struct work_struct *w)
427  {
428  	struct strparser *strp = container_of(w, struct strparser,
429  					      msg_timer_work.work);
430  
431  	/* Message assembly timed out */
432  	STRP_STATS_INCR(strp->stats.msg_timeouts);
433  	strp->cb.lock(strp);
434  	strp->cb.abort_parser(strp, -ETIMEDOUT);
435  	strp->cb.unlock(strp);
436  }
437  
strp_sock_lock(struct strparser * strp)438  static void strp_sock_lock(struct strparser *strp)
439  {
440  	lock_sock(strp->sk);
441  }
442  
strp_sock_unlock(struct strparser * strp)443  static void strp_sock_unlock(struct strparser *strp)
444  {
445  	release_sock(strp->sk);
446  }
447  
strp_init(struct strparser * strp,struct sock * sk,const struct strp_callbacks * cb)448  int strp_init(struct strparser *strp, struct sock *sk,
449  	      const struct strp_callbacks *cb)
450  {
451  
452  	if (!cb || !cb->rcv_msg || !cb->parse_msg)
453  		return -EINVAL;
454  
455  	/* The sk (sock) arg determines the mode of the stream parser.
456  	 *
457  	 * If the sock is set then the strparser is in receive callback mode.
458  	 * The upper layer calls strp_data_ready to kick receive processing
459  	 * and strparser calls the read_sock function on the socket to
460  	 * get packets.
461  	 *
462  	 * If the sock is not set then the strparser is in general mode.
463  	 * The upper layer calls strp_process for each skb to be parsed.
464  	 */
465  
466  	if (!sk) {
467  		if (!cb->lock || !cb->unlock)
468  			return -EINVAL;
469  	}
470  
471  	memset(strp, 0, sizeof(*strp));
472  
473  	strp->sk = sk;
474  
475  	strp->cb.lock = cb->lock ? : strp_sock_lock;
476  	strp->cb.unlock = cb->unlock ? : strp_sock_unlock;
477  	strp->cb.rcv_msg = cb->rcv_msg;
478  	strp->cb.parse_msg = cb->parse_msg;
479  	strp->cb.read_sock_done = cb->read_sock_done ? : default_read_sock_done;
480  	strp->cb.abort_parser = cb->abort_parser ? : strp_abort_strp;
481  
482  	INIT_DELAYED_WORK(&strp->msg_timer_work, strp_msg_timeout);
483  	INIT_WORK(&strp->work, strp_work);
484  
485  	return 0;
486  }
487  EXPORT_SYMBOL_GPL(strp_init);
488  
489  /* Sock process lock held (lock_sock) */
__strp_unpause(struct strparser * strp)490  void __strp_unpause(struct strparser *strp)
491  {
492  	strp->paused = 0;
493  
494  	if (strp->need_bytes) {
495  		if (strp_peek_len(strp) < strp->need_bytes)
496  			return;
497  	}
498  	strp_read_sock(strp);
499  }
500  EXPORT_SYMBOL_GPL(__strp_unpause);
501  
strp_unpause(struct strparser * strp)502  void strp_unpause(struct strparser *strp)
503  {
504  	strp->paused = 0;
505  
506  	/* Sync setting paused with RX work */
507  	smp_mb();
508  
509  	queue_work(strp_wq, &strp->work);
510  }
511  EXPORT_SYMBOL_GPL(strp_unpause);
512  
513  /* strp must already be stopped so that strp_recv will no longer be called.
514   * Note that strp_done is not called with the lower socket held.
515   */
strp_done(struct strparser * strp)516  void strp_done(struct strparser *strp)
517  {
518  	WARN_ON(!strp->stopped);
519  
520  	cancel_delayed_work_sync(&strp->msg_timer_work);
521  	cancel_work_sync(&strp->work);
522  
523  	if (strp->skb_head) {
524  		kfree_skb(strp->skb_head);
525  		strp->skb_head = NULL;
526  	}
527  }
528  EXPORT_SYMBOL_GPL(strp_done);
529  
strp_stop(struct strparser * strp)530  void strp_stop(struct strparser *strp)
531  {
532  	strp->stopped = 1;
533  }
534  EXPORT_SYMBOL_GPL(strp_stop);
535  
strp_check_rcv(struct strparser * strp)536  void strp_check_rcv(struct strparser *strp)
537  {
538  	queue_work(strp_wq, &strp->work);
539  }
540  EXPORT_SYMBOL_GPL(strp_check_rcv);
541  
strp_dev_init(void)542  static int __init strp_dev_init(void)
543  {
544  	strp_wq = create_singlethread_workqueue("kstrp");
545  	if (unlikely(!strp_wq))
546  		return -ENOMEM;
547  
548  	return 0;
549  }
550  device_initcall(strp_dev_init);
551