1 // SPDX-License-Identifier: GPL-2.0
2 #include "string2.h"
3 #include "strfilter.h"
4 
5 #include <errno.h>
6 #include <stdlib.h>
7 #include <linux/ctype.h>
8 #include <linux/string.h>
9 #include <linux/zalloc.h>
10 
11 /* Operators */
12 static const char *OP_and	= "&";	/* Logical AND */
13 static const char *OP_or	= "|";	/* Logical OR */
14 static const char *OP_not	= "!";	/* Logical NOT */
15 
16 #define is_operator(c)	((c) == '|' || (c) == '&' || (c) == '!')
17 #define is_separator(c)	(is_operator(c) || (c) == '(' || (c) == ')')
18 
strfilter_node__delete(struct strfilter_node * node)19 static void strfilter_node__delete(struct strfilter_node *node)
20 {
21 	if (node) {
22 		if (node->p && !is_operator(*node->p))
23 			zfree((char **)&node->p);
24 		strfilter_node__delete(node->l);
25 		strfilter_node__delete(node->r);
26 		free(node);
27 	}
28 }
29 
strfilter__delete(struct strfilter * filter)30 void strfilter__delete(struct strfilter *filter)
31 {
32 	if (filter) {
33 		strfilter_node__delete(filter->root);
34 		free(filter);
35 	}
36 }
37 
get_token(const char * s,const char ** e)38 static const char *get_token(const char *s, const char **e)
39 {
40 	const char *p;
41 
42 	s = skip_spaces(s);
43 
44 	if (*s == '\0') {
45 		p = s;
46 		goto end;
47 	}
48 
49 	p = s + 1;
50 	if (!is_separator(*s)) {
51 		/* End search */
52 retry:
53 		while (*p && !is_separator(*p) && !isspace(*p))
54 			p++;
55 		/* Escape and special case: '!' is also used in glob pattern */
56 		if (*(p - 1) == '\\' || (*p == '!' && *(p - 1) == '[')) {
57 			p++;
58 			goto retry;
59 		}
60 	}
61 end:
62 	*e = p;
63 	return s;
64 }
65 
strfilter_node__alloc(const char * op,struct strfilter_node * l,struct strfilter_node * r)66 static struct strfilter_node *strfilter_node__alloc(const char *op,
67 						    struct strfilter_node *l,
68 						    struct strfilter_node *r)
69 {
70 	struct strfilter_node *node = zalloc(sizeof(*node));
71 
72 	if (node) {
73 		node->p = op;
74 		node->l = l;
75 		node->r = r;
76 	}
77 
78 	return node;
79 }
80 
strfilter_node__new(const char * s,const char ** ep)81 static struct strfilter_node *strfilter_node__new(const char *s,
82 						  const char **ep)
83 {
84 	struct strfilter_node root, *cur, *last_op;
85 	const char *e;
86 
87 	if (!s)
88 		return NULL;
89 
90 	memset(&root, 0, sizeof(root));
91 	last_op = cur = &root;
92 
93 	s = get_token(s, &e);
94 	while (*s != '\0' && *s != ')') {
95 		switch (*s) {
96 		case '&':	/* Exchg last OP->r with AND */
97 			if (!cur->r || !last_op->r)
98 				goto error;
99 			cur = strfilter_node__alloc(OP_and, last_op->r, NULL);
100 			if (!cur)
101 				goto nomem;
102 			last_op->r = cur;
103 			last_op = cur;
104 			break;
105 		case '|':	/* Exchg the root with OR */
106 			if (!cur->r || !root.r)
107 				goto error;
108 			cur = strfilter_node__alloc(OP_or, root.r, NULL);
109 			if (!cur)
110 				goto nomem;
111 			root.r = cur;
112 			last_op = cur;
113 			break;
114 		case '!':	/* Add NOT as a leaf node */
115 			if (cur->r)
116 				goto error;
117 			cur->r = strfilter_node__alloc(OP_not, NULL, NULL);
118 			if (!cur->r)
119 				goto nomem;
120 			cur = cur->r;
121 			break;
122 		case '(':	/* Recursively parses inside the parenthesis */
123 			if (cur->r)
124 				goto error;
125 			cur->r = strfilter_node__new(s + 1, &s);
126 			if (!s)
127 				goto nomem;
128 			if (!cur->r || *s != ')')
129 				goto error;
130 			e = s + 1;
131 			break;
132 		default:
133 			if (cur->r)
134 				goto error;
135 			cur->r = strfilter_node__alloc(NULL, NULL, NULL);
136 			if (!cur->r)
137 				goto nomem;
138 			cur->r->p = strndup(s, e - s);
139 			if (!cur->r->p)
140 				goto nomem;
141 		}
142 		s = get_token(e, &e);
143 	}
144 	if (!cur->r)
145 		goto error;
146 	*ep = s;
147 	return root.r;
148 nomem:
149 	s = NULL;
150 error:
151 	*ep = s;
152 	strfilter_node__delete(root.r);
153 	return NULL;
154 }
155 
156 /*
157  * Parse filter rule and return new strfilter.
158  * Return NULL if fail, and *ep == NULL if memory allocation failed.
159  */
strfilter__new(const char * rules,const char ** err)160 struct strfilter *strfilter__new(const char *rules, const char **err)
161 {
162 	struct strfilter *filter = zalloc(sizeof(*filter));
163 	const char *ep = NULL;
164 
165 	if (filter)
166 		filter->root = strfilter_node__new(rules, &ep);
167 
168 	if (!filter || !filter->root || *ep != '\0') {
169 		if (err)
170 			*err = ep;
171 		strfilter__delete(filter);
172 		filter = NULL;
173 	}
174 
175 	return filter;
176 }
177 
strfilter__append(struct strfilter * filter,bool _or,const char * rules,const char ** err)178 static int strfilter__append(struct strfilter *filter, bool _or,
179 			     const char *rules, const char **err)
180 {
181 	struct strfilter_node *right, *root;
182 	const char *ep = NULL;
183 
184 	if (!filter || !rules)
185 		return -EINVAL;
186 
187 	right = strfilter_node__new(rules, &ep);
188 	if (!right || *ep != '\0') {
189 		if (err)
190 			*err = ep;
191 		goto error;
192 	}
193 	root = strfilter_node__alloc(_or ? OP_or : OP_and, filter->root, right);
194 	if (!root) {
195 		ep = NULL;
196 		goto error;
197 	}
198 
199 	filter->root = root;
200 	return 0;
201 
202 error:
203 	strfilter_node__delete(right);
204 	return ep ? -EINVAL : -ENOMEM;
205 }
206 
strfilter__or(struct strfilter * filter,const char * rules,const char ** err)207 int strfilter__or(struct strfilter *filter, const char *rules, const char **err)
208 {
209 	return strfilter__append(filter, true, rules, err);
210 }
211 
strfilter__and(struct strfilter * filter,const char * rules,const char ** err)212 int strfilter__and(struct strfilter *filter, const char *rules,
213 		   const char **err)
214 {
215 	return strfilter__append(filter, false, rules, err);
216 }
217 
strfilter_node__compare(struct strfilter_node * node,const char * str)218 static bool strfilter_node__compare(struct strfilter_node *node,
219 				    const char *str)
220 {
221 	if (!node || !node->p)
222 		return false;
223 
224 	switch (*node->p) {
225 	case '|':	/* OR */
226 		return strfilter_node__compare(node->l, str) ||
227 			strfilter_node__compare(node->r, str);
228 	case '&':	/* AND */
229 		return strfilter_node__compare(node->l, str) &&
230 			strfilter_node__compare(node->r, str);
231 	case '!':	/* NOT */
232 		return !strfilter_node__compare(node->r, str);
233 	default:
234 		return strglobmatch(str, node->p);
235 	}
236 }
237 
238 /* Return true if STR matches the filter rules */
strfilter__compare(struct strfilter * filter,const char * str)239 bool strfilter__compare(struct strfilter *filter, const char *str)
240 {
241 	if (!filter)
242 		return false;
243 	return strfilter_node__compare(filter->root, str);
244 }
245 
246 static int strfilter_node__sprint(struct strfilter_node *node, char *buf);
247 
248 /* sprint node in parenthesis if needed */
strfilter_node__sprint_pt(struct strfilter_node * node,char * buf)249 static int strfilter_node__sprint_pt(struct strfilter_node *node, char *buf)
250 {
251 	int len;
252 	int pt = node->r ? 2 : 0;	/* don't need to check node->l */
253 
254 	if (buf && pt)
255 		*buf++ = '(';
256 	len = strfilter_node__sprint(node, buf);
257 	if (len < 0)
258 		return len;
259 	if (buf && pt)
260 		*(buf + len) = ')';
261 	return len + pt;
262 }
263 
strfilter_node__sprint(struct strfilter_node * node,char * buf)264 static int strfilter_node__sprint(struct strfilter_node *node, char *buf)
265 {
266 	int len = 0, rlen;
267 
268 	if (!node || !node->p)
269 		return -EINVAL;
270 
271 	switch (*node->p) {
272 	case '|':
273 	case '&':
274 		len = strfilter_node__sprint_pt(node->l, buf);
275 		if (len < 0)
276 			return len;
277 		fallthrough;
278 	case '!':
279 		if (buf) {
280 			*(buf + len++) = *node->p;
281 			buf += len;
282 		} else
283 			len++;
284 		rlen = strfilter_node__sprint_pt(node->r, buf);
285 		if (rlen < 0)
286 			return rlen;
287 		len += rlen;
288 		break;
289 	default:
290 		len = strlen(node->p);
291 		if (buf)
292 			strcpy(buf, node->p);
293 	}
294 
295 	return len;
296 }
297 
strfilter__string(struct strfilter * filter)298 char *strfilter__string(struct strfilter *filter)
299 {
300 	int len;
301 	char *ret = NULL;
302 
303 	len = strfilter_node__sprint(filter->root, NULL);
304 	if (len < 0)
305 		return NULL;
306 
307 	ret = malloc(len + 1);
308 	if (ret)
309 		strfilter_node__sprint(filter->root, ret);
310 
311 	return ret;
312 }
313