1#!/usr/bin/env python
2# coding=utf-8
3#
4# This script decodes Xtensa CPU trace dumps. It allows tracing the program
5# execution at instruction level.
6#
7# Some trivia about the Xtensa CPU trace (TRAX):
8# TRAX format mostly follows the IEEE-ISTO 5001-2003 (Nexus) standard.
9# The following Nexus Program Trace messages are implemented by TRAX:
10# - Indirect Branch Message
11# - Syncronization Message
12# - Indirect Branch with Synchronization Message
13# - Correlation Message
14# TRAX outputs compressed traces with 2 MSEO bits (LSB) and 6 MDO bits (MSB),
15# packed into a byte. MSEO bits are used to split the stream into packets and messages,
16# and MDO bits carry the actual data of the messages. Each message may contain multiple packets.
17#
18# This script can be used standalone, or loaded into GDB.
19# When used standalone, it dumps the list of trace messages to stdout.
20# When used from GDB, it also invokes GDB command to dump the list of assembly
21# instructions corresponding to each of the messages.
22#
23# Standalone usage:
24#   traceparse.py <dump_file>
25#
26# Usage from GDB:
27#   xtensa-esp32-elf-gdb -n --batch program.elf -x gdbinit
28# with the following gdbinit script:
29#   set pagination off
30#   set confirm off
31#   add-symbol-file rom.elf <address of ROM .text section>
32#   source traceparse.py
33#   python parse_and_dump("/path/to/dump_file")
34#
35# Loading the ROM code is optional; if not loaded, disassembly for ROM sections of code
36# will be missing.
37#
38###
39# Copyright 2020 Espressif Systems (Shanghai) PTE LTD
40#
41# Licensed under the Apache License, Version 2.0 (the "License");
42# you may not use this file except in compliance with the License.
43# You may obtain a copy of the License at
44#
45#     http://www.apache.org/licenses/LICENSE-2.0
46#
47# Unless required by applicable law or agreed to in writing, software
48# distributed under the License is distributed on an "AS IS" BASIS,
49# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
50# See the License for the specific language governing permissions and
51# limitations under the License.
52from __future__ import print_function
53
54import sys
55
56# Check if loaded into GDB
57try:
58    assert gdb.__name__ == 'gdb'  # type: ignore
59    WITH_GDB = True
60except NameError:
61    WITH_GDB = False
62
63# MSEO bit masks:
64MSEO_PKTEND = 1 << 0    # bit 0: indicates the last byte of a packet
65MSEO_MSGEND = 1 << 1    # bit 1: indicates the last byte of the message
66
67# Message types. The type is stored in the first 6 MDO bits or the first packet.
68TVAL_INDBR = 4	        # Indirect branch
69TVAL_INDBRSYNC = 12	    # Indirect branch w/ synchronisation
70TVAL_SYNC = 9	        # Synchronisation msg
71TVAL_CORR = 33	        # Correlation message
72
73
74class TraxPacket(object):
75    def __init__(self, data, truncated=False):
76        self.data = data
77        self.size_bytes = len(data)
78        self.truncated = truncated
79
80    def get_bits(self, start, count=0):
81        """
82        Extract data bits from the packet
83        :param start: offset, in bits, of the part to be extracted
84        :param count: number of bits to extract; if omitted or zero,
85                      extracts until the end of the packet
86        :return: integer containing the extracted bits
87        """
88        start_byte = start // 6
89        if count <= 0:
90            # all remaining bits
91            count = len(self.data) * 6 - start
92        bits_remaining = count
93        result = 0
94        shift = 0
95        for i, b in enumerate(self.data[start_byte:]):
96            # which bit in the byte is the starting bit
97            if i == 0:
98                # at start_byte: take the offset into account
99                start_bit = 2 + (start % 6)
100            else:
101                # every other byte: start after MSEO bits
102                start_bit = 2
103            # how many bits do we need to copy from this byte
104            cnt_bits = min(bits_remaining, 8 - start_bit)
105            mask = (2 ** cnt_bits) - 1
106            # take this many bits after the start_bit
107            bits = (b >> start_bit) & mask
108            # add these bits to the result
109            result |= bits << shift
110            # update the remaining bit count
111            shift += cnt_bits
112            bits_remaining -= cnt_bits
113            if bits_remaining == 0:
114                break
115        return result
116
117    def __str__(self):
118        return '%d byte packet%s' % (self.size_bytes, ' (truncated)' if self.truncated else '')
119
120
121class TraxMessage(object):
122    def __init__(self, packets, truncated=False):
123        """
124        Create and parse a TRAX message from packets
125        :param packets: list of TraxPacket objects, must not be empty
126        :param truncated: whether the message was truncated in the stream
127        """
128        assert len(packets) > 0
129        self.packets = packets
130        self.truncated = truncated
131        if truncated:
132            self.msg_type = None
133        else:
134            self.msg_type = self._get_type()
135
136        # Start and end of the instruction range corresponding to this message
137        self.pc_start = 0   # inclusive
138        self.pc_end = 0     # not inclusive
139        self.pc_target = 0      # PC of the next range
140        self.is_exception = False   # whether the message indicates an exception
141        self.is_correlation = False     # whether this is a correlation message
142
143        # message-specific fields
144        self.icnt = 0
145        self.uaddr = 0
146        self.dcont = 0
147
148        # decode the fields
149        if not truncated:
150            self._decode()
151
152    def _get_type(self):
153        """
154        :return: Message type, one of TVAL_XXX values
155        """
156        return self.packets[0].get_bits(0, 6)
157
158    def _decode(self):
159        """ Parse the packets and fill in the message-specific fields """
160        if self.msg_type == TVAL_INDBR:
161            self.icnt = self.packets[0].get_bits(7, -1)
162            self.btype = self.packets[0].get_bits(6, 1)
163            self.uaddr = self.packets[1].get_bits(0)
164            self.is_exception = self.btype > 0
165        elif self.msg_type == TVAL_INDBRSYNC:
166            self.icnt = self.packets[0].get_bits(8, -1)
167            self.btype = self.packets[0].get_bits(7, 1)
168            self.pc_target = self.packets[1].get_bits(0)
169            self.dcont = self.packets[0].get_bits(6, 1)
170            self.is_exception = self.btype > 0
171        elif self.msg_type == TVAL_SYNC:
172            self.icnt = self.packets[0].get_bits(7, -1)
173            self.dcont = self.packets[0].get_bits(6, 1)
174            self.pc_target = self.packets[1].get_bits(0)
175        elif self.msg_type == TVAL_CORR:
176            self.icnt = self.packets[0].get_bits(12, -1)
177            self.is_correlation = True
178        else:
179            raise NotImplementedError('Unknown message type (%d)' % self.msg_type)
180
181    def process_forward(self, cur_pc):
182        """
183        Given the target PC known from the previous message, determine
184        the PC range corresponding to the current message.
185        :param cur_pc: previous known PC
186        :return: target PC after the current message
187        """
188        assert not self.truncated
189
190        next_pc = cur_pc
191        if self.msg_type == TVAL_INDBR:
192            next_pc = cur_pc ^ self.uaddr
193            self.pc_target = next_pc
194            self.pc_start = cur_pc
195            self.pc_end = self.pc_start + self.icnt + 1
196        if self.msg_type == TVAL_INDBRSYNC:
197            next_pc = self.pc_target
198            self.pc_start = cur_pc
199            self.pc_end = cur_pc + self.icnt + 1
200        if self.msg_type == TVAL_SYNC:
201            next_pc = self.pc_target
202            self.pc_start = next_pc - self.icnt
203            self.pc_end = next_pc + 1
204        if self.msg_type == TVAL_CORR:
205            pass
206        return next_pc
207
208    def process_backward(self, cur_pc):
209        """
210        Given the address of the PC known from the _next_ message, determine
211        the PC range corresponding to the current message.
212        :param cur_pc: next known PC
213        :return: target PC of the _previous_ message
214        """
215        assert not self.truncated
216        # Backward pass is only used to resolve addresses of messages
217        # up to the first SYNC/INDBRSYNC message.
218        # SYNC/INDBRSYNC messages are only handled in the forward pass.
219        assert self.msg_type != TVAL_INDBRSYNC
220        assert self.msg_type != TVAL_SYNC
221
222        prev_pc = cur_pc
223        self.pc_target = cur_pc
224        if self.msg_type == TVAL_INDBR:
225            prev_pc ^= self.uaddr
226            self.pc_start = prev_pc
227            self.pc_end = prev_pc + self.icnt + 1
228        if self.msg_type == TVAL_CORR:
229            pass
230        return prev_pc
231
232    def __str__(self):
233        desc = 'Unknown (%d)' % self.msg_type
234        extra = ''
235        if self.truncated:
236            desc = 'Truncated'
237        if self.msg_type == TVAL_INDBR:
238            desc = 'Indirect branch'
239            extra = ', icnt=%d, uaddr=0x%x, exc=%d' % (self.icnt, self.uaddr, self.is_exception)
240        if self.msg_type == TVAL_INDBRSYNC:
241            desc = 'Indirect branch w/sync'
242            extra = ', icnt=%d, dcont=%d, exc=%d' % (self.icnt, self.dcont, self.is_exception)
243        if self.msg_type == TVAL_SYNC:
244            desc = 'Synchronization'
245            extra = ', icnt=%d, dcont=%d' % (self.icnt, self.dcont)
246        if self.msg_type == TVAL_CORR:
247            desc = 'Correlation'
248            extra = ', icnt=%d' % self.icnt
249        return '%s message, %d packets, PC range 0x%08x - 0x%08x, target PC 0x%08x' % (
250            desc, len(self.packets), self.pc_start, self.pc_end, self.pc_target) + extra
251
252
253def load_messages(data):
254    """
255    Decodes TRAX data and resolves PC ranges.
256    :param data: input data, bytes
257    :return: list of TraxMessage objects
258    """
259    messages = []
260    packets = []
261    packet_start = 0
262    msg_cnt = 0
263    pkt_cnt = 0
264
265    # Iterate over the input data, splitting bytes into packets and messages
266    for i, b in enumerate(data):
267        if (b & MSEO_MSGEND) and not (b & MSEO_PKTEND):
268            raise AssertionError('Invalid MSEO bits in b=0x%x. Not a TRAX dump?' % b)
269
270        if b & MSEO_PKTEND:
271            pkt_cnt += 1
272            packets.append(TraxPacket(data[packet_start:i + 1], packet_start == 0))
273            packet_start = i + 1
274
275        if b & MSEO_MSGEND:
276            msg_cnt += 1
277            try:
278                messages.append(TraxMessage(packets, len(messages) == 0))
279            except NotImplementedError as e:
280                sys.stderr.write('Failed to parse message #%03d (at %d bytes): %s\n' % (msg_cnt, i, str(e)))
281            packets = []
282
283    # Resolve PC ranges of messages.
284    # Forward pass: skip messages until a message with known PC,
285    # i.e. a SYNC/INDBRSYNC message. Process all messages following it.
286    pc = 0
287    first_sync_index = -1
288    for i, m in enumerate(messages):
289        if pc == 0 and m.pc_target == 0:
290            continue
291        if first_sync_index < 0:
292            first_sync_index = i
293        pc = m.process_forward(pc)
294
295    # Now process the skipped messages in the reverse direction,
296    # starting from the first message with known PC.
297    pc = messages[first_sync_index].pc_start
298    for m in reversed(messages[0:first_sync_index]):
299        if m.truncated:
300            break
301        pc = m.process_backward(pc)
302
303    return messages
304
305
306def parse_and_dump(filename, disassemble=WITH_GDB):
307    """
308    Decode TRAX data from a file, print out the messages.
309    :param filename: file to load the dump from
310    :param disassemble: if True, print disassembly of PC ranges
311    """
312    with open(filename, 'rb') as f:
313        data = f.read()
314
315    messages = load_messages(data)
316    sys.stderr.write('Loaded %d messages in %d bytes\n' % (len(messages), len(data)))
317
318    for i, m in enumerate(messages):
319        if m.truncated:
320            continue
321        print('%04d: %s' % (i, str(m)))
322        if m.is_exception:
323            print('*** Exception occurred ***')
324        if disassemble and WITH_GDB:
325            try:
326                gdb.execute('disassemble 0x%08x, 0x%08x' % (m.pc_start, m.pc_end))  # noqa: F821
327            except gdb.MemoryError:  # noqa: F821
328                print('Failed to disassemble from 0x%08x to 0x%08x' % (m.pc_start, m.pc_end))
329
330
331def main():
332    if len(sys.argv) < 2:
333        sys.stderr.write('Usage: %s <dump_file>\n')
334        raise SystemExit(1)
335
336    parse_and_dump(sys.argv[1])
337
338
339if __name__ == '__main__' and not WITH_GDB:
340    main()
341