1# 2# Copyright 2022 Google LLC 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# 16 17import math 18 19class Bitstream: 20 21 def __init__(self, data): 22 23 self.bytes = data 24 25 self.bp_bw = len(data) - 1 26 self.mask_bw = 1 27 28 self.bp = 0 29 self.low = 0 30 self.range = 0xffffff 31 32 def dump(self): 33 34 b = self.bytes 35 36 for i in range(0, len(b), 20): 37 print(''.join('{:02x} '.format(x) 38 for x in b[i:min(i+20, len(b))] )) 39 40class BitstreamReader(Bitstream): 41 42 def __init__(self, data): 43 44 super().__init__(data) 45 46 self.low = ( (self.bytes[0] << 16) | 47 (self.bytes[1] << 8) | 48 (self.bytes[2] ) ) 49 self.bp = 3 50 51 def read_bit(self): 52 53 bit = bool(self.bytes[self.bp_bw] & self.mask_bw) 54 55 self.mask_bw <<= 1 56 if self.mask_bw == 0x100: 57 self.mask_bw = 1 58 self.bp_bw -= 1 59 60 return bit 61 62 def read_uint(self, nbits): 63 64 val = 0 65 for k in range(nbits): 66 val |= self.read_bit() << k 67 68 return val 69 70 def ac_decode(self, cum_freqs, sym_freqs): 71 72 r = self.range >> 10 73 if self.low >= r << 10: 74 raise ValueError('Invalid ac bitstream') 75 76 val = len(cum_freqs) - 1 77 while self.low < r * cum_freqs[val]: 78 val -= 1 79 80 self.low -= r * cum_freqs[val] 81 self.range = r * sym_freqs[val] 82 while self.range < 0x10000: 83 self.range <<= 8 84 85 self.low <<= 8 86 self.low &= 0xffffff 87 self.low += self.bytes[self.bp] 88 self.bp += 1 89 90 return val 91 92 def get_bits_left(self): 93 94 nbits = 8 * len(self.bytes) 95 96 nbits_bw = nbits - \ 97 (8*self.bp_bw + 8 - int(math.log2(self.mask_bw))) 98 99 nbits_ac = 8 * (self.bp - 3) + \ 100 (25 - int(math.floor(math.log2(self.range)))) 101 102 return nbits - (nbits_bw + nbits_ac) 103 104class BitstreamWriter(Bitstream): 105 106 def __init__(self, nbytes): 107 108 super().__init__(bytearray(nbytes)) 109 110 self.cache = -1 111 self.carry = 0 112 self.carry_count = 0 113 114 def write_bit(self, bit): 115 116 mask = self.mask_bw 117 bp = self.bp_bw 118 119 if bit == 0: 120 self.bytes[bp] &= ~mask 121 else: 122 self.bytes[bp] |= mask 123 124 self.mask_bw <<= 1 125 if self.mask_bw == 0x100: 126 self.mask_bw = 1 127 self.bp_bw -= 1 128 129 def write_uint(self, val, nbits): 130 131 for k in range(nbits): 132 self.write_bit(val & 1) 133 val >>= 1 134 135 def ac_shift(self): 136 137 if self.low < 0xff0000 or self.carry == 1: 138 139 if self.cache >= 0: 140 self.bytes[self.bp] = self.cache + self.carry 141 self.bp += 1 142 143 while self.carry_count > 0: 144 self.bytes[self.bp] = (self.carry + 0xff) & 0xff 145 self.bp += 1 146 self.carry_count -= 1 147 148 self.cache = self.low >> 16 149 self.carry = 0 150 151 else: 152 self.carry_count += 1 153 154 self.low <<= 8 155 self.low &= 0xffffff 156 157 def ac_encode(self, cum_freq, sym_freq): 158 159 r = self.range >> 10 160 self.low += r * cum_freq 161 if (self.low >> 24) != 0: 162 self.carry = 1 163 164 self.low &= 0xffffff 165 self.range = r * sym_freq 166 while self.range < 0x10000: 167 self.range <<= 8; 168 self.ac_shift() 169 170 def get_bits_left(self): 171 172 nbits = 8 * len(self.bytes) 173 174 nbits_bw = nbits - \ 175 (8*self.bp_bw + 8 - int(math.log2(self.mask_bw))) 176 177 nbits_ac = 8 * self.bp + (25 - int(math.floor(math.log2(self.range)))) 178 if self.cache >= 0: 179 nbits_ac += 8 180 if self.carry_count > 0: 181 nbits_ac += 8 * self.carry_count 182 183 return nbits - (nbits_bw + nbits_ac) 184 185 def terminate(self): 186 187 bits = 1 188 while self.range >> (24 - bits) == 0: 189 bits += 1 190 191 mask = 0xffffff >> bits; 192 val = self.low + mask; 193 194 over1 = val >> 24 195 val &= 0x00ffffff 196 high = self.low + self.range 197 over2 = high >> 24 198 high &= 0x00ffffff 199 val = val & ~mask 200 201 if over1 == over2: 202 203 if val + mask >= high: 204 bits += 1 205 mask >>= 1 206 val = ((self.low + mask) & 0x00ffffff) & ~mask 207 208 if val < self.low: 209 self.carry = 1 210 211 self.low = val 212 while bits > 0: 213 self.ac_shift() 214 bits -= 8 215 bits += 8; 216 217 val = self.cache 218 219 if self.carry_count > 0: 220 self.bytes[self.bp] = self.cache 221 self.bp += 1 222 223 while self.carry_count > 1: 224 self.bytes[self.bp] = 0xff 225 self.bp += 1 226 self.carry_count -= 1 227 228 val = 0xff >> (8 - bits) 229 230 mask = 0x80; 231 for k in range(bits): 232 233 if val & mask == 0: 234 self.bytes[self.bp] &= ~mask 235 else: 236 self.bytes[self.bp] |= mask 237 238 mask >>= 1 239 240 return self.bytes 241