1"""
2   Copyright (c) 2023 Nordic Semiconductor ASA
3
4   Licensed under the Apache License, Version 2.0 (the "License");
5   you may not use this file except in compliance with the License.
6   You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10   Unless required by applicable law or agreed to in writing, software
11   distributed under the License is distributed on an "AS IS" BASIS,
12   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   See the License for the specific language governing permissions and
14   limitations under the License.
15"""
16
17import struct
18import inspect
19from typing import List
20from abc import ABC, abstractmethod
21
22from tlv.dataset_tlv import MeshcopTlvType
23from tlv.tlv import TLV
24
25
26class DatasetEntry(ABC):
27
28    def __init__(self, type: MeshcopTlvType):
29        self.type = type
30        self.length = None
31        self.maxlen = None
32
33    def print_content(self, indent: int = 0, excluded_fields: List[str] = []):
34        excluded_fields += ['length', 'maxlen', 'type']
35        indentation = " " * 4 * indent
36        for attr_name in dir(self):
37            if not attr_name.startswith('_') and attr_name not in excluded_fields:
38                value = getattr(self, attr_name)
39                if not inspect.ismethod(value):
40                    if isinstance(value, bytes):
41                        value = value.hex()
42                    print(f'{indentation}{attr_name}: {value}')
43
44    @abstractmethod
45    def to_tlv(self) -> TLV:
46        pass
47
48    @abstractmethod
49    def set_from_tlv(self, tlv: TLV):
50        pass
51
52    @abstractmethod
53    def set(self, args: List[str]):
54        pass
55
56
57class ActiveTimestamp(DatasetEntry):
58
59    def __init__(self):
60        super().__init__(MeshcopTlvType.ACTIVETIMESTAMP)
61        self.length = 8  # spec defined
62        self.seconds = 0
63        self.ubit = 0
64        self.ticks = 0
65
66    def set(self, args: List[str]):
67        if len(args) == 0:
68            raise ValueError('No argument for ActiveTimestamp')
69        self._seconds = int(args[0])
70
71    def set_from_tlv(self, tlv: TLV):
72        (value,) = struct.unpack('>Q', tlv.value)
73        self.ubit = value & 0x1
74        self.ticks = (value >> 1) & 0x7FFF
75        self.seconds = (value >> 16) & 0xFFFF
76
77    def to_tlv(self):
78        value = (self.seconds << 16) | (self.ticks << 1) | self.ubit
79        tlv = struct.pack('>BBQ', self.type.value, self.length, value)
80        return TLV.from_bytes(tlv)
81
82
83class PendingTimestamp(DatasetEntry):
84
85    def __init__(self):
86        super().__init__(MeshcopTlvType.PENDINGTIMESTAMP)
87        self.length = 8  # spec defined
88        self.seconds = 0
89        self.ubit = 0
90        self.ticks = 0
91
92    def set(self, args: List[str]):
93        if len(args) == 0:
94            raise ValueError('No argument for PendingTimestamp')
95        self._seconds = int(args[0])
96
97    def set_from_tlv(self, tlv: TLV):
98        (value,) = struct.unpack('>Q', tlv.value)
99        self.ubit = value & 0x1
100        self.ticks = (value >> 1) & 0x7FFF
101        self.seconds = (value >> 16) & 0xFFFF
102
103    def to_tlv(self):
104        value = (self.seconds << 16) | (self.ticks << 1) | self.ubit
105        tlv = struct.pack('>BBQ', self.type.value, self.length, value)
106        return TLV.from_bytes(tlv)
107
108
109class NetworkKey(DatasetEntry):
110
111    def __init__(self):
112        super().__init__(MeshcopTlvType.NETWORKKEY)
113        self.length = 16  # spec defined
114        self.data: str = ''
115
116    def set(self, args: List[str]):
117        if len(args) == 0:
118            raise ValueError('No argument for NetworkKey')
119        if args[0].startswith('0x'):
120            args[0] = args[0][2:]
121        nk = args[0]
122        if len(nk) != self.length * 2:  # need length * 2 hex characters
123            raise ValueError('Invalid length of NetworkKey')
124        self.data = nk
125
126    def set_from_tlv(self, tlv: TLV):
127        self.data = tlv.value.hex()
128
129    def to_tlv(self):
130        if len(self.data) != self.length * 2:  # need length * 2 hex characters
131            raise ValueError('Invalid length of NetworkKey')
132        value = bytes.fromhex(self.data)
133        tlv = struct.pack('>BB', self.type.value, self.length) + value
134        return TLV.from_bytes(tlv)
135
136
137class NetworkName(DatasetEntry):
138
139    def __init__(self):
140        super().__init__(MeshcopTlvType.NETWORKNAME)
141        self.maxlen = 16
142        self.data: str = ''
143
144    def set(self, args: List[str]):
145        if len(args) == 0:
146            raise ValueError('No argument for NetworkName')
147        nn = args[0]
148        if len(nn) > self.maxlen:
149            raise ValueError('Invalid length of NetworkName')
150        self.data = nn
151
152    def set_from_tlv(self, tlv: TLV):
153        self.data = tlv.value.decode('utf-8')
154
155    def to_tlv(self):
156        length_value = len(self.data)
157        value = self.data.encode('utf-8')
158        tlv = struct.pack('>BB', self.type.value, length_value) + value
159        return TLV.from_bytes(tlv)
160
161
162class ExtPanID(DatasetEntry):
163
164    def __init__(self):
165        super().__init__(MeshcopTlvType.EXTPANID)
166        self.length = 8  # spec defined
167        self.data: str = ''
168
169    def set(self, args: List[str]):
170        if len(args) == 0:
171            raise ValueError('No argument for ExtPanID')
172        if args[0].startswith('0x'):
173            args[0] = args[0][2:]
174        epid = args[0]
175        if len(epid) != self.length * 2:  # need length*2 hex characters
176            raise ValueError('Invalid length of ExtPanID')
177        self.data = epid
178
179    def set_from_tlv(self, tlv: TLV):
180        self.data = tlv.value.hex()
181
182    def to_tlv(self):
183        if len(self.data) != self.length * 2:  # need length*2 hex characters
184            raise ValueError('Invalid length of ExtPanID')
185
186        value = bytes.fromhex(self.data)
187        tlv = struct.pack('>BB', self.type.value, self.length) + value
188        return TLV.from_bytes(tlv)
189
190
191class MeshLocalPrefix(DatasetEntry):
192
193    def __init__(self):
194        super().__init__(MeshcopTlvType.MESHLOCALPREFIX)
195        self.length = 8  # spec defined
196        self.data = ''
197
198    def set(self, args: List[str]):
199        if len(args) == 0:
200            raise ValueError('No argument for MeshLocalPrefix')
201        if args[0].startswith('0x'):
202            args[0] = args[0][2:]
203        mlp = args[0]
204        if len(mlp) != self.length * 2:  # need length*2 hex characters
205            raise ValueError('Invalid length of MeshLocalPrefix')
206        self.data = mlp
207
208    def set_from_tlv(self, tlv: TLV):
209        self.data = tlv.value.hex()
210
211    def to_tlv(self):
212        if len(self.data) != self.length * 2:  # need length*2 hex characters
213            raise ValueError('Invalid length of MeshLocalPrefix')
214
215        value = bytes.fromhex(self.data)
216        tlv = struct.pack('>BB', self.type.value, self.length) + value
217        return TLV.from_bytes(tlv)
218
219
220class DelayTimer(DatasetEntry):
221
222    def __init__(self):
223        super().__init__(MeshcopTlvType.DELAYTIMER)
224        self.length = 4  # spec defined
225        self.time_remaining = 0
226
227    def set(self, args: List[str]):
228        if len(args) == 0:
229            raise ValueError('No argument for DelayTimer')
230        dt = int(args[0])
231        self.time_remaining = dt
232
233    def set_from_tlv(self, tlv: TLV):
234        self.time_remaining = tlv.value
235
236    def to_tlv(self):
237        value = self.time_remaining
238        tlv = struct.pack('>BBI', self.type.value, self.length, value)
239        return TLV.from_bytes(tlv)
240
241
242class PanID(DatasetEntry):
243
244    def __init__(self):
245        super().__init__(MeshcopTlvType.PANID)
246        self.length = 2  # spec defined
247        self.data: str = ''
248
249    def set(self, args: List[str]):
250        if len(args) == 0:
251            raise ValueError('No argument for PanID')
252        if args[0].startswith('0x'):
253            args[0] = args[0][2:]
254        pid = args[0]
255        if len(pid) != self.length * 2:  # need length*2 hex characters
256            raise ValueError('Invalid length of PanID')
257        self.data = pid
258
259    def set_from_tlv(self, tlv: TLV):
260        self.data = tlv.value.hex()
261
262    def to_tlv(self):
263        if len(self.data) != self.length * 2:  # need length*2 hex characters
264            raise ValueError('Invalid length of PanID')
265
266        value = bytes.fromhex(self.data)
267        tlv = struct.pack('>BB', self.type.value, self.length) + value
268        return TLV.from_bytes(tlv)
269
270
271class Channel(DatasetEntry):
272
273    def __init__(self):
274        super().__init__(MeshcopTlvType.CHANNEL)
275        self.length = 3  # spec defined
276        self.channel_page = 0
277        self.channel = 0
278
279    def set(self, args: List[str]):
280        if len(args) == 0:
281            raise ValueError('No argument for Channel')
282        channel = int(args[0])
283        self.channel = channel
284
285    def set_from_tlv(self, tlv: TLV):
286        self.channel = int.from_bytes(tlv.value[1:3], byteorder='big')
287        self.channel_page = tlv.value[0]
288
289    def to_tlv(self):
290        tlv = struct.pack('>BBB', self.type.value, self.length, self.channel_page)
291        tlv += struct.pack('>H', self.channel)
292        return TLV.from_bytes(tlv)
293
294
295class Pskc(DatasetEntry):
296
297    def __init__(self):
298        super().__init__(MeshcopTlvType.PSKC)
299        self.maxlen = 16
300        self.data = ''
301
302    def set(self, args: List[str]):
303        if len(args) == 0:
304            raise ValueError('No argument for Pskc')
305        if args[0].startswith('0x'):
306            args[0] = args[0][2:]
307        pskc = args[0]
308        if (len(pskc) > self.maxlen * 2):
309            raise ValueError(f'Invalid length of PSKc. Can be max {self.length * 2} hex characters.')
310        self.data = pskc
311
312    def set_from_tlv(self, tlv: TLV):
313        self.data = tlv.value.hex()
314
315    def to_tlv(self):
316        # should not exceed max length*2 hex characters
317        if (len(self.data) > self.maxlen * 2):
318            raise ValueError('Invalid length of Pskc')
319
320        length_value = len(self.data) // 2
321        value = bytes.fromhex(self.data)
322        tlv = struct.pack('>BB', self.type.value, length_value) + value
323        return TLV.from_bytes(tlv)
324
325
326class SecurityPolicy(DatasetEntry):
327
328    def __init__(self):
329        super().__init__(MeshcopTlvType.SECURITYPOLICY)
330        self.length = 4  # spec defined
331        self.rotation_time = 0
332        self.out_of_band = 0  # o
333        self.native = 0  # n
334        self.routers_1_2 = 0  # r
335        self.external_commissioners = 0  # c
336        self.reserved = 0
337        self.commercial_commissioning_off = 0  # C
338        self.autonomous_enrollment_off = 0  # e
339        self.networkkey_provisioning_off = 0  # p
340        self.thread_over_ble = 0
341        self.non_ccm_routers_off = 0  # R
342        self.rsv = 0b111
343        self.version_threshold = 0
344
345    def set(self, args: List[str]):
346        if len(args) == 0:
347            raise ValueError('No argument for SecurityPolicy')
348        rotation_time, flags, version_threshold = args + [None] * (3 - len(args))
349        self.rotation_time = int(rotation_time) & 0xffff
350
351        if flags:
352            self.out_of_band = 1 if 'o' in flags else 0
353            self.native = 1 if 'n' in flags else 0
354            self.routers_1_2 = 1 if 'r' in flags else 0
355            self.external_commissioners = 1 if 'c' in flags else 0
356            self.commercial_commissioning_off = 0 if 'C' in flags else 1
357            self.autonomous_enrollment_off = 0 if 'e' in flags else 1
358            self.networkkey_provisioning_off = 0 if 'p' in flags else 1
359            self.non_ccm_routers_off = 0 if 'R' in flags else 1
360
361        if version_threshold:
362            self.version_threshold = int(version_threshold) & 0b111
363
364    def set_from_tlv(self, tlv: TLV):
365        value = int.from_bytes(tlv.value, byteorder='big')
366
367        self.rotation_time = (value >> 16) & 0xFFFF
368        self.out_of_band = (value >> 15) & 0x1
369        self.native = (value >> 14) & 0x1
370        self.routers_1_2 = (value >> 13) & 0x1
371        self.external_commissioners = (value >> 12) & 0x1
372        self.reserved = (value >> 11) & 0x1
373        self.commercial_commissioning_off = (value >> 10) & 0x1
374        self.autonomous_enrollment_off = (value >> 9) & 0x1
375        self.networkkey_provisioning_off = (value >> 8) & 0x1
376        self.thread_over_ble = (value >> 7) & 0x1
377        self.non_ccm_routers_off = (value >> 6) & 0x1
378        self.rsv = (value >> 3) & 0x7
379        self.version_threshold = value & 0x7
380
381    def to_tlv(self):
382        value = self.rotation_time << 16
383        value |= self.out_of_band << 15
384        value |= self.native << 14
385        value |= self.routers_1_2 << 13
386        value |= self.external_commissioners << 12
387        value |= self.reserved << 11
388        value |= self.commercial_commissioning_off << 10
389        value |= self.autonomous_enrollment_off << 9
390        value |= self.networkkey_provisioning_off << 8
391        value |= self.thread_over_ble << 7
392        value |= self.non_ccm_routers_off << 6
393        value |= self.rsv << 3
394        value |= self.version_threshold
395        tlv = struct.pack('>BBI', self.type.value, self.length, value)
396        return TLV.from_bytes(tlv)
397
398    def print_content(self, indent: int = 0):
399        flags = ''
400        if self.out_of_band:
401            flags += 'o'
402        if self.native:
403            flags += 'n'
404        if self.routers_1_2:
405            flags += 'r'
406        if self.external_commissioners:
407            flags += 'c'
408        if not self.commercial_commissioning_off:
409            flags += 'C'
410        if not self.autonomous_enrollment_off:
411            flags += 'e'
412        if not self.networkkey_provisioning_off:
413            flags += 'p'
414        if not self.non_ccm_routers_off:
415            flags += 'R'
416        indentation = " " * 4 * indent
417        print(f'{indentation}rotation_time: {self.rotation_time}')
418        print(f'{indentation}flags: {flags}')
419        print(f'{indentation}version_threshold: {self.version_threshold}')
420
421
422class ChannelMask(DatasetEntry):
423
424    def __init__(self):
425        super().__init__(MeshcopTlvType.CHANNELMASK)
426        self.entries: List[ChannelMaskEntry] = []
427
428    def set(self, args: List[str]):
429        # to remain consistent with the OpenThread CLI API,
430        # provided hex string is value of the first channel mask entry
431        if len(args) == 0:
432            raise ValueError('No argument for ChannelMask')
433        if args[0].startswith('0x'):
434            args[0] = args[0][2:]
435        channelmsk = bytes.fromhex(args[0])
436        self.entries = [ChannelMaskEntry()]
437        self.entries[0].channel_mask = channelmsk
438
439    def print_content(self, indent: int = 0):
440        super().print_content(indent=indent, excluded_fields=['entries'])
441        indentation = " " * 4 * indent
442        for i, entry in enumerate(self.entries):
443            print(f'{indentation}ChannelMaskEntry {i}')
444            entry.print_content(indent=indent + 1)
445
446    def set_from_tlv(self, tlv: TLV):
447        self.entries = []
448        for mask_entry_tlv in TLV.parse_tlvs(tlv.value):
449            new_entry = ChannelMaskEntry()
450            new_entry.set_from_tlv(mask_entry_tlv)
451            self.entries.append(new_entry)
452
453    def to_tlv(self):
454        tlv_value = b''.join(mask_entry.to_tlv().to_bytes() for mask_entry in self.entries)
455        tlv = struct.pack('>BB', self.type.value, len(tlv_value)) + tlv_value
456        return TLV.from_bytes(tlv)
457
458
459class ChannelMaskEntry(DatasetEntry):
460
461    def __init__(self):
462        self.channel_page = 0
463        self.channel_mask: bytes = None
464
465    def set(self, args: List[str]):
466        pass
467
468    def set_from_tlv(self, tlv: TLV):
469        self.channel_page = tlv.type
470        self.mask_length = len(tlv.value)
471        self.channel_mask = tlv.value
472
473    def to_tlv(self):
474        mask_len = len(self.channel_mask)
475        tlv = struct.pack('>BB', self.channel_page, mask_len) + self.channel_mask
476        return TLV.from_bytes(tlv)
477
478
479ENTRY_CLASSES = {
480    MeshcopTlvType.ACTIVETIMESTAMP: ActiveTimestamp,
481    MeshcopTlvType.PENDINGTIMESTAMP: PendingTimestamp,
482    MeshcopTlvType.NETWORKKEY: NetworkKey,
483    MeshcopTlvType.NETWORKNAME: NetworkName,
484    MeshcopTlvType.EXTPANID: ExtPanID,
485    MeshcopTlvType.MESHLOCALPREFIX: MeshLocalPrefix,
486    MeshcopTlvType.DELAYTIMER: DelayTimer,
487    MeshcopTlvType.PANID: PanID,
488    MeshcopTlvType.CHANNEL: Channel,
489    MeshcopTlvType.PSKC: Pskc,
490    MeshcopTlvType.SECURITYPOLICY: SecurityPolicy,
491    MeshcopTlvType.CHANNELMASK: ChannelMask
492}
493
494
495def create_dataset_entry(type: MeshcopTlvType, args=None):
496    entry_class = ENTRY_CLASSES.get(type)
497    if not entry_class:
498        raise ValueError(f"Invalid configuration type: {type}")
499
500    res = entry_class()
501    if args:
502        res.set(args)
503    return res
504