1 /* Kernel module to match connection tracking byte counter.
2  * GPL (C) 2002 Martin Devera (devik@cdi.cz).
3  */
4 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
5 #include <linux/module.h>
6 #include <linux/bitops.h>
7 #include <linux/skbuff.h>
8 #include <linux/math64.h>
9 #include <linux/netfilter/x_tables.h>
10 #include <linux/netfilter/xt_connbytes.h>
11 #include <net/netfilter/nf_conntrack.h>
12 #include <net/netfilter/nf_conntrack_acct.h>
13 
14 MODULE_LICENSE("GPL");
15 MODULE_AUTHOR("Harald Welte <laforge@netfilter.org>");
16 MODULE_DESCRIPTION("Xtables: Number of packets/bytes per connection matching");
17 MODULE_ALIAS("ipt_connbytes");
18 MODULE_ALIAS("ip6t_connbytes");
19 
20 static bool
connbytes_mt(const struct sk_buff * skb,struct xt_action_param * par)21 connbytes_mt(const struct sk_buff *skb, struct xt_action_param *par)
22 {
23 	const struct xt_connbytes_info *sinfo = par->matchinfo;
24 	const struct nf_conn *ct;
25 	enum ip_conntrack_info ctinfo;
26 	u_int64_t what = 0;	/* initialize to make gcc happy */
27 	u_int64_t bytes = 0;
28 	u_int64_t pkts = 0;
29 	const struct nf_conn_acct *acct;
30 	const struct nf_conn_counter *counters;
31 
32 	ct = nf_ct_get(skb, &ctinfo);
33 	if (!ct)
34 		return false;
35 
36 	acct = nf_conn_acct_find(ct);
37 	if (!acct)
38 		return false;
39 
40 	counters = acct->counter;
41 	switch (sinfo->what) {
42 	case XT_CONNBYTES_PKTS:
43 		switch (sinfo->direction) {
44 		case XT_CONNBYTES_DIR_ORIGINAL:
45 			what = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].packets);
46 			break;
47 		case XT_CONNBYTES_DIR_REPLY:
48 			what = atomic64_read(&counters[IP_CT_DIR_REPLY].packets);
49 			break;
50 		case XT_CONNBYTES_DIR_BOTH:
51 			what = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].packets);
52 			what += atomic64_read(&counters[IP_CT_DIR_REPLY].packets);
53 			break;
54 		}
55 		break;
56 	case XT_CONNBYTES_BYTES:
57 		switch (sinfo->direction) {
58 		case XT_CONNBYTES_DIR_ORIGINAL:
59 			what = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].bytes);
60 			break;
61 		case XT_CONNBYTES_DIR_REPLY:
62 			what = atomic64_read(&counters[IP_CT_DIR_REPLY].bytes);
63 			break;
64 		case XT_CONNBYTES_DIR_BOTH:
65 			what = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].bytes);
66 			what += atomic64_read(&counters[IP_CT_DIR_REPLY].bytes);
67 			break;
68 		}
69 		break;
70 	case XT_CONNBYTES_AVGPKT:
71 		switch (sinfo->direction) {
72 		case XT_CONNBYTES_DIR_ORIGINAL:
73 			bytes = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].bytes);
74 			pkts  = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].packets);
75 			break;
76 		case XT_CONNBYTES_DIR_REPLY:
77 			bytes = atomic64_read(&counters[IP_CT_DIR_REPLY].bytes);
78 			pkts  = atomic64_read(&counters[IP_CT_DIR_REPLY].packets);
79 			break;
80 		case XT_CONNBYTES_DIR_BOTH:
81 			bytes = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].bytes) +
82 				atomic64_read(&counters[IP_CT_DIR_REPLY].bytes);
83 			pkts  = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].packets) +
84 				atomic64_read(&counters[IP_CT_DIR_REPLY].packets);
85 			break;
86 		}
87 		if (pkts != 0)
88 			what = div64_u64(bytes, pkts);
89 		break;
90 	}
91 
92 	if (sinfo->count.to >= sinfo->count.from)
93 		return what <= sinfo->count.to && what >= sinfo->count.from;
94 	else /* inverted */
95 		return what < sinfo->count.to || what > sinfo->count.from;
96 }
97 
connbytes_mt_check(const struct xt_mtchk_param * par)98 static int connbytes_mt_check(const struct xt_mtchk_param *par)
99 {
100 	const struct xt_connbytes_info *sinfo = par->matchinfo;
101 	int ret;
102 
103 	if (sinfo->what != XT_CONNBYTES_PKTS &&
104 	    sinfo->what != XT_CONNBYTES_BYTES &&
105 	    sinfo->what != XT_CONNBYTES_AVGPKT)
106 		return -EINVAL;
107 
108 	if (sinfo->direction != XT_CONNBYTES_DIR_ORIGINAL &&
109 	    sinfo->direction != XT_CONNBYTES_DIR_REPLY &&
110 	    sinfo->direction != XT_CONNBYTES_DIR_BOTH)
111 		return -EINVAL;
112 
113 	ret = nf_ct_netns_get(par->net, par->family);
114 	if (ret < 0)
115 		pr_info_ratelimited("cannot load conntrack support for proto=%u\n",
116 				    par->family);
117 
118 	/*
119 	 * This filter cannot function correctly unless connection tracking
120 	 * accounting is enabled, so complain in the hope that someone notices.
121 	 */
122 	if (!nf_ct_acct_enabled(par->net)) {
123 		pr_warn("Forcing CT accounting to be enabled\n");
124 		nf_ct_set_acct(par->net, true);
125 	}
126 
127 	return ret;
128 }
129 
connbytes_mt_destroy(const struct xt_mtdtor_param * par)130 static void connbytes_mt_destroy(const struct xt_mtdtor_param *par)
131 {
132 	nf_ct_netns_put(par->net, par->family);
133 }
134 
135 static struct xt_match connbytes_mt_reg __read_mostly = {
136 	.name       = "connbytes",
137 	.revision   = 0,
138 	.family     = NFPROTO_UNSPEC,
139 	.checkentry = connbytes_mt_check,
140 	.match      = connbytes_mt,
141 	.destroy    = connbytes_mt_destroy,
142 	.matchsize  = sizeof(struct xt_connbytes_info),
143 	.me         = THIS_MODULE,
144 };
145 
connbytes_mt_init(void)146 static int __init connbytes_mt_init(void)
147 {
148 	return xt_register_match(&connbytes_mt_reg);
149 }
150 
connbytes_mt_exit(void)151 static void __exit connbytes_mt_exit(void)
152 {
153 	xt_unregister_match(&connbytes_mt_reg);
154 }
155 
156 module_init(connbytes_mt_init);
157 module_exit(connbytes_mt_exit);
158