1#!/usr/bin/env python
2# Copyright 2023 The ChromiumOS Authors
3# SPDX-License-Identifier: Apache-2.0
4import ctypes
5import mmap
6import os
7import re
8import struct
9import sys
10import time
11from glob import glob
12
13# MT8195 audio firmware load/debug gadget
14
15# Note that the hardware handling here is only partial: in practice
16# the audio DSP depends on clock and power well devices whose drivers
17# live elsewhere in the kernel.  Those aren't duplicated here.  Make
18# sure the DSP has been started by a working kernel driver first.
19#
20# See gen_img.py for docs on the image format itself.  The way this
21# script works is to map the device memory regions and registers via
22# /dev/mem and copy the two segments while resetting the DSP.
23#
24# In the kernel driver, the address/size values come from devicetree.
25# But currently the MediaTek architecture is one kernel driver per SOC
26# (i.e. the devicetree values in the kenrel source are tied to the
27# specific SOC anyway), so it really doesn't matter and we hard-code
28# the addresses for simplicity.
29#
30# (For future reference: in /proc/device-tree on current ChromeOS
31# kernels, the host registers are a "cfg" platform resource on the
32# "adsp@10803000" node.  The sram is likewise the "sram" resource on
33# that device node, and the two dram areas are "memory-region"
34# phandles pointing to "adsp_mem_region" and "adsp_dma_mem_region"
35# nodes under "/reserved-memory").
36
37FILE_MAGIC = 0xE463BE95
38
39# Runtime mmap objects for each MAPPINGS entry
40maps = {}
41
42
43# Returns a string (e.g. "mt8195", "mt8186", "mt8188") if a supported
44# adsp is detected, or None if not
45def detect():
46    compat = readfile(glob("/proc/device-tree/**/adsp@*/compatible", recursive=True)[0], "r")
47    m = re.match(r'.*(mt\d{4})-dsp', compat)
48    if m:
49        return m.group(1)
50
51
52# Parse devicetree to find the MMIO mappings: there is an "adsp" node
53# (in various locations) with an array of named "reg" mappings.  It
54# also refers by reference to reserved-memory regions of system
55# DRAM.  Those don't have names, call them dram0/1 (dram1 is the main
56# region to which code is linked, dram0 is presumably a dma pool but
57# unused by current firmware).  Returns a dict mapping name to a
58# (addr, size) tuple.
59def mappings():
60    path = glob("/proc/device-tree/**/adsp@*/", recursive=True)[0]
61    rnames = readfile(path + "reg-names", "r").split('\0')[:-1]
62    regs = struct.unpack(f">{2 * len(rnames)}Q", readfile(path + "reg"))
63    maps = {n: (regs[2 * i], regs[2 * i + 1]) for i, n in enumerate(rnames)}
64    for i, ph in enumerate(struct.unpack(">II", readfile(path + "memory-region"))):
65        for rmem in glob("/proc/device-tree/reserved-memory/*/"):
66            phf = rmem + "phandle"
67            if os.path.exists(phf) and struct.unpack(">I", readfile(phf))[0] == ph:
68                (addr, sz) = struct.unpack(">QQ", readfile(rmem + "reg"))
69                maps[f"dram{i}"] = (addr, sz)
70                break
71    return maps
72
73
74# Register API for 8195
75class MT8195:
76    def __init__(self, maps):
77        # Create a Regs object for the registers
78        r = Regs(ctypes.addressof(ctypes.c_int.from_buffer(maps["cfg"])))
79        r.ALTRESETVEC = 0x0004  # Xtensa boot address
80        r.RESET_SW = 0x0024  # Xtensa halt/reset/boot control
81        r.PDEBUGBUS0 = 0x000C  # Unclear, enabled by host, unused by SOF?
82        r.SRAM_POOL_CON = 0x0930  # SRAM power control: low 4 bits (banks?) enable
83        r.EMI_MAP_ADDR = 0x981C  # == host SRAM mapping - 0x40000000 (controls MMIO map?)
84        r.freeze()
85        self.cfg = r
86
87    def logrange(self):
88        return range(0x700000, 0x800000)
89
90    def stop(self):
91        self.cfg.RESET_SW |= 8  # Set RUNSTALL: halt CPU
92        self.cfg.RESET_SW |= 3  # Set low two bits: "BRESET|DRESET"
93
94    def start(self, boot_vector):
95        self.stop()
96        self.cfg.RESET_SW |= 0x10  # Enable "alternate reset" boot vector
97        self.cfg.ALTRESETVEC = boot_vector
98        self.cfg.RESET_SW &= ~3  # Release reset bits
99        self.cfg.RESET_SW &= ~8  # Clear RUNSTALL: go!
100
101
102# Register API for 8186/8188
103class MT818x:
104    def __init__(self, maps):
105        # These have registers spread across two blocks
106        cfg_base = ctypes.addressof(ctypes.c_int.from_buffer(maps["cfg"]))
107        sec_base = ctypes.addressof(ctypes.c_int.from_buffer(maps["sec"]))
108        self.cfg = Regs(cfg_base)
109        self.cfg.SW_RSTN = 0x00
110        self.cfg.IO_CONFIG = 0x0C
111        self.cfg.freeze()
112        self.sec = Regs(sec_base)
113        self.sec.ALTVEC_C0 = 0x04
114        self.sec.ALTVECSEL = 0x0C
115        self.sec.freeze()
116
117    def logrange(self):
118        return range(0x700000, 0x800000)
119
120    def stop(self):
121        self.cfg.IO_CONFIG |= 1 << 31  # Set RUNSTALL to stop core
122        time.sleep(0.1)
123        self.cfg.SW_RSTN |= 0x11  # Assert reset: SW_RSTN_C0|SW_DBG_RSTN_C0
124
125    # Note: 8186 and 8188 use different bits in ALTVECSEC, but
126    # it's safe to write both to enable the alternate boot vector
127    def start(self, boot_vector):
128        self.cfg.IO_CONFIG |= 1 << 31  # Set RUNSTALL
129        self.sec.ALTVEC_C0 = boot_vector
130        self.sec.ALTVECSEL = 0x03  # Enable alternate vector
131        self.cfg.SW_RSTN |= 0x00000011  # Assert reset
132        self.cfg.SW_RSTN &= 0xFFFFFFEE  # Release reset
133        self.cfg.IO_CONFIG &= 0x7FFFFFFF  # Clear RUNSTALL
134
135
136class MT8196:
137    def __init__(self, maps):
138        cfg_base = ctypes.addressof(ctypes.c_int.from_buffer(maps["cfg"]))
139        sec_base = ctypes.addressof(ctypes.c_int.from_buffer(maps["sec"]))
140        self.cfg = Regs(cfg_base)
141        self.cfg.CFGREG_SW_RSTN = 0x0000
142        self.cfg.MBOX_IRQ_EN = 0x009C
143        self.cfg.HIFI_RUNSTALL = 0x0108
144        self.cfg.freeze()
145        self.sec = Regs(sec_base)
146        self.sec.ALTVEC_C0 = 0x04
147        self.sec.ALTVECSEL = 0x0C
148        self.sec.freeze()
149
150    def logrange(self):
151        return range(0x580000, 0x600000)
152
153    def stop(self):
154        self.cfg.HIFI_RUNSTALL |= 0x1000
155        self.cfg.CFGREG_SW_RSTN |= 0x11
156
157    def start(self, boot_vector):
158        self.sec.ALTVEC_C0 = 0
159        self.sec.ALTVECSEL = 0
160        self.sec.ALTVEC_C0 = boot_vector
161        self.sec.ALTVECSEL = 1
162        self.cfg.HIFI_RUNSTALL |= 0x1000
163        self.cfg.MBOX_IRQ_EN |= 3
164        self.cfg.CFGREG_SW_RSTN |= 0x11
165        time.sleep(0.1)
166        self.cfg.CFGREG_SW_RSTN &= ~0x11
167        self.cfg.HIFI_RUNSTALL &= ~0x1000
168
169
170# Temporary logging protocol: watch the 1M null-terminated log
171# stream at 0x60700000 -- the top of the linkable region of
172# existing SOF firmware, before the heap.  Nothing uses this
173# currently.  Will be replaced by winstream very soon.
174def old_log(dev):
175    msg = b''
176    dram = maps["dram1"]
177    for i in dev.logrange():
178        x = dram[i]
179        if x == 0:
180            sys.stdout.buffer.write(msg)
181            sys.stdout.buffer.flush()
182            msg = b''
183            while x == 0:
184                time.sleep(0.1)
185                x = dram[i]
186        msg += x.to_bytes(1, "little")
187    sys.stdout.buffer.write(msg)
188    sys.stdout.buffer.flush()
189
190
191# (Cribbed from cavstool.py)
192class Regs:
193    def __init__(self, base_addr):
194        vars(self)["base_addr"] = base_addr
195        vars(self)["ptrs"] = {}
196        vars(self)["frozen"] = False
197
198    def freeze(self):
199        vars(self)["frozen"] = True
200
201    def __setattr__(self, name, val):
202        if not self.frozen and name not in self.ptrs:
203            addr = self.base_addr + val
204            self.ptrs[name] = ctypes.c_uint32.from_address(addr)
205        else:
206            self.ptrs[name].value = val
207
208    def __getattr__(self, name):
209        return self.ptrs[name].value
210
211
212def readfile(f, mode="rb"):
213    return open(f, mode).read()
214
215
216def le4(bstr):
217    assert len(bstr) == 4
218    return struct.unpack("<I", bstr)[0]
219
220
221# Wrapper class for winstream logging.  Instantiate with a single
222# integer argument representing a local/in-process address for the
223# shared winstream memory.  The memory mapped access is encapsulated
224# with a Regs object for the fields and a ctypes array for the data
225# area.  The lockless algorithm in read() matches the C version in
226# upstream Zephyr, don't modify in isolation.  Note that on some
227# platforms word access to the data array (by e.g. copying a slice
228# into a bytes object or by calling memmove) produces bus errors
229# (plausibly an alignment requirement on the fabric with the DSP
230# memory, where arm64 python is happy doing unaligned loads?).  Access
231# to the data bytes is done bytewise for safety.
232class Winstream:
233    def __init__(self, addr):
234        r = Regs(addr)
235        r.WLEN = 0x00
236        r.START = 0x04
237        r.END = 0x08
238        r.SEQ = 0x0C
239        r.freeze()
240        # Sanity-check, the 32M size limit isn't a rule, but seems reasonable
241        if r.WLEN > 0x2000000 or (r.START >= r.WLEN) or (r.END >= r.WLEN):
242            raise RuntimeError("Invalid winstream")
243        self.regs = r
244        self.data = (ctypes.c_char * r.WLEN).from_address(addr + 16)
245        self.msg = bytearray(r.WLEN)
246        self.seq = 0
247
248    def read(self):
249        ws, msg, data = self.regs, self.msg, self.data
250        last_seq = self.seq
251        wlen = ws.WLEN
252        while True:
253            start, end, seq = ws.START, ws.END, ws.SEQ
254            self.seq = seq
255            if seq == last_seq or start == end:
256                return ""
257            behind = seq - last_seq
258            if behind > ((end - start) % wlen):
259                return ""
260            copy = (end - behind) % wlen
261            suffix = min(behind, wlen - copy)
262            for i in range(suffix):
263                msg[i] = data[copy + i][0]
264            msglen = suffix
265            l2 = behind - suffix
266            if l2 > 0:
267                for i in range(l2):
268                    msg[msglen + i] = data[i][0]
269                msglen += l2
270            if start == ws.START and seq == ws.SEQ:
271                return msg[0:msglen].decode("utf-8", "replace")
272
273
274# Locates a winstream descriptor in the firmware via its 96-bit magic
275# number and returns the address and size fields it finds there.
276def find_winstream(maps):
277    magic = b'\x74\x5f\x6a\xd0\x79\xe2\x4f\x00\xcd\xb8\xbd\xf9'
278    for m in maps:
279        if "dram" in m:
280            # Some python versions produce bus errors (!) on the
281            # hardware when finding a 12 byte substring (maybe a SIMD
282            # load that the hardware doesn't like?).  Do it in two
283            # chunks.
284            magoff = maps[m].find(magic[0:8])
285            if magoff >= 0:
286                magoff = maps[m].find(magic[8:], magoff) - 8
287            if magoff >= 0:
288                addr = le4(maps[m][magoff + 12 : magoff + 16])
289                return addr
290    raise RuntimeError("Cannot find winstream descriptor in firmware runtime")
291
292
293def winstream_localaddr(globaddr, mmio, maps):
294    for m in mmio:
295        off = globaddr - mmio[m][0]
296        if 0 <= off < mmio[m][1]:
297            return ctypes.addressof(ctypes.c_int.from_buffer(maps[m])) + off
298    raise RuntimeError("Winstream address not inside DSP memory")
299
300
301def winstream_log(mmio, maps):
302    physaddr = find_winstream(maps)
303    regsbase = winstream_localaddr(physaddr, mmio, maps)
304    ws = Winstream(regsbase)
305    while True:
306        msg = ws.read()
307        if msg:
308            sys.stdout.write(msg)
309            sys.stdout.flush()
310        else:
311            time.sleep(0.1)
312
313
314def main():
315    dsp = detect()
316    assert dsp
317
318    # Probe devicetree for mappable memory regions
319    mmio = mappings()
320
321    # Open device and establish mappings
322    with open("/dev/mem", "wb+") as devmem_fd:
323        for mp in mmio:
324            paddr = mmio[mp][0]
325            mapsz = mmio[mp][1]
326            mapsz = int((mapsz + 4095) / 4096) * 4096
327            maps[mp] = mmap.mmap(
328                devmem_fd.fileno(),
329                mapsz,
330                offset=paddr,
331                flags=mmap.MAP_SHARED,
332                prot=mmap.PROT_WRITE | mmap.PROT_READ,
333            )
334
335    if dsp == "mt8195":
336        dev = MT8195(maps)
337    elif dsp in ("mt8186", "mt8188"):
338        dev = MT818x(maps)
339    elif dsp == "mt8196":
340        dev = MT8196(maps)
341
342    if sys.argv[1] == "load":
343        dat = None
344        with open(sys.argv[2], "rb") as f:
345            dat = f.read()
346        assert le4(dat[0:4]) == FILE_MAGIC
347        sram_len = le4(dat[4:8])
348        boot_vector = le4(dat[8:12])
349        sram = dat[12 : 12 + sram_len]
350        dram = dat[12 + sram_len :]
351        assert len(sram) <= mmio["sram"][1]
352        assert len(dram) <= mmio["dram1"][1]
353
354        # Stop the device and write the regions.  Note that we don't
355        # zero-fill SRAM, as that's been observed to reboot the host
356        # (!!) on mt8186 when the writes near the end of the 512k
357        # region.
358        # pylint: disable=consider-using-enumerate
359        for i in range(sram_len):
360            maps["sram"][i] = sram[i]
361        # for i in range(sram_len, mmio["sram"][1]):
362        #    maps["sram"][i] = 0
363        for i in range(len(dram)):
364            maps["dram1"][i] = dram[i]
365        for i in range(len(dram), mmio["dram1"][1]):
366            maps["dram1"][i] = 0
367        dev.start(boot_vector)
368        winstream_log(mmio, maps)
369
370    elif sys.argv[1] == "log":
371        winstream_log(mmio, maps)
372
373    elif sys.argv[1] == "oldlog":
374        old_log(dev)
375
376    elif sys.argv[1] == "mem":
377        print("Memory Regions:")
378        for m in mmio:
379            print(f"  {m}: {mmio[m][1]} @ 0x{mmio[m][0]:08x}")
380
381    elif sys.argv[1] == "dump":
382        sz = mmio[sys.argv[2]][1]
383        mm = maps[sys.argv[2]]
384        sys.stdout.buffer.write(mm[0:sz])
385
386    else:
387        print(f"Usage: {sys.argv[0]} log | load <file>")
388
389
390if __name__ == "__main__":
391    main()
392