1#!/usr/bin/env python
2#
3# Copyright (c) 2016, The OpenThread Authors.
4# All rights reserved.
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are met:
8# 1. Redistributions of source code must retain the above copyright
9#    notice, this list of conditions and the following disclaimer.
10# 2. Redistributions in binary form must reproduce the above copyright
11#    notice, this list of conditions and the following disclaimer in the
12#    documentation and/or other materials provided with the distribution.
13# 3. Neither the name of the copyright holder nor the
14#    names of its contributors may be used to endorse or promote products
15#    derived from this software without specific prior written permission.
16#
17# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
21# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27# POSSIBILITY OF SUCH DAMAGE.
28#
29
30import logging
31import re
32import socket
33import threading
34import time
35
36import serial
37
38from . import settings
39
40__all__ = ['OpenThreadController']
41logger = logging.getLogger(__name__)
42
43linesepx = re.compile(r'\r\n|\n')
44
45
46class OpenThreadController(threading.Thread):
47    """This is an simple wrapper to communicate with openthread"""
48
49    _lock = threading.Lock()
50    viewing = False
51
52    def __init__(self, port, log=False):
53        """Initialize the controller
54
55        Args:
56            port (str): serial port's path or name(windows)
57        """
58        super(OpenThreadController, self).__init__()
59        self.port = port
60        self.handle = None
61        self.lines = []
62        self._log = log
63        self._is_net = False
64        self._init()
65
66    def _init(self):
67        self._connect()
68        if not self._log:
69            return
70
71        self.start()
72
73    def __del__(self):
74        self.close()
75
76    def close(self):
77        if self.is_alive():
78            self.viewing = False
79            self.join()
80
81        self._close()
82
83    def __enter__(self):
84        return self
85
86    def __exit__(self, type, value, traceback):
87        self.close()
88
89    def _close(self):
90        if self.handle:
91            self.handle.close()
92            self.handle = None
93
94    def _connect(self):
95        logger.debug('My port is %s', self.port)
96        if self.port.startswith('NET'):
97            portnum = settings.SER2NET_PORTBASE + int(self.port.split('NET')[1])
98            logger.debug('My port num is %d', portnum)
99            address = (settings.SER2NET_HOSTNAME, portnum)
100            self.handle = socket.create_connection(address)
101            self.handle.setblocking(0)
102            self._is_net = True
103        elif ':' in self.port:
104            host, port = self.port.split(':')
105            self.handle = socket.create_connection((host, port))
106            self.handle.setblocking(0)
107            self._is_net = True
108        else:
109            self.handle = serial.Serial(self.port, 115200, timeout=0, xonxoff=True)
110            self._is_net = False
111
112    def _read(self, size=512):
113        if self._is_net:
114            return self.handle.recv(size)
115        else:
116            return self.handle.read(size)
117
118    def _write(self, data):
119        if self._is_net:
120            self.handle.sendall(data)
121        else:
122            self.handle.write(data)
123
124    def _expect(self, expected, times=50):
125        """Find the `expected` line within `times` trials.
126
127        Args:
128            expected    str: the expected string
129            times       int: number of trials
130        """
131        logger.debug('[%s] Expecting [%s]', self.port, expected)
132        retry_times = 10
133        while times:
134            if not retry_times:
135                break
136
137            line = self._readline()
138
139            if line == expected:
140                return
141
142            if not line:
143                retry_times -= 1
144                time.sleep(0.1)
145
146            times -= 1
147
148        raise Exception('failed to find expected string[%s]' % expected)
149
150    def _readline(self):
151        """Read exactly one line from the device, nonblocking.
152
153        Returns:
154            None on no data
155        """
156        if len(self.lines) > 1:
157            return self.lines.pop(0)
158
159        tail = ''
160        if len(self.lines):
161            tail = self.lines.pop()
162
163        try:
164            tail += self._read()
165        except socket.error:
166            logging.exception('No new data')
167            time.sleep(0.1)
168
169        self.lines += linesepx.split(tail)
170        if len(self.lines) > 1:
171            return self.lines.pop(0)
172
173    def _sendline(self, line):
174        """Send exactly one line to the device
175
176        Args:
177            line str: data send to device
178        """
179        self.lines = []
180        try:
181            self._read()
182        except socket.error:
183            logging.debug('Nothing cleared')
184
185        logger.debug('sending [%s]', line)
186        self._write(line + '\r\n')
187
188        # wait for write to complete
189        time.sleep(0.5)
190
191    def _req(self, req):
192        """Send command and wait for response.
193
194        The command will be repeated 3 times at most in case data loss of serial port.
195
196        Args:
197            req (str): Command to send, please do not include new line in the end.
198
199        Returns:
200            [str]: The output lines
201        """
202        logger.debug('DUT> %s', req)
203        self._log and self.pause()
204        times = 3
205        res = None
206
207        while times:
208            times = times - 1
209            try:
210                self._sendline(req)
211                self._expect(req)
212
213                line = None
214                res = []
215
216                while True:
217                    line = self._readline()
218                    logger.debug('Got line %s', line)
219
220                    if line == 'Done':
221                        break
222
223                    if line:
224                        res.append(line)
225                break
226
227            except BaseException:
228                logger.exception('Failed to send command')
229                self.close()
230                self._init()
231
232        self._log and self.resume()
233        return res
234
235    def run(self):
236        """Threading callback"""
237
238        self.viewing = True
239        while self.viewing and self._lock.acquire():
240            try:
241                line = self._readline()
242            except BaseException:
243                pass
244            else:
245                logger.info(line)
246            self._lock.release()
247            time.sleep(0)
248
249    def is_started(self):
250        """check if openthread is started
251
252        Returns:
253            bool: started or not
254        """
255        state = self._req('state')[0]
256        return state != 'disabled'
257
258    def start(self):
259        """Start openthread
260        """
261        self._req('ifconfig up')
262        self._req('thread start')
263
264    def stop(self):
265        """Stop openthread
266        """
267        self._req('thread stop')
268        self._req('ifconfig down')
269
270    def reset(self):
271        """Reset openthread device, not equivalent to stop and start
272        """
273        logger.debug('DUT> reset')
274        self._log and self.pause()
275        self._sendline('reset')
276        self._read()
277        self._log and self.resume()
278
279    def resume(self):
280        """Start dumping logs"""
281        self._lock.release()
282
283    def pause(self):
284        """Start dumping logs"""
285        self._lock.acquire()
286
287    @property
288    def networkname(self):
289        """str: Thread network name."""
290        return self._req('networkname')[0]
291
292    @networkname.setter
293    def networkname(self, value):
294        self._req('networkname %s' % value)
295
296    @property
297    def mode(self):
298        """str: Thread mode."""
299        return self._req('mode')[0]
300
301    @mode.setter
302    def mode(self, value):
303        self._req('mode %s' % value)
304
305    @property
306    def mac(self):
307        """str: MAC address of the device"""
308        return self._req('extaddr')[0]
309
310    @property
311    def addrs(self):
312        """[str]: IP addresses of the devices"""
313        return self._req('ipaddr')
314
315    @property
316    def short_addr(self):
317        """str: Short address"""
318        return self._req('rloc16')[0]
319
320    @property
321    def channel(self):
322        """int: Channel number of openthread"""
323        return int(self._req('channel')[0])
324
325    @channel.setter
326    def channel(self, value):
327        self._req('channel %d' % value)
328
329    @property
330    def panid(self):
331        """str: Thread panid"""
332        return self._req('panid')[0]
333
334    @panid.setter
335    def panid(self, value):
336        self._req('panid %s' % value)
337
338    @property
339    def extpanid(self):
340        """str: Thread extpanid"""
341        return self._req('extpanid')[0]
342
343    @extpanid.setter
344    def extpanid(self, value):
345        self._req('extpanid %s' % value)
346
347    @property
348    def child_timeout(self):
349        """str: Thread child timeout in seconds"""
350        return self._req('childtimeout')[0]
351
352    @child_timeout.setter
353    def child_timeout(self, value):
354        self._req('childtimeout %d' % value)
355
356    @property
357    def version(self):
358        """str: Open thread version"""
359        return self._req('version')[0]
360
361    def add_prefix(self, prefix, flags, prf):
362        """Add network prefix.
363
364        Args:
365            prefix (str): network prefix.
366            flags (str): network prefix flags, please refer thread documentation for details
367            prf (str): network prf, please refer thread documentation for details
368        """
369        self._req('prefix add %s %s %s' % (prefix, flags, prf))
370        time.sleep(1)
371        self._req('netdata register')
372
373    def remove_prefix(self, prefix):
374        """Remove network prefix.
375        """
376        self._req('prefix remove %s' % prefix)
377        time.sleep(1)
378        self._req('netdata register')
379
380    def enable_denylist(self):
381        """Enable denylist feature"""
382        self._req('denylist enable')
383
384    def add_denylist(self, mac):
385        """Add a mac address to denylist"""
386        self._req('denylist add %s' % mac)
387