1 /******************************************************************************
2 *
3 * Copyright 2022 Google LLC
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at:
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 ******************************************************************************/
18
19 #include "bits.h"
20 #include "common.h"
21
22
23 /* ----------------------------------------------------------------------------
24 * Common
25 * -------------------------------------------------------------------------- */
26
27 static inline int ac_get(struct lc3_bits_buffer *);
28 static inline void accu_load(struct lc3_bits_accu *, struct lc3_bits_buffer *);
29
30 /**
31 * Arithmetic coder return range bits
32 * ac Arithmetic coder
33 * return 1 + log2(ac->range)
34 */
ac_get_range_bits(const struct lc3_bits_ac * ac)35 static int ac_get_range_bits(const struct lc3_bits_ac *ac)
36 {
37 int nbits = 0;
38
39 for (unsigned r = ac->range; r; r >>= 1, nbits++);
40
41 return nbits;
42 }
43
44 /**
45 * Arithmetic coder return pending bits
46 * ac Arithmetic coder
47 * return Pending bits
48 */
ac_get_pending_bits(const struct lc3_bits_ac * ac)49 static int ac_get_pending_bits(const struct lc3_bits_ac *ac)
50 {
51 return 26 - ac_get_range_bits(ac) +
52 ((ac->cache >= 0) + ac->carry_count) * 8;
53 }
54
55 /**
56 * Return number of bits left in the bitstream
57 * bits Bitstream context
58 * return >= 0: Number of bits left < 0: Overflow
59 */
get_bits_left(const struct lc3_bits * bits)60 static int get_bits_left(const struct lc3_bits *bits)
61 {
62 const struct lc3_bits_buffer *buffer = &bits->buffer;
63 const struct lc3_bits_accu *accu = &bits->accu;
64 const struct lc3_bits_ac *ac = &bits->ac;
65
66 uintptr_t end = (uintptr_t)buffer->p_bw +
67 (bits->mode == LC3_BITS_MODE_READ ? LC3_ACCU_BITS/8 : 0);
68
69 uintptr_t start = (uintptr_t)buffer->p_fw -
70 (bits->mode == LC3_BITS_MODE_READ ? LC3_AC_BITS/8 : 0);
71
72 int n = end > start ? (int)(end - start) : -(int)(start - end);
73
74 return 8 * n - (accu->n + accu->nover + ac_get_pending_bits(ac));
75 }
76
77 /**
78 * Setup bitstream writing
79 */
lc3_setup_bits(struct lc3_bits * bits,enum lc3_bits_mode mode,void * buffer,int len)80 void lc3_setup_bits(struct lc3_bits *bits,
81 enum lc3_bits_mode mode, void *buffer, int len)
82 {
83 *bits = (struct lc3_bits){
84 .mode = mode,
85 .accu = {
86 .n = mode == LC3_BITS_MODE_READ ? LC3_ACCU_BITS : 0,
87 },
88 .ac = {
89 .range = 0xffffff,
90 .cache = -1
91 },
92 .buffer = {
93 .start = (uint8_t *)buffer, .end = (uint8_t *)buffer + len,
94 .p_fw = (uint8_t *)buffer, .p_bw = (uint8_t *)buffer + len,
95 }
96 };
97
98 if (mode == LC3_BITS_MODE_READ) {
99 struct lc3_bits_ac *ac = &bits->ac;
100 struct lc3_bits_accu *accu = &bits->accu;
101 struct lc3_bits_buffer *buffer = &bits->buffer;
102
103 ac->low = ac_get(buffer) << 16;
104 ac->low |= ac_get(buffer) << 8;
105 ac->low |= ac_get(buffer);
106
107 accu_load(accu, buffer);
108 }
109 }
110
111 /**
112 * Return number of bits left in the bitstream
113 */
lc3_get_bits_left(const struct lc3_bits * bits)114 int lc3_get_bits_left(const struct lc3_bits *bits)
115 {
116 return LC3_MAX(get_bits_left(bits), 0);
117 }
118
119 /**
120 * Return number of bits left in the bitstream
121 */
lc3_check_bits(const struct lc3_bits * bits)122 int lc3_check_bits(const struct lc3_bits *bits)
123 {
124 const struct lc3_bits_ac *ac = &bits->ac;
125
126 return -(get_bits_left(bits) < 0 || ac->error);
127 }
128
129
130 /* ----------------------------------------------------------------------------
131 * Writing
132 * -------------------------------------------------------------------------- */
133
134 /**
135 * Flush the bits accumulator
136 * accu Bitstream accumulator
137 * buffer Bitstream buffer
138 */
accu_flush(struct lc3_bits_accu * accu,struct lc3_bits_buffer * buffer)139 static inline void accu_flush(
140 struct lc3_bits_accu *accu, struct lc3_bits_buffer *buffer)
141 {
142 int nbytes = LC3_MIN(accu->n >> 3,
143 LC3_MAX(buffer->p_bw - buffer->p_fw, 0));
144
145 accu->n -= 8 * nbytes;
146
147 for ( ; nbytes; accu->v >>= 8, nbytes--)
148 *(--buffer->p_bw) = accu->v & 0xff;
149
150 if (accu->n >= 8)
151 accu->n = 0;
152 }
153
154 /**
155 * Arithmetic coder put byte
156 * buffer Bitstream buffer
157 * byte Byte to output
158 */
ac_put(struct lc3_bits_buffer * buffer,int byte)159 static inline void ac_put(struct lc3_bits_buffer *buffer, int byte)
160 {
161 if (buffer->p_fw < buffer->end)
162 *(buffer->p_fw++) = byte;
163 }
164
165 /**
166 * Arithmetic coder range shift
167 * ac Arithmetic coder
168 * buffer Bitstream buffer
169 */
ac_shift(struct lc3_bits_ac * ac,struct lc3_bits_buffer * buffer)170 LC3_HOT static inline void ac_shift(
171 struct lc3_bits_ac *ac, struct lc3_bits_buffer *buffer)
172 {
173 if (ac->low < 0xff0000 || ac->carry)
174 {
175 if (ac->cache >= 0)
176 ac_put(buffer, ac->cache + ac->carry);
177
178 for ( ; ac->carry_count > 0; ac->carry_count--)
179 ac_put(buffer, ac->carry ? 0x00 : 0xff);
180
181 ac->cache = ac->low >> 16;
182 ac->carry = 0;
183 }
184 else
185 ac->carry_count++;
186
187 ac->low = (ac->low << 8) & 0xffffff;
188 }
189
190 /**
191 * Arithmetic coder termination
192 * ac Arithmetic coder
193 * buffer Bitstream buffer
194 * end_val/nbits End value and count of bits to terminate (1 to 8)
195 */
ac_terminate(struct lc3_bits_ac * ac,struct lc3_bits_buffer * buffer)196 static void ac_terminate(struct lc3_bits_ac *ac,
197 struct lc3_bits_buffer *buffer)
198 {
199 int nbits = 25 - ac_get_range_bits(ac);
200 unsigned mask = 0xffffff >> nbits;
201 unsigned val = ac->low + mask;
202 unsigned high = ac->low + ac->range;
203
204 bool over_val = val >> 24;
205 bool over_high = high >> 24;
206
207 val = (val & 0xffffff) & ~mask;
208 high = (high & 0xffffff);
209
210 if (over_val == over_high) {
211
212 if (val + mask >= high) {
213 nbits++;
214 mask >>= 1;
215 val = ((ac->low + mask) & 0xffffff) & ~mask;
216 }
217
218 ac->carry |= val < ac->low;
219 }
220
221 ac->low = val;
222
223 for (; nbits > 8; nbits -= 8)
224 ac_shift(ac, buffer);
225 ac_shift(ac, buffer);
226
227 int end_val = ac->cache >> (8 - nbits);
228
229 if (ac->carry_count) {
230 ac_put(buffer, ac->cache);
231 for ( ; ac->carry_count > 1; ac->carry_count--)
232 ac_put(buffer, 0xff);
233
234 end_val = nbits < 8 ? 0 : 0xff;
235 }
236
237 if (buffer->p_fw < buffer->end) {
238 *buffer->p_fw &= 0xff >> nbits;
239 *buffer->p_fw |= end_val << (8 - nbits);
240 }
241 }
242
243 /**
244 * Flush and terminate bitstream
245 */
lc3_flush_bits(struct lc3_bits * bits)246 void lc3_flush_bits(struct lc3_bits *bits)
247 {
248 struct lc3_bits_ac *ac = &bits->ac;
249 struct lc3_bits_accu *accu = &bits->accu;
250 struct lc3_bits_buffer *buffer = &bits->buffer;
251
252 int nleft = buffer->p_bw - buffer->p_fw;
253 for (int n = 8 * nleft - accu->n; n > 0; n -= 32)
254 lc3_put_bits(bits, 0, LC3_MIN(n, 32));
255
256 accu_flush(accu, buffer);
257
258 ac_terminate(ac, buffer);
259 }
260
261 /**
262 * Write from 1 to 32 bits,
263 * exceeding the capacity of the accumulator
264 */
lc3_put_bits_generic(struct lc3_bits * bits,unsigned v,int n)265 LC3_HOT void lc3_put_bits_generic(struct lc3_bits *bits, unsigned v, int n)
266 {
267 struct lc3_bits_accu *accu = &bits->accu;
268
269 /* --- Fulfill accumulator and flush -- */
270
271 int n1 = LC3_MIN(LC3_ACCU_BITS - accu->n, n);
272 if (n1) {
273 accu->v |= v << accu->n;
274 accu->n = LC3_ACCU_BITS;
275 }
276
277 accu_flush(accu, &bits->buffer);
278
279 /* --- Accumulate remaining bits -- */
280
281 accu->v = v >> n1;
282 accu->n = n - n1;
283 }
284
285 /**
286 * Arithmetic coder renormalization
287 */
lc3_ac_write_renorm(struct lc3_bits * bits)288 LC3_HOT void lc3_ac_write_renorm(struct lc3_bits *bits)
289 {
290 struct lc3_bits_ac *ac = &bits->ac;
291
292 for ( ; ac->range < 0x10000; ac->range <<= 8)
293 ac_shift(ac, &bits->buffer);
294 }
295
296
297 /* ----------------------------------------------------------------------------
298 * Reading
299 * -------------------------------------------------------------------------- */
300
301 /**
302 * Arithmetic coder get byte
303 * buffer Bitstream buffer
304 * return Byte read, 0 on overflow
305 */
ac_get(struct lc3_bits_buffer * buffer)306 static inline int ac_get(struct lc3_bits_buffer *buffer)
307 {
308 return buffer->p_fw < buffer->end ? *(buffer->p_fw++) : 0;
309 }
310
311 /**
312 * Load the accumulator
313 * accu Bitstream accumulator
314 * buffer Bitstream buffer
315 */
accu_load(struct lc3_bits_accu * accu,struct lc3_bits_buffer * buffer)316 static inline void accu_load(struct lc3_bits_accu *accu,
317 struct lc3_bits_buffer *buffer)
318 {
319 int nbytes = LC3_MIN(accu->n >> 3, buffer->p_bw - buffer->start);
320
321 accu->n -= 8 * nbytes;
322
323 for ( ; nbytes; nbytes--) {
324 accu->v >>= 8;
325 accu->v |= (unsigned)*(--buffer->p_bw) << (LC3_ACCU_BITS - 8);
326 }
327
328 if (accu->n >= 8) {
329 accu->nover = LC3_MIN(accu->nover + accu->n, LC3_ACCU_BITS);
330 accu->v >>= accu->n;
331 accu->n = 0;
332 }
333 }
334
335 /**
336 * Read from 1 to 32 bits,
337 * exceeding the capacity of the accumulator
338 */
lc3_get_bits_generic(struct lc3_bits * bits,int n)339 LC3_HOT unsigned lc3_get_bits_generic(struct lc3_bits *bits, int n)
340 {
341 struct lc3_bits_accu *accu = &bits->accu;
342 struct lc3_bits_buffer *buffer = &bits->buffer;
343
344 /* --- Fulfill accumulator and read -- */
345
346 accu_load(accu, buffer);
347
348 int n1 = LC3_MIN(LC3_ACCU_BITS - accu->n, n);
349 unsigned v = (accu->v >> accu->n) & ((1u << n1) - 1);
350 accu->n += n1;
351
352 /* --- Second round --- */
353
354 int n2 = n - n1;
355
356 if (n2) {
357 accu_load(accu, buffer);
358
359 v |= ((accu->v >> accu->n) & ((1u << n2) - 1)) << n1;
360 accu->n += n2;
361 }
362
363 return v;
364 }
365
366 /**
367 * Arithmetic coder renormalization
368 */
lc3_ac_read_renorm(struct lc3_bits * bits)369 LC3_HOT void lc3_ac_read_renorm(struct lc3_bits *bits)
370 {
371 struct lc3_bits_ac *ac = &bits->ac;
372
373 for ( ; ac->range < 0x10000; ac->range <<= 8)
374 ac->low = ((ac->low << 8) | ac_get(&bits->buffer)) & 0xffffff;
375 }
376