1# 2# Copyright (c) 2010-2024 Antmicro 3# 4# This file is licensed under the MIT License. 5# Full license text is available in 'licenses/MIT.txt'. 6# 7 8import math 9import random 10import time 11from typing import List 12 13 14class CacheLine: 15 """ 16 Represents a cache line in a cache set. 17 18 tag (int): The tag of the cache line. 19 use_count (int): Used for replacement policies. 20 insertion_time (float): The time when the cache line was inserted. 21 last_access_time (float): The time when the cache line was last accessed. 22 free (bool): Indicates if the line contains valid data. 23 """ 24 25 def __init__(self): 26 self.init() 27 28 def init(self, tag: int = 0, free: bool = True): 29 self.tag = tag 30 self.free = free 31 self.use_count: int = 0 32 self.insertion_time: float = time.time() 33 self.last_access_time: float = 0 34 35 def __str__(self) -> str: 36 return f"[CacheLine]: tag: {self.tag:b}, free: {self.free}, use: {self.use_count}, insertion: {self.insertion_time}, last access: {self.last_access_time}" 37 38 39class Cache: 40 """ 41 Cache memory model. 42 43 name (str): Cache name, used in the `printd` debug helpers. 44 cache_width (int): log2(cache_size). 45 block_width (int): log2(cache_block_size). 46 memory_width (int): log2(memory_size). 47 48 lines_per_set (int): cache mapping policy selection: 49 * -1 for fully associative 50 * 1 for direct mapping 51 * 2^n for n-way associativity 52 53 replacement_policy (str | None): Selected line eviction policy (defaults to None): 54 * FIFO: first in first out 55 * LRU: least recently used 56 * LFU: least frequently used 57 * None: random 58 59 debug (bool): print debug messages (defaults to False). 60 """ 61 62 def __init__( 63 self, 64 name: str, 65 cache_width: int, 66 block_width: int, 67 memory_width: int, 68 lines_per_set: int, 69 replacement_policy: str | None = None, 70 debug: bool = False 71 ): 72 self.name = name 73 self.debug = debug 74 75 # Width of the memories 76 self._cache_width = cache_width 77 self._block_width = block_width 78 self._memory_width = memory_width 79 80 # Convert width to size in bytes 81 self._cache_size = 2 ** self._cache_width 82 self._block_size = 2 ** self._block_width 83 self._memory_size = 2 ** self._memory_width 84 85 self._num_lines = self._cache_size // self._block_size 86 self._lines = [CacheLine() for i in range(self._num_lines)] 87 88 if lines_per_set == -1: 89 # special configuration case for fully associative mapping 90 lines_per_set = self._num_lines 91 92 if not (lines_per_set & (lines_per_set - 1) == 0) or lines_per_set == 0: 93 raise Exception('Lines per set must be a power of two (1, 2, 4, 8, ...)') 94 95 self._lines_per_set = lines_per_set 96 self._sets = self._num_lines // lines_per_set 97 self._set_width = int(math.log(self._sets, 2)) 98 99 self._replacement_policy = replacement_policy if replacement_policy is not None else 'RAND' 100 101 # Statistics 102 self.misses = 0 103 self.hits = 0 104 self.invalidations = 0 105 self.flushes = 0 106 107 def read(self, addr: int) -> None: 108 sset = self._addr_get_set(addr) 109 line = self._line_lookup(addr) 110 self.printd(f'[read] attempt to fetch {hex(addr)} (set {sset})') 111 112 if line and not line.free: 113 self.printd('[read] rhit') 114 self.hits += 1 115 line.use_count += 1 116 line.last_access_time = time.time() 117 else: 118 self.printd('[read] rmiss') 119 self.misses += 1 120 self._load(addr) 121 122 def write(self, addr: int) -> None: 123 sset = self._addr_get_set(addr) 124 line = self._line_lookup(addr) 125 self.printd(f'[write] attempted write to {hex(addr)} (set {sset})') 126 127 if line: 128 self.printd('[write] whit') 129 self.hits += 1 130 line.last_access_time = time.time() 131 else: 132 self.printd('[write] wmiss') 133 self.misses += 1 134 self._load(addr) 135 136 def flush(self) -> None: 137 self.printd('[flush] flushing all lines!') 138 self.flushes += 1 139 self._lines = [CacheLine() for i in range(self._num_lines)] 140 141 def _select_evicted_index(self, lines_in_set: list) -> int: 142 if self._replacement_policy == 'RAND': 143 return random.randint(0, self._lines_per_set - 1) 144 elif self._replacement_policy == 'LFU': 145 return min(range(len(lines_in_set)), key=lambda i: lines_in_set[i].use_count) 146 elif self._replacement_policy == 'FIFO': 147 return min(range(len(lines_in_set)), key=lambda i: lines_in_set[i].insertion_time) 148 elif self._replacement_policy == 'LRU': 149 return min(range(len(lines_in_set)), key=lambda i: lines_in_set[i].last_access_time) 150 else: 151 raise Exception(f"Unknown replacement policy: {self._replacement_policy}! Exiting!") 152 153 def _load(self, addr: int) -> None: 154 self.printd(f'[load] loading @ {hex(addr)} to cache from Main Memory') 155 tag = self._addr_get_tag(addr) 156 set_index = self._addr_get_set(addr) 157 lines_in_set = self._get_lines_in_set(set_index) 158 159 # Determine the index of the cache line to load into 160 free_line_index = next((index for index, obj in enumerate(lines_in_set) if obj.free), None) 161 if free_line_index is not None: 162 index = free_line_index 163 self.printd(f'[load] loaded new cache index: {free_line_index} in the set {set_index}') 164 else: 165 self.printd(f"[load] lines in set {set_index}:") 166 self.printd(' selecting a line to invalidate:\n', '\n'.join(f'{index}: {line}' for index, line in enumerate(lines_in_set)), sep='') 167 index = self._select_evicted_index(lines_in_set) 168 self.printd(f'[load] invalidated index: {index} in the set {set_index}') 169 self.invalidations += 1 170 171 lines_in_set[index].init(tag, False) 172 173 @staticmethod 174 def _extract_bits(value: int, start_bit: int, end_bit: int) -> int: 175 num_bits = end_bit - start_bit + 1 176 mask = ((1 << num_bits) - 1) << start_bit 177 extracted_bits = (value & mask) >> start_bit 178 return extracted_bits 179 180 def _addr_get_tag(self, addr: int) -> int: 181 start = self._block_width + self._set_width 182 end = self._memory_width 183 return self._extract_bits(addr, start, end) 184 185 def _addr_get_set(self, addr: int) -> int: 186 start = self._block_width 187 end = self._block_width + self._set_width - 1 188 return self._extract_bits(addr, start, end) 189 190 def _addr_get_offset(self, addr: int) -> int: 191 start = 0 192 end = self._block_width - 1 193 return self._extract_bits(addr, start, end) 194 195 def _get_lines_in_set(self, set_index: int) -> List[CacheLine]: 196 line_index = set_index * self._lines_per_set 197 return self._lines[ 198 line_index: 199 line_index + self._lines_per_set 200 ] 201 202 def _line_lookup(self, addr: int) -> CacheLine | None: 203 tag = self._addr_get_tag(addr) 204 lines_in_set = self._get_lines_in_set(self._addr_get_set(addr)) 205 return next((line for line in lines_in_set if line.tag == tag), None) 206 207 def printd(self, *args, **kwargs): 208 if self.debug: 209 print(f'[{self.name}]', *args, **kwargs) 210 211 def print_addr_info(self, addr: int, format: str = 'hex') -> None: 212 convop = {'bin': bin, 'hex': hex, 'dec': int}.get(format, hex) 213 print(f'addr: {convop(addr)}') 214 print(f'tag : {convop(self._addr_get_tag(addr))}') 215 print(f'set : {convop(self._addr_get_set(addr))}') 216 print(f'off : {convop(self._addr_get_offset(addr))}') 217 218 def print_cache_info(self) -> None: 219 print(f'{self.name} configuration:') 220 print(f'Cache size: {self._cache_size} bytes') 221 print(f'Block size: {self._block_size} bytes') 222 print(f'Number of lines: {self._num_lines}') 223 print(f'Number of sets: {self._sets} ({self._lines_per_set} lines per set)') 224 print(f'Replacement policy: {self._replacement_policy if self._replacement_policy is not None else "RAND"}') 225 226 if self.debug: 227 print(f'Cache block width: {self._block_width} bits') 228 print(f'Addressable memory: {self._memory_size} bytes') 229 tag_width = self._memory_width - self._block_width - self._set_width 230 print('Addressing parameters:') 231 print(f'Tag: {tag_width} bits') 232 print(f'Set: {self._set_width} bits') 233 print(f'Block: {self._block_width} bits\n') 234 235 print() 236 237 def print_hmr(self) -> None: 238 ratio = (self.hits / ((self.hits + self.misses) if self.misses else 1)) * 100 239 print(f'Misses: {self.misses}') 240 print(f'Hits: {self.hits}') 241 print(f'Invalidations: {self.invalidations}') 242 print(f'Hit ratio: {round(ratio, 2)}%') 243 244 def print_debug_lines(self, include_empty_tags: bool = False) -> None: 245 tag_width = self._memory_width - self._block_width - self._set_width 246 print(f'tag: {tag_width} bits') 247 print(f'set: {self._set_width} bits') 248 print(f'block: {self._block_width} bits') 249 250 for id, line in enumerate(self._lines): 251 if line.tag or include_empty_tags: 252 print(line) 253 if self._lines_per_set and (id + 1) % self._lines_per_set == 0: 254 print() 255