1#!/usr/bin/env python3
2# Copyright(c) 2022 Intel Corporation. All rights reserved.
3# SPDX-License-Identifier: Apache-2.0
4import os
5import sys
6import struct
7import logging
8import time
9import subprocess
10import argparse
11import socketserver
12import threading
13import hashlib
14import queue
15from urllib.parse import urlparse
16
17# Global variable use to sync between log and request services.
18runner = None
19
20# pylint: disable=duplicate-code
21
22# INADDR_ANY as default
23HOST = ''
24PORT_LOG = 9999
25PORT_REQ = PORT_LOG + 1
26BUF_SIZE = 4096
27
28# Define the command and the max size
29CMD_LOG_START = "start_log"
30CMD_DOWNLOAD = "download"
31MAX_CMD_SZ = 16
32
33# Define the return value in handle function
34ERR_FAIL = 1
35
36# Define the header format and size for
37# transmiting the firmware
38PACKET_HEADER_FORMAT_FW = 'I 42s 32s'
39HEADER_SZ = 78
40
41logging.basicConfig(level=logging.INFO)
42log = logging.getLogger("remote-fw")
43
44
45class adsp_request_handler(socketserver.BaseRequestHandler):
46    """
47    The request handler class for control the actions of server.
48    """
49
50    def receive_fw(self):
51        log.info("Receiving...")
52        # Receive the header first
53        d = self.request.recv(HEADER_SZ)
54
55        # Unpacked the header data
56        # Include size(4), filename(42) and MD5(32)
57        header = d[:HEADER_SZ]
58        total = d[HEADER_SZ:]
59        s = struct.Struct(PACKET_HEADER_FORMAT_FW)
60        fsize, fname, md5_tx_b = s.unpack(header)
61        log.info(f'size:{fsize}, filename:{fname}, MD5:{md5_tx_b}')
62
63        # Receive the firmware. We only receive the specified amount of bytes.
64        while len(total) < fsize:
65            data = self.request.recv(min(BUF_SIZE, fsize - len(total)))
66            if not data:
67                raise EOFError("truncated firmware file")
68            total += data
69
70        log.info(f"Done Receiving {len(total)}.")
71
72        try:
73            with open(fname,'wb') as f:
74                f.write(total)
75        except Exception as e:
76            log.error(f"Get exception {e} during FW transfer.")
77            return None
78
79        # Check the MD5 of the firmware
80        md5_rx = hashlib.md5(total).hexdigest()
81        md5_tx = md5_tx_b.decode('utf-8')
82
83        if md5_tx != md5_rx:
84            log.error(f'MD5 mismatch: {md5_tx} vs. {md5_rx}')
85            return None
86
87        return fname
88
89    def do_download(self):
90        recv_file = self.receive_fw()
91
92        if recv_file:
93            recv_file = recv_file.decode('utf-8')
94
95            if os.path.exists(recv_file):
96                runner.set_fw_ready(recv_file)
97                return 0
98
99        log.error("Cannot find the FW file.")
100        return ERR_FAIL
101
102    def handle(self):
103        cmd = self.request.recv(MAX_CMD_SZ)
104        log.info(f"{self.client_address[0]} wrote: {cmd}")
105        action = cmd.decode("utf-8")
106        log.debug(f'load {action}')
107        ret = ERR_FAIL
108
109        if action == CMD_DOWNLOAD:
110            self.request.sendall(cmd)
111            ret = self.do_download()
112        else:
113            log.error("incorrect load communitcation!")
114            return
115
116        if not ret:
117            self.request.sendall("success".encode('utf-8'))
118            log.info("Firmware well received. Ready to download.")
119        else:
120            self.request.sendall("failed".encode('utf-8'))
121            log.error("Receive firmware failed.")
122
123class adsp_log_handler(socketserver.BaseRequestHandler):
124    """
125    The log handler class for grabbing output messages of server.
126    """
127
128    def handle(self):
129        cmd = self.request.recv(MAX_CMD_SZ)
130        log.info(f"{self.client_address[0]} wrote: {cmd}")
131        action = cmd.decode("utf-8")
132        log.debug(f'monitor {action}')
133
134        if action == CMD_LOG_START:
135            self.request.sendall(cmd)
136        else:
137            log.error("incorrect monitor communitcation!")
138
139        log.info("wait for FW ready...")
140        while not runner.is_fw_ready():
141            if not self.is_connection_alive():
142                return
143
144            time.sleep(1)
145
146        log.info("FW is ready...")
147
148        # start_new_session=True in order to get a different Process Group
149        # ID. When the PGID is the same, sudo does NOT propagate signals out of
150        # fear of "accidentally killing itself" (man sudo).
151        # Compare:
152        #
153        # - Different PGID: signal is propagated and sleep is terminated
154        #
155        #    sudo sleep 15 & kill $!
156        #
157        # - Same PGID, sleep is NOT terminated
158        #
159        #    sudo bash -c 'sleep 15 & killall sudo'
160        #
161        #    ps  xfao pid,ppid,pgid,sid,comm | grep -C 5 -e PID -e sleep -e sudo
162
163        with subprocess.Popen(runner.get_script(), stdout=subprocess.PIPE,
164                              start_new_session=True) as proc:
165            # Thread for monitoring the conntection
166            t = threading.Thread(target=self.check_connection, args=(proc,))
167            t.start()
168
169            while True:
170                try:
171                    out = proc.stdout.readline()
172                    self.request.sendall(out)
173                    ret = proc.poll()
174                    if ret:
175                        log.info(f"retrun code: {ret}")
176                        break
177
178                except (BrokenPipeError, ConnectionResetError):
179                    log.info("Client is disconnect.")
180                    break
181
182            t.join()
183
184        log.info("service complete.")
185
186    def finish(self):
187        runner.cleanup()
188        log.info("Wait for next service...")
189
190    def is_connection_alive(self):
191        try:
192            self.request.sendall(b'\x00')
193        except (BrokenPipeError, ConnectionResetError):
194            log.info("Client is disconnect.")
195            return False
196
197        return True
198
199    def check_connection(self, proc):
200        # Not to check connection alive for
201        # the first 10 secs.
202        time.sleep(10)
203
204        poll_interval = 1
205        log.info("Now checking client connection every %ds", poll_interval)
206        while True:
207            if not self.is_connection_alive():
208                # cavstool
209                child_desc = " ".join(runner.script) + f", PID={proc.pid}"
210                log.info("Terminating %s", child_desc)
211
212                try:
213                    # sudo does _not_ propagate SIGKILL (man sudo)
214                    proc.terminate()
215                    try:
216                        proc.wait(timeout=0.5)
217                    except subprocess.TimeoutExpired:
218                        log.error("SIGTERM failed on child %s", child_desc)
219                        if os.geteuid() == 0: # sudo not needed and not used
220                            log.error("Sending %d SIGKILL", proc.pid)
221                            proc.kill()
222                        else:
223                            log.error("Try: sudo pkill -9 -f %s", runner.load_cmd)
224
225                except PermissionError:
226                    log.info("cannot kill proc due to it start with sudo...")
227                    os.system(f"sudo kill -9 {proc.pid} ")
228                return
229
230            time.sleep(poll_interval)
231
232class device_runner():
233    def __init__(self, args):
234        self.fw_file = None
235        self.lock = threading.Lock()
236        self.fw_queue = queue.Queue()
237
238        # Board specific config
239        self.board = board_config(args)
240        self.load_cmd = self.board.get_cmd()
241
242    def set_fw_ready(self, fw_recv):
243        if fw_recv:
244            self.fw_queue.put(fw_recv)
245
246    def is_fw_ready(self):
247        self.fw_file = self.fw_queue.get()
248        log.info(f"Current FW is {self.fw_file}")
249
250        return bool(self.fw_file)
251
252    def cleanup(self):
253        self.lock.acquire()
254        self.script = None
255        if self.fw_file:
256            os.remove(self.fw_file)
257        self.fw_file = None
258        self.lock.release()
259
260    def get_script(self):
261        if os.geteuid() != 0:
262            self.script = [f'sudo', f'{self.load_cmd}']
263        else:
264            self.script = [f'{self.load_cmd}']
265
266        self.script.append(f'{self.fw_file}')
267
268        if self.board.params:
269            for param in self.board.params:
270                self.script.append(param)
271
272        log.info(f'run script: {self.script}')
273        return self.script
274
275class board_config():
276    def __init__(self, args):
277
278        self.load_cmd = args.load_cmd    # cmd for loading
279        self.params = []            # params of loading cmd
280
281        if not self.load_cmd:
282            self.load_cmd = "./cavstool.py"
283
284        if not self.load_cmd or not os.path.exists(self.load_cmd):
285            log.error(f'Cannot find load cmd {self.load_cmd}.')
286            sys.exit(1)
287
288    def get_cmd(self):
289        return self.load_cmd
290
291    def get_params(self):
292        return self.params
293
294
295ap = argparse.ArgumentParser(description="RemoteHW service tool", allow_abbrev=False)
296ap.add_argument("-q", "--quiet", action="store_true",
297                help="No loader output, just DSP logging")
298ap.add_argument("-v", "--verbose", action="store_true",
299                help="More loader output, DEBUG logging level")
300ap.add_argument("-s", "--server-addr",
301                help="Specify the only IP address the log server will LISTEN on")
302ap.add_argument("-p", "--log-port",
303                help="Specify the PORT that the log server to active")
304ap.add_argument("-r", "--req-port",
305                help="Specify the PORT that the request server to active")
306ap.add_argument("-c", "--load-cmd",
307                help="Specify loading command of the board")
308
309args = ap.parse_args()
310
311if args.quiet:
312    log.setLevel(logging.WARN)
313elif args.verbose:
314    log.setLevel(logging.DEBUG)
315
316if args.server_addr:
317    url = urlparse("//" + args.server_addr)
318
319    if url.hostname:
320        HOST = url.hostname
321
322    if url.port:
323        PORT_LOG = int(url.port)
324
325if args.log_port:
326    PORT_LOG = int(args.log_port)
327
328if args.req_port:
329    PORT_REQ = int(args.req_port)
330
331log.info(f"Serve on LOG PORT: {PORT_LOG} REQ PORT: {PORT_REQ}")
332
333
334if __name__ == "__main__":
335
336    # Do board configuration setup
337    runner = device_runner(args)
338
339    # Launch the command request service
340    socketserver.TCPServer.allow_reuse_address = True
341    req_server = socketserver.TCPServer((HOST, PORT_REQ), adsp_request_handler)
342    req_t = threading.Thread(target=req_server.serve_forever, daemon=True)
343
344    # Activate the log service which output board's execution result
345    log_server = socketserver.TCPServer((HOST, PORT_LOG), adsp_log_handler)
346    log_t = threading.Thread(target=log_server.serve_forever, daemon=True)
347
348    try:
349        log.info("Req server start...")
350        req_t.start()
351        log.info("Log server start...")
352        log_t.start()
353        req_t.join()
354        log_t.join()
355    except KeyboardInterrupt:
356        log_server.shutdown()
357        req_server.shutdown()
358