1 /******************************************************************************
2  *
3  *  Copyright 2022 Google LLC
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at:
8  *
9  *  http://www.apache.org/licenses/LICENSE-2.0
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  ******************************************************************************/
18 
19 #include "bits.h"
20 #include "common.h"
21 
22 
23 /* ----------------------------------------------------------------------------
24  *  Common
25  * -------------------------------------------------------------------------- */
26 
27 static inline int ac_get(struct lc3_bits_buffer *);
28 static inline void accu_load(struct lc3_bits_accu *, struct lc3_bits_buffer *);
29 
30 /**
31  * Arithmetic coder return range bits
32  * ac              Arithmetic coder
33  * return          1 + log2(ac->range)
34  */
ac_get_range_bits(const struct lc3_bits_ac * ac)35 static int ac_get_range_bits(const struct lc3_bits_ac *ac)
36 {
37     int nbits = 0;
38 
39     for (unsigned r = ac->range; r; r >>= 1, nbits++);
40 
41     return nbits;
42 }
43 
44 /**
45  * Arithmetic coder return pending bits
46  * ac              Arithmetic coder
47  * return          Pending bits
48  */
ac_get_pending_bits(const struct lc3_bits_ac * ac)49 static int ac_get_pending_bits(const struct lc3_bits_ac *ac)
50 {
51     return 26 - ac_get_range_bits(ac) +
52         ((ac->cache >= 0) + ac->carry_count) * 8;
53 }
54 
55 /**
56  * Return number of bits left in the bitstream
57  * bits            Bitstream context
58  * return          >= 0: Number of bits left  < 0: Overflow
59  */
get_bits_left(const struct lc3_bits * bits)60 static int get_bits_left(const struct lc3_bits *bits)
61 {
62     const struct lc3_bits_buffer *buffer = &bits->buffer;
63     const struct lc3_bits_accu *accu = &bits->accu;
64     const struct lc3_bits_ac *ac = &bits->ac;
65 
66     uintptr_t end = (uintptr_t)buffer->p_bw +
67         (bits->mode == LC3_BITS_MODE_READ ? LC3_ACCU_BITS/8 : 0);
68 
69     uintptr_t start = (uintptr_t)buffer->p_fw -
70         (bits->mode == LC3_BITS_MODE_READ ? LC3_AC_BITS/8 : 0);
71 
72     int n = end > start ? (int)(end - start) : -(int)(start - end);
73 
74     return 8 * n - (accu->n + accu->nover + ac_get_pending_bits(ac));
75 }
76 
77 /**
78  * Setup bitstream writing
79  */
lc3_setup_bits(struct lc3_bits * bits,enum lc3_bits_mode mode,void * buffer,int len)80 void lc3_setup_bits(struct lc3_bits *bits,
81     enum lc3_bits_mode mode, void *buffer, int len)
82 {
83     *bits = (struct lc3_bits){
84         .mode = mode,
85         .accu = {
86             .n = mode == LC3_BITS_MODE_READ ? LC3_ACCU_BITS : 0,
87         },
88         .ac = {
89             .range = 0xffffff,
90             .cache = -1
91         },
92         .buffer = {
93             .start = (uint8_t *)buffer, .end  = (uint8_t *)buffer + len,
94             .p_fw  = (uint8_t *)buffer, .p_bw = (uint8_t *)buffer + len,
95         }
96     };
97 
98     if (mode == LC3_BITS_MODE_READ) {
99         struct lc3_bits_ac *ac = &bits->ac;
100         struct lc3_bits_accu *accu = &bits->accu;
101         struct lc3_bits_buffer *buffer = &bits->buffer;
102 
103         ac->low  = ac_get(buffer) << 16;
104         ac->low |= ac_get(buffer) <<  8;
105         ac->low |= ac_get(buffer);
106 
107         accu_load(accu, buffer);
108     }
109 }
110 
111 /**
112  * Return number of bits left in the bitstream
113  */
lc3_get_bits_left(const struct lc3_bits * bits)114 int lc3_get_bits_left(const struct lc3_bits *bits)
115 {
116     return LC3_MAX(get_bits_left(bits), 0);
117 }
118 
119 /**
120  * Return number of bits left in the bitstream
121  */
lc3_check_bits(const struct lc3_bits * bits)122 int lc3_check_bits(const struct lc3_bits *bits)
123 {
124     const struct lc3_bits_ac *ac = &bits->ac;
125 
126     return -(get_bits_left(bits) < 0 || ac->error);
127 }
128 
129 
130 /* ----------------------------------------------------------------------------
131  *  Writing
132  * -------------------------------------------------------------------------- */
133 
134 /**
135  * Flush the bits accumulator
136  * accu            Bitstream accumulator
137  * buffer          Bitstream buffer
138  */
accu_flush(struct lc3_bits_accu * accu,struct lc3_bits_buffer * buffer)139 static inline void accu_flush(
140     struct lc3_bits_accu *accu, struct lc3_bits_buffer *buffer)
141 {
142     int nbytes = LC3_MIN(accu->n >> 3,
143         LC3_MAX(buffer->p_bw - buffer->p_fw, 0));
144 
145     accu->n -= 8 * nbytes;
146 
147     for ( ; nbytes; accu->v >>= 8, nbytes--)
148         *(--buffer->p_bw) = accu->v & 0xff;
149 
150     if (accu->n >= 8)
151         accu->n = 0;
152 }
153 
154 /**
155  * Arithmetic coder put byte
156  * buffer          Bitstream buffer
157  * byte            Byte to output
158  */
ac_put(struct lc3_bits_buffer * buffer,int byte)159 static inline void ac_put(struct lc3_bits_buffer *buffer, int byte)
160 {
161     if (buffer->p_fw < buffer->end)
162         *(buffer->p_fw++) = byte;
163 }
164 
165 /**
166  * Arithmetic coder range shift
167  * ac              Arithmetic coder
168  * buffer          Bitstream buffer
169  */
ac_shift(struct lc3_bits_ac * ac,struct lc3_bits_buffer * buffer)170 LC3_HOT static inline void ac_shift(
171     struct lc3_bits_ac *ac, struct lc3_bits_buffer *buffer)
172 {
173     if (ac->low < 0xff0000 || ac->carry)
174     {
175         if (ac->cache >= 0)
176             ac_put(buffer, ac->cache + ac->carry);
177 
178         for ( ; ac->carry_count > 0; ac->carry_count--)
179             ac_put(buffer, ac->carry ? 0x00 : 0xff);
180 
181          ac->cache = ac->low >> 16;
182          ac->carry = 0;
183     }
184     else
185          ac->carry_count++;
186 
187     ac->low = (ac->low << 8) & 0xffffff;
188 }
189 
190 /**
191  * Arithmetic coder termination
192  * ac              Arithmetic coder
193  * buffer          Bitstream buffer
194  * end_val/nbits   End value and count of bits to terminate (1 to 8)
195  */
ac_terminate(struct lc3_bits_ac * ac,struct lc3_bits_buffer * buffer)196 static void ac_terminate(struct lc3_bits_ac *ac,
197     struct lc3_bits_buffer *buffer)
198 {
199     int nbits = 25 - ac_get_range_bits(ac);
200     unsigned mask = 0xffffff >> nbits;
201     unsigned val  = ac->low + mask;
202     unsigned high = ac->low + ac->range;
203 
204     bool over_val  = val  >> 24;
205     bool over_high = high >> 24;
206 
207     val  = (val  & 0xffffff) & ~mask;
208     high = (high & 0xffffff);
209 
210     if (over_val == over_high) {
211 
212         if (val + mask >= high) {
213             nbits++;
214             mask >>= 1;
215             val = ((ac->low + mask) & 0xffffff) & ~mask;
216         }
217 
218         ac->carry |= val < ac->low;
219     }
220 
221     ac->low = val;
222 
223     for (; nbits > 8; nbits -= 8)
224         ac_shift(ac, buffer);
225     ac_shift(ac, buffer);
226 
227     int end_val = ac->cache >> (8 - nbits);
228 
229     if (ac->carry_count) {
230         ac_put(buffer, ac->cache);
231         for ( ; ac->carry_count > 1; ac->carry_count--)
232             ac_put(buffer, 0xff);
233 
234         end_val = nbits < 8 ? 0 : 0xff;
235     }
236 
237     if (buffer->p_fw < buffer->end) {
238         *buffer->p_fw &= 0xff >> nbits;
239         *buffer->p_fw |= end_val << (8 - nbits);
240     }
241 }
242 
243 /**
244  * Flush and terminate bitstream
245  */
lc3_flush_bits(struct lc3_bits * bits)246 void lc3_flush_bits(struct lc3_bits *bits)
247 {
248     struct lc3_bits_ac *ac = &bits->ac;
249     struct lc3_bits_accu *accu = &bits->accu;
250     struct lc3_bits_buffer *buffer = &bits->buffer;
251 
252     int nleft = buffer->p_bw - buffer->p_fw;
253     for (int n = 8 * nleft - accu->n; n > 0; n -= 32)
254         lc3_put_bits(bits, 0, LC3_MIN(n, 32));
255 
256     accu_flush(accu, buffer);
257 
258     ac_terminate(ac, buffer);
259 }
260 
261 /**
262  * Write from 1 to 32 bits,
263  * exceeding the capacity of the accumulator
264  */
lc3_put_bits_generic(struct lc3_bits * bits,unsigned v,int n)265 LC3_HOT void lc3_put_bits_generic(struct lc3_bits *bits, unsigned v, int n)
266 {
267     struct lc3_bits_accu *accu = &bits->accu;
268 
269     /* --- Fulfill accumulator and flush -- */
270 
271     int n1 = LC3_MIN(LC3_ACCU_BITS - accu->n, n);
272     if (n1) {
273         accu->v |= v << accu->n;
274         accu->n = LC3_ACCU_BITS;
275     }
276 
277     accu_flush(accu, &bits->buffer);
278 
279     /* --- Accumulate remaining bits -- */
280 
281     accu->v = v >> n1;
282     accu->n = n - n1;
283 }
284 
285 /**
286  * Arithmetic coder renormalization
287  */
lc3_ac_write_renorm(struct lc3_bits * bits)288 LC3_HOT void lc3_ac_write_renorm(struct lc3_bits *bits)
289 {
290     struct lc3_bits_ac *ac = &bits->ac;
291 
292     for ( ; ac->range < 0x10000; ac->range <<= 8)
293         ac_shift(ac, &bits->buffer);
294 }
295 
296 
297 /* ----------------------------------------------------------------------------
298  *  Reading
299  * -------------------------------------------------------------------------- */
300 
301 /**
302  * Arithmetic coder get byte
303  * buffer          Bitstream buffer
304  * return          Byte read, 0 on overflow
305  */
ac_get(struct lc3_bits_buffer * buffer)306 static inline int ac_get(struct lc3_bits_buffer *buffer)
307 {
308     return buffer->p_fw < buffer->end ? *(buffer->p_fw++) : 0;
309 }
310 
311 /**
312  * Load the accumulator
313  * accu            Bitstream accumulator
314  * buffer          Bitstream buffer
315  */
accu_load(struct lc3_bits_accu * accu,struct lc3_bits_buffer * buffer)316 static inline void accu_load(struct lc3_bits_accu *accu,
317     struct lc3_bits_buffer *buffer)
318 {
319     int nbytes = LC3_MIN(accu->n >> 3, buffer->p_bw - buffer->start);
320 
321     accu->n -= 8 * nbytes;
322 
323     for ( ; nbytes; nbytes--) {
324         accu->v >>= 8;
325         accu->v |= *(--buffer->p_bw) << (LC3_ACCU_BITS - 8);
326     }
327 
328     if (accu->n >= 8) {
329         accu->nover = LC3_MIN(accu->nover + accu->n, LC3_ACCU_BITS);
330         accu->v >>= accu->n;
331         accu->n = 0;
332     }
333 }
334 
335 /**
336  * Read from 1 to 32 bits,
337  * exceeding the capacity of the accumulator
338  */
lc3_get_bits_generic(struct lc3_bits * bits,int n)339 LC3_HOT unsigned lc3_get_bits_generic(struct lc3_bits *bits, int n)
340 {
341     struct lc3_bits_accu *accu = &bits->accu;
342     struct lc3_bits_buffer *buffer = &bits->buffer;
343 
344     /* --- Fulfill accumulator and read -- */
345 
346     accu_load(accu, buffer);
347 
348     int n1 = LC3_MIN(LC3_ACCU_BITS - accu->n, n);
349     unsigned v = (accu->v >> accu->n) & ((1u << n1) - 1);
350     accu->n += n1;
351 
352     /* --- Second round --- */
353 
354     int n2 = n - n1;
355 
356     if (n2) {
357         accu_load(accu, buffer);
358 
359         v |= ((accu->v >> accu->n) & ((1u << n2) - 1)) << n1;
360         accu->n += n2;
361     }
362 
363     return v;
364 }
365 
366 /**
367  * Arithmetic coder renormalization
368  */
lc3_ac_read_renorm(struct lc3_bits * bits)369 LC3_HOT void lc3_ac_read_renorm(struct lc3_bits *bits)
370 {
371     struct lc3_bits_ac *ac = &bits->ac;
372 
373     for ( ; ac->range < 0x10000; ac->range <<= 8)
374         ac->low = ((ac->low << 8) | ac_get(&bits->buffer)) & 0xffffff;
375 }
376