1"""
2REST client for Leshan demo server
3##################################
4
5Copyright (c) 2023 Nordic Semiconductor ASA
6
7SPDX-License-Identifier: Apache-2.0
8
9"""
10
11from __future__ import annotations
12
13import json
14import binascii
15import time
16from datetime import datetime
17from contextlib import contextmanager
18import requests
19
20class Leshan:
21    """This class represents a Leshan client that interacts with demo server's REAT API"""
22    def __init__(self, url: str):
23        """Initialize Leshan client and check if server is available"""
24        self.api_url = url
25        self.timeout = 10
26        #self.format = 'TLV'
27        self.format = "SENML_CBOR"
28        self._s = requests.Session()
29        try:
30            resp = self.get('/security/clients')
31            if not isinstance(resp, list):
32                raise RuntimeError('Did not receive list of endpoints')
33        except requests.exceptions.ConnectionError as exc:
34            raise RuntimeError('Leshan not responding') from exc
35
36    @staticmethod
37    def handle_response(resp: requests.models.Response):
38        """
39        Handle the response received from the server.
40
41        Parameters:
42        - response: The response object received from the server.
43
44        Returns:
45        - dict: The parsed JSON response as a dictionary.
46
47        Raises:
48        - Exception: If the response indicates an error condition.
49        """
50        if resp.status_code >= 300 or resp.status_code < 200:
51            raise RuntimeError(f'Error {resp.status_code}: {resp.text}')
52        if len(resp.text):
53            obj = json.loads(resp.text)
54            return obj
55        return None
56
57    def get(self, path: str):
58        """Send HTTP GET query with typical parameters"""
59        params = {'timeout': self.timeout}
60        if self.format is not None:
61            params['format'] = self.format
62        resp = self._s.get(f'{self.api_url}{path}', params=params, timeout=self.timeout)
63        return Leshan.handle_response(resp)
64
65    def put_raw(self, path: str, data: str | dict | None = None, headers: dict | None = None, params: dict | None = None):
66        """Send HTTP PUT query without any default parameters"""
67        resp = self._s.put(f'{self.api_url}{path}', data=data, headers=headers, params=params, timeout=self.timeout)
68        return Leshan.handle_response(resp)
69
70    def put(self, path: str, data: str | dict, uri_options: str = ''):
71        """Send HTTP PUT query with typical parameters"""
72        if isinstance(data, dict):
73            data = json.dumps(data)
74        return self.put_raw(f'{path}?timeout={self.timeout}&format={self.format}' + uri_options, data=data, headers={'content-type': 'application/json'})
75
76    def post(self, path: str, data: str | dict | None = None):
77        """Send HTTP POST query"""
78        if isinstance(data, dict):
79            data = json.dumps(data)
80        if data is not None:
81            headers={'content-type': 'application/json'}
82            uri_options = f'?timeout={self.timeout}&format={self.format}'
83        else:
84            headers=None
85            uri_options = ''
86        resp = self._s.post(f'{self.api_url}{path}' + uri_options, data=data, headers=headers, timeout=self.timeout)
87        return Leshan.handle_response(resp)
88
89    def delete_raw(self, path: str):
90        """Send HTTP DELETE query"""
91        resp = self._s.delete(f'{self.api_url}{path}', timeout=self.timeout)
92        return Leshan.handle_response(resp)
93
94    def delete(self, endpoint: str, path: str):
95        """Send LwM2M DELETE command"""
96        return self.delete_raw(f'/clients/{endpoint}/{path}')
97
98    def execute(self, endpoint: str, path: str):
99        """Send LwM2M EXECUTE command"""
100        return self.post(f'/clients/{endpoint}/{path}')
101
102    def write(self, endpoint: str, path: str, value: bool | int | str):
103        """Send LwM2M WRITE command to a single resource or resource instance"""
104        if len(path.split('/')) == 3:
105            kind = 'singleResource'
106        else:
107            kind = 'resourceInstance'
108        rid = path.split('/')[-1]
109        return self.put(f'/clients/{endpoint}/{path}', self._define_resource(rid, value, kind))
110
111    def write_attributes(self, endpoint: str, path: str, attributes: dict):
112        """Send LwM2M Write-Attributes to given path
113            example:
114                leshan.write_attributes(endpoint, '1/2/3, {'pmin': 10, 'pmax': 40})
115        """
116        return self.put_raw(f'/clients/{endpoint}/{path}/attributes', params=attributes)
117
118    def remove_attributes(self, endpoint: str, path: str, attributes: list):
119        """Send LwM2M Write-Attributes to given path
120            example:
121                leshan.remove_attributes(endpoint, '1/2/3, ['pmin', 'pmax'])
122        """
123        attrs = '&'.join(attributes)
124        return self.put_raw(f'/clients/{endpoint}/{path}/attributes?'+ attrs)
125
126    def update_obj_instance(self, endpoint: str, path: str, resources: dict):
127        """Update object instance"""
128        data = self._define_obj_inst(path, resources)
129        return self.put(f'/clients/{endpoint}/{path}', data, uri_options='&replace=false')
130
131    def replace_obj_instance(self, endpoint: str, path: str, resources: dict):
132        """Replace object instance"""
133        data = self._define_obj_inst(path, resources)
134        return self.put(f'/clients/{endpoint}/{path}', data, uri_options='&replace=true')
135
136    def create_obj_instance(self, endpoint: str, path: str, resources: dict):
137        """Send LwM2M CREATE command"""
138        data = self._define_obj_inst(path, resources)
139        path = '/'.join(path.split('/')[:-1]) # Create call should not have instance ID in path
140        return self.post(f'/clients/{endpoint}/{path}', data)
141
142    @classmethod
143    def _type_to_string(cls, value):
144        """
145        Convert a Python value to its corresponding Leshan representation.
146
147        Parameters:
148        - value: The value to be converted.
149
150        Returns:
151        - str: The string representation of the value.
152        """
153        if isinstance(value, bool):
154            return 'boolean'
155        if isinstance(value, int):
156            return 'integer'
157        if isinstance(value, datetime):
158            return 'time'
159        if isinstance(value, bytes):
160            return 'opaque'
161        return 'string'
162
163    @classmethod
164    def _convert_type(cls, value):
165        """Wrapper for special types that are not understood by Json"""
166        if isinstance(value, datetime):
167            return int(value.timestamp())
168        elif isinstance(value, bytes):
169            return binascii.b2a_hex(value).decode()
170        else:
171            return value
172
173    @classmethod
174    def _define_obj_inst(cls, path: str, resources: dict):
175        """Define an object instance for Leshan"""
176        data = {
177            "kind": "instance",
178            "id": int(path.split('/')[-1]),  # ID is last element of path
179            "resources": []
180        }
181        for key, value in resources.items():
182            if isinstance(value, dict):
183                kind = 'multiResource'
184            else:
185                kind = 'singleResource'
186            data['resources'].append(cls._define_resource(key, value, kind))
187        return data
188
189    @classmethod
190    def _define_resource(cls, rid, value, kind='singleResource'):
191        """Define a resource for Leshan"""
192        if kind in ('singleResource', 'resourceInstance'):
193            return {
194                "id": rid,
195                "kind": kind,
196                "value": cls._convert_type(value),
197                "type": cls._type_to_string(value)
198            }
199        if kind == 'multiResource':
200            return {
201                "id": rid,
202                "kind": kind,
203                "values": value,
204                "type": cls._type_to_string(list(value.values())[0])
205            }
206        raise RuntimeError(f'Unhandled type {kind}')
207
208    @classmethod
209    def _decode_value(cls, val_type: str, value: str):
210        """
211        Decode the Leshan representation of a value back to a Python value.
212        """
213        if val_type == 'BOOLEAN':
214            return bool(value)
215        if val_type == 'INTEGER':
216            return int(value)
217        return value
218
219    @classmethod
220    def _decode_resource(cls, content: dict):
221        """
222        Decode the Leshan representation of a resource back to a Python dictionary.
223        """
224        if content['kind'] == 'singleResource' or content['kind'] == 'resourceInstance':
225            return {content['id']: cls._decode_value(content['type'], content['value'])}
226        elif content['kind'] == 'multiResource':
227            values = {}
228            for riid, value in content['values'].items():
229                values.update({int(riid): cls._decode_value(content['type'], value)})
230            return {content['id']: values}
231        raise RuntimeError(f'Unhandled type {content["kind"]}')
232
233    @classmethod
234    def _decode_obj_inst(cls, content):
235        """
236        Decode the Leshan representation of an object instance back to a Python dictionary.
237        """
238        resources = {}
239        for resource in content['resources']:
240            resources.update(cls._decode_resource(resource))
241        return {content['id']: resources}
242
243    @classmethod
244    def _decode_obj(cls, content):
245        """
246        Decode the Leshan representation of an object back to a Python dictionary.
247        """
248        instances = {}
249        for instance in content['instances']:
250            instances.update(cls._decode_obj_inst(instance))
251        return {content['id']: instances}
252
253    def read(self, endpoint: str, path: str):
254        """Send LwM2M READ command and decode the response to a Python dictionary"""
255        resp = self.get(f'/clients/{endpoint}/{path}')
256        if not resp['success']:
257            return resp
258        content = resp['content']
259        if content['kind'] == 'obj':
260            return self._decode_obj(content)
261        elif content['kind'] == 'instance':
262            return self._decode_obj_inst(content)
263        elif content['kind'] == 'singleResource' or content['kind'] == 'resourceInstance':
264            return self._decode_value(content['type'], content['value'])
265        elif content['kind'] == 'multiResource':
266            return self._decode_resource(content)
267        raise RuntimeError(f'Unhandled type {content["kind"]}')
268
269    @classmethod
270    def parse_composite(cls, payload: dict):
271        """Decode the Leshan's response to composite query back to a Python dictionary"""
272        data = {}
273        if 'status' in payload:
274            if payload['status'] != 'CONTENT(205)' or 'content' not in payload:
275                raise RuntimeError(f'No content received')
276            payload = payload['content']
277        for path, content in payload.items():
278            if path == "/":
279                for obj in content['objects']:
280                    data.update(cls._decode_obj(obj))
281                continue
282            keys = [int(key) for key in path.lstrip("/").split('/')]
283            if len(keys) == 1:
284                data.update(cls._decode_obj(content))
285            elif len(keys) == 2:
286                if keys[0] not in data:
287                    data[keys[0]] = {}
288                data[keys[0]].update(cls._decode_obj_inst(content))
289            elif len(keys) == 3:
290                if keys[0] not in data:
291                    data[keys[0]] = {}
292                if keys[1] not in data[keys[0]]:
293                    data[keys[0]][keys[1]] = {}
294                data[keys[0]][keys[1]].update(cls._decode_resource(content))
295            elif len(keys) == 4:
296                if keys[0] not in data:
297                    data[keys[0]] = {}
298                if keys[1] not in data[keys[0]]:
299                    data[keys[0]][keys[1]] = {}
300                if keys[2] not in data[keys[0]][keys[1]]:
301                    data[keys[0]][keys[1]][keys[2]] = {}
302                data[keys[0]][keys[1]][keys[2]].update(cls._decode_resource(content))
303            else:
304                raise RuntimeError(f'Unhandled path {path}')
305        return data
306
307    def _composite_params(self, paths: list[str] | None = None):
308        """Common URI parameters for composite query"""
309        parameters = {
310            'pathformat': self.format,
311            'nodeformat': self.format,
312            'timeout': self.timeout
313        }
314        if paths is not None:
315            paths = [path if path.startswith('/') else '/' + path for path in paths]
316            parameters['paths'] = ','.join(paths)
317
318        return parameters
319
320    def composite_read(self, endpoint: str, paths: list[str]):
321        """Send LwM2M Composite-Read command and decode the response to a Python dictionary"""
322        parameters = self._composite_params(paths)
323        resp = self._s.get(f'{self.api_url}/clients/{endpoint}/composite', params=parameters, timeout=self.timeout)
324        payload = Leshan.handle_response(resp)
325        return self.parse_composite(payload)
326
327    def composite_write(self, endpoint: str, resources: dict):
328        """
329        Send LwM2m Composite-Write operation.
330
331        Targeted resources are defined as a dictionary with the following structure:
332        {
333            "/1/0/1": 60,
334            "/1/0/6": True,
335            "/16/0/0": {
336                "0": "aa",
337                "1": "bb",
338                "2": "cc",
339                "3": "dd"
340            }
341        }
342
343        Objects or object instances cannot be targeted.
344        """
345        data = { }
346        parameters = self._composite_params()
347        for path, value in resources.items():
348            path = path if path.startswith('/') else '/' + path
349            level = len(path.split('/')) - 1
350            rid = int(path.split('/')[-1])
351            if level == 3:
352                if isinstance(value, dict):
353                    value = self._define_resource(rid, value, kind='multiResource')
354                else:
355                    value = self._define_resource(rid, value)
356            elif level == 4:
357                value = self._define_resource(rid, value, kind='resourceInstance')
358            else:
359                raise RuntimeError(f'Unhandled path {path}')
360            data[path] = value
361
362        resp = self._s.put(f'{self.api_url}/clients/{endpoint}/composite', params=parameters, json=data, timeout=self.timeout)
363        return Leshan.handle_response(resp)
364
365    def discover(self, endpoint: str, path: str):
366        resp = self.handle_response(self._s.get(f'{self.api_url}/clients/{endpoint}/{path}/discover', timeout=self.timeout))
367        data = {}
368        for obj in resp['objectLinks']:
369            data[obj['url']] = obj['attributes']
370        return data
371
372    def create_psk_device(self, endpoint: str, passwd: str):
373        psk = binascii.b2a_hex(passwd.encode()).decode()
374        self.put('/security/clients/', f'{{"endpoint":"{endpoint}","tls":{{"mode":"psk","details":{{"identity":"{endpoint}","key":"{psk}"}} }} }}')
375
376    def delete_device(self, endpoint: str):
377        self.delete_raw(f'/security/clients/{endpoint}')
378
379    def create_bs_device(self, endpoint: str, server_uri: str, bs_passwd: str, passwd: str):
380        psk = binascii.b2a_hex(bs_passwd.encode()).decode()
381        data = f'{{"tls":{{"mode":"psk","details":{{"identity":"{endpoint}","key":"{psk}"}}}},"endpoint":"{endpoint}"}}'
382        self.put('/security/clients/', data)
383        ep = str([ord(n) for n in endpoint])
384        key = str([ord(n) for n in passwd])
385        content = '{"servers":{"0":{"binding":"U","defaultMinPeriod":1,"lifetime":86400,"notifIfDisabled":false,"shortId":1}},"security":{"1":{"bootstrapServer":false,"clientOldOffTime":1,"publicKeyOrId":' + ep + ',"secretKey":' + key + ',"securityMode":"PSK","serverId":1,"serverSmsNumber":"","smsBindingKeyParam":[],"smsBindingKeySecret":[],"smsSecurityMode":"NO_SEC","uri":"'+server_uri+'"}},"oscore":{},"toDelete":["/0","/1"]}'
386        self.post(f'/bootstrap/{endpoint}', content)
387
388    def delete_bs_device(self, endpoint: str):
389        self.delete_raw(f'/security/clients/{endpoint}')
390        self.delete_raw(f'/bootstrap/{endpoint}')
391
392    def observe(self, endpoint: str, path: str):
393        return self.post(f'/clients/{endpoint}/{path}/observe', data="")
394
395    def cancel_observe(self, endpoint: str, path: str):
396        return self.delete_raw(f'/clients/{endpoint}/{path}/observe?active')
397
398    def passive_cancel_observe(self, endpoint: str, path: str):
399        return self.delete_raw(f'/clients/{endpoint}/{path}/observe')
400
401    def composite_observe(self, endpoint: str, paths: list[str]):
402        parameters = self._composite_params(paths)
403        resp = self._s.post(f'{self.api_url}/clients/{endpoint}/composite/observe', params=parameters, timeout=self.timeout)
404        payload = Leshan.handle_response(resp)
405        return self.parse_composite(payload)
406
407    def cancel_composite_observe(self, endpoint: str, paths: list[str]):
408        paths = [path if path.startswith('/') else '/' + path for path in paths]
409        return self.delete_raw(f'/clients/{endpoint}/composite/observe?paths=' + ','.join(paths) + '&active')
410
411    def passive_cancel_composite_observe(self, endpoint: str, paths: list[str]):
412        paths = [path if path.startswith('/') else '/' + path for path in paths]
413        return self.delete_raw(f'/clients/{endpoint}/composite/observe?paths=' + ','.join(paths))
414
415    @contextmanager
416    def get_event_stream(self, endpoint: str, timeout: int = None):
417        """
418        Get stream of events regarding the given endpoint.
419
420        Events are notifications, updates and sends.
421
422        The event stream must be closed after the use, so this must be used in 'with' statement like this:
423            with leshan.get_event_stream('native_sim') as events:
424                data = events.next_event('SEND')
425
426        If timeout happens, the event streams returns None.
427        """
428        if timeout is None:
429            timeout = self.timeout
430        r = requests.get(f'{self.api_url}/event?{endpoint}', stream=True, headers={'Accept': 'text/event-stream'}, timeout=timeout)
431        if r.encoding is None:
432            r.encoding = 'utf-8'
433        try:
434            yield LeshanEventsIterator(r, timeout)
435        finally:
436            r.close()
437
438class LeshanEventsIterator:
439    """Iterator for Leshan event stream"""
440    def __init__(self, req: requests.Response, timeout: int):
441        """Initialize the iterator in line mode"""
442        self._it = req.iter_lines(chunk_size=1, decode_unicode=True)
443        self._timeout = timeout
444
445    def next_event(self, event: str):
446        """
447        Finds the next occurrence of a specific event in the stream.
448
449        If timeout happens, the returns None.
450        """
451        timeout = time.time() + self._timeout
452        try:
453            for line in self._it:
454                if line == f'event: {event}':
455                    for line in self._it:
456                        if not line.startswith('data: '):
457                            continue
458                        data = json.loads(line.removeprefix('data: '))
459                        if event == 'SEND' or (event == 'NOTIFICATION' and data['kind'] == 'composite'):
460                            return Leshan.parse_composite(data['val'])
461                        if event == 'NOTIFICATION':
462                            d = {data['res']: data['val']}
463                            return Leshan.parse_composite(d)
464                        return data
465                if time.time() > timeout:
466                    return None
467        except requests.exceptions.Timeout:
468            return None
469