1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2018 Facebook */
3 
4 #include <stdlib.h>
5 #include <linux/in.h>
6 #include <linux/ip.h>
7 #include <linux/ipv6.h>
8 #include <linux/tcp.h>
9 #include <linux/udp.h>
10 #include <linux/bpf.h>
11 #include <linux/types.h>
12 #include <linux/if_ether.h>
13 
14 #include <bpf/bpf_endian.h>
15 #include <bpf/bpf_helpers.h>
16 #include "test_select_reuseport_common.h"
17 
18 #ifndef offsetof
19 #define offsetof(TYPE, MEMBER) ((size_t) &((TYPE *)0)->MEMBER)
20 #endif
21 
22 struct {
23 	__uint(type, BPF_MAP_TYPE_ARRAY_OF_MAPS);
24 	__uint(max_entries, 1);
25 	__type(key, __u32);
26 	__type(value, __u32);
27 } outer_map SEC(".maps");
28 
29 struct {
30 	__uint(type, BPF_MAP_TYPE_ARRAY);
31 	__uint(max_entries, NR_RESULTS);
32 	__type(key, __u32);
33 	__type(value, __u32);
34 } result_map SEC(".maps");
35 
36 struct {
37 	__uint(type, BPF_MAP_TYPE_ARRAY);
38 	__uint(max_entries, 1);
39 	__type(key, __u32);
40 	__type(value, int);
41 } tmp_index_ovr_map SEC(".maps");
42 
43 struct {
44 	__uint(type, BPF_MAP_TYPE_ARRAY);
45 	__uint(max_entries, 1);
46 	__type(key, __u32);
47 	__type(value, __u32);
48 } linum_map SEC(".maps");
49 
50 struct {
51 	__uint(type, BPF_MAP_TYPE_ARRAY);
52 	__uint(max_entries, 1);
53 	__type(key, __u32);
54 	__type(value, struct data_check);
55 } data_check_map SEC(".maps");
56 
57 #define GOTO_DONE(_result) ({			\
58 	result = (_result);			\
59 	linum = __LINE__;			\
60 	goto done;				\
61 })
62 
63 SEC("sk_reuseport")
_select_by_skb_data(struct sk_reuseport_md * reuse_md)64 int _select_by_skb_data(struct sk_reuseport_md *reuse_md)
65 {
66 	__u32 linum, index = 0, flags = 0, index_zero = 0;
67 	__u32 *result_cnt, *linum_value;
68 	struct data_check data_check = {};
69 	struct cmd *cmd, cmd_copy;
70 	void *data, *data_end;
71 	void *reuseport_array;
72 	enum result result;
73 	int *index_ovr;
74 	int err;
75 
76 	data = reuse_md->data;
77 	data_end = reuse_md->data_end;
78 	data_check.len = reuse_md->len;
79 	data_check.eth_protocol = reuse_md->eth_protocol;
80 	data_check.ip_protocol = reuse_md->ip_protocol;
81 	data_check.hash = reuse_md->hash;
82 	data_check.bind_inany = reuse_md->bind_inany;
83 	if (data_check.eth_protocol == bpf_htons(ETH_P_IP)) {
84 		if (bpf_skb_load_bytes_relative(reuse_md,
85 						offsetof(struct iphdr, saddr),
86 						data_check.skb_addrs, 8,
87 						BPF_HDR_START_NET))
88 			GOTO_DONE(DROP_MISC);
89 	} else {
90 		if (bpf_skb_load_bytes_relative(reuse_md,
91 						offsetof(struct ipv6hdr, saddr),
92 						data_check.skb_addrs, 32,
93 						BPF_HDR_START_NET))
94 			GOTO_DONE(DROP_MISC);
95 	}
96 
97 	/*
98 	 * The ip_protocol could be a compile time decision
99 	 * if the bpf_prog.o is dedicated to either TCP or
100 	 * UDP.
101 	 *
102 	 * Otherwise, reuse_md->ip_protocol or
103 	 * the protocol field in the iphdr can be used.
104 	 */
105 	if (data_check.ip_protocol == IPPROTO_TCP) {
106 		struct tcphdr *th = data;
107 
108 		if (th + 1 > data_end)
109 			GOTO_DONE(DROP_MISC);
110 
111 		data_check.skb_ports[0] = th->source;
112 		data_check.skb_ports[1] = th->dest;
113 
114 		if (th->fin)
115 			/* The connection is being torn down at the end of a
116 			 * test. It can't contain a cmd, so return early.
117 			 */
118 			return SK_PASS;
119 
120 		if ((th->doff << 2) + sizeof(*cmd) > data_check.len)
121 			GOTO_DONE(DROP_ERR_SKB_DATA);
122 		if (bpf_skb_load_bytes(reuse_md, th->doff << 2, &cmd_copy,
123 				       sizeof(cmd_copy)))
124 			GOTO_DONE(DROP_MISC);
125 		cmd = &cmd_copy;
126 	} else if (data_check.ip_protocol == IPPROTO_UDP) {
127 		struct udphdr *uh = data;
128 
129 		if (uh + 1 > data_end)
130 			GOTO_DONE(DROP_MISC);
131 
132 		data_check.skb_ports[0] = uh->source;
133 		data_check.skb_ports[1] = uh->dest;
134 
135 		if (sizeof(struct udphdr) + sizeof(*cmd) > data_check.len)
136 			GOTO_DONE(DROP_ERR_SKB_DATA);
137 		if (data + sizeof(struct udphdr) + sizeof(*cmd) > data_end) {
138 			if (bpf_skb_load_bytes(reuse_md, sizeof(struct udphdr),
139 					       &cmd_copy, sizeof(cmd_copy)))
140 				GOTO_DONE(DROP_MISC);
141 			cmd = &cmd_copy;
142 		} else {
143 			cmd = data + sizeof(struct udphdr);
144 		}
145 	} else {
146 		GOTO_DONE(DROP_MISC);
147 	}
148 
149 	reuseport_array = bpf_map_lookup_elem(&outer_map, &index_zero);
150 	if (!reuseport_array)
151 		GOTO_DONE(DROP_ERR_INNER_MAP);
152 
153 	index = cmd->reuseport_index;
154 	index_ovr = bpf_map_lookup_elem(&tmp_index_ovr_map, &index_zero);
155 	if (!index_ovr)
156 		GOTO_DONE(DROP_MISC);
157 
158 	if (*index_ovr != -1) {
159 		index = *index_ovr;
160 		*index_ovr = -1;
161 	}
162 	err = bpf_sk_select_reuseport(reuse_md, reuseport_array, &index,
163 				      flags);
164 	if (!err)
165 		GOTO_DONE(PASS);
166 
167 	if (cmd->pass_on_failure)
168 		GOTO_DONE(PASS_ERR_SK_SELECT_REUSEPORT);
169 	else
170 		GOTO_DONE(DROP_ERR_SK_SELECT_REUSEPORT);
171 
172 done:
173 	result_cnt = bpf_map_lookup_elem(&result_map, &result);
174 	if (!result_cnt)
175 		return SK_DROP;
176 
177 	bpf_map_update_elem(&linum_map, &index_zero, &linum, BPF_ANY);
178 	bpf_map_update_elem(&data_check_map, &index_zero, &data_check, BPF_ANY);
179 
180 	(*result_cnt)++;
181 	return result < PASS ? SK_DROP : SK_PASS;
182 }
183 
184 char _license[] SEC("license") = "GPL";
185