1#
2# Copyright 2022 Google LLC
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16
17import numpy as np
18
19import lc3
20import tables as T, appendix_c as C
21
22import bwdet as m_bwdet
23import ltpf as m_ltpf
24import sns as m_sns
25import tns as m_tns
26
27### ------------------------------------------------------------------------ ###
28
29class SpectrumQuantization:
30
31    def __init__(self, dt, sr):
32
33        self.dt = dt
34        self.sr = sr
35
36    def get_gain_offset(self, nbytes):
37
38        sr_ind = self.sr if self.sr < T.SRATE_48K_HR \
39            else 4 + (self.sr - T.SRATE_48K_HR)
40
41        g_off = (nbytes * 8) // (10 * (1 + sr_ind))
42        g_off = -min(115, g_off) - (105 + 5*(1 + sr_ind))
43        if self.sr >= T.SRATE_48K_HR:
44            g_off = max(g_off, -181)
45
46        return g_off
47
48    def get_noise_indices(self, bw, xq, lastnz):
49
50        nf_start = [  6, 12, 18, 24 ][self.dt]
51        nf_width = [  1,  1,  2,  3 ][self.dt]
52
53        bw_stop = T.I[self.dt][min(bw, T.SRATE_48K)][-1]
54
55        xq = np.append(xq[:lastnz], np.zeros(len(xq) - lastnz))
56        xq[:nf_start-nf_width] = 1
57
58        return [ np.all(xq[max(k-nf_width, 0):min(k+nf_width+1, bw_stop)] == 0)
59                    for k in range(bw_stop) ]
60
61class SpectrumAnalysis(SpectrumQuantization):
62
63    def __init__(self, dt, sr):
64
65        super().__init__(dt, sr)
66
67        self.reset_off  = 0
68        self.nbits_off  = 0
69        self.nbits_spec = 0
70        self.nbits_est  = 0
71
72        (self.g_idx, self.noise_factor, self.xq, self.lastnz,
73                self.nbits_residual_max, self.xg) = \
74            (None, None, None, None, None, None)
75
76    def estimate_gain(self, x, nbytes, nbits_spec, nbits_off, g_off):
77
78        nbits = int(nbits_spec + nbits_off + 0.5)
79
80        ### Energy (dB) by 4 MDCT coefficients
81
82        hr = (self.sr >= T.SRATE_48K_HR)
83        nf = 0
84
85        if hr:
86            dt = self.dt
87            sr = self.sr
88
89            dt_ms = T.DT_MS[dt]
90            bitrate = (8 * nbytes / (dt_ms * 1e-3)).astype(int)
91
92            C = [ [ -6, 0, None, 2 ], [ -6, 0, None, 5 ] ]
93
94            reg_bits = np.clip(
95                bitrate // 12500 + C[sr - T.SRATE_48K_HR][dt], 6, 23)
96
97            M0 = np.sum(np.abs(x)) + 1e-5
98            M1 = np.sum(np.arange(len(x)) * np.abs(x)) + 1e-5
99
100            low_bits = (4 / dt_ms) * (2*dt_ms - min(M0/M1, 2*dt_ms))
101
102            nf = np.max(np.abs(x)) * np.exp2(-reg_bits - low_bits)
103
104        e = [ np.sum(x[4*k:4*(k+1)] ** 2) for k in range(len(x) // 4) ]
105        e = 10 * np.log10(2**-31 + np.array(e) + nf)
106
107        ### Compute gain index
108
109        g_idx = 255
110
111        for i in range(8):
112            factor = 1 << (7 - i)
113            g_idx -= factor
114            tmp = 0
115            iszero = 1
116
117            for ei in e[-1::-1]:
118
119                if ei * 28/20 < g_idx + g_off:
120                    if iszero == 0:
121                        tmp += 2.7*28/20
122                else:
123                    if g_idx + g_off < (ei - 43) * 28/20:
124                        tmp += 2*ei*28/20 - 2*(g_idx + g_off) - 36*28/20
125                    else:
126                        tmp += ei*28/20 - (g_idx + g_off) + 7*28/20
127                    iszero = 0
128
129            if tmp > nbits * 1.4 * 28/20 and iszero == 0:
130                g_idx += factor
131
132        ### Limit gain index
133
134        x_max = np.amax(np.abs(x))
135        if x_max > 0:
136            x_lim = [ 2**15 - 0.375,  2**23 ][hr]
137            g_min = 28 * np.log10(x_max / x_lim)
138            g_min = np.ceil(g_min).astype(int) - g_off
139            reset_off = g_idx < g_min
140        else:
141            g_min = 0
142            reset_off = True
143
144        if reset_off:
145            g_idx = g_min
146
147        return (g_min, g_idx + g_off, reset_off)
148
149    def quantize(self, g_int, x):
150
151        xg = x / 10 ** (g_int / 28)
152
153        hr = (self.sr >= T.SRATE_48K_HR)
154        offset = [ 0.375, 0.5 ][hr]
155        xq_min = [ -(2**15)  , -(2**23)   ][hr]
156        xq_max = [  (2**15)-1,  (2**23)-1 ][hr]
157
158        xq = np.where(xg < 0, np.ceil(xg - offset), np.floor(xg + offset))
159        xq = xq.astype(np.int32)
160        xq = np.fmin(np.fmax(xq, xq_min), xq_max)
161
162        nz_pairs = np.any([ xq[::2] != 0, xq[1::2] != 0 ], axis=0)
163        lastnz = len(xq) - 2 * np.argmax(nz_pairs[-1::-1])
164        if not np.any(nz_pairs):
165            lastnz = 0
166
167        return (xg, xq, lastnz)
168
169    def compute_nbits(self, nbytes, x, lastnz, nbits_spec):
170
171        mode = [ 0,   1 ][int(self.sr < T.SRATE_96K_HR and \
172                              nbytes >= 20 * (3 + min(self.sr, T.SRATE_48K)))]
173        rate = [ 0, 512 ][int(self.sr < T.SRATE_96K_HR and \
174                              nbytes >  20 * (1 + min(self.sr, T.SRATE_48K)))]
175
176        nbits_est = 0
177        nbits_trunc = 0
178        nbits_lsb = 0
179        lastnz_trunc = 2
180        c = 0
181
182        for n in range(0, lastnz, 2):
183            t = c + rate
184            if n > len(x) // 2:
185                t += 256
186
187            a = abs(x[n  ])
188            b = abs(x[n+1])
189            lev = 0
190            while max(a, b) >= 4:
191                nbits_est += \
192                    T.AC_SPEC_BITS[T.AC_SPEC_LOOKUP[t + lev*1024]][16]
193                if lev == 0 and mode == 1:
194                    nbits_lsb += 2
195                else:
196                    nbits_est += 2 * 2048
197
198                a >>= 1
199                b >>= 1
200                lev = min(lev + 1, 3)
201
202            nbits_est += \
203                T.AC_SPEC_BITS[T.AC_SPEC_LOOKUP[t + lev*1024]][a + 4*b]
204
205            a_lsb = abs(x[n  ])
206            b_lsb = abs(x[n+1])
207            nbits_est += (min(a_lsb, 1) + min(b_lsb, 1)) * 2048
208            if lev > 0 and mode == 1:
209                a_lsb >>= 1
210                b_lsb >>= 1
211                nbits_lsb += int(a_lsb == 0 and x[n  ] != 0)
212                nbits_lsb += int(b_lsb == 0 and x[n+1] != 0)
213
214            if (x[n] != 0 or x[n+1] != 0) and \
215                    (nbits_est <= nbits_spec * 2048):
216                lastnz_trunc = n + 2
217                nbits_trunc = nbits_est
218
219            t = 1 + (a + b) * (lev + 1) if lev <= 1 else 12 + lev
220            c = (c & 15) * 16 + t
221
222        nbits_est = (nbits_est + 2047) // 2048 + nbits_lsb
223        nbits_trunc = (nbits_trunc + 2047) // 2048
224
225        self.rate = rate
226        self.lsb_mode = mode == 1 and nbits_est > nbits_spec
227
228        return (nbits_est, nbits_trunc, lastnz_trunc, self.lsb_mode)
229
230    def adjust_gain(self, g_idx, nbits, nbits_spec):
231
232        T1 = [  80,  230,  380,  530,  680,  680,  830 ]
233        T2 = [ 500, 1025, 1550, 2075, 2600, 2600, 3125 ]
234        T3 = [ 850, 1700, 2550, 3400, 4250, 4250, 5100 ]
235
236        dt = self.dt
237        sr = self.sr
238
239        if nbits < T1[sr]:
240            delta = (nbits + 48) / 16
241
242        elif nbits < T2[sr]:
243            a = T1[sr] / 16 + 3
244            b = T2[sr] / 48
245            delta = a + (nbits - T1[sr]) * (b - a) / (T2[sr] - T1[sr])
246
247        elif nbits < T3[sr]:
248            delta = nbits / 48
249
250        else:
251            delta = T3[sr] / 48
252
253        delta = np.fix(delta + 0.5).astype(int)
254
255        if self.sr >= T.SRATE_48K_HR and \
256            (g_idx < 255 and nbits > nbits_spec):
257
258            factor = [ 3 + (nbits >= 520), 2, 0, 1 ][dt]
259            g_incr = int(factor * (1 + (nbits - nbits_spec) / delta))
260            return min(g_idx + g_incr, 255) - g_idx
261
262        elif self.sr < T.SRATE_48K_HR and \
263            ( (g_idx < 255 and nbits > nbits_spec) or \
264              (g_idx >   0 and nbits < nbits_spec - (delta + 2)) ):
265
266            if nbits < nbits_spec - (delta + 2):
267                return -1
268
269            if g_idx == 254 or nbits < nbits_spec + delta:
270                return 1
271
272            else:
273                return 2
274
275        return 0
276
277    def estimate_noise(self, bw, xq, lastnz, x):
278
279        i_nf = self.get_noise_indices(bw, xq, lastnz)
280        l_nf = sum(abs(x[:len(i_nf)] * i_nf)) / sum(i_nf) \
281            if sum(i_nf) > 0 else 0
282
283        return min(max(np.rint(8 - 16 * l_nf).astype(int), 0), 7)
284
285    def run(self, bw, nbytes, nbits_bw, nbits_ltpf, nbits_sns, nbits_tns, x):
286
287        sr = self.sr
288
289        ### Bit budget
290
291        hr = self.sr >= T.SRATE_48K_HR
292
293        nbits_gain = 8
294        nbits_nf   = 3
295
296        nbits_ari  = np.ceil(np.log2(len(x) / 2)).astype(int)
297        nbits_ari += 3 + int(hr) + min((8*nbytes - 1) // 1280, 2)
298
299        nbits_spec = 8*nbytes - \
300            nbits_bw - nbits_ltpf - nbits_sns - nbits_tns - \
301            nbits_gain - nbits_nf - nbits_ari
302
303        ### Global gain estimation
304
305        nbits_off = self.nbits_off + self.nbits_spec - self.nbits_est
306        nbits_off = min(40, max(-40, nbits_off))
307
308        nbits_off = 0 if self.reset_off else \
309                    0.8 * self.nbits_off + 0.2 * nbits_off
310
311        g_off = self.get_gain_offset(nbytes)
312
313        (g_min, g_int, self.reset_off) = \
314            self.estimate_gain(x, nbytes, nbits_spec, nbits_off, g_off)
315        self.nbits_off = nbits_off
316        self.nbits_spec = nbits_spec
317
318        ### Quantization
319
320        (xg, xq, lastnz) = self.quantize(g_int, x)
321
322        (nbits_est, nbits_trunc, lastnz_trunc, _) = \
323            self.compute_nbits(nbytes, xq, lastnz, nbits_spec)
324
325        self.nbits_est = nbits_est
326
327        ### Adjust gain and requantize
328
329        g_adj = self.adjust_gain(g_int - g_off, nbits_est, nbits_spec)
330        g_adj = max(g_int + g_adj, g_min + g_off) - g_int
331
332        (xg, xq, lastnz) = self.quantize(g_adj, xg)
333
334        (nbits_est, nbits_trunc, lastnz_trunc, lsb_mode) = \
335            self.compute_nbits(nbytes, xq, lastnz, nbits_spec)
336
337        self.g_idx = g_int + g_adj - g_off
338        self.xq = xq
339        self.lastnz = lastnz_trunc
340
341        self.nbits_residual_max = nbits_spec - nbits_trunc + 4
342        self.xg = xg
343
344        ### Noise factor
345
346        self.noise_factor = self.estimate_noise(bw, xq, lastnz, x)
347
348        return (self.xq, self.lastnz, self.xg)
349
350    def store(self, b):
351
352        ne = T.I[self.dt][self.sr][-1]
353        nbits_lastnz = np.ceil(np.log2(ne/2)).astype(int)
354
355        b.write_uint((self.lastnz >> 1) - 1, nbits_lastnz)
356        b.write_uint(self.lsb_mode, 1)
357        b.write_uint(self.g_idx, 8)
358
359    def encode(self, bits):
360
361        ### Noise factor
362
363        bits.write_uint(self.noise_factor, 3)
364
365        ### Quantized data
366
367        lsbs = []
368
369        x = self.xq
370        c = 0
371
372        for n in range(0, self.lastnz, 2):
373            t = c + self.rate
374            if n > len(x) // 2:
375                t += 256
376
377            a = abs(x[n  ])
378            b = abs(x[n+1])
379            lev = 0
380            while max(a, b) >= 4:
381
382                bits.ac_encode(
383                    T.AC_SPEC_CUMFREQ[T.AC_SPEC_LOOKUP[t + lev*1024]][16],
384                    T.AC_SPEC_FREQ[T.AC_SPEC_LOOKUP[t + lev*1024]][16])
385
386                if lev == 0 and self.lsb_mode:
387                    lsb_0 = a & 1
388                    lsb_1 = b & 1
389                else:
390                    bits.write_bit(a & 1)
391                    bits.write_bit(b & 1)
392
393                a >>= 1
394                b >>= 1
395                lev = min(lev + 1, 3)
396
397            bits.ac_encode(
398                T.AC_SPEC_CUMFREQ[T.AC_SPEC_LOOKUP[t + lev*1024]][a + 4*b],
399                T.AC_SPEC_FREQ[T.AC_SPEC_LOOKUP[t + lev*1024]][a + 4*b])
400
401            a_lsb = abs(x[n  ])
402            b_lsb = abs(x[n+1])
403            if lev > 0 and self.lsb_mode:
404                a_lsb >>= 1
405                b_lsb >>= 1
406
407                lsbs.append(lsb_0)
408                if a_lsb == 0 and x[n+0] != 0:
409                    lsbs.append(int(x[n+0] < 0))
410
411                lsbs.append(lsb_1)
412                if b_lsb == 0 and x[n+1] != 0:
413                    lsbs.append(int(x[n+1] < 0))
414
415            if a_lsb > 0:
416                bits.write_bit(int(x[n+0] < 0))
417
418            if b_lsb > 0:
419                bits.write_bit(int(x[n+1] < 0))
420
421            t = 1 + (a + b) * (lev + 1) if lev <= 1 else 12 + lev
422            c = (c & 15) * 16 + t
423
424        ### Residual data
425
426        if self.lsb_mode == 0:
427            nbits_residual = min(bits.get_bits_left(), self.nbits_residual_max)
428
429            for i in range(len(self.xg)):
430
431                if self.xq[i] == 0:
432                    continue
433
434                bits.write_bit(self.xg[i] >= self.xq[i])
435                nbits_residual -= 1
436                if nbits_residual <= 0:
437                    break
438
439        else:
440            nbits_residual = min(bits.get_bits_left(), len(lsbs))
441            for lsb in lsbs[:nbits_residual]:
442                bits.write_bit(lsb)
443
444
445class SpectrumSynthesis(SpectrumQuantization):
446
447    def __init__(self, dt, sr):
448
449        super().__init__(dt, sr)
450
451        (self.lastnz, self.lsb_mode, self.g_idx) = \
452            (None, None, None)
453
454    def fill_noise(self, bw, x, lastnz, f_nf, nf_seed):
455
456        i_nf = self.get_noise_indices(bw, x, lastnz)
457
458        k_nf = np.argwhere(i_nf)
459        l_nf = (8 - f_nf)/16
460
461        for k in k_nf:
462            nf_seed = (13849 + nf_seed * 31821) & 0xffff
463            x[k] = [ -l_nf, l_nf ][nf_seed < 0x8000]
464
465        return x
466
467    def load(self, b):
468
469        ne = T.I[self.dt][self.sr][-1]
470        nbits_lastnz = np.ceil(np.log2(ne/2)).astype(int)
471
472        self.lastnz = (b.read_uint(nbits_lastnz) + 1) << 1
473        self.lsb_mode = b.read_uint(1)
474        self.g_idx = b.read_uint(8)
475
476        if self.lastnz > ne:
477            raise ValueError('Invalid count of coded samples')
478
479    def decode(self, bits, bw, nbytes):
480
481        ### Noise factor
482
483        f_nf = bits.read_uint(3)
484
485        ### Quantized data
486
487        ne = T.I[self.dt][self.sr][-1]
488        x  = np.zeros(ne)
489        rate = [ 0, 512 ][int(self.sr < T.SRATE_96K_HR and \
490                              nbytes >  20 * (1 + min(self.sr, T.SRATE_48K)))]
491
492        levs = np.zeros(len(x), dtype=np.intc)
493        c = 0
494
495        for n in range(0, self.lastnz, 2):
496            t = c + rate
497            if n > len(x) // 2:
498                t += 256
499
500            for lev in range(14):
501
502                s = t + min(lev, 3) * 1024
503
504                sym = bits.ac_decode(
505                    T.AC_SPEC_CUMFREQ[T.AC_SPEC_LOOKUP[s]],
506                    T.AC_SPEC_FREQ[T.AC_SPEC_LOOKUP[s]])
507
508                if sym < 16:
509                    break
510
511                if self.lsb_mode == 0 or lev > 0:
512                    x[n  ] += bits.read_bit() << lev
513                    x[n+1] += bits.read_bit() << lev
514
515            if lev >= 14:
516                raise ValueError('Out of range value')
517
518            a = sym %  4
519            b = sym // 4
520
521            levs[n  ] = lev
522            levs[n+1] = lev
523
524            x[n  ] += a << lev
525            x[n+1] += b << lev
526
527            if x[n] and bits.read_bit():
528                x[n] = -x[n]
529
530            if x[n+1] and bits.read_bit():
531                x[n+1] = -x[n+1]
532
533            lev = min(lev, 3)
534            t = 1 + (a + b) * (lev + 1) if lev <= 1 else 12 + lev
535            c = (c & 15) * 16 + t
536
537        ### Residual data
538
539        nbits_residual = bits.get_bits_left()
540        if nbits_residual < 0:
541            raise ValueError('Out of bitstream')
542
543        if self.lsb_mode == 0:
544
545            xr = np.zeros(len(x), dtype=np.bool)
546
547            for i in range(len(x)):
548
549                if nbits_residual <= 0:
550                    xr.resize(i)
551                    break
552
553                if x[i] == 0:
554                    continue
555
556                xr[i] = bits.read_bit()
557                nbits_residual -= 1
558
559        else:
560
561            for i in range(len(levs)):
562
563                if nbits_residual <= 0:
564                    break
565
566                if levs[i] <= 0:
567                    continue
568
569                lsb = bits.read_bit()
570                nbits_residual -= 1
571                if not lsb:
572                    continue
573
574                sign = int(x[i] < 0)
575
576                if x[i] == 0:
577
578                    if nbits_residual <= 0:
579                        break
580
581                    sign = bits.read_bit()
582                    nbits_residual -= 1
583
584                x[i] += [ 1, -1 ][sign]
585
586        ### Set residual and noise
587
588        nf_seed = sum(abs(x.astype(np.intc)) * range(len(x)))
589
590        zero_frame = (self.lastnz <= 2 and x[0] == 0 and x[1] == 0
591                      and self.g_idx <= 0 and f_nf >= 7)
592
593        if self.lsb_mode == 0:
594
595            for i in range(len(xr)):
596
597                if x[i] and xr[i] == 0:
598                    x[i] += [ -0.1875, -0.3125 ][x[i] < 0]
599                elif x[i]:
600                    x[i] += [  0.1875,  0.3125 ][x[i] > 0]
601
602        if not zero_frame:
603            x = self.fill_noise(bw, x, self.lastnz, f_nf, nf_seed)
604
605        ### Rescale coefficients
606
607        g_int = self.get_gain_offset(nbytes) + self.g_idx
608        x *= 10 ** (g_int / 28)
609
610        return x
611
612
613def initial_state():
614    return { 'nbits_off' : 0.0, 'nbits_spare' : 0 }
615
616
617### ------------------------------------------------------------------------ ###
618
619def check_estimate_gain(rng, dt, sr):
620
621    ok = True
622
623    analysis = SpectrumAnalysis(dt, sr)
624
625    mismatch_count = 0
626    for i in range(10):
627        ne = T.I[dt][sr][-1]
628        x  = rng.random(ne) * i * 1e2
629
630        nbytes = 20 + int(rng.random() * 100)
631        nbits_budget = 8 * nbytes - int(rng.random() * 100)
632        nbits_off = rng.random() * 10
633        g_off = 10 - int(rng.random() * 20)
634
635        (_, g_int, reset_off) = \
636            analysis.estimate_gain(x, nbytes, nbits_budget, nbits_off, g_off)
637
638        (g_int_c, reset_off_c, _) = lc3.spec_estimate_gain(
639            dt, sr, x, nbytes, nbits_budget, nbits_off, -g_off)
640
641        if g_int_c != g_int:
642            mismatch_count += 1
643
644        ok = ok and (g_int_c == g_int or mismatch_count <= 1)
645        ok = ok and (reset_off_c == reset_off or mismatch_count <= 1)
646
647    return ok
648
649def check_quantization(rng, dt, sr):
650
651    ok = True
652
653    analysis = SpectrumAnalysis(dt, sr)
654
655    for g_int in range(-128, 128):
656
657        ne = T.I[dt][sr][-1]
658        x  = rng.random(ne) * 1e2
659        nbytes = 20 + int(rng.random() * 30)
660
661        (xg, xq, nq) = analysis.quantize(g_int, x)
662        (xg_c, nq_c) = lc3.spec_quantize(dt, sr, g_int, x)
663
664        ok = ok and np.amax(np.abs(1 - xg_c/xg)) < 1e-6
665        ok = ok and nq_c == nq
666
667    return ok
668
669def check_compute_nbits(rng, dt, sr):
670
671    ok = True
672
673    analysis = SpectrumAnalysis(dt, sr)
674
675    for nbytes in range(20, 150):
676
677        nbits_budget = nbytes * 8 - int(rng.random() * 100)
678        ne = T.I[dt][sr][-1]
679        xq = (rng.random(ne) * 8).astype(int)
680        nq = ne // 2 + int(rng.random() * ne // 2)
681
682        nq = nq - nq % 2
683        if xq[nq-2] == 0 and xq[nq-1] == 0:
684            xq[nq-2] = 1
685
686        (nbits, nbits_trunc, nq_trunc, lsb_mode) = \
687            analysis.compute_nbits(nbytes, xq, nq, nbits_budget)
688
689        (nbits_c, nq_c, _) = \
690            lc3.spec_compute_nbits(dt, sr, nbytes, xq, nq, 0)
691
692        (nbits_trunc_c, nq_trunc_c, lsb_mode_c) = \
693            lc3.spec_compute_nbits(dt, sr, nbytes, xq, nq, nbits_budget)
694
695        ok = ok and nbits_c == nbits
696        ok = ok and nbits_trunc_c == nbits_trunc
697        ok = ok and nq_trunc_c == nq_trunc
698        ok = ok and lsb_mode_c == lsb_mode
699
700    return ok
701
702def check_adjust_gain(rng, dt, sr):
703
704    ok = True
705
706    analysis = SpectrumAnalysis(dt, sr)
707
708    for g_idx in (0, 128, 254, 255):
709        for nbits in range(50, 5000, 5):
710            nbits_budget = int(nbits * (0.95 + (rng.random() * 0.1)))
711
712            g_adj = analysis.adjust_gain(g_idx, nbits, nbits_budget)
713
714            g_adj_c = lc3.spec_adjust_gain(
715                dt, sr, g_idx, nbits, nbits_budget, 0)
716
717            ok = ok and g_adj_c == g_adj
718
719    return ok
720
721def check_unit(rng, dt, sr):
722
723    ok = True
724
725    state_c = initial_state()
726
727    bwdet = m_bwdet.BandwidthDetector(dt, sr)
728    ltpf = m_ltpf.LtpfAnalysis(dt, sr)
729    tns = m_tns.TnsAnalysis(dt)
730    sns = m_sns.SnsAnalysis(dt, sr)
731    analysis = SpectrumAnalysis(dt, sr)
732
733    nbytes = 100
734
735    for i in range(10):
736        ns = T.NS[dt][sr]
737        ne = T.I[dt][sr][-1]
738
739        x = rng.random(ns) * 1e4
740        e = rng.random(min(len(x), 64)) * 1e10
741
742        if sr < T.SRATE_48K_HR:
743            bwdet.run(e)
744        pitch_present = ltpf.run(x)
745        tns.run(x[:ne], sr, False, nbytes)
746        sns.run(e, False, 0, x)
747
748        (xq, nq, xg) = analysis.run(sr, nbytes,
749            0 if sr >= T.SRATE_48K_HR else bwdet.get_nbits(),
750            ltpf.get_nbits(), sns.get_nbits(), tns.get_nbits(), x[:ne])
751
752        (xg_c, side_c) = lc3.spec_analyze(dt, sr,
753            nbytes, pitch_present, tns.get_data(), state_c, x[:ne])
754
755        ok = ok and side_c['g_idx'] == analysis.g_idx
756        ok = ok and side_c['nq'] == nq
757        ok = ok and np.amax(np.abs(1 - xg_c/xg)) < 1e-6
758
759    return ok
760
761def check_noise(rng, dt, bw, hrmode = False):
762
763    ok = True
764
765    analysis = SpectrumAnalysis(dt, bw)
766
767    xq_off = [ 0.375, 0.5 ][hrmode]
768
769    for i in range(10):
770        ne = T.I[dt][bw][-1]
771        xq = ((rng.random(ne) - 0.5) * 10 ** (0.5)).astype(int)
772        nq = ne - int(rng.random() * 5)
773        x  = xq - np.select([xq < 0, xq > 0], np.array([ xq_off, -xq_off ]))
774
775        nf = analysis.estimate_noise(bw, xq, nq, x)
776        nf_c = lc3.spec_estimate_noise(dt, bw, hrmode, x, nq)
777
778        ok = ok and nf_c == nf
779
780    return ok
781
782def check_appendix_c(dt):
783
784    i0 = dt - T.DT_7M5
785    sr = T.SRATE_16K
786
787    ok = True
788
789    state_c = initial_state()
790
791    for i in range(len(C.X_F[i0])):
792
793        ne = T.I[dt][sr][-1]
794
795        g_int = lc3.spec_estimate_gain(dt, sr, C.X_F[i0][i],
796            0, C.NBITS_SPEC[i0][i], C.NBITS_OFFSET[i0][i], -C.GG_OFF[i0][i])[0]
797        ok = ok and g_int == C.GG_IND[i0][i] + C.GG_OFF[i0][i]
798
799        (x, nq) = lc3.spec_quantize(dt, sr,
800            C.GG_IND[i0][i] + C.GG_OFF[i0][i], C.X_F[i0][i])
801        x += np.select([x < 0, x > 0], np.array([ 0.375, -0.375 ]))
802        ok = ok and np.any((np.trunc(x) - C.X_Q[i0][i]) == 0)
803        ok = ok and nq == C.LASTNZ[i0][i]
804        nbits = lc3.spec_compute_nbits(dt, sr,
805            C.NBYTES[i0], C.X_Q[i0][i], C.LASTNZ[i0][i], 0)[0]
806        ok = ok and nbits == C.NBITS_EST[i0][i]
807
808        g_adj = lc3.spec_adjust_gain(dt, sr,
809            C.GG_IND[i0][i], C.NBITS_EST[i0][i], C.NBITS_SPEC[i0][i], 0)
810        ok = ok and g_adj == C.GG_IND_ADJ[i0][i] - C.GG_IND[i0][i]
811
812        if C.GG_IND_ADJ[i0][i] != C.GG_IND[i0][i]:
813
814            (x, nq) = lc3.spec_quantize(dt, sr,
815                C.GG_IND_ADJ[i0][i] + C.GG_OFF[i0][i], C.X_F[i0][i])
816            lastnz = C.LASTNZ_REQ[i0][i]
817            x += np.select([x < 0, x > 0], np.array([ 0.375, -0.375 ]))
818            ok = ok and np.any(((np.trunc(x) - C.X_Q_REQ[i0][i])[:lastnz]) == 0)
819
820        tns_data = {
821            'nfilters' : C.NUM_TNS_FILTERS[i0][i],
822            'lpc_weighting' : [ True, True ],
823            'rc_order' : [ C.RC_ORDER[i0][i][0], 0 ],
824            'rc' : [ C.RC_I_1[i0][i] - 8, np.zeros(8, dtype = np.intc) ]
825        }
826
827        (x, side) = lc3.spec_analyze(dt, sr, C.NBYTES[i0],
828            C.PITCH_PRESENT[i0][i], tns_data, state_c, C.X_F[i0][i])
829
830        xq = x + np.select([x < 0, x > 0], np.array([ 0.375, -0.375 ]))
831        xq = np.trunc(xq)
832
833        ok = ok and np.abs(state_c['nbits_off'] - C.NBITS_OFFSET[i0][i]) < 1e-5
834        if C.GG_IND_ADJ[i0][i] != C.GG_IND[i0][i]:
835            xq = C.X_Q_REQ[i0][i]
836            nq = C.LASTNZ_REQ[i0][i]
837            ok = ok and side['g_idx'] == C.GG_IND_ADJ[i0][i]
838            ok = ok and side['nq'] == nq
839            ok = ok and np.any(((xq[:nq] - xq[:nq])) == 0)
840        else:
841            xq = C.X_Q[i0][i]
842            nq = C.LASTNZ[i0][i]
843            ok = ok and side['g_idx'] == C.GG_IND[i0][i]
844            ok = ok and side['nq'] == nq
845            ok = ok and np.any((xq[:nq] - C.X_Q[i0][i][:nq]) == 0)
846        ok = ok and side['lsb_mode'] == C.LSB_MODE[i0][i]
847
848        gg = C.GG[i0][i] if C.GG_IND_ADJ[i0][i] == C.GG_IND[i0][i] \
849                else C.GG_ADJ[i0][i]
850
851        nf = lc3.spec_estimate_noise(
852                dt, C.P_BW[i0][i], False, C.X_F[i0][i] / gg, nq)
853        ok = ok and nf == C.F_NF[i0][i]
854
855    return ok
856
857def check():
858
859    rng = np.random.default_rng(1234)
860    ok = True
861
862    for dt in range(T.NUM_DT):
863        for sr in range(T.SRATE_8K, T.SRATE_48K + 1):
864            ok = ok and check_estimate_gain(rng, dt, sr)
865            ok = ok and check_quantization(rng, dt, sr)
866            ok = ok and check_compute_nbits(rng, dt, sr)
867            ok = ok and check_adjust_gain(rng, dt, sr)
868            ok = ok and check_unit(rng, dt, sr)
869            ok = ok and check_noise(rng, dt, sr)
870
871    for dt in ( T.DT_2M5, T.DT_5M, T.DT_10M ):
872        for sr in ( T.SRATE_48K_HR, T.SRATE_96K_HR ):
873            ok = ok and check_estimate_gain(rng, dt, sr)
874            ok = ok and check_quantization(rng, dt, sr)
875            ok = ok and check_compute_nbits(rng, dt, sr)
876            ok = ok and check_adjust_gain(rng, dt, sr)
877            ok = ok and check_unit(rng, dt, sr)
878            ok = ok and check_noise(rng, dt, sr, True)
879
880    for dt in ( T.DT_7M5, T.DT_10M ):
881        ok = ok and check_appendix_c(dt)
882
883    return ok
884
885### ------------------------------------------------------------------------ ###
886