1"""
2  Copyright (c) 2024, The OpenThread Authors.
3  All rights reserved.
4
5  Redistribution and use in source and binary forms, with or without
6  modification, are permitted provided that the following conditions are met:
7  1. Redistributions of source code must retain the above copyright
8     notice, this list of conditions and the following disclaimer.
9  2. Redistributions in binary form must reproduce the above copyright
10     notice, this list of conditions and the following disclaimer in the
11     documentation and/or other materials provided with the distribution.
12  3. Neither the name of the copyright holder nor the
13     names of its contributors may be used to endorse or promote products
14     derived from this software without specific prior written permission.
15
16  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  POSSIBILITY OF SUCH DAMAGE.
27"""
28
29from abc import abstractmethod
30from ble.ble_connection_constants import BBTC_SERVICE_UUID, BBTC_TX_CHAR_UUID, \
31    BBTC_RX_CHAR_UUID
32from ble.ble_stream import BleStream
33from ble.ble_stream_secure import BleStreamSecure
34from ble import ble_scanner
35from tlv.tlv import TLV
36from tlv.tcat_tlv import TcatTLVType
37from cli.command import Command, CommandResultNone, CommandResultTLV
38from dataset.dataset import ThreadDataset
39from utils import select_device_by_user_input
40from os import path
41from time import time
42from secrets import token_bytes
43from hashlib import sha256
44import hmac
45import binascii
46
47
48class HelpCommand(Command):
49
50    def get_help_string(self) -> str:
51        return 'Display help and return.'
52
53    async def execute_default(self, args, context):
54        commands = context['commands']
55        for name, command in commands.items():
56            print(f'{name}')
57            command.print_help(indent=1)
58        return CommandResultNone()
59
60
61class DataNotPrepared(Exception):
62    pass
63
64
65class BleCommand(Command):
66
67    @abstractmethod
68    def get_log_string(self) -> str:
69        pass
70
71    @abstractmethod
72    def prepare_data(self, args, context):
73        pass
74
75    async def execute_default(self, args, context):
76        if 'ble_sstream' not in context or context['ble_sstream'] is None:
77            print("TCAT Device not connected.")
78            return CommandResultNone()
79        bless: BleStreamSecure = context['ble_sstream']
80
81        print(self.get_log_string())
82        try:
83            data = self.prepare_data(args, context)
84            response = await bless.send_with_resp(data)
85            if not response:
86                return
87            tlv_response = TLV.from_bytes(response)
88            return CommandResultTLV(tlv_response)
89        except DataNotPrepared as err:
90            print('Command failed', err)
91        return CommandResultNone()
92
93
94class HelloCommand(BleCommand):
95
96    def get_log_string(self) -> str:
97        return 'Sending hello world...'
98
99    def get_help_string(self) -> str:
100        return 'Send round trip "Hello world!" message.'
101
102    def prepare_data(self, args, context):
103        return TLV(TcatTLVType.APPLICATION.value, bytes('Hello world!', 'ascii')).to_bytes()
104
105
106class CommissionCommand(BleCommand):
107
108    def get_log_string(self) -> str:
109        return 'Commissioning...'
110
111    def get_help_string(self) -> str:
112        return 'Update the connected device with current dataset.'
113
114    def prepare_data(self, args, context):
115        dataset: ThreadDataset = context['dataset']
116        dataset_bytes = dataset.to_bytes()
117        return TLV(TcatTLVType.ACTIVE_DATASET.value, dataset_bytes).to_bytes()
118
119
120class DecommissionCommand(BleCommand):
121
122    def get_log_string(self) -> str:
123        return 'Disabling Thread and decommissioning device...'
124
125    def get_help_string(self) -> str:
126        return 'Stop Thread interface and decommission device from current network.'
127
128    def prepare_data(self, args, context):
129        return TLV(TcatTLVType.DECOMMISSION.value, bytes()).to_bytes()
130
131
132class GetDeviceIdCommand(BleCommand):
133
134    def get_log_string(self) -> str:
135        return 'Retrieving device id.'
136
137    def get_help_string(self) -> str:
138        return 'Get unique identifier for the TCAT device.'
139
140    def prepare_data(self, args, context):
141        return TLV(TcatTLVType.GET_DEVICE_ID.value, bytes()).to_bytes()
142
143
144class GetExtPanIDCommand(BleCommand):
145
146    def get_log_string(self) -> str:
147        return 'Retrieving extended PAN ID.'
148
149    def get_help_string(self) -> str:
150        return 'Get extended PAN ID that is commissioned in the active dataset.'
151
152    def prepare_data(self, args, context):
153        return TLV(TcatTLVType.GET_EXT_PAN_ID.value, bytes()).to_bytes()
154
155
156class GetProvisioningUrlCommand(BleCommand):
157
158    def get_log_string(self) -> str:
159        return 'Retrieving provisioning url.'
160
161    def get_help_string(self) -> str:
162        return 'Get a URL for an application suited to commission the TCAT device.'
163
164    def prepare_data(self, args, context):
165        return TLV(TcatTLVType.GET_PROVISIONING_URL.value, bytes()).to_bytes()
166
167
168class GetNetworkNameCommand(BleCommand):
169
170    def get_log_string(self) -> str:
171        return 'Retrieving network name.'
172
173    def get_help_string(self) -> str:
174        return 'Get the Thread network name that is commissioned in the active dataset.'
175
176    def prepare_data(self, args, context):
177        return TLV(TcatTLVType.GET_NETWORK_NAME.value, bytes()).to_bytes()
178
179
180class PresentHash(BleCommand):
181
182    def get_log_string(self) -> str:
183        return 'Presenting hash.'
184
185    def get_help_string(self) -> str:
186        return 'Present calculated hash.'
187
188    def prepare_data(self, args, context):
189        type = args[0]
190        code = None
191        tlv_type = None
192        if type == "pskd":
193            code = bytes(args[1], 'utf-8')
194            tlv_type = TcatTLVType.PRESENT_PSKD_HASH.value
195        elif type == "pskc":
196            code = bytes.fromhex(args[1])
197            tlv_type = TcatTLVType.PRESENT_PSKC_HASH.value
198        elif type == "install":
199            code = bytes(args[1], 'utf-8')
200            tlv_type = TcatTLVType.PRESENT_INSTALL_CODE_HASH.value
201        else:
202            raise DataNotPrepared("Hash code name incorrect.")
203        bless: BleStreamSecure = context['ble_sstream']
204        if bless.peer_public_key is None:
205            raise DataNotPrepared("Peer certificate not present.")
206
207        if bless.peer_challenge is None:
208            raise DataNotPrepared("Peer challenge not present.")
209
210        hash = hmac.new(code, digestmod=sha256)
211        hash.update(bless.peer_challenge)
212        hash.update(bless.peer_public_key)
213
214        data = TLV(tlv_type, hash.digest()).to_bytes()
215        return data
216
217
218class GetPskdHash(Command):
219
220    def get_log_string(self) -> str:
221        return 'Retrieving peer PSKd hash.'
222
223    def get_help_string(self) -> str:
224        return 'Get calculated PSKd hash.'
225
226    async def execute_default(self, args, context):
227        bless: BleStreamSecure = context['ble_sstream']
228
229        print(self.get_log_string())
230        try:
231            if bless.peer_public_key is None:
232                print("Peer certificate not present.")
233                return
234            challenge_size = 8
235            challenge = token_bytes(challenge_size)
236            pskd = bytes(args[0], 'utf-8')
237            data = TLV(TcatTLVType.GET_PSKD_HASH.value, challenge).to_bytes()
238            response = await bless.send_with_resp(data)
239            if not response:
240                return
241            tlv_response = TLV.from_bytes(response)
242            if tlv_response.value != None:
243                hash = hmac.new(pskd, digestmod=sha256)
244                hash.update(challenge)
245                hash.update(bless.peer_public_key)
246                digest = hash.digest()
247                if digest == tlv_response.value:
248                    print('Requested hash is valid.')
249                else:
250                    print('Requested hash is NOT valid.')
251            return CommandResultTLV(tlv_response)
252        except DataNotPrepared as err:
253            print('Command failed', err)
254
255
256class GetRandomNumberChallenge(Command):
257
258    def get_log_string(self) -> str:
259        return 'Retrieving random challenge.'
260
261    def get_help_string(self) -> str:
262        return 'Get the device random number challenge.'
263
264    async def execute_default(self, args, context):
265        bless: BleStreamSecure = context['ble_sstream']
266
267        print(self.get_log_string())
268        try:
269            data = TLV(TcatTLVType.GET_RANDOM_NUMBER_CHALLENGE.value, bytes()).to_bytes()
270            response = await bless.send_with_resp(data)
271            if not response:
272                return
273            tlv_response = TLV.from_bytes(response)
274            if tlv_response.value != None:
275                if len(tlv_response.value) == 8:
276                    bless.peer_challenge = tlv_response.value
277                else:
278                    print('Challenge format invalid.')
279                    return CommandResultNone()
280            return CommandResultTLV(tlv_response)
281        except DataNotPrepared as err:
282            print('Command failed', err)
283
284
285class PingCommand(Command):
286
287    def get_help_string(self) -> str:
288        return 'Send echo request to TCAT device.'
289
290    async def execute_default(self, args, context):
291        bless: BleStreamSecure = context['ble_sstream']
292        payload_size = 10
293        max_payload = 512
294        if len(args) > 0:
295            payload_size = int(args[0])
296            if payload_size > max_payload:
297                print(f'Payload size too large. Maximum supported value is {max_payload}')
298                return
299        to_send = token_bytes(payload_size)
300        data = TLV(TcatTLVType.PING.value, to_send).to_bytes()
301        elapsed_time = time()
302        response = await bless.send_with_resp(data)
303        elapsed_time = 1e3 * (time() - elapsed_time)
304        if not response:
305            return CommandResultNone()
306
307        tlv_response = TLV.from_bytes(response)
308        if tlv_response.value != to_send:
309            print("Received malformed response.")
310
311        print(f"Roundtrip time: {elapsed_time} ms")
312
313        return CommandResultTLV(tlv_response)
314
315
316class ThreadStartCommand(BleCommand):
317
318    def get_log_string(self) -> str:
319        return 'Enabling Thread...'
320
321    def get_help_string(self) -> str:
322        return 'Enable thread interface.'
323
324    def prepare_data(self, args, context):
325        return TLV(TcatTLVType.THREAD_START.value, bytes()).to_bytes()
326
327
328class ThreadStopCommand(BleCommand):
329
330    def get_log_string(self) -> str:
331        return 'Disabling Thread...'
332
333    def get_help_string(self) -> str:
334        return 'Disable thread interface.'
335
336    def prepare_data(self, args, context):
337        return TLV(TcatTLVType.THREAD_STOP.value, bytes()).to_bytes()
338
339
340class ThreadStateCommand(Command):
341
342    def __init__(self):
343        self._subcommands = {'start': ThreadStartCommand(), 'stop': ThreadStopCommand()}
344
345    def get_help_string(self) -> str:
346        return 'Manipulate state of the Thread interface of the connected device.'
347
348    async def execute_default(self, args, context):
349        print('Invalid usage. Provide a subcommand.')
350        return CommandResultNone()
351
352
353class ScanCommand(Command):
354
355    def get_help_string(self) -> str:
356        return 'Perform scan for TCAT devices.'
357
358    async def execute_default(self, args, context):
359        if 'ble_sstream' in context and context['ble_sstream'] is not None:
360            context['ble_sstream'].close()
361            del context['ble_sstream']
362
363        tcat_devices = await ble_scanner.scan_tcat_devices()
364        device = select_device_by_user_input(tcat_devices)
365
366        if device is None:
367            return CommandResultNone()
368
369        ble_sstream = None
370
371        print(f'Connecting to {device}')
372        ble_stream = await BleStream.create(device.address, BBTC_SERVICE_UUID, BBTC_TX_CHAR_UUID, BBTC_RX_CHAR_UUID)
373        ble_sstream = BleStreamSecure(ble_stream)
374        cert_path = context['cmd_args'].cert_path if context['cmd_args'] else 'auth'
375        ble_sstream.load_cert(
376            certfile=path.join(cert_path, 'commissioner_cert.pem'),
377            keyfile=path.join(cert_path, 'commissioner_key.pem'),
378            cafile=path.join(cert_path, 'ca_cert.pem'),
379        )
380        print('Setting up secure channel...')
381        if await ble_sstream.do_handshake():
382            print('Done')
383            context['ble_sstream'] = ble_sstream
384        else:
385            print('Secure channel not established.')
386            await ble_stream.disconnect()
387        return CommandResultNone()
388
389
390class DisconnectCommand(Command):
391
392    def get_help_string(self) -> str:
393        return 'Disconnect client from TCAT device'
394
395    async def execute_default(self, args, context):
396        if 'ble_sstream' not in context or context['ble_sstream'] is None:
397            print("TCAT Device not connected.")
398            return CommandResultNone()
399        await context['ble_sstream'].close()
400        return CommandResultNone()
401