1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import os
7import re
8import shutil
9import tempfile
10import yaml
11
12from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
13
14
15def c_upper(name):
16    return name.upper().replace('-', '_')
17
18
19def c_lower(name):
20    return name.lower().replace('-', '_')
21
22
23class BaseNlLib:
24    def get_family_id(self):
25        return 'ys->family_id'
26
27    def parse_cb_run(self, cb, data, is_dump=False, indent=1):
28        ind = '\n\t\t' + '\t' * indent + ' '
29        if is_dump:
30            return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)"
31        else:
32            return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \
33                   "ynl_cb_array, NLMSG_MIN_TYPE)"
34
35
36class Type(SpecAttr):
37    def __init__(self, family, attr_set, attr, value):
38        super().__init__(family, attr_set, attr, value)
39
40        self.attr = attr
41        self.attr_set = attr_set
42        self.type = attr['type']
43        self.checks = attr.get('checks', {})
44
45        if 'len' in attr:
46            self.len = attr['len']
47        if 'nested-attributes' in attr:
48            self.nested_attrs = attr['nested-attributes']
49            if self.nested_attrs == family.name:
50                self.nested_render_name = f"{family.name}"
51            else:
52                self.nested_render_name = f"{family.name}_{c_lower(self.nested_attrs)}"
53
54            if self.nested_attrs in self.family.consts:
55                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
56            else:
57                self.nested_struct_type = 'struct ' + self.nested_render_name
58
59        self.c_name = c_lower(self.name)
60        if self.c_name in _C_KW:
61            self.c_name += '_'
62
63        # Added by resolve():
64        self.enum_name = None
65        delattr(self, "enum_name")
66
67    def resolve(self):
68        if 'name-prefix' in self.attr:
69            enum_name = f"{self.attr['name-prefix']}{self.name}"
70        else:
71            enum_name = f"{self.attr_set.name_prefix}{self.name}"
72        self.enum_name = c_upper(enum_name)
73
74    def is_multi_val(self):
75        return None
76
77    def is_scalar(self):
78        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
79
80    def presence_type(self):
81        return 'bit'
82
83    def presence_member(self, space, type_filter):
84        if self.presence_type() != type_filter:
85            return
86
87        if self.presence_type() == 'bit':
88            pfx = '__' if space == 'user' else ''
89            return f"{pfx}u32 {self.c_name}:1;"
90
91        if self.presence_type() == 'len':
92            pfx = '__' if space == 'user' else ''
93            return f"{pfx}u32 {self.c_name}_len;"
94
95    def _complex_member_type(self, ri):
96        return None
97
98    def free_needs_iter(self):
99        return False
100
101    def free(self, ri, var, ref):
102        if self.is_multi_val() or self.presence_type() == 'len':
103            ri.cw.p(f'free({var}->{ref}{self.c_name});')
104
105    def arg_member(self, ri):
106        member = self._complex_member_type(ri)
107        if member:
108            arg = [member + ' *' + self.c_name]
109            if self.presence_type() == 'count':
110                arg += ['unsigned int n_' + self.c_name]
111            return arg
112        raise Exception(f"Struct member not implemented for class type {self.type}")
113
114    def struct_member(self, ri):
115        if self.is_multi_val():
116            ri.cw.p(f"unsigned int n_{self.c_name};")
117        member = self._complex_member_type(ri)
118        if member:
119            ptr = '*' if self.is_multi_val() else ''
120            ri.cw.p(f"{member} {ptr}{self.c_name};")
121            return
122        members = self.arg_member(ri)
123        for one in members:
124            ri.cw.p(one + ';')
125
126    def _attr_policy(self, policy):
127        return '{ .type = ' + policy + ', }'
128
129    def attr_policy(self, cw):
130        policy = c_upper('nla-' + self.attr['type'])
131
132        spec = self._attr_policy(policy)
133        cw.p(f"\t[{self.enum_name}] = {spec},")
134
135    def _attr_typol(self):
136        raise Exception(f"Type policy not implemented for class type {self.type}")
137
138    def attr_typol(self, cw):
139        typol = self._attr_typol()
140        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
141
142    def _attr_put_line(self, ri, var, line):
143        if self.presence_type() == 'bit':
144            ri.cw.p(f"if ({var}->_present.{self.c_name})")
145        elif self.presence_type() == 'len':
146            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
147        ri.cw.p(f"{line};")
148
149    def _attr_put_simple(self, ri, var, put_type):
150        line = f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
151        self._attr_put_line(ri, var, line)
152
153    def attr_put(self, ri, var):
154        raise Exception(f"Put not implemented for class type {self.type}")
155
156    def _attr_get(self, ri, var):
157        raise Exception(f"Attr get not implemented for class type {self.type}")
158
159    def attr_get(self, ri, var, first):
160        lines, init_lines, local_vars = self._attr_get(ri, var)
161        if type(lines) is str:
162            lines = [lines]
163        if type(init_lines) is str:
164            init_lines = [init_lines]
165
166        kw = 'if' if first else 'else if'
167        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
168        if local_vars:
169            for local in local_vars:
170                ri.cw.p(local)
171            ri.cw.nl()
172
173        if not self.is_multi_val():
174            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
175            ri.cw.p("return MNL_CB_ERROR;")
176            if self.presence_type() == 'bit':
177                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
178
179        if init_lines:
180            ri.cw.nl()
181            for line in init_lines:
182                ri.cw.p(line)
183
184        for line in lines:
185            ri.cw.p(line)
186        ri.cw.block_end()
187        return True
188
189    def _setter_lines(self, ri, member, presence):
190        raise Exception(f"Setter not implemented for class type {self.type}")
191
192    def setter(self, ri, space, direction, deref=False, ref=None):
193        ref = (ref if ref else []) + [self.c_name]
194        var = "req"
195        member = f"{var}->{'.'.join(ref)}"
196
197        code = []
198        presence = ''
199        for i in range(0, len(ref)):
200            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
201            if self.presence_type() == 'bit':
202                code.append(presence + ' = 1;')
203        code += self._setter_lines(ri, member, presence)
204
205        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
206        free = bool([x for x in code if 'free(' in x])
207        alloc = bool([x for x in code if 'alloc(' in x])
208        if free and not alloc:
209            func_name = '__' + func_name
210        ri.cw.write_func('static inline void', func_name, body=code,
211                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
212
213
214class TypeUnused(Type):
215    def presence_type(self):
216        return ''
217
218    def arg_member(self, ri):
219        return []
220
221    def _attr_get(self, ri, var):
222        return ['return MNL_CB_ERROR;'], None, None
223
224    def _attr_typol(self):
225        return '.type = YNL_PT_REJECT, '
226
227    def attr_policy(self, cw):
228        pass
229
230
231class TypePad(Type):
232    def presence_type(self):
233        return ''
234
235    def arg_member(self, ri):
236        return []
237
238    def _attr_typol(self):
239        return '.type = YNL_PT_IGNORE, '
240
241    def attr_put(self, ri, var):
242        pass
243
244    def attr_get(self, ri, var, first):
245        pass
246
247    def attr_policy(self, cw):
248        pass
249
250    def setter(self, ri, space, direction, deref=False, ref=None):
251        pass
252
253
254class TypeScalar(Type):
255    def __init__(self, family, attr_set, attr, value):
256        super().__init__(family, attr_set, attr, value)
257
258        self.byte_order_comment = ''
259        if 'byte-order' in attr:
260            self.byte_order_comment = f" /* {attr['byte-order']} */"
261
262        # Added by resolve():
263        self.is_bitfield = None
264        delattr(self, "is_bitfield")
265        self.type_name = None
266        delattr(self, "type_name")
267
268    def resolve(self):
269        self.resolve_up(super())
270
271        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
272            self.is_bitfield = True
273        elif 'enum' in self.attr:
274            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
275        else:
276            self.is_bitfield = False
277
278        maybe_enum = not self.is_bitfield and 'enum' in self.attr
279        if maybe_enum and self.family.consts[self.attr['enum']].enum_name:
280            self.type_name = f"enum {self.family.name}_{c_lower(self.attr['enum'])}"
281        else:
282            self.type_name = '__' + self.type
283
284    def _mnl_type(self):
285        t = self.type
286        # mnl does not have a helper for signed types
287        if t[0] == 's':
288            t = 'u' + t[1:]
289        return t
290
291    def _attr_policy(self, policy):
292        if 'flags-mask' in self.checks or self.is_bitfield:
293            if self.is_bitfield:
294                enum = self.family.consts[self.attr['enum']]
295                mask = enum.get_mask(as_flags=True)
296            else:
297                flags = self.family.consts[self.checks['flags-mask']]
298                flag_cnt = len(flags['entries'])
299                mask = (1 << flag_cnt) - 1
300            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
301        elif 'min' in self.checks:
302            return f"NLA_POLICY_MIN({policy}, {self.checks['min']})"
303        elif 'enum' in self.attr:
304            enum = self.family.consts[self.attr['enum']]
305            low, high = enum.value_range()
306            if low == 0:
307                return f"NLA_POLICY_MAX({policy}, {high})"
308            return f"NLA_POLICY_RANGE({policy}, {low}, {high})"
309        return super()._attr_policy(policy)
310
311    def _attr_typol(self):
312        return f'.type = YNL_PT_U{self.type[1:]}, '
313
314    def arg_member(self, ri):
315        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
316
317    def attr_put(self, ri, var):
318        self._attr_put_simple(ri, var, self._mnl_type())
319
320    def _attr_get(self, ri, var):
321        return f"{var}->{self.c_name} = mnl_attr_get_{self._mnl_type()}(attr);", None, None
322
323    def _setter_lines(self, ri, member, presence):
324        return [f"{member} = {self.c_name};"]
325
326
327class TypeFlag(Type):
328    def arg_member(self, ri):
329        return []
330
331    def _attr_typol(self):
332        return '.type = YNL_PT_FLAG, '
333
334    def attr_put(self, ri, var):
335        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, 0, NULL)")
336
337    def _attr_get(self, ri, var):
338        return [], None, None
339
340    def _setter_lines(self, ri, member, presence):
341        return []
342
343
344class TypeString(Type):
345    def arg_member(self, ri):
346        return [f"const char *{self.c_name}"]
347
348    def presence_type(self):
349        return 'len'
350
351    def struct_member(self, ri):
352        ri.cw.p(f"char *{self.c_name};")
353
354    def _attr_typol(self):
355        return f'.type = YNL_PT_NUL_STR, '
356
357    def _attr_policy(self, policy):
358        mem = '{ .type = ' + policy
359        if 'max-len' in self.checks:
360            mem += ', .len = ' + str(self.checks['max-len'])
361        mem += ', }'
362        return mem
363
364    def attr_policy(self, cw):
365        if self.checks.get('unterminated-ok', False):
366            policy = 'NLA_STRING'
367        else:
368            policy = 'NLA_NUL_STRING'
369
370        spec = self._attr_policy(policy)
371        cw.p(f"\t[{self.enum_name}] = {spec},")
372
373    def attr_put(self, ri, var):
374        self._attr_put_simple(ri, var, 'strz')
375
376    def _attr_get(self, ri, var):
377        len_mem = var + '->_present.' + self.c_name + '_len'
378        return [f"{len_mem} = len;",
379                f"{var}->{self.c_name} = malloc(len + 1);",
380                f"memcpy({var}->{self.c_name}, mnl_attr_get_str(attr), len);",
381                f"{var}->{self.c_name}[len] = 0;"], \
382               ['len = strnlen(mnl_attr_get_str(attr), mnl_attr_get_payload_len(attr));'], \
383               ['unsigned int len;']
384
385    def _setter_lines(self, ri, member, presence):
386        return [f"free({member});",
387                f"{presence}_len = strlen({self.c_name});",
388                f"{member} = malloc({presence}_len + 1);",
389                f'memcpy({member}, {self.c_name}, {presence}_len);',
390                f'{member}[{presence}_len] = 0;']
391
392
393class TypeBinary(Type):
394    def arg_member(self, ri):
395        return [f"const void *{self.c_name}", 'size_t len']
396
397    def presence_type(self):
398        return 'len'
399
400    def struct_member(self, ri):
401        ri.cw.p(f"void *{self.c_name};")
402
403    def _attr_typol(self):
404        return f'.type = YNL_PT_BINARY,'
405
406    def _attr_policy(self, policy):
407        mem = '{ '
408        if len(self.checks) == 1 and 'min-len' in self.checks:
409            mem += '.len = ' + str(self.checks['min-len'])
410        elif len(self.checks) == 0:
411            mem += '.type = NLA_BINARY'
412        else:
413            raise Exception('One or more of binary type checks not implemented, yet')
414        mem += ', }'
415        return mem
416
417    def attr_put(self, ri, var):
418        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, " +
419                            f"{var}->_present.{self.c_name}_len, {var}->{self.c_name})")
420
421    def _attr_get(self, ri, var):
422        len_mem = var + '->_present.' + self.c_name + '_len'
423        return [f"{len_mem} = len;",
424                f"{var}->{self.c_name} = malloc(len);",
425                f"memcpy({var}->{self.c_name}, mnl_attr_get_payload(attr), len);"], \
426               ['len = mnl_attr_get_payload_len(attr);'], \
427               ['unsigned int len;']
428
429    def _setter_lines(self, ri, member, presence):
430        return [f"free({member});",
431                f"{presence}_len = len;",
432                f"{member} = malloc({presence}_len);",
433                f'memcpy({member}, {self.c_name}, {presence}_len);']
434
435
436class TypeNest(Type):
437    def _complex_member_type(self, ri):
438        return self.nested_struct_type
439
440    def free(self, ri, var, ref):
441        ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name});')
442
443    def _attr_typol(self):
444        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
445
446    def _attr_policy(self, policy):
447        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
448
449    def attr_put(self, ri, var):
450        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
451                            f"{self.enum_name}, &{var}->{self.c_name})")
452
453    def _attr_get(self, ri, var):
454        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
455                     "return MNL_CB_ERROR;"]
456        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
457                      f"parg.data = &{var}->{self.c_name};"]
458        return get_lines, init_lines, None
459
460    def setter(self, ri, space, direction, deref=False, ref=None):
461        ref = (ref if ref else []) + [self.c_name]
462
463        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
464            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
465
466
467class TypeMultiAttr(Type):
468    def __init__(self, family, attr_set, attr, value, base_type):
469        super().__init__(family, attr_set, attr, value)
470
471        self.base_type = base_type
472
473    def is_multi_val(self):
474        return True
475
476    def presence_type(self):
477        return 'count'
478
479    def _mnl_type(self):
480        t = self.type
481        # mnl does not have a helper for signed types
482        if t[0] == 's':
483            t = 'u' + t[1:]
484        return t
485
486    def _complex_member_type(self, ri):
487        if 'type' not in self.attr or self.attr['type'] == 'nest':
488            return self.nested_struct_type
489        elif self.attr['type'] in scalars:
490            scalar_pfx = '__' if ri.ku_space == 'user' else ''
491            return scalar_pfx + self.attr['type']
492        else:
493            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
494
495    def free_needs_iter(self):
496        return 'type' not in self.attr or self.attr['type'] == 'nest'
497
498    def free(self, ri, var, ref):
499        if self.attr['type'] in scalars:
500            ri.cw.p(f"free({var}->{ref}{self.c_name});")
501        elif 'type' not in self.attr or self.attr['type'] == 'nest':
502            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
503            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
504            ri.cw.p(f"free({var}->{ref}{self.c_name});")
505        else:
506            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
507
508    def _attr_policy(self, policy):
509        return self.base_type._attr_policy(policy)
510
511    def _attr_typol(self):
512        return self.base_type._attr_typol()
513
514    def _attr_get(self, ri, var):
515        return f'n_{self.c_name}++;', None, None
516
517    def attr_put(self, ri, var):
518        if self.attr['type'] in scalars:
519            put_type = self._mnl_type()
520            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
521            ri.cw.p(f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
522        elif 'type' not in self.attr or self.attr['type'] == 'nest':
523            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
524            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
525                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
526        else:
527            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
528
529    def _setter_lines(self, ri, member, presence):
530        # For multi-attr we have a count, not presence, hack up the presence
531        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
532        return [f"free({member});",
533                f"{member} = {self.c_name};",
534                f"{presence} = n_{self.c_name};"]
535
536
537class TypeArrayNest(Type):
538    def is_multi_val(self):
539        return True
540
541    def presence_type(self):
542        return 'count'
543
544    def _complex_member_type(self, ri):
545        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
546            return self.nested_struct_type
547        elif self.attr['sub-type'] in scalars:
548            scalar_pfx = '__' if ri.ku_space == 'user' else ''
549            return scalar_pfx + self.attr['sub-type']
550        else:
551            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
552
553    def _attr_typol(self):
554        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
555
556    def _attr_get(self, ri, var):
557        local_vars = ['const struct nlattr *attr2;']
558        get_lines = [f'attr_{self.c_name} = attr;',
559                     'mnl_attr_for_each_nested(attr2, attr)',
560                     f'\t{var}->n_{self.c_name}++;']
561        return get_lines, None, local_vars
562
563
564class TypeNestTypeValue(Type):
565    def _complex_member_type(self, ri):
566        return self.nested_struct_type
567
568    def _attr_typol(self):
569        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
570
571    def _attr_get(self, ri, var):
572        prev = 'attr'
573        tv_args = ''
574        get_lines = []
575        local_vars = []
576        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
577                      f"parg.data = &{var}->{self.c_name};"]
578        if 'type-value' in self.attr:
579            tv_names = [c_lower(x) for x in self.attr["type-value"]]
580            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
581            local_vars += [f'__u32 {", ".join(tv_names)};']
582            for level in self.attr["type-value"]:
583                level = c_lower(level)
584                get_lines += [f'attr_{level} = mnl_attr_get_payload({prev});']
585                get_lines += [f'{level} = mnl_attr_get_type(attr_{level});']
586                prev = 'attr_' + level
587
588            tv_args = f", {', '.join(tv_names)}"
589
590        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
591        return get_lines, init_lines, local_vars
592
593
594class Struct:
595    def __init__(self, family, space_name, type_list=None, inherited=None):
596        self.family = family
597        self.space_name = space_name
598        self.attr_set = family.attr_sets[space_name]
599        # Use list to catch comparisons with empty sets
600        self._inherited = inherited if inherited is not None else []
601        self.inherited = []
602
603        self.nested = type_list is None
604        if family.name == c_lower(space_name):
605            self.render_name = f"{family.name}"
606        else:
607            self.render_name = f"{family.name}_{c_lower(space_name)}"
608        self.struct_name = 'struct ' + self.render_name
609        if self.nested and space_name in family.consts:
610            self.struct_name += '_'
611        self.ptr_name = self.struct_name + ' *'
612
613        self.request = False
614        self.reply = False
615
616        self.attr_list = []
617        self.attrs = dict()
618        if type_list is not None:
619            for t in type_list:
620                self.attr_list.append((t, self.attr_set[t]),)
621        else:
622            for t in self.attr_set:
623                self.attr_list.append((t, self.attr_set[t]),)
624
625        max_val = 0
626        self.attr_max_val = None
627        for name, attr in self.attr_list:
628            if attr.value >= max_val:
629                max_val = attr.value
630                self.attr_max_val = attr
631            self.attrs[name] = attr
632
633    def __iter__(self):
634        yield from self.attrs
635
636    def __getitem__(self, key):
637        return self.attrs[key]
638
639    def member_list(self):
640        return self.attr_list
641
642    def set_inherited(self, new_inherited):
643        if self._inherited != new_inherited:
644            raise Exception("Inheriting different members not supported")
645        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
646
647
648class EnumEntry(SpecEnumEntry):
649    def __init__(self, enum_set, yaml, prev, value_start):
650        super().__init__(enum_set, yaml, prev, value_start)
651
652        if prev:
653            self.value_change = (self.value != prev.value + 1)
654        else:
655            self.value_change = (self.value != 0)
656        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
657
658        # Added by resolve:
659        self.c_name = None
660        delattr(self, "c_name")
661
662    def resolve(self):
663        self.resolve_up(super())
664
665        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
666
667
668class EnumSet(SpecEnumSet):
669    def __init__(self, family, yaml):
670        self.render_name = c_lower(family.name + '-' + yaml['name'])
671
672        if 'enum-name' in yaml:
673            if yaml['enum-name']:
674                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
675            else:
676                self.enum_name = None
677        else:
678            self.enum_name = 'enum ' + self.render_name
679
680        self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
681
682        super().__init__(family, yaml)
683
684    def new_entry(self, entry, prev_entry, value_start):
685        return EnumEntry(self, entry, prev_entry, value_start)
686
687    def value_range(self):
688        low = min([x.value for x in self.entries.values()])
689        high = max([x.value for x in self.entries.values()])
690
691        if high - low + 1 != len(self.entries):
692            raise Exception("Can't get value range for a noncontiguous enum")
693
694        return low, high
695
696
697class AttrSet(SpecAttrSet):
698    def __init__(self, family, yaml):
699        super().__init__(family, yaml)
700
701        if self.subset_of is None:
702            if 'name-prefix' in yaml:
703                pfx = yaml['name-prefix']
704            elif self.name == family.name:
705                pfx = family.name + '-a-'
706            else:
707                pfx = f"{family.name}-a-{self.name}-"
708            self.name_prefix = c_upper(pfx)
709            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
710        else:
711            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
712            self.max_name = family.attr_sets[self.subset_of].max_name
713
714        # Added by resolve:
715        self.c_name = None
716        delattr(self, "c_name")
717
718    def resolve(self):
719        self.c_name = c_lower(self.name)
720        if self.c_name in _C_KW:
721            self.c_name += '_'
722        if self.c_name == self.family.c_name:
723            self.c_name = ''
724
725    def new_attr(self, elem, value):
726        if elem['type'] in scalars:
727            t = TypeScalar(self.family, self, elem, value)
728        elif elem['type'] == 'unused':
729            t = TypeUnused(self.family, self, elem, value)
730        elif elem['type'] == 'pad':
731            t = TypePad(self.family, self, elem, value)
732        elif elem['type'] == 'flag':
733            t = TypeFlag(self.family, self, elem, value)
734        elif elem['type'] == 'string':
735            t = TypeString(self.family, self, elem, value)
736        elif elem['type'] == 'binary':
737            t = TypeBinary(self.family, self, elem, value)
738        elif elem['type'] == 'nest':
739            t = TypeNest(self.family, self, elem, value)
740        elif elem['type'] == 'array-nest':
741            t = TypeArrayNest(self.family, self, elem, value)
742        elif elem['type'] == 'nest-type-value':
743            t = TypeNestTypeValue(self.family, self, elem, value)
744        else:
745            raise Exception(f"No typed class for type {elem['type']}")
746
747        if 'multi-attr' in elem and elem['multi-attr']:
748            t = TypeMultiAttr(self.family, self, elem, value, t)
749
750        return t
751
752
753class Operation(SpecOperation):
754    def __init__(self, family, yaml, req_value, rsp_value):
755        super().__init__(family, yaml, req_value, rsp_value)
756
757        self.render_name = family.name + '_' + c_lower(self.name)
758
759        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
760                         ('dump' in yaml and 'request' in yaml['dump'])
761
762        self.has_ntf = False
763
764        # Added by resolve:
765        self.enum_name = None
766        delattr(self, "enum_name")
767
768    def resolve(self):
769        self.resolve_up(super())
770
771        if not self.is_async:
772            self.enum_name = self.family.op_prefix + c_upper(self.name)
773        else:
774            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
775
776    def mark_has_ntf(self):
777        self.has_ntf = True
778
779
780class Family(SpecFamily):
781    def __init__(self, file_name, exclude_ops):
782        # Added by resolve:
783        self.c_name = None
784        delattr(self, "c_name")
785        self.op_prefix = None
786        delattr(self, "op_prefix")
787        self.async_op_prefix = None
788        delattr(self, "async_op_prefix")
789        self.mcgrps = None
790        delattr(self, "mcgrps")
791        self.consts = None
792        delattr(self, "consts")
793        self.hooks = None
794        delattr(self, "hooks")
795
796        super().__init__(file_name, exclude_ops=exclude_ops)
797
798        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
799        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
800
801        if 'definitions' not in self.yaml:
802            self.yaml['definitions'] = []
803
804        if 'uapi-header' in self.yaml:
805            self.uapi_header = self.yaml['uapi-header']
806        else:
807            self.uapi_header = f"linux/{self.name}.h"
808
809    def resolve(self):
810        self.resolve_up(super())
811
812        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
813            raise Exception("Codegen only supported for genetlink")
814
815        self.c_name = c_lower(self.name)
816        if 'name-prefix' in self.yaml['operations']:
817            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
818        else:
819            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
820        if 'async-prefix' in self.yaml['operations']:
821            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
822        else:
823            self.async_op_prefix = self.op_prefix
824
825        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
826
827        self.hooks = dict()
828        for when in ['pre', 'post']:
829            self.hooks[when] = dict()
830            for op_mode in ['do', 'dump']:
831                self.hooks[when][op_mode] = dict()
832                self.hooks[when][op_mode]['set'] = set()
833                self.hooks[when][op_mode]['list'] = []
834
835        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
836        self.root_sets = dict()
837        # dict space-name -> set('request', 'reply')
838        self.pure_nested_structs = dict()
839
840        self._mark_notify()
841        self._mock_up_events()
842
843        self._load_root_sets()
844        self._load_nested_sets()
845        self._load_hooks()
846
847        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
848        if self.kernel_policy == 'global':
849            self._load_global_policy()
850
851    def new_enum(self, elem):
852        return EnumSet(self, elem)
853
854    def new_attr_set(self, elem):
855        return AttrSet(self, elem)
856
857    def new_operation(self, elem, req_value, rsp_value):
858        return Operation(self, elem, req_value, rsp_value)
859
860    def _mark_notify(self):
861        for op in self.msgs.values():
862            if 'notify' in op:
863                self.ops[op['notify']].mark_has_ntf()
864
865    # Fake a 'do' equivalent of all events, so that we can render their response parsing
866    def _mock_up_events(self):
867        for op in self.yaml['operations']['list']:
868            if 'event' in op:
869                op['do'] = {
870                    'reply': {
871                        'attributes': op['event']['attributes']
872                    }
873                }
874
875    def _load_root_sets(self):
876        for op_name, op in self.msgs.items():
877            if 'attribute-set' not in op:
878                continue
879
880            req_attrs = set()
881            rsp_attrs = set()
882            for op_mode in ['do', 'dump']:
883                if op_mode in op and 'request' in op[op_mode]:
884                    req_attrs.update(set(op[op_mode]['request']['attributes']))
885                if op_mode in op and 'reply' in op[op_mode]:
886                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
887            if 'event' in op:
888                rsp_attrs.update(set(op['event']['attributes']))
889
890            if op['attribute-set'] not in self.root_sets:
891                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
892            else:
893                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
894                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
895
896    def _load_nested_sets(self):
897        attr_set_queue = list(self.root_sets.keys())
898        attr_set_seen = set(self.root_sets.keys())
899
900        while len(attr_set_queue):
901            a_set = attr_set_queue.pop(0)
902            for attr, spec in self.attr_sets[a_set].items():
903                if 'nested-attributes' not in spec:
904                    continue
905
906                nested = spec['nested-attributes']
907                if nested not in attr_set_seen:
908                    attr_set_queue.append(nested)
909                    attr_set_seen.add(nested)
910
911                inherit = set()
912                if nested not in self.root_sets:
913                    if nested not in self.pure_nested_structs:
914                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
915                else:
916                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
917
918                if 'type-value' in spec:
919                    if nested in self.root_sets:
920                        raise Exception("Inheriting members to a space used as root not supported")
921                    inherit.update(set(spec['type-value']))
922                elif spec['type'] == 'array-nest':
923                    inherit.add('idx')
924                self.pure_nested_structs[nested].set_inherited(inherit)
925
926        for root_set, rs_members in self.root_sets.items():
927            for attr, spec in self.attr_sets[root_set].items():
928                if 'nested-attributes' in spec:
929                    nested = spec['nested-attributes']
930                    if attr in rs_members['request']:
931                        self.pure_nested_structs[nested].request = True
932                    if attr in rs_members['reply']:
933                        self.pure_nested_structs[nested].reply = True
934
935        # Try to reorder according to dependencies
936        pns_key_list = list(self.pure_nested_structs.keys())
937        pns_key_seen = set()
938        rounds = len(pns_key_list)**2  # it's basically bubble sort
939        for _ in range(rounds):
940            if len(pns_key_list) == 0:
941                break
942            name = pns_key_list.pop(0)
943            finished = True
944            for _, spec in self.attr_sets[name].items():
945                if 'nested-attributes' in spec:
946                    if spec['nested-attributes'] not in pns_key_seen:
947                        # Dicts are sorted, this will make struct last
948                        struct = self.pure_nested_structs.pop(name)
949                        self.pure_nested_structs[name] = struct
950                        finished = False
951                        break
952            if finished:
953                pns_key_seen.add(name)
954            else:
955                pns_key_list.append(name)
956        # Propagate the request / reply
957        for attr_set, struct in reversed(self.pure_nested_structs.items()):
958            for _, spec in self.attr_sets[attr_set].items():
959                if 'nested-attributes' in spec:
960                    child = self.pure_nested_structs.get(spec['nested-attributes'])
961                    if child:
962                        child.request |= struct.request
963                        child.reply |= struct.reply
964
965    def _load_global_policy(self):
966        global_set = set()
967        attr_set_name = None
968        for op_name, op in self.ops.items():
969            if not op:
970                continue
971            if 'attribute-set' not in op:
972                continue
973
974            if attr_set_name is None:
975                attr_set_name = op['attribute-set']
976            if attr_set_name != op['attribute-set']:
977                raise Exception('For a global policy all ops must use the same set')
978
979            for op_mode in ['do', 'dump']:
980                if op_mode in op:
981                    req = op[op_mode].get('request')
982                    if req:
983                        global_set.update(req.get('attributes', []))
984
985        self.global_policy = []
986        self.global_policy_set = attr_set_name
987        for attr in self.attr_sets[attr_set_name]:
988            if attr in global_set:
989                self.global_policy.append(attr)
990
991    def _load_hooks(self):
992        for op in self.ops.values():
993            for op_mode in ['do', 'dump']:
994                if op_mode not in op:
995                    continue
996                for when in ['pre', 'post']:
997                    if when not in op[op_mode]:
998                        continue
999                    name = op[op_mode][when]
1000                    if name in self.hooks[when][op_mode]['set']:
1001                        continue
1002                    self.hooks[when][op_mode]['set'].add(name)
1003                    self.hooks[when][op_mode]['list'].append(name)
1004
1005
1006class RenderInfo:
1007    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1008        self.family = family
1009        self.nl = cw.nlib
1010        self.ku_space = ku_space
1011        self.op_mode = op_mode
1012        self.op = op
1013
1014        # 'do' and 'dump' response parsing is identical
1015        self.type_consistent = True
1016        if op_mode != 'do' and 'dump' in op and 'do' in op:
1017            if ('reply' in op['do']) != ('reply' in op["dump"]):
1018                self.type_consistent = False
1019            elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1020                self.type_consistent = False
1021
1022        self.attr_set = attr_set
1023        if not self.attr_set:
1024            self.attr_set = op['attribute-set']
1025
1026        self.type_name_conflict = False
1027        if op:
1028            self.type_name = c_lower(op.name)
1029        else:
1030            self.type_name = c_lower(attr_set)
1031            if attr_set in family.consts:
1032                self.type_name_conflict = True
1033
1034        self.cw = cw
1035
1036        self.struct = dict()
1037        if op_mode == 'notify':
1038            op_mode = 'do'
1039        for op_dir in ['request', 'reply']:
1040            if op and op_dir in op[op_mode]:
1041                self.struct[op_dir] = Struct(family, self.attr_set,
1042                                             type_list=op[op_mode][op_dir]['attributes'])
1043        if op_mode == 'event':
1044            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1045
1046
1047class CodeWriter:
1048    def __init__(self, nlib, out_file=None):
1049        self.nlib = nlib
1050
1051        self._nl = False
1052        self._block_end = False
1053        self._silent_block = False
1054        self._ind = 0
1055        if out_file is None:
1056            self._out = os.sys.stdout
1057        else:
1058            self._out = tempfile.TemporaryFile('w+')
1059            self._out_file = out_file
1060
1061    def __del__(self):
1062        self.close_out_file()
1063
1064    def close_out_file(self):
1065        if self._out == os.sys.stdout:
1066            return
1067        with open(self._out_file, 'w+') as out_file:
1068            self._out.seek(0)
1069            shutil.copyfileobj(self._out, out_file)
1070            self._out.close()
1071        self._out = os.sys.stdout
1072
1073    @classmethod
1074    def _is_cond(cls, line):
1075        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1076
1077    def p(self, line, add_ind=0):
1078        if self._block_end:
1079            self._block_end = False
1080            if line.startswith('else'):
1081                line = '} ' + line
1082            else:
1083                self._out.write('\t' * self._ind + '}\n')
1084
1085        if self._nl:
1086            self._out.write('\n')
1087            self._nl = False
1088
1089        ind = self._ind
1090        if line[-1] == ':':
1091            ind -= 1
1092        if self._silent_block:
1093            ind += 1
1094        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1095        if add_ind:
1096            ind += add_ind
1097        self._out.write('\t' * ind + line + '\n')
1098
1099    def nl(self):
1100        self._nl = True
1101
1102    def block_start(self, line=''):
1103        if line:
1104            line = line + ' '
1105        self.p(line + '{')
1106        self._ind += 1
1107
1108    def block_end(self, line=''):
1109        if line and line[0] not in {';', ','}:
1110            line = ' ' + line
1111        self._ind -= 1
1112        self._nl = False
1113        if not line:
1114            # Delay printing closing bracket in case "else" comes next
1115            if self._block_end:
1116                self._out.write('\t' * (self._ind + 1) + '}\n')
1117            self._block_end = True
1118        else:
1119            self.p('}' + line)
1120
1121    def write_doc_line(self, doc, indent=True):
1122        words = doc.split()
1123        line = ' *'
1124        for word in words:
1125            if len(line) + len(word) >= 79:
1126                self.p(line)
1127                line = ' *'
1128                if indent:
1129                    line += '  '
1130            line += ' ' + word
1131        self.p(line)
1132
1133    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1134        if not args:
1135            args = ['void']
1136
1137        if doc:
1138            self.p('/*')
1139            self.p(' * ' + doc)
1140            self.p(' */')
1141
1142        oneline = qual_ret
1143        if qual_ret[-1] != '*':
1144            oneline += ' '
1145        oneline += f"{name}({', '.join(args)}){suffix}"
1146
1147        if len(oneline) < 80:
1148            self.p(oneline)
1149            return
1150
1151        v = qual_ret
1152        if len(v) > 3:
1153            self.p(v)
1154            v = ''
1155        elif qual_ret[-1] != '*':
1156            v += ' '
1157        v += name + '('
1158        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1159        delta_ind = len(v) - len(ind)
1160        v += args[0]
1161        i = 1
1162        while i < len(args):
1163            next_len = len(v) + len(args[i])
1164            if v[0] == '\t':
1165                next_len += delta_ind
1166            if next_len > 76:
1167                self.p(v + ',')
1168                v = ind
1169            else:
1170                v += ', '
1171            v += args[i]
1172            i += 1
1173        self.p(v + ')' + suffix)
1174
1175    def write_func_lvar(self, local_vars):
1176        if not local_vars:
1177            return
1178
1179        if type(local_vars) is str:
1180            local_vars = [local_vars]
1181
1182        local_vars.sort(key=len, reverse=True)
1183        for var in local_vars:
1184            self.p(var)
1185        self.nl()
1186
1187    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1188        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1189        self.write_func_lvar(local_vars=local_vars)
1190
1191        self.block_start()
1192        for line in body:
1193            self.p(line)
1194        self.block_end()
1195
1196    def writes_defines(self, defines):
1197        longest = 0
1198        for define in defines:
1199            if len(define[0]) > longest:
1200                longest = len(define[0])
1201        longest = ((longest + 8) // 8) * 8
1202        for define in defines:
1203            line = '#define ' + define[0]
1204            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1205            if type(define[1]) is int:
1206                line += str(define[1])
1207            elif type(define[1]) is str:
1208                line += '"' + define[1] + '"'
1209            self.p(line)
1210
1211    def write_struct_init(self, members):
1212        longest = max([len(x[0]) for x in members])
1213        longest += 1  # because we prepend a .
1214        longest = ((longest + 8) // 8) * 8
1215        for one in members:
1216            line = '.' + one[0]
1217            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1218            line += '= ' + one[1] + ','
1219            self.p(line)
1220
1221
1222scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
1223
1224direction_to_suffix = {
1225    'reply': '_rsp',
1226    'request': '_req',
1227    '': ''
1228}
1229
1230op_mode_to_wrapper = {
1231    'do': '',
1232    'dump': '_list',
1233    'notify': '_ntf',
1234    'event': '',
1235}
1236
1237_C_KW = {
1238    'auto',
1239    'bool',
1240    'break',
1241    'case',
1242    'char',
1243    'const',
1244    'continue',
1245    'default',
1246    'do',
1247    'double',
1248    'else',
1249    'enum',
1250    'extern',
1251    'float',
1252    'for',
1253    'goto',
1254    'if',
1255    'inline',
1256    'int',
1257    'long',
1258    'register',
1259    'return',
1260    'short',
1261    'signed',
1262    'sizeof',
1263    'static',
1264    'struct',
1265    'switch',
1266    'typedef',
1267    'union',
1268    'unsigned',
1269    'void',
1270    'volatile',
1271    'while'
1272}
1273
1274
1275def rdir(direction):
1276    if direction == 'reply':
1277        return 'request'
1278    if direction == 'request':
1279        return 'reply'
1280    return direction
1281
1282
1283def op_prefix(ri, direction, deref=False):
1284    suffix = f"_{ri.type_name}"
1285
1286    if not ri.op_mode or ri.op_mode == 'do':
1287        suffix += f"{direction_to_suffix[direction]}"
1288    else:
1289        if direction == 'request':
1290            suffix += '_req_dump'
1291        else:
1292            if ri.type_consistent:
1293                if deref:
1294                    suffix += f"{direction_to_suffix[direction]}"
1295                else:
1296                    suffix += op_mode_to_wrapper[ri.op_mode]
1297            else:
1298                suffix += '_rsp'
1299                suffix += '_dump' if deref else '_list'
1300
1301    return f"{ri.family['name']}{suffix}"
1302
1303
1304def type_name(ri, direction, deref=False):
1305    return f"struct {op_prefix(ri, direction, deref=deref)}"
1306
1307
1308def print_prototype(ri, direction, terminate=True, doc=None):
1309    suffix = ';' if terminate else ''
1310
1311    fname = ri.op.render_name
1312    if ri.op_mode == 'dump':
1313        fname += '_dump'
1314
1315    args = ['struct ynl_sock *ys']
1316    if 'request' in ri.op[ri.op_mode]:
1317        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1318
1319    ret = 'int'
1320    if 'reply' in ri.op[ri.op_mode]:
1321        ret = f"{type_name(ri, rdir(direction))} *"
1322
1323    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1324
1325
1326def print_req_prototype(ri):
1327    print_prototype(ri, "request", doc=ri.op['doc'])
1328
1329
1330def print_dump_prototype(ri):
1331    print_prototype(ri, "request")
1332
1333
1334def put_typol(cw, struct):
1335    type_max = struct.attr_set.max_name
1336    cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1337
1338    for _, arg in struct.member_list():
1339        arg.attr_typol(cw)
1340
1341    cw.block_end(line=';')
1342    cw.nl()
1343
1344    cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1345    cw.p(f'.max_attr = {type_max},')
1346    cw.p(f'.table = {struct.render_name}_policy,')
1347    cw.block_end(line=';')
1348    cw.nl()
1349
1350
1351def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1352    args = [f'int {arg_name}']
1353    if enum and not ('enum-name' in enum and not enum['enum-name']):
1354        args = [f'enum {render_name} {arg_name}']
1355    cw.write_func_prot('const char *', f'{render_name}_str', args)
1356    cw.block_start()
1357    if enum and enum.type == 'flags':
1358        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1359    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)MNL_ARRAY_SIZE({map_name}))')
1360    cw.p('return NULL;')
1361    cw.p(f'return {map_name}[{arg_name}];')
1362    cw.block_end()
1363    cw.nl()
1364
1365
1366def put_op_name_fwd(family, cw):
1367    cw.write_func_prot('const char *', f'{family.name}_op_str', ['int op'], suffix=';')
1368
1369
1370def put_op_name(family, cw):
1371    map_name = f'{family.name}_op_strmap'
1372    cw.block_start(line=f"static const char * const {map_name}[] =")
1373    for op_name, op in family.msgs.items():
1374        if op.rsp_value:
1375            if op.req_value == op.rsp_value:
1376                cw.p(f'[{op.enum_name}] = "{op_name}",')
1377            else:
1378                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1379    cw.block_end(line=';')
1380    cw.nl()
1381
1382    _put_enum_to_str_helper(cw, family.name + '_op', map_name, 'op')
1383
1384
1385def put_enum_to_str_fwd(family, cw, enum):
1386    args = [f'enum {enum.render_name} value']
1387    if 'enum-name' in enum and not enum['enum-name']:
1388        args = ['int value']
1389    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1390
1391
1392def put_enum_to_str(family, cw, enum):
1393    map_name = f'{enum.render_name}_strmap'
1394    cw.block_start(line=f"static const char * const {map_name}[] =")
1395    for entry in enum.entries.values():
1396        cw.p(f'[{entry.value}] = "{entry.name}",')
1397    cw.block_end(line=';')
1398    cw.nl()
1399
1400    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1401
1402
1403def put_req_nested(ri, struct):
1404    func_args = ['struct nlmsghdr *nlh',
1405                 'unsigned int attr_type',
1406                 f'{struct.ptr_name}obj']
1407
1408    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args)
1409    ri.cw.block_start()
1410    ri.cw.write_func_lvar('struct nlattr *nest;')
1411
1412    ri.cw.p("nest = mnl_attr_nest_start(nlh, attr_type);")
1413
1414    for _, arg in struct.member_list():
1415        arg.attr_put(ri, "obj")
1416
1417    ri.cw.p("mnl_attr_nest_end(nlh, nest);")
1418
1419    ri.cw.nl()
1420    ri.cw.p('return 0;')
1421    ri.cw.block_end()
1422    ri.cw.nl()
1423
1424
1425def _multi_parse(ri, struct, init_lines, local_vars):
1426    if struct.nested:
1427        iter_line = "mnl_attr_for_each_nested(attr, nested)"
1428    else:
1429        iter_line = "mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr))"
1430
1431    array_nests = set()
1432    multi_attrs = set()
1433    needs_parg = False
1434    for arg, aspec in struct.member_list():
1435        if aspec['type'] == 'array-nest':
1436            local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1437            array_nests.add(arg)
1438        if 'multi-attr' in aspec:
1439            multi_attrs.add(arg)
1440        needs_parg |= 'nested-attributes' in aspec
1441    if array_nests or multi_attrs:
1442        local_vars.append('int i;')
1443    if needs_parg:
1444        local_vars.append('struct ynl_parse_arg parg;')
1445        init_lines.append('parg.ys = yarg->ys;')
1446
1447    all_multi = array_nests | multi_attrs
1448
1449    for anest in sorted(all_multi):
1450        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1451
1452    ri.cw.block_start()
1453    ri.cw.write_func_lvar(local_vars)
1454
1455    for line in init_lines:
1456        ri.cw.p(line)
1457    ri.cw.nl()
1458
1459    for arg in struct.inherited:
1460        ri.cw.p(f'dst->{arg} = {arg};')
1461
1462    for anest in sorted(all_multi):
1463        aspec = struct[anest]
1464        ri.cw.p(f"if (dst->{aspec.c_name})")
1465        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1466
1467    ri.cw.nl()
1468    ri.cw.block_start(line=iter_line)
1469    ri.cw.p('unsigned int type = mnl_attr_get_type(attr);')
1470    ri.cw.nl()
1471
1472    first = True
1473    for _, arg in struct.member_list():
1474        good = arg.attr_get(ri, 'dst', first=first)
1475        # First may be 'unused' or 'pad', ignore those
1476        first &= not good
1477
1478    ri.cw.block_end()
1479    ri.cw.nl()
1480
1481    for anest in sorted(array_nests):
1482        aspec = struct[anest]
1483
1484        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1485        ri.cw.p(f"dst->{aspec.c_name} = calloc({aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1486        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1487        ri.cw.p('i = 0;')
1488        ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1489        ri.cw.block_start(line=f"mnl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1490        ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1491        ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, mnl_attr_get_type(attr)))")
1492        ri.cw.p('return MNL_CB_ERROR;')
1493        ri.cw.p('i++;')
1494        ri.cw.block_end()
1495        ri.cw.block_end()
1496    ri.cw.nl()
1497
1498    for anest in sorted(multi_attrs):
1499        aspec = struct[anest]
1500        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1501        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1502        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1503        ri.cw.p('i = 0;')
1504        if 'nested-attributes' in aspec:
1505            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1506        ri.cw.block_start(line=iter_line)
1507        ri.cw.block_start(line=f"if (mnl_attr_get_type(attr) == {aspec.enum_name})")
1508        if 'nested-attributes' in aspec:
1509            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1510            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1511            ri.cw.p('return MNL_CB_ERROR;')
1512        elif aspec['type'] in scalars:
1513            t = aspec['type']
1514            if t[0] == 's':
1515                t = 'u' + t[1:]
1516            ri.cw.p(f"dst->{aspec.c_name}[i] = mnl_attr_get_{t}(attr);")
1517        else:
1518            raise Exception('Nest parsing type not supported yet')
1519        ri.cw.p('i++;')
1520        ri.cw.block_end()
1521        ri.cw.block_end()
1522        ri.cw.block_end()
1523    ri.cw.nl()
1524
1525    if struct.nested:
1526        ri.cw.p('return 0;')
1527    else:
1528        ri.cw.p('return MNL_CB_OK;')
1529    ri.cw.block_end()
1530    ri.cw.nl()
1531
1532
1533def parse_rsp_nested(ri, struct):
1534    func_args = ['struct ynl_parse_arg *yarg',
1535                 'const struct nlattr *nested']
1536    for arg in struct.inherited:
1537        func_args.append('__u32 ' + arg)
1538
1539    local_vars = ['const struct nlattr *attr;',
1540                  f'{struct.ptr_name}dst = yarg->data;']
1541    init_lines = []
1542
1543    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args)
1544
1545    _multi_parse(ri, struct, init_lines, local_vars)
1546
1547
1548def parse_rsp_msg(ri, deref=False):
1549    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1550        return
1551
1552    func_args = ['const struct nlmsghdr *nlh',
1553                 'void *data']
1554
1555    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1556                  'struct ynl_parse_arg *yarg = data;',
1557                  'const struct nlattr *attr;']
1558    init_lines = ['dst = yarg->data;']
1559
1560    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1561
1562    if ri.struct["reply"].member_list():
1563        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1564    else:
1565        # Empty reply
1566        ri.cw.block_start()
1567        ri.cw.p('return MNL_CB_OK;')
1568        ri.cw.block_end()
1569        ri.cw.nl()
1570
1571
1572def print_req(ri):
1573    ret_ok = '0'
1574    ret_err = '-1'
1575    direction = "request"
1576    local_vars = ['struct nlmsghdr *nlh;',
1577                  'int err;']
1578
1579    if 'reply' in ri.op[ri.op_mode]:
1580        ret_ok = 'rsp'
1581        ret_err = 'NULL'
1582        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;',
1583                       'struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };']
1584
1585    print_prototype(ri, direction, terminate=False)
1586    ri.cw.block_start()
1587    ri.cw.write_func_lvar(local_vars)
1588
1589    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1590
1591    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1592    if 'reply' in ri.op[ri.op_mode]:
1593        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1594    ri.cw.nl()
1595    for _, attr in ri.struct["request"].member_list():
1596        attr.attr_put(ri, "req")
1597    ri.cw.nl()
1598
1599    parse_arg = "NULL"
1600    if 'reply' in ri.op[ri.op_mode]:
1601        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1602        ri.cw.p('yrs.yarg.data = rsp;')
1603        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1604        if ri.op.value is not None:
1605            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1606        else:
1607            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1608        ri.cw.nl()
1609        parse_arg = '&yrs'
1610    ri.cw.p(f"err = ynl_exec(ys, nlh, {parse_arg});")
1611    ri.cw.p('if (err < 0)')
1612    if 'reply' in ri.op[ri.op_mode]:
1613        ri.cw.p('goto err_free;')
1614    else:
1615        ri.cw.p('return -1;')
1616    ri.cw.nl()
1617
1618    ri.cw.p(f"return {ret_ok};")
1619    ri.cw.nl()
1620
1621    if 'reply' in ri.op[ri.op_mode]:
1622        ri.cw.p('err_free:')
1623        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1624        ri.cw.p(f"return {ret_err};")
1625
1626    ri.cw.block_end()
1627
1628
1629def print_dump(ri):
1630    direction = "request"
1631    print_prototype(ri, direction, terminate=False)
1632    ri.cw.block_start()
1633    local_vars = ['struct ynl_dump_state yds = {};',
1634                  'struct nlmsghdr *nlh;',
1635                  'int err;']
1636
1637    for var in local_vars:
1638        ri.cw.p(f'{var}')
1639    ri.cw.nl()
1640
1641    ri.cw.p('yds.ys = ys;')
1642    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1643    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1644    if ri.op.value is not None:
1645        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1646    else:
1647        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1648    ri.cw.p(f"yds.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1649    ri.cw.nl()
1650    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1651
1652    if "request" in ri.op[ri.op_mode]:
1653        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1654        ri.cw.nl()
1655        for _, attr in ri.struct["request"].member_list():
1656            attr.attr_put(ri, "req")
1657    ri.cw.nl()
1658
1659    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1660    ri.cw.p('if (err < 0)')
1661    ri.cw.p('goto free_list;')
1662    ri.cw.nl()
1663
1664    ri.cw.p('return yds.first;')
1665    ri.cw.nl()
1666    ri.cw.p('free_list:')
1667    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1668    ri.cw.p('return NULL;')
1669    ri.cw.block_end()
1670
1671
1672def call_free(ri, direction, var):
1673    return f"{op_prefix(ri, direction)}_free({var});"
1674
1675
1676def free_arg_name(direction):
1677    if direction:
1678        return direction_to_suffix[direction][1:]
1679    return 'obj'
1680
1681
1682def print_alloc_wrapper(ri, direction):
1683    name = op_prefix(ri, direction)
1684    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1685    ri.cw.block_start()
1686    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1687    ri.cw.block_end()
1688
1689
1690def print_free_prototype(ri, direction, suffix=';'):
1691    name = op_prefix(ri, direction)
1692    struct_name = name
1693    if ri.type_name_conflict:
1694        struct_name += '_'
1695    arg = free_arg_name(direction)
1696    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1697
1698
1699def _print_type(ri, direction, struct):
1700    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1701    if not direction and ri.type_name_conflict:
1702        suffix += '_'
1703
1704    if ri.op_mode == 'dump':
1705        suffix += '_dump'
1706
1707    ri.cw.block_start(line=f"struct {ri.family['name']}{suffix}")
1708
1709    meta_started = False
1710    for _, attr in struct.member_list():
1711        for type_filter in ['len', 'bit']:
1712            line = attr.presence_member(ri.ku_space, type_filter)
1713            if line:
1714                if not meta_started:
1715                    ri.cw.block_start(line=f"struct")
1716                    meta_started = True
1717                ri.cw.p(line)
1718    if meta_started:
1719        ri.cw.block_end(line='_present;')
1720        ri.cw.nl()
1721
1722    for arg in struct.inherited:
1723        ri.cw.p(f"__u32 {arg};")
1724
1725    for _, attr in struct.member_list():
1726        attr.struct_member(ri)
1727
1728    ri.cw.block_end(line=';')
1729    ri.cw.nl()
1730
1731
1732def print_type(ri, direction):
1733    _print_type(ri, direction, ri.struct[direction])
1734
1735
1736def print_type_full(ri, struct):
1737    _print_type(ri, "", struct)
1738
1739
1740def print_type_helpers(ri, direction, deref=False):
1741    print_free_prototype(ri, direction)
1742    ri.cw.nl()
1743
1744    if ri.ku_space == 'user' and direction == 'request':
1745        for _, attr in ri.struct[direction].member_list():
1746            attr.setter(ri, ri.attr_set, direction, deref=deref)
1747    ri.cw.nl()
1748
1749
1750def print_req_type_helpers(ri):
1751    print_alloc_wrapper(ri, "request")
1752    print_type_helpers(ri, "request")
1753
1754
1755def print_rsp_type_helpers(ri):
1756    if 'reply' not in ri.op[ri.op_mode]:
1757        return
1758    print_type_helpers(ri, "reply")
1759
1760
1761def print_parse_prototype(ri, direction, terminate=True):
1762    suffix = "_rsp" if direction == "reply" else "_req"
1763    term = ';' if terminate else ''
1764
1765    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1766                          ['const struct nlattr **tb',
1767                           f"struct {ri.op.render_name}{suffix} *req"],
1768                          suffix=term)
1769
1770
1771def print_req_type(ri):
1772    print_type(ri, "request")
1773
1774
1775def print_req_free(ri):
1776    if 'request' not in ri.op[ri.op_mode]:
1777        return
1778    _free_type(ri, 'request', ri.struct['request'])
1779
1780
1781def print_rsp_type(ri):
1782    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
1783        direction = 'reply'
1784    elif ri.op_mode == 'event':
1785        direction = 'reply'
1786    else:
1787        return
1788    print_type(ri, direction)
1789
1790
1791def print_wrapped_type(ri):
1792    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
1793    if ri.op_mode == 'dump':
1794        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
1795    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
1796        ri.cw.p('__u16 family;')
1797        ri.cw.p('__u8 cmd;')
1798        ri.cw.p('struct ynl_ntf_base_type *next;')
1799        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
1800    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__ ((aligned (8)));")
1801    ri.cw.block_end(line=';')
1802    ri.cw.nl()
1803    print_free_prototype(ri, 'reply')
1804    ri.cw.nl()
1805
1806
1807def _free_type_members_iter(ri, struct):
1808    for _, attr in struct.member_list():
1809        if attr.free_needs_iter():
1810            ri.cw.p('unsigned int i;')
1811            ri.cw.nl()
1812            break
1813
1814
1815def _free_type_members(ri, var, struct, ref=''):
1816    for _, attr in struct.member_list():
1817        attr.free(ri, var, ref)
1818
1819
1820def _free_type(ri, direction, struct):
1821    var = free_arg_name(direction)
1822
1823    print_free_prototype(ri, direction, suffix='')
1824    ri.cw.block_start()
1825    _free_type_members_iter(ri, struct)
1826    _free_type_members(ri, var, struct)
1827    if direction:
1828        ri.cw.p(f'free({var});')
1829    ri.cw.block_end()
1830    ri.cw.nl()
1831
1832
1833def free_rsp_nested(ri, struct):
1834    _free_type(ri, "", struct)
1835
1836
1837def print_rsp_free(ri):
1838    if 'reply' not in ri.op[ri.op_mode]:
1839        return
1840    _free_type(ri, 'reply', ri.struct['reply'])
1841
1842
1843def print_dump_type_free(ri):
1844    sub_type = type_name(ri, 'reply')
1845
1846    print_free_prototype(ri, 'reply', suffix='')
1847    ri.cw.block_start()
1848    ri.cw.p(f"{sub_type} *next = rsp;")
1849    ri.cw.nl()
1850    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
1851    _free_type_members_iter(ri, ri.struct['reply'])
1852    ri.cw.p('rsp = next;')
1853    ri.cw.p('next = rsp->next;')
1854    ri.cw.nl()
1855
1856    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1857    ri.cw.p(f'free(rsp);')
1858    ri.cw.block_end()
1859    ri.cw.block_end()
1860    ri.cw.nl()
1861
1862
1863def print_ntf_type_free(ri):
1864    print_free_prototype(ri, 'reply', suffix='')
1865    ri.cw.block_start()
1866    _free_type_members_iter(ri, ri.struct['reply'])
1867    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1868    ri.cw.p(f'free(rsp);')
1869    ri.cw.block_end()
1870    ri.cw.nl()
1871
1872
1873def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
1874    if terminate and ri and policy_should_be_static(struct.family):
1875        return
1876
1877    if terminate:
1878        prefix = 'extern '
1879    else:
1880        if ri and policy_should_be_static(struct.family):
1881            prefix = 'static '
1882        else:
1883            prefix = ''
1884
1885    suffix = ';' if terminate else ' = {'
1886
1887    max_attr = struct.attr_max_val
1888    if ri:
1889        name = ri.op.render_name
1890        if ri.op.dual_policy:
1891            name += '_' + ri.op_mode
1892    else:
1893        name = struct.render_name
1894    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
1895
1896
1897def print_req_policy(cw, struct, ri=None):
1898    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
1899    for _, arg in struct.member_list():
1900        arg.attr_policy(cw)
1901    cw.p("};")
1902    cw.nl()
1903
1904
1905def kernel_can_gen_family_struct(family):
1906    return family.proto == 'genetlink'
1907
1908
1909def policy_should_be_static(family):
1910    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
1911
1912
1913def print_kernel_op_table_fwd(family, cw, terminate):
1914    exported = not kernel_can_gen_family_struct(family)
1915
1916    if not terminate or exported:
1917        cw.p(f"/* Ops table for {family.name} */")
1918
1919        pol_to_struct = {'global': 'genl_small_ops',
1920                         'per-op': 'genl_ops',
1921                         'split': 'genl_split_ops'}
1922        struct_type = pol_to_struct[family.kernel_policy]
1923
1924        if not exported:
1925            cnt = ""
1926        elif family.kernel_policy == 'split':
1927            cnt = 0
1928            for op in family.ops.values():
1929                if 'do' in op:
1930                    cnt += 1
1931                if 'dump' in op:
1932                    cnt += 1
1933        else:
1934            cnt = len(family.ops)
1935
1936        qual = 'static const' if not exported else 'const'
1937        line = f"{qual} struct {struct_type} {family.name}_nl_ops[{cnt}]"
1938        if terminate:
1939            cw.p(f"extern {line};")
1940        else:
1941            cw.block_start(line=line + ' =')
1942
1943    if not terminate:
1944        return
1945
1946    cw.nl()
1947    for name in family.hooks['pre']['do']['list']:
1948        cw.write_func_prot('int', c_lower(name),
1949                           ['const struct genl_split_ops *ops',
1950                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1951    for name in family.hooks['post']['do']['list']:
1952        cw.write_func_prot('void', c_lower(name),
1953                           ['const struct genl_split_ops *ops',
1954                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1955    for name in family.hooks['pre']['dump']['list']:
1956        cw.write_func_prot('int', c_lower(name),
1957                           ['struct netlink_callback *cb'], suffix=';')
1958    for name in family.hooks['post']['dump']['list']:
1959        cw.write_func_prot('int', c_lower(name),
1960                           ['struct netlink_callback *cb'], suffix=';')
1961
1962    cw.nl()
1963
1964    for op_name, op in family.ops.items():
1965        if op.is_async:
1966            continue
1967
1968        if 'do' in op:
1969            name = c_lower(f"{family.name}-nl-{op_name}-doit")
1970            cw.write_func_prot('int', name,
1971                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1972
1973        if 'dump' in op:
1974            name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
1975            cw.write_func_prot('int', name,
1976                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
1977    cw.nl()
1978
1979
1980def print_kernel_op_table_hdr(family, cw):
1981    print_kernel_op_table_fwd(family, cw, terminate=True)
1982
1983
1984def print_kernel_op_table(family, cw):
1985    print_kernel_op_table_fwd(family, cw, terminate=False)
1986    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
1987        for op_name, op in family.ops.items():
1988            if op.is_async:
1989                continue
1990
1991            cw.block_start()
1992            members = [('cmd', op.enum_name)]
1993            if 'dont-validate' in op:
1994                members.append(('validate',
1995                                ' | '.join([c_upper('genl-dont-validate-' + x)
1996                                            for x in op['dont-validate']])), )
1997            for op_mode in ['do', 'dump']:
1998                if op_mode in op:
1999                    name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2000                    members.append((op_mode + 'it', name))
2001            if family.kernel_policy == 'per-op':
2002                struct = Struct(family, op['attribute-set'],
2003                                type_list=op['do']['request']['attributes'])
2004
2005                name = c_lower(f"{family.name}-{op_name}-nl-policy")
2006                members.append(('policy', name))
2007                members.append(('maxattr', struct.attr_max_val.enum_name))
2008            if 'flags' in op:
2009                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2010            cw.write_struct_init(members)
2011            cw.block_end(line=',')
2012    elif family.kernel_policy == 'split':
2013        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2014                    'dump': {'pre': 'start', 'post': 'done'}}
2015
2016        for op_name, op in family.ops.items():
2017            for op_mode in ['do', 'dump']:
2018                if op.is_async or op_mode not in op:
2019                    continue
2020
2021                cw.block_start()
2022                members = [('cmd', op.enum_name)]
2023                if 'dont-validate' in op:
2024                    dont_validate = []
2025                    for x in op['dont-validate']:
2026                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2027                            continue
2028                        if op_mode == "dump" and x == 'strict':
2029                            continue
2030                        dont_validate.append(x)
2031
2032                    if dont_validate:
2033                        members.append(('validate',
2034                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2035                                                    for x in dont_validate])), )
2036                name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2037                if 'pre' in op[op_mode]:
2038                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2039                members.append((op_mode + 'it', name))
2040                if 'post' in op[op_mode]:
2041                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2042                if 'request' in op[op_mode]:
2043                    struct = Struct(family, op['attribute-set'],
2044                                    type_list=op[op_mode]['request']['attributes'])
2045
2046                    if op.dual_policy:
2047                        name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
2048                    else:
2049                        name = c_lower(f"{family.name}-{op_name}-nl-policy")
2050                    members.append(('policy', name))
2051                    members.append(('maxattr', struct.attr_max_val.enum_name))
2052                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2053                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2054                cw.write_struct_init(members)
2055                cw.block_end(line=',')
2056
2057    cw.block_end(line=';')
2058    cw.nl()
2059
2060
2061def print_kernel_mcgrp_hdr(family, cw):
2062    if not family.mcgrps['list']:
2063        return
2064
2065    cw.block_start('enum')
2066    for grp in family.mcgrps['list']:
2067        grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
2068        cw.p(grp_id)
2069    cw.block_end(';')
2070    cw.nl()
2071
2072
2073def print_kernel_mcgrp_src(family, cw):
2074    if not family.mcgrps['list']:
2075        return
2076
2077    cw.block_start('static const struct genl_multicast_group ' + family.name + '_nl_mcgrps[] =')
2078    for grp in family.mcgrps['list']:
2079        name = grp['name']
2080        grp_id = c_upper(f"{family.name}-nlgrp-{name}")
2081        cw.p('[' + grp_id + '] = { "' + name + '", },')
2082    cw.block_end(';')
2083    cw.nl()
2084
2085
2086def print_kernel_family_struct_hdr(family, cw):
2087    if not kernel_can_gen_family_struct(family):
2088        return
2089
2090    cw.p(f"extern struct genl_family {family.name}_nl_family;")
2091    cw.nl()
2092
2093
2094def print_kernel_family_struct_src(family, cw):
2095    if not kernel_can_gen_family_struct(family):
2096        return
2097
2098    cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
2099    cw.p('.name\t\t= ' + family.fam_key + ',')
2100    cw.p('.version\t= ' + family.ver_key + ',')
2101    cw.p('.netnsok\t= true,')
2102    cw.p('.parallel_ops\t= true,')
2103    cw.p('.module\t\t= THIS_MODULE,')
2104    if family.kernel_policy == 'per-op':
2105        cw.p(f'.ops\t\t= {family.name}_nl_ops,')
2106        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.name}_nl_ops),')
2107    elif family.kernel_policy == 'split':
2108        cw.p(f'.split_ops\t= {family.name}_nl_ops,')
2109        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.name}_nl_ops),')
2110    if family.mcgrps['list']:
2111        cw.p(f'.mcgrps\t\t= {family.name}_nl_mcgrps,')
2112        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.name}_nl_mcgrps),')
2113    cw.block_end(';')
2114
2115
2116def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2117    start_line = 'enum'
2118    if enum_name in obj:
2119        if obj[enum_name]:
2120            start_line = 'enum ' + c_lower(obj[enum_name])
2121    elif ckey and ckey in obj:
2122        start_line = 'enum ' + family.name + '_' + c_lower(obj[ckey])
2123    cw.block_start(line=start_line)
2124
2125
2126def render_uapi(family, cw):
2127    hdr_prot = f"_UAPI_LINUX_{family.name.upper()}_H"
2128    cw.p('#ifndef ' + hdr_prot)
2129    cw.p('#define ' + hdr_prot)
2130    cw.nl()
2131
2132    defines = [(family.fam_key, family["name"]),
2133               (family.ver_key, family.get('version', 1))]
2134    cw.writes_defines(defines)
2135    cw.nl()
2136
2137    defines = []
2138    for const in family['definitions']:
2139        if const['type'] != 'const':
2140            cw.writes_defines(defines)
2141            defines = []
2142            cw.nl()
2143
2144        # Write kdoc for enum and flags (one day maybe also structs)
2145        if const['type'] == 'enum' or const['type'] == 'flags':
2146            enum = family.consts[const['name']]
2147
2148            if enum.has_doc():
2149                cw.p('/**')
2150                doc = ''
2151                if 'doc' in enum:
2152                    doc = ' - ' + enum['doc']
2153                cw.write_doc_line(enum.enum_name + doc)
2154                for entry in enum.entries.values():
2155                    if entry.has_doc():
2156                        doc = '@' + entry.c_name + ': ' + entry['doc']
2157                        cw.write_doc_line(doc)
2158                cw.p(' */')
2159
2160            uapi_enum_start(family, cw, const, 'name')
2161            name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
2162            for entry in enum.entries.values():
2163                suffix = ','
2164                if entry.value_change:
2165                    suffix = f" = {entry.user_value()}" + suffix
2166                cw.p(entry.c_name + suffix)
2167
2168            if const.get('render-max', False):
2169                cw.nl()
2170                cw.p('/* private: */')
2171                if const['type'] == 'flags':
2172                    max_name = c_upper(name_pfx + 'mask')
2173                    max_val = f' = {enum.get_mask()},'
2174                    cw.p(max_name + max_val)
2175                else:
2176                    max_name = c_upper(name_pfx + 'max')
2177                    cw.p('__' + max_name + ',')
2178                    cw.p(max_name + ' = (__' + max_name + ' - 1)')
2179            cw.block_end(line=';')
2180            cw.nl()
2181        elif const['type'] == 'const':
2182            defines.append([c_upper(family.get('c-define-name',
2183                                               f"{family.name}-{const['name']}")),
2184                            const['value']])
2185
2186    if defines:
2187        cw.writes_defines(defines)
2188        cw.nl()
2189
2190    max_by_define = family.get('max-by-define', False)
2191
2192    for _, attr_set in family.attr_sets.items():
2193        if attr_set.subset_of:
2194            continue
2195
2196        cnt_name = c_upper(family.get('attr-cnt-name', f"__{attr_set.name_prefix}MAX"))
2197        max_value = f"({cnt_name} - 1)"
2198
2199        val = 0
2200        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2201        for _, attr in attr_set.items():
2202            suffix = ','
2203            if attr.value != val:
2204                suffix = f" = {attr.value},"
2205                val = attr.value
2206            val += 1
2207            cw.p(attr.enum_name + suffix)
2208        cw.nl()
2209        cw.p(cnt_name + ('' if max_by_define else ','))
2210        if not max_by_define:
2211            cw.p(f"{attr_set.max_name} = {max_value}")
2212        cw.block_end(line=';')
2213        if max_by_define:
2214            cw.p(f"#define {attr_set.max_name} {max_value}")
2215        cw.nl()
2216
2217    # Commands
2218    separate_ntf = 'async-prefix' in family['operations']
2219
2220    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2221    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2222    max_value = f"({cnt_name} - 1)"
2223
2224    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2225    val = 0
2226    for op in family.msgs.values():
2227        if separate_ntf and ('notify' in op or 'event' in op):
2228            continue
2229
2230        suffix = ','
2231        if op.value != val:
2232            suffix = f" = {op.value},"
2233            val = op.value
2234        cw.p(op.enum_name + suffix)
2235        val += 1
2236    cw.nl()
2237    cw.p(cnt_name + ('' if max_by_define else ','))
2238    if not max_by_define:
2239        cw.p(f"{max_name} = {max_value}")
2240    cw.block_end(line=';')
2241    if max_by_define:
2242        cw.p(f"#define {max_name} {max_value}")
2243    cw.nl()
2244
2245    if separate_ntf:
2246        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2247        for op in family.msgs.values():
2248            if separate_ntf and not ('notify' in op or 'event' in op):
2249                continue
2250
2251            suffix = ','
2252            if 'value' in op:
2253                suffix = f" = {op['value']},"
2254            cw.p(op.enum_name + suffix)
2255        cw.block_end(line=';')
2256        cw.nl()
2257
2258    # Multicast
2259    defines = []
2260    for grp in family.mcgrps['list']:
2261        name = grp['name']
2262        defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2263                        f'{name}'])
2264    cw.nl()
2265    if defines:
2266        cw.writes_defines(defines)
2267        cw.nl()
2268
2269    cw.p(f'#endif /* {hdr_prot} */')
2270
2271
2272def _render_user_ntf_entry(ri, op):
2273    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2274    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2275    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2276    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2277    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2278    ri.cw.block_end(line=',')
2279
2280
2281def render_user_family(family, cw, prototype):
2282    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2283    if prototype:
2284        cw.p(f'extern {symbol};')
2285        return
2286
2287    if family.ntfs:
2288        cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2289        for ntf_op_name, ntf_op in family.ntfs.items():
2290            if 'notify' in ntf_op:
2291                op = family.ops[ntf_op['notify']]
2292                ri = RenderInfo(cw, family, "user", op, "notify")
2293            elif 'event' in ntf_op:
2294                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2295            else:
2296                raise Exception('Invalid notification ' + ntf_op_name)
2297            _render_user_ntf_entry(ri, ntf_op)
2298        for op_name, op in family.ops.items():
2299            if 'event' not in op:
2300                continue
2301            ri = RenderInfo(cw, family, "user", op, "event")
2302            _render_user_ntf_entry(ri, op)
2303        cw.block_end(line=";")
2304        cw.nl()
2305
2306    cw.block_start(f'{symbol} = ')
2307    cw.p(f'.name\t\t= "{family.name}",')
2308    if family.ntfs:
2309        cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2310        cw.p(f".ntf_info_size\t= MNL_ARRAY_SIZE({family['name']}_ntf_info),")
2311    cw.block_end(line=';')
2312
2313
2314def find_kernel_root(full_path):
2315    sub_path = ''
2316    while True:
2317        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2318        full_path = os.path.dirname(full_path)
2319        maintainers = os.path.join(full_path, "MAINTAINERS")
2320        if os.path.exists(maintainers):
2321            return full_path, sub_path[:-1]
2322
2323
2324def main():
2325    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2326    parser.add_argument('--mode', dest='mode', type=str, required=True)
2327    parser.add_argument('--spec', dest='spec', type=str, required=True)
2328    parser.add_argument('--header', dest='header', action='store_true', default=None)
2329    parser.add_argument('--source', dest='header', action='store_false')
2330    parser.add_argument('--user-header', nargs='+', default=[])
2331    parser.add_argument('--exclude-op', action='append', default=[])
2332    parser.add_argument('-o', dest='out_file', type=str, default=None)
2333    args = parser.parse_args()
2334
2335    if args.header is None:
2336        parser.error("--header or --source is required")
2337
2338    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2339
2340    try:
2341        parsed = Family(args.spec, exclude_ops)
2342        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2343            print('Spec license:', parsed.license)
2344            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2345            os.sys.exit(1)
2346    except yaml.YAMLError as exc:
2347        print(exc)
2348        os.sys.exit(1)
2349        return
2350
2351    supported_models = ['unified']
2352    if args.mode in ['user', 'kernel']:
2353        supported_models += ['directional']
2354    if parsed.msg_id_model not in supported_models:
2355        print(f'Message enum-model {parsed.msg_id_model} not supported for {args.mode} generation')
2356        os.sys.exit(1)
2357
2358    cw = CodeWriter(BaseNlLib(), args.out_file)
2359
2360    _, spec_kernel = find_kernel_root(args.spec)
2361    if args.mode == 'uapi' or args.header:
2362        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2363    else:
2364        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2365    cw.p("/* Do not edit directly, auto-generated from: */")
2366    cw.p(f"/*\t{spec_kernel} */")
2367    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2368    if args.exclude_op or args.user_header:
2369        line = ''
2370        line += ' --user-header '.join([''] + args.user_header)
2371        line += ' --exclude-op '.join([''] + args.exclude_op)
2372        cw.p(f'/* YNL-ARG{line} */')
2373    cw.nl()
2374
2375    if args.mode == 'uapi':
2376        render_uapi(parsed, cw)
2377        return
2378
2379    hdr_prot = f"_LINUX_{parsed.name.upper()}_GEN_H"
2380    if args.header:
2381        cw.p('#ifndef ' + hdr_prot)
2382        cw.p('#define ' + hdr_prot)
2383        cw.nl()
2384
2385    if args.mode == 'kernel':
2386        cw.p('#include <net/netlink.h>')
2387        cw.p('#include <net/genetlink.h>')
2388        cw.nl()
2389        if not args.header:
2390            if args.out_file:
2391                cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2392            cw.nl()
2393        headers = ['uapi/' + parsed.uapi_header]
2394    else:
2395        cw.p('#include <stdlib.h>')
2396        cw.p('#include <string.h>')
2397        if args.header:
2398            cw.p('#include <linux/types.h>')
2399        else:
2400            cw.p(f'#include "{parsed.name}-user.h"')
2401            cw.p('#include "ynl.h"')
2402        headers = [parsed.uapi_header]
2403    for definition in parsed['definitions']:
2404        if 'header' in definition:
2405            headers.append(definition['header'])
2406    for one in headers:
2407        cw.p(f"#include <{one}>")
2408    cw.nl()
2409
2410    if args.mode == "user":
2411        if not args.header:
2412            cw.p("#include <libmnl/libmnl.h>")
2413            cw.p("#include <linux/genetlink.h>")
2414            cw.nl()
2415            for one in args.user_header:
2416                cw.p(f'#include "{one}"')
2417        else:
2418            cw.p('struct ynl_sock;')
2419            cw.nl()
2420            render_user_family(parsed, cw, True)
2421        cw.nl()
2422
2423    if args.mode == "kernel":
2424        if args.header:
2425            for _, struct in sorted(parsed.pure_nested_structs.items()):
2426                if struct.request:
2427                    cw.p('/* Common nested types */')
2428                    break
2429            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2430                if struct.request:
2431                    print_req_policy_fwd(cw, struct)
2432            cw.nl()
2433
2434            if parsed.kernel_policy == 'global':
2435                cw.p(f"/* Global operation policy for {parsed.name} */")
2436
2437                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2438                print_req_policy_fwd(cw, struct)
2439                cw.nl()
2440
2441            if parsed.kernel_policy in {'per-op', 'split'}:
2442                for op_name, op in parsed.ops.items():
2443                    if 'do' in op and 'event' not in op:
2444                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2445                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2446                        cw.nl()
2447
2448            print_kernel_op_table_hdr(parsed, cw)
2449            print_kernel_mcgrp_hdr(parsed, cw)
2450            print_kernel_family_struct_hdr(parsed, cw)
2451        else:
2452            for _, struct in sorted(parsed.pure_nested_structs.items()):
2453                if struct.request:
2454                    cw.p('/* Common nested types */')
2455                    break
2456            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2457                if struct.request:
2458                    print_req_policy(cw, struct)
2459            cw.nl()
2460
2461            if parsed.kernel_policy == 'global':
2462                cw.p(f"/* Global operation policy for {parsed.name} */")
2463
2464                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2465                print_req_policy(cw, struct)
2466                cw.nl()
2467
2468            for op_name, op in parsed.ops.items():
2469                if parsed.kernel_policy in {'per-op', 'split'}:
2470                    for op_mode in ['do', 'dump']:
2471                        if op_mode in op and 'request' in op[op_mode]:
2472                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2473                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2474                            print_req_policy(cw, ri.struct['request'], ri=ri)
2475                            cw.nl()
2476
2477            print_kernel_op_table(parsed, cw)
2478            print_kernel_mcgrp_src(parsed, cw)
2479            print_kernel_family_struct_src(parsed, cw)
2480
2481    if args.mode == "user":
2482        if args.header:
2483            cw.p('/* Enums */')
2484            put_op_name_fwd(parsed, cw)
2485
2486            for name, const in parsed.consts.items():
2487                if isinstance(const, EnumSet):
2488                    put_enum_to_str_fwd(parsed, cw, const)
2489            cw.nl()
2490
2491            cw.p('/* Common nested types */')
2492            for attr_set, struct in parsed.pure_nested_structs.items():
2493                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2494                print_type_full(ri, struct)
2495
2496            for op_name, op in parsed.ops.items():
2497                cw.p(f"/* ============== {op.enum_name} ============== */")
2498
2499                if 'do' in op and 'event' not in op:
2500                    cw.p(f"/* {op.enum_name} - do */")
2501                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2502                    print_req_type(ri)
2503                    print_req_type_helpers(ri)
2504                    cw.nl()
2505                    print_rsp_type(ri)
2506                    print_rsp_type_helpers(ri)
2507                    cw.nl()
2508                    print_req_prototype(ri)
2509                    cw.nl()
2510
2511                if 'dump' in op:
2512                    cw.p(f"/* {op.enum_name} - dump */")
2513                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2514                    if 'request' in op['dump']:
2515                        print_req_type(ri)
2516                        print_req_type_helpers(ri)
2517                    if not ri.type_consistent:
2518                        print_rsp_type(ri)
2519                    print_wrapped_type(ri)
2520                    print_dump_prototype(ri)
2521                    cw.nl()
2522
2523                if op.has_ntf:
2524                    cw.p(f"/* {op.enum_name} - notify */")
2525                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2526                    if not ri.type_consistent:
2527                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2528                    print_wrapped_type(ri)
2529
2530            for op_name, op in parsed.ntfs.items():
2531                if 'event' in op:
2532                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2533                    cw.p(f"/* {op.enum_name} - event */")
2534                    print_rsp_type(ri)
2535                    cw.nl()
2536                    print_wrapped_type(ri)
2537            cw.nl()
2538        else:
2539            cw.p('/* Enums */')
2540            put_op_name(parsed, cw)
2541
2542            for name, const in parsed.consts.items():
2543                if isinstance(const, EnumSet):
2544                    put_enum_to_str(parsed, cw, const)
2545            cw.nl()
2546
2547            cw.p('/* Policies */')
2548            for name in parsed.pure_nested_structs:
2549                struct = Struct(parsed, name)
2550                put_typol(cw, struct)
2551            for name in parsed.root_sets:
2552                struct = Struct(parsed, name)
2553                put_typol(cw, struct)
2554
2555            cw.p('/* Common nested types */')
2556            for attr_set, struct in parsed.pure_nested_structs.items():
2557                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2558
2559                free_rsp_nested(ri, struct)
2560                if struct.request:
2561                    put_req_nested(ri, struct)
2562                if struct.reply:
2563                    parse_rsp_nested(ri, struct)
2564
2565            for op_name, op in parsed.ops.items():
2566                cw.p(f"/* ============== {op.enum_name} ============== */")
2567                if 'do' in op and 'event' not in op:
2568                    cw.p(f"/* {op.enum_name} - do */")
2569                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2570                    print_req_free(ri)
2571                    print_rsp_free(ri)
2572                    parse_rsp_msg(ri)
2573                    print_req(ri)
2574                    cw.nl()
2575
2576                if 'dump' in op:
2577                    cw.p(f"/* {op.enum_name} - dump */")
2578                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
2579                    if not ri.type_consistent:
2580                        parse_rsp_msg(ri, deref=True)
2581                    print_dump_type_free(ri)
2582                    print_dump(ri)
2583                    cw.nl()
2584
2585                if op.has_ntf:
2586                    cw.p(f"/* {op.enum_name} - notify */")
2587                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2588                    if not ri.type_consistent:
2589                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2590                    print_ntf_type_free(ri)
2591
2592            for op_name, op in parsed.ntfs.items():
2593                if 'event' in op:
2594                    cw.p(f"/* {op.enum_name} - event */")
2595
2596                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2597                    parse_rsp_msg(ri)
2598
2599                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
2600                    print_ntf_type_free(ri)
2601            cw.nl()
2602            render_user_family(parsed, cw, False)
2603
2604    if args.header:
2605        cw.p(f'#endif /* {hdr_prot} */')
2606
2607
2608if __name__ == "__main__":
2609    main()
2610