1#
2# Copyright 2022 Google LLC
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16
17import math
18
19class Bitstream:
20
21    def __init__(self, data):
22
23        self.bytes = data
24
25        self.bp_bw = len(data) - 1
26        self.mask_bw = 1
27
28        self.bp = 0
29        self.low = 0
30        self.range = 0xffffff
31
32    def dump(self):
33
34        b = self.bytes
35
36        for i in range(0, len(b), 20):
37            print(''.join('{:02x} '.format(x)
38                for x in b[i:min(i+20, len(b))] ))
39
40class BitstreamReader(Bitstream):
41
42    def __init__(self, data):
43
44        super().__init__(data)
45
46        self.low = ( (self.bytes[0] << 16) |
47                     (self.bytes[1] <<  8) |
48                     (self.bytes[2]      ) )
49        self.bp = 3
50
51    def read_bit(self):
52
53        bit = bool(self.bytes[self.bp_bw] & self.mask_bw)
54
55        self.mask_bw <<= 1
56        if self.mask_bw == 0x100:
57            self.mask_bw = 1
58            self.bp_bw -= 1
59
60        return bit
61
62    def read_uint(self, nbits):
63
64        val = 0
65        for k in range(nbits):
66            val |= self.read_bit() << k
67
68        return val
69
70    def ac_decode(self, cum_freqs, sym_freqs):
71
72        r = self.range >> 10
73        if self.low >= r << 10:
74            raise ValueError('Invalid ac bitstream')
75
76        val = len(cum_freqs) - 1
77        while self.low < r * cum_freqs[val]:
78            val -= 1
79
80        self.low -= r * cum_freqs[val]
81        self.range = r * sym_freqs[val]
82        while self.range < 0x10000:
83            self.range <<= 8
84
85            self.low <<= 8
86            self.low &= 0xffffff
87            self.low += self.bytes[self.bp]
88            self.bp += 1
89
90        return val
91
92    def get_bits_left(self):
93
94        nbits = 8 * len(self.bytes)
95
96        nbits_bw = nbits - \
97            (8*self.bp_bw + 8 - int(math.log2(self.mask_bw)))
98
99        nbits_ac = 8 * (self.bp - 3) + \
100            (25 - int(math.floor(math.log2(self.range))))
101
102        return nbits - (nbits_bw + nbits_ac)
103
104class BitstreamWriter(Bitstream):
105
106    def __init__(self, nbytes):
107
108        super().__init__(bytearray(nbytes))
109
110        self.cache = -1
111        self.carry = 0
112        self.carry_count = 0
113
114    def write_bit(self, bit):
115
116        mask = self.mask_bw
117        bp = self.bp_bw
118
119        if bit == 0:
120            self.bytes[bp] &= ~mask
121        else:
122            self.bytes[bp] |= mask
123
124        self.mask_bw <<= 1
125        if self.mask_bw == 0x100:
126            self.mask_bw = 1
127            self.bp_bw -= 1
128
129    def write_uint(self, val, nbits):
130
131        for k in range(nbits):
132            self.write_bit(val & 1)
133            val >>= 1
134
135    def ac_shift(self):
136
137        if self.low < 0xff0000 or self.carry == 1:
138
139            if self.cache >= 0:
140                self.bytes[self.bp] = self.cache + self.carry
141                self.bp += 1
142
143            while self.carry_count > 0:
144                self.bytes[self.bp] = (self.carry + 0xff) & 0xff
145                self.bp += 1
146                self.carry_count -= 1
147
148            self.cache = self.low >> 16
149            self.carry = 0
150
151        else:
152            self.carry_count += 1
153
154        self.low <<= 8
155        self.low &= 0xffffff
156
157    def ac_encode(self, cum_freq, sym_freq):
158
159        r = self.range >> 10
160        self.low += r * cum_freq
161        if (self.low >> 24) != 0:
162            self.carry = 1
163
164        self.low &= 0xffffff
165        self.range = r * sym_freq
166        while self.range < 0x10000:
167            self.range <<= 8;
168            self.ac_shift()
169
170    def get_bits_left(self):
171
172        nbits = 8 * len(self.bytes)
173
174        nbits_bw = nbits - \
175            (8*self.bp_bw + 8 - int(math.log2(self.mask_bw)))
176
177        nbits_ac = 8 * self.bp + (25 - int(math.floor(math.log2(self.range))))
178        if self.cache >= 0:
179            nbits_ac += 8
180        if self.carry_count > 0:
181            nbits_ac += 8 * self.carry_count
182
183        return nbits - (nbits_bw + nbits_ac)
184
185    def terminate(self):
186
187        bits = 1
188        while self.range >> (24 - bits) == 0:
189            bits += 1
190
191        mask = 0xffffff >> bits;
192        val = self.low + mask;
193
194        over1 = val >> 24
195        val &= 0x00ffffff
196        high = self.low + self.range
197        over2 = high >> 24
198        high &= 0x00ffffff
199        val = val & ~mask
200
201        if over1 == over2:
202
203            if val + mask >= high:
204                bits += 1
205                mask >>= 1
206                val = ((self.low + mask) & 0x00ffffff) & ~mask
207
208            if val < self.low:
209                self.carry = 1
210
211        self.low = val
212        while bits > 0:
213            self.ac_shift()
214            bits -= 8
215        bits += 8;
216
217        val = self.cache
218
219        if self.carry_count > 0:
220            self.bytes[self.bp] = self.cache
221            self.bp += 1
222
223            while self.carry_count > 1:
224                self.bytes[self.bp] = 0xff
225                self.bp += 1
226                self.carry_count -= 1
227
228            val = 0xff >> (8 - bits)
229
230        mask = 0x80;
231        for k in range(bits):
232
233            if val & mask == 0:
234                self.bytes[self.bp] &= ~mask
235            else:
236                self.bytes[self.bp] |= mask
237
238            mask >>= 1
239
240        return self.bytes
241