1#
2# Copyright (c) 2010-2024 Antmicro
3#
4# This file is licensed under the MIT License.
5# Full license text is available in 'licenses/MIT.txt'.
6#
7
8import math
9import random
10import time
11from typing import List
12
13
14class CacheLine:
15    """
16    Represents a cache line in a cache set.
17
18    tag (int): The tag of the cache line.
19    use_count (int): Used for replacement policies.
20    insertion_time (float): The time when the cache line was inserted.
21    last_access_time (float): The time when the cache line was last accessed.
22    free (bool): Indicates if the line contains valid data.
23    """
24
25    def __init__(self):
26        self.init()
27
28    def init(self, tag: int = 0, free: bool = True):
29        self.tag = tag
30        self.free = free
31        self.use_count: int = 0
32        self.insertion_time: float = time.time()
33        self.last_access_time: float = 0
34
35    def __str__(self) -> str:
36        return f"[CacheLine]: tag: {self.tag:b}, free: {self.free}, use: {self.use_count}, insertion: {self.insertion_time}, last access: {self.last_access_time}"
37
38
39class Cache:
40    """
41    Cache memory model.
42
43    name (str): Cache name, used in the `printd` debug helpers.
44    cache_width (int): log2(cache_size).
45    block_width (int): log2(cache_block_size).
46    memory_width (int): log2(memory_size).
47
48    lines_per_set (int): cache mapping policy selection:
49        * -1 for fully associative
50        * 1 for direct mapping
51        * 2^n for n-way associativity
52
53    replacement_policy (str | None): Selected line eviction policy (defaults to None):
54        * FIFO: first in first out
55        * LRU: least recently used
56        * LFU: least frequently used
57        * None: random
58
59    debug (bool): print debug messages (defaults to False).
60    """
61
62    def __init__(
63        self,
64        name: str,
65        cache_width: int,
66        block_width: int,
67        memory_width: int,
68        lines_per_set: int,
69        replacement_policy: str | None = None,
70        debug: bool = False
71    ):
72        self.name = name
73        self.debug = debug
74
75        # Width of the memories
76        self._cache_width = cache_width
77        self._block_width = block_width
78        self._memory_width = memory_width
79
80        # Convert width to size in bytes
81        self._cache_size = 2 ** self._cache_width
82        self._block_size = 2 ** self._block_width
83        self._memory_size = 2 ** self._memory_width
84
85        self._num_lines = self._cache_size // self._block_size
86        self._lines = [CacheLine() for i in range(self._num_lines)]
87
88        if lines_per_set == -1:
89            # special configuration case for fully associative mapping
90            lines_per_set = self._num_lines
91
92        if not (lines_per_set & (lines_per_set - 1) == 0) or lines_per_set == 0:
93            raise Exception('Lines per set must be a power of two (1, 2, 4, 8, ...)')
94
95        self._lines_per_set = lines_per_set
96        self._sets = self._num_lines // lines_per_set
97        self._set_width = int(math.log(self._sets, 2))
98
99        self._replacement_policy = replacement_policy if replacement_policy is not None else 'RAND'
100
101        # Statistics
102        self.misses = 0
103        self.hits = 0
104        self.invalidations = 0
105        self.flushes = 0
106
107    def read(self, addr: int) -> None:
108        sset = self._addr_get_set(addr)
109        line = self._line_lookup(addr)
110        self.printd(f'[read] attempt to fetch {hex(addr)} (set {sset})')
111
112        if line and not line.free:
113            self.printd('[read] rhit')
114            self.hits += 1
115            line.use_count += 1
116            line.last_access_time = time.time()
117        else:
118            self.printd('[read] rmiss')
119            self.misses += 1
120            self._load(addr)
121
122    def write(self, addr: int) -> None:
123        sset = self._addr_get_set(addr)
124        line = self._line_lookup(addr)
125        self.printd(f'[write] attempted write to {hex(addr)} (set {sset})')
126
127        if line:
128            self.printd('[write] whit')
129            self.hits += 1
130            line.last_access_time = time.time()
131        else:
132            self.printd('[write] wmiss')
133            self.misses += 1
134            self._load(addr)
135
136    def flush(self) -> None:
137        self.printd('[flush] flushing all lines!')
138        self.flushes += 1
139        self._lines = [CacheLine() for i in range(self._num_lines)]
140
141    def _select_evicted_index(self, lines_in_set: list) -> int:
142        if self._replacement_policy == 'RAND':
143            return random.randint(0, self._lines_per_set - 1)
144        elif self._replacement_policy == 'LFU':
145            return min(range(len(lines_in_set)), key=lambda i: lines_in_set[i].use_count)
146        elif self._replacement_policy == 'FIFO':
147            return min(range(len(lines_in_set)), key=lambda i: lines_in_set[i].insertion_time)
148        elif self._replacement_policy == 'LRU':
149            return min(range(len(lines_in_set)), key=lambda i: lines_in_set[i].last_access_time)
150        else:
151            raise Exception(f"Unknown replacement policy: {self._replacement_policy}! Exiting!")
152
153    def _load(self, addr: int) -> None:
154        self.printd(f'[load] loading @ {hex(addr)} to cache from Main Memory')
155        tag = self._addr_get_tag(addr)
156        set_index = self._addr_get_set(addr)
157        lines_in_set = self._get_lines_in_set(set_index)
158
159        # Determine the index of the cache line to load into
160        free_line_index = next((index for index, obj in enumerate(lines_in_set) if obj.free), None)
161        if free_line_index is not None:
162            index = free_line_index
163            self.printd(f'[load] loaded new cache index: {free_line_index} in the set {set_index}')
164        else:
165            self.printd(f"[load] lines in set {set_index}:")
166            self.printd(' selecting a line to invalidate:\n', '\n'.join(f'{index}: {line}' for index, line in enumerate(lines_in_set)), sep='')
167            index = self._select_evicted_index(lines_in_set)
168            self.printd(f'[load] invalidated index: {index} in the set {set_index}')
169            self.invalidations += 1
170
171        lines_in_set[index].init(tag, False)
172
173    @staticmethod
174    def _extract_bits(value: int, start_bit: int, end_bit: int) -> int:
175        num_bits = end_bit - start_bit + 1
176        mask = ((1 << num_bits) - 1) << start_bit
177        extracted_bits = (value & mask) >> start_bit
178        return extracted_bits
179
180    def _addr_get_tag(self, addr: int) -> int:
181        start = self._block_width + self._set_width
182        end = self._memory_width
183        return self._extract_bits(addr, start, end)
184
185    def _addr_get_set(self, addr: int) -> int:
186        start = self._block_width
187        end = self._block_width + self._set_width - 1
188        return self._extract_bits(addr, start, end)
189
190    def _addr_get_offset(self, addr: int) -> int:
191        start = 0
192        end = self._block_width - 1
193        return self._extract_bits(addr, start, end)
194
195    def _get_lines_in_set(self, set_index: int) -> List[CacheLine]:
196        line_index = set_index * self._lines_per_set
197        return self._lines[
198            line_index:
199            line_index + self._lines_per_set
200        ]
201
202    def _line_lookup(self, addr: int) -> CacheLine | None:
203        tag = self._addr_get_tag(addr)
204        lines_in_set = self._get_lines_in_set(self._addr_get_set(addr))
205        return next((line for line in lines_in_set if line.tag == tag), None)
206
207    def printd(self, *args, **kwargs):
208        if self.debug:
209            print(f'[{self.name}]', *args, **kwargs)
210
211    def print_addr_info(self, addr: int, format: str = 'hex') -> None:
212        convop = {'bin': bin, 'hex': hex, 'dec': int}.get(format, hex)
213        print(f'addr: {convop(addr)}')
214        print(f'tag : {convop(self._addr_get_tag(addr))}')
215        print(f'set : {convop(self._addr_get_set(addr))}')
216        print(f'off : {convop(self._addr_get_offset(addr))}')
217
218    def print_cache_info(self) -> None:
219        print(f'{self.name} configuration:')
220        print(f'Cache size:          {self._cache_size} bytes')
221        print(f'Block size:          {self._block_size} bytes')
222        print(f'Number of lines:     {self._num_lines}')
223        print(f'Number of sets:      {self._sets} ({self._lines_per_set} lines per set)')
224        print(f'Replacement policy:  {self._replacement_policy if self._replacement_policy is not None else "RAND"}')
225
226        if self.debug:
227            print(f'Cache block width:   {self._block_width} bits')
228            print(f'Addressable memory:  {self._memory_size} bytes')
229            tag_width = self._memory_width - self._block_width - self._set_width
230            print('Addressing parameters:')
231            print(f'Tag: {tag_width} bits')
232            print(f'Set: {self._set_width} bits')
233            print(f'Block: {self._block_width} bits\n')
234
235        print()
236
237    def print_hmr(self) -> None:
238        ratio = (self.hits / ((self.hits + self.misses) if self.misses else 1)) * 100
239        print(f'Misses: {self.misses}')
240        print(f'Hits: {self.hits}')
241        print(f'Invalidations: {self.invalidations}')
242        print(f'Hit ratio: {round(ratio, 2)}%')
243
244    def print_debug_lines(self, include_empty_tags: bool = False) -> None:
245        tag_width = self._memory_width - self._block_width - self._set_width
246        print(f'tag: {tag_width} bits')
247        print(f'set: {self._set_width} bits')
248        print(f'block: {self._block_width} bits')
249
250        for id, line in enumerate(self._lines):
251            if line.tag or include_empty_tags:
252                print(line)
253                if self._lines_per_set and (id + 1) % self._lines_per_set == 0:
254                    print()
255