1#
2# Copyright 2022 Google LLC
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16
17import numpy as np
18
19import lc3
20import tables as T, appendix_c as C
21
22### ------------------------------------------------------------------------ ###
23
24class Tns:
25    SUB_LIM_2M5_NB   = [ [   3,  10,  20  ] ]
26    SUB_LIM_2M5_WB   = [ [   3,  20,  40  ] ]
27    SUB_LIM_2M5_SSWB = [ [   3,  30,  60  ] ]
28    SUB_LIM_2M5_SWB  = [ [   3,  40,  80  ] ]
29    SUB_LIM_2M5_FB   = [ [   3,  51, 100  ] ]
30
31    SUB_LIM_2M5 = [
32        SUB_LIM_2M5_NB , SUB_LIM_2M5_WB, SUB_LIM_2M5_SSWB,
33        SUB_LIM_2M5_SWB, SUB_LIM_2M5_FB, SUB_LIM_2M5_FB, SUB_LIM_2M5_FB ]
34
35    SUB_LIM_5M_NB    = [ [  6,  23,  40 ] ]
36    SUB_LIM_5M_WB    = [ [  6,  43,  80 ] ]
37    SUB_LIM_5M_SSWB  = [ [  6,  63, 120 ] ]
38    SUB_LIM_5M_SWB   = [ [  6,  43,  80 ], [  80, 120, 160 ] ]
39    SUB_LIM_5M_FB    = [ [  6,  53, 100 ], [ 100, 150, 200 ] ]
40
41    SUB_LIM_5M = [
42        SUB_LIM_5M_NB , SUB_LIM_5M_WB, SUB_LIM_5M_SSWB,
43        SUB_LIM_5M_SWB, SUB_LIM_5M_FB, SUB_LIM_5M_FB, SUB_LIM_5M_FB ]
44
45    SUB_LIM_7M5_NB   = [ [   9,  26,  43,  60 ] ]
46    SUB_LIM_7M5_WB   = [ [   9,  46,  83, 120 ] ]
47    SUB_LIM_7M5_SSWB = [ [   9,  66, 123, 180 ] ]
48    SUB_LIM_7M5_SWB  = [ [   9,  46,  82, 120 ], [ 120, 159, 200, 240 ] ]
49    SUB_LIM_7M5_FB   = [ [   9,  56, 103, 150 ], [ 150, 200, 250, 300 ] ]
50
51    SUB_LIM_7M5 = [
52        SUB_LIM_7M5_NB , SUB_LIM_7M5_WB, SUB_LIM_7M5_SSWB,
53        SUB_LIM_7M5_SWB, SUB_LIM_7M5_FB, None, None ]
54
55    SUB_LIM_10M_NB   = [ [  12,  34,  57,  80 ] ]
56    SUB_LIM_10M_WB   = [ [  12,  61, 110, 160 ] ]
57    SUB_LIM_10M_SSWB = [ [  12,  88, 164, 240 ] ]
58    SUB_LIM_10M_SWB  = [ [  12,  61, 110, 160 ], [ 160, 213, 266, 320 ] ]
59    SUB_LIM_10M_FB   = [ [  12,  74, 137, 200 ], [ 200, 266, 333, 400 ] ]
60
61    SUB_LIM_10M = [
62        SUB_LIM_10M_NB , SUB_LIM_10M_WB, SUB_LIM_10M_SSWB,
63        SUB_LIM_10M_SWB, SUB_LIM_10M_FB, SUB_LIM_10M_FB, SUB_LIM_10M_FB ]
64
65    SUB_LIM = [ SUB_LIM_2M5, SUB_LIM_5M, SUB_LIM_7M5, SUB_LIM_10M ]
66
67
68    FREQ_LIM_2M5_NB   = [   3,  20 ]
69    FREQ_LIM_2M5_WB   = [   3,  40 ]
70    FREQ_LIM_2M5_SSWB = [   3,  60 ]
71    FREQ_LIM_2M5_SWB  = [   3,  80 ]
72    FREQ_LIM_2M5_FB   = [   3, 100 ]
73
74    FREQ_LIM_2M5 = [
75        FREQ_LIM_2M5_NB , FREQ_LIM_2M5_WB, FREQ_LIM_2M5_SSWB,
76        FREQ_LIM_2M5_SWB, FREQ_LIM_2M5_FB, FREQ_LIM_2M5_FB, FREQ_LIM_2M5_FB ]
77
78    FREQ_LIM_5M_NB    = [   6,  40 ]
79    FREQ_LIM_5M_WB    = [   6,  80 ]
80    FREQ_LIM_5M_SSWB  = [   6, 120 ]
81    FREQ_LIM_5M_SWB   = [   6,  80, 160 ]
82    FREQ_LIM_5M_FB    = [   6, 100, 200 ]
83
84    FREQ_LIM_5M = [
85        FREQ_LIM_5M_NB , FREQ_LIM_5M_WB, FREQ_LIM_5M_SSWB,
86        FREQ_LIM_5M_SWB, FREQ_LIM_5M_FB, FREQ_LIM_5M_FB, FREQ_LIM_5M_FB ]
87
88    FREQ_LIM_7M5_NB   = [   9,  60 ]
89    FREQ_LIM_7M5_WB   = [   9, 120 ]
90    FREQ_LIM_7M5_SSWB = [   9, 180 ]
91    FREQ_LIM_7M5_SWB  = [   9, 120, 240 ]
92    FREQ_LIM_7M5_FB   = [   9, 150, 300 ]
93
94    FREQ_LIM_7M5 = [
95        FREQ_LIM_7M5_NB , FREQ_LIM_7M5_WB, FREQ_LIM_7M5_SSWB,
96        FREQ_LIM_7M5_SWB, FREQ_LIM_7M5_FB, None, None ]
97
98    FREQ_LIM_10M_NB   = [  12,  80 ]
99    FREQ_LIM_10M_WB   = [  12, 160 ]
100    FREQ_LIM_10M_SSWB = [  12, 240 ]
101    FREQ_LIM_10M_SWB  = [  12, 160, 320 ]
102    FREQ_LIM_10M_FB   = [  12, 200, 400 ]
103
104    FREQ_LIM_10M = [
105        FREQ_LIM_10M_NB , FREQ_LIM_10M_WB, FREQ_LIM_10M_SSWB,
106        FREQ_LIM_10M_SWB, FREQ_LIM_10M_FB, FREQ_LIM_10M_FB, FREQ_LIM_10M_FB ]
107
108    FREQ_LIM = [ FREQ_LIM_2M5, FREQ_LIM_5M, FREQ_LIM_7M5, FREQ_LIM_10M ]
109
110
111    def __init__(self, dt):
112
113        self.dt = dt
114
115        (self.nfilters, self.lpc_weighting, self.rc_order, self.rc) = \
116            (0, False, np.array([ 0, 0 ]), np.array([ 0, 0 ]))
117
118    def get_data(self):
119
120        rc = np.append(self.rc - 8, np.zeros((2, 8 - len(self.rc[0]))), axis=1)
121
122        return { 'nfilters' : self.nfilters,
123                 'lpc_weighting' : self.lpc_weighting,
124                 'rc_order' : self.rc_order, 'rc' : rc }
125
126    def get_nbits(self):
127
128        lpc_weighting = self.lpc_weighting
129        nbits = 0
130
131        for f in range(self.nfilters):
132            rc_order = self.rc_order[f]
133            rc = self.rc[f]
134
135            nbits_order = T.TNS_ORDER_BITS[int(lpc_weighting)][rc_order]
136            nbits_coef = sum([ T.TNS_COEF_BITS[k][rc[k]]
137                                  for k in range(rc_order) ])
138
139            nbits += ((2048 + nbits_order + nbits_coef) + 2047) >> 11
140
141        return nbits
142
143
144class TnsAnalysis(Tns):
145
146    def __init__(self, dt):
147
148        super().__init__(dt)
149
150    def compute_lpc_coeffs(self, bw, f, x):
151
152        ### Normalized autocorrelation function
153
154        S = Tns.SUB_LIM[self.dt][bw][f]
155        maxorder = [ 4, 8 ][self.dt > T.DT_5M]
156
157        r = np.append([ 3 ], np.zeros(maxorder))
158        e = [ sum(x[S[s]:S[s+1]] ** 2) for s in range(len(S)-1) ]
159
160        for k in range(len(r) if sum(e) > 0 else 0):
161            c = [ np.dot(x[S[s]:S[s+1]-k], x[S[s]+k:S[s+1]])
162                      for s in range(len(S)-1) ]
163
164            r[k] = np.sum( np.array(c) / np.array(e) )
165
166        r *= np.exp(-0.5 * (0.02 * np.pi * np.arange(1+maxorder)) ** 2)
167
168        ### Levinson-Durbin recursion
169
170        err = r[0]
171        a = np.ones(len(r))
172
173        for k in range(1, len(a)):
174
175            rc = -sum(a[:k] * r[k:0:-1]) / err
176
177            a[1:k] += rc * a[k-1:0:-1]
178            a[k] = rc
179
180            err *= 1 - rc ** 2
181
182        return (r[0] / err, a)
183
184    def lpc_weight(self, pred_gain, a):
185
186        gamma = 1 - (1 - 0.85) * (2 - pred_gain) / (2 - 1.5)
187        return a * np.power(gamma, np.arange(len(a)))
188
189    def coeffs_reflexion(self, a):
190
191        rc = np.zeros(len(a)-1)
192        b  = a.copy()
193
194        for k in range(len(rc), 0, -1):
195            rc[k-1] = b[k]
196            e = 1 - rc[k-1] ** 2
197            b[1:k] = (b[1:k] - rc[k-1] * b[k-1:0:-1]) / e
198
199        return rc
200
201    def quantization(self, rc, lpc_weighting):
202
203        delta = np.pi / 17
204        rc_i = np.rint(np.arcsin(rc) / delta).astype(int) + 8
205        rc_q = np.sin(delta * (rc_i - 8))
206        rc_q = np.rint(rc_q * 2**15) / 2**15
207
208        rc_order = len(rc_i) - np.argmin(rc_i[::-1] == 8)
209
210        return (rc_order, rc_q, rc_i)
211
212    def filtering(self, st, x, rc_order, rc):
213
214        y = np.empty(len(x))
215
216        for i in range(len(x)):
217
218            xi = x[i]
219            s1 = xi
220
221            for k in range(rc_order):
222                s0 = st[k]
223                st[k] = s1
224
225                s1  = rc[k] * xi + s0
226                xi += rc[k] * s0
227
228            y[i] = xi
229
230        return y
231
232    def run(self, x, bw, nn_flag, nbytes):
233
234        fstate = np.zeros(8)
235        y = x.copy()
236
237        self.nfilters = len(Tns.SUB_LIM[self.dt][bw])
238        maxorder = [ 4, 8 ][self.dt > T.DT_5M]
239
240        self.lpc_weighting = nbytes < 120 * (1 + self.dt) / 8
241
242        self.rc_order = np.zeros(2, dtype=np.intc)
243        self.rc = np.zeros((2, maxorder), dtype=np.intc)
244
245        for f in range(self.nfilters):
246
247            (pred_gain, a) = self.compute_lpc_coeffs(bw, f, x)
248
249            tns_off = pred_gain <= 1.5 or nn_flag
250            if tns_off:
251                continue
252
253            if self.lpc_weighting and pred_gain < 2:
254                a = self.lpc_weight(pred_gain, a)
255
256            rc = self.coeffs_reflexion(a)
257
258            (rc_order, rc_q, rc_i) = \
259                self.quantization(rc, self.lpc_weighting)
260
261            self.rc_order[f] = rc_order
262            self.rc[f] = rc_i
263
264            if rc_order > 0:
265                i0 = Tns.FREQ_LIM[self.dt][bw][f]
266                i1 = Tns.FREQ_LIM[self.dt][bw][f+1]
267
268                y[i0:i1] = self.filtering(
269                    fstate, x[i0:i1], rc_order, rc_q)
270
271        return y
272
273    def store(self, b):
274
275        for f in range(self.nfilters):
276            lpc_weighting = self.lpc_weighting
277            rc_order = self.rc_order[f]
278            rc = self.rc[f]
279
280            b.write_bit(min(rc_order, 1))
281
282            if rc_order > 0:
283                b.ac_encode(
284                    T.TNS_ORDER_CUMFREQ[int(lpc_weighting)][rc_order-1],
285                    T.TNS_ORDER_FREQ[int(lpc_weighting)][rc_order-1]    )
286
287            for k in range(rc_order):
288                b.ac_encode(T.TNS_COEF_CUMFREQ[k][rc[k]],
289                            T.TNS_COEF_FREQ[k][rc[k]]    )
290
291
292class TnsSynthesis(Tns):
293
294    def filtering(self, st, x, rc_order, rc):
295
296        y = x.copy()
297
298        for i in range(len(x)):
299
300            xi = x[i] - rc[rc_order-1] * st[rc_order-1]
301            for k in range(rc_order-2, -1, -1):
302                xi -= rc[k] * st[k]
303                st[k+1] = xi * rc[k] + st[k]
304            st[0] = xi
305
306            y[i] = xi
307
308        return y
309
310    def load(self, b, bw, nbytes):
311
312        self.nfilters = len(Tns.SUB_LIM[self.dt][bw])
313        self.lpc_weighting = nbytes < 120 * (1 + self.dt) / 8
314        self.rc_order = np.zeros(2, dtype=np.intc)
315        self.rc = 8 * np.ones((2, 8), dtype=np.intc)
316
317        for f in range(self.nfilters):
318
319            if not b.read_bit():
320                continue
321
322            rc_order = 1 + b.ac_decode(
323                T.TNS_ORDER_CUMFREQ[int(self.lpc_weighting)],
324                T.TNS_ORDER_FREQ[int(self.lpc_weighting)])
325
326            self.rc_order[f] = rc_order
327
328            for k in range(rc_order):
329                rc = b.ac_decode(T.TNS_COEF_CUMFREQ[k], T.TNS_COEF_FREQ[k])
330                self.rc[f][k] = rc
331
332    def run(self, x, bw):
333
334        fstate = np.zeros(8)
335        y = x.copy()
336
337        for f in range(self.nfilters):
338
339            rc_order = self.rc_order[f]
340            rc = np.sin((np.pi / 17) * (self.rc[f] - 8))
341            rc = np.rint(rc * 2**15) / 2**15
342
343            if rc_order > 0:
344                i0 = Tns.FREQ_LIM[self.dt][bw][f]
345                i1 = Tns.FREQ_LIM[self.dt][bw][f+1]
346
347                y[i0:i1] = self.filtering(
348                    fstate, x[i0:i1], rc_order, rc)
349
350        return y
351
352
353### ------------------------------------------------------------------------ ###
354
355def check_analysis(rng, dt, bw):
356
357    ok = True
358
359    analysis = TnsAnalysis(dt)
360    nbytes_lim = int((48 * T.DT_MS[dt]) // 8)
361
362    for i in range(10):
363        ne = T.I[dt][bw][-1]
364        x  = rng.random(ne) * 1e2
365        x  = pow(x, .5 + i/5)
366
367        for nn_flag in (True, False):
368            for nbytes in (nbytes_lim, nbytes_lim + 1):
369
370                y = analysis.run(x, bw, nn_flag, nbytes)
371                (y_c, data_c) = lc3.tns_analyze(dt, bw, nn_flag, nbytes, x)
372
373                ok = ok and data_c['lpc_weighting'] == analysis.lpc_weighting
374                ok = ok and data_c['nfilters'] == analysis.nfilters
375                for f in range(analysis.nfilters):
376                    rc_order = analysis.rc_order[f]
377                    rc_order_c = data_c['rc_order'][f]
378                    rc_c = 8 + data_c['rc'][f]
379                    ok = ok and rc_order_c == rc_order
380                    ok = ok and not np.any(rc_c[:rc_order] - analysis.rc[f][:rc_order])
381
382                ok = ok and lc3.tns_get_nbits(data_c) == analysis.get_nbits()
383                ok = ok and np.amax(np.abs(y_c - y)) < 1e-2
384
385    return ok
386
387def check_synthesis(rng, dt, bw):
388
389    ok = True
390    synthesis = TnsSynthesis(dt)
391
392    for i in range(100):
393
394        ne = T.I[dt][bw][-1]
395        x  = rng.random(ne) * 1e2
396
397        maxorder = [ 4, 8 ][dt > T.DT_5M]
398        synthesis.nfilters = 1 + int(dt >= T.DT_5M and bw >= T.SRATE_32K)
399        synthesis.rc_order = rng.integers(0, 1+maxorder, 2)
400        synthesis.rc = rng.integers(0, 17, 16).reshape(2, 8)
401
402        y = synthesis.run(x, bw)
403        y_c = lc3.tns_synthesize(dt, bw, synthesis.get_data(), x)
404
405        ok = ok and np.amax(np.abs(y_c - y) < 1e-4)
406
407    return ok
408
409def check_analysis_appendix_c(dt):
410
411    i0 = dt - T.DT_7M5
412    sr = T.SRATE_16K
413
414    ok = True
415
416    fs = Tns.FREQ_LIM[i0][sr][0]
417    fe = Tns.FREQ_LIM[i0][sr][1]
418    st = np.zeros(8)
419
420    for i in range(len(C.X_S[i0])):
421
422        (_, a) = lc3.tns_compute_lpc_coeffs(dt, sr, C.X_S[i0][i])
423        ok = ok and np.amax(np.abs(a[0] - C.TNS_LEV_A[i0][i])) < 1e-5
424
425        rc = lc3.tns_lpc_reflection(dt, a[0])
426        ok = ok and np.amax(np.abs(rc - C.TNS_LEV_RC[i0][i])) < 1e-5
427
428        (rc_order, rc_i) = lc3.tns_quantize_rc(dt, C.TNS_LEV_RC[i0][i])
429        ok = ok and rc_order == C.RC_ORDER[i0][i][0]
430        ok = ok and np.any((rc_i + 8) - C.RC_I_1[i0][i] == 0)
431
432        rc_q = lc3.tns_unquantize_rc(rc_i, rc_order)
433        ok = ok and np.amax(np.abs(rc_q - C.RC_Q_1[i0][i])) < 1e-6
434
435        (x, side) = lc3.tns_analyze(dt, sr, False, C.NBYTES[i0], C.X_S[i0][i])
436        ok = ok and side['nfilters'] == 1
437        ok = ok and side['rc_order'][0] == C.RC_ORDER[i0][i][0]
438        ok = ok and not np.any((side['rc'][0] + 8) - C.RC_I_1[i0][i])
439        ok = ok and lc3.tns_get_nbits(side) == C.NBITS_TNS[i0][i]
440        ok = ok and np.amax(np.abs(x - C.X_F[i0][i])) < 1e-3
441
442    return ok
443
444def check_synthesis_appendix_c(dt):
445
446    i0 = dt - T.DT_7M5
447    sr = T.SRATE_16K
448
449    ok = True
450
451    for i in range(len(C.X_HAT_Q[i0])):
452
453        side = {
454            'nfilters' : 1,
455            'lpc_weighting' : C.NBYTES[i0] < 120 * (1 + dt) / 8,
456            'rc_order': C.RC_ORDER[i0][i],
457            'rc': [ C.RC_I_1[i0][i] - 8, C.RC_I_2[i0][i] - 8 ]
458        }
459
460        g_int = C.GG_IND_ADJ[i0][i] + C.GG_OFF[i0][i]
461        x = C.X_HAT_Q[i0][i] * (10 ** (g_int / 28))
462
463        x = lc3.tns_synthesize(dt, sr, side, x)
464        ok = ok and np.amax(np.abs(x - C.X_HAT_TNS[i0][i])) < 1e-3
465
466    sr = T.SRATE_48K
467    if dt != T.DT_10M:
468        return ok
469
470    side = {
471        'nfilters' : 2,
472        'lpc_weighting' : False,
473        'rc_order': C.RC_ORDER_48K_10M,
474        'rc': [ C.RC_I_1_48K_10M - 8, C.RC_I_2_48K_10M - 8 ]
475    }
476
477    x = C.X_HAT_F_48K_10M
478    x = lc3.tns_synthesize(dt, sr, side, x)
479    ok = ok and np.amax(np.abs(x - C.X_HAT_TNS_48K_10M)) < 1e-3
480
481    return ok
482
483def check():
484
485    rng = np.random.default_rng(1234)
486    ok = True
487
488    for dt in range(T.NUM_DT):
489        for sr in range(T.SRATE_8K, T.SRATE_48K + 1):
490            ok = ok and check_analysis(rng, dt, sr)
491            ok = ok and check_synthesis(rng, dt, sr)
492
493    for dt in ( T.DT_2M5, T.DT_5M, T.DT_10M ):
494        for sr in ( T.SRATE_48K_HR, T.SRATE_96K_HR ):
495            ok = ok and check_analysis(rng, dt, sr)
496
497    for dt in ( T.DT_7M5, T.DT_10M ):
498        check_analysis_appendix_c(dt)
499        check_synthesis_appendix_c(dt)
500
501    return ok
502
503### ------------------------------------------------------------------------ ###
504