1#
2# Copyright 2022 Google LLC
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16
17import numpy as np
18
19import lc3
20import tables as T, appendix_c as C
21
22
23BW_START = [
24    [ [], [ 51 ], [ 45, 58 ], [ 42, 53, 60 ], [ 40, 51, 57, 61 ] ],
25    [ [], [ 53 ], [ 47, 59 ], [ 44, 54, 60 ], [ 41, 51, 57, 61 ] ]
26]
27
28BW_STOP = [
29    [ [], [ 63 ], [ 55, 63 ], [ 51, 58, 63 ], [ 48, 55, 60, 63 ] ],
30    [ [], [ 63 ], [ 56, 63 ], [ 52, 59, 63 ], [ 49, 55, 60, 63 ] ]
31]
32
33TQ = [ 20, 10, 10, 10 ]
34
35TC = [ 15, 23, 20, 20 ]
36L  = [ [ 4, 4, 3, 2 ], [ 4, 4, 3, 1 ] ]
37
38
39### ------------------------------------------------------------------------ ###
40
41class BandwidthDetector:
42
43    def __init__(self, dt, sr):
44
45        self.dt = dt
46        self.sr = sr
47
48    def run(self, e):
49
50        dt = self.dt
51        sr = self.sr
52
53        ### Stage 1, determine bw0 candidate
54
55        bw0 = 0
56
57        for bw in range(sr):
58            i0 = BW_START[dt][sr][bw]
59            i1 = BW_STOP[dt][sr][bw]
60            if np.mean(e[i0:i1+1]) >= TQ[bw]:
61                bw0 = bw + 1
62
63        ### Stage 2, Cut-off random coefficients at each steps
64
65        bw = bw0
66
67        if bw0 < sr:
68            l  = L[dt][bw0]
69            i0 = BW_START[dt][sr][bw0] - l
70            i1 = BW_START[dt][sr][bw0]
71
72            c = 10 * np.log10(1e-31 + e[i0-l+1:i1-l+2] / e[i0+1:i1+2])
73            if np.amax(c) <= TC[bw0]:
74                bw = sr
75
76        self.bw = bw
77        return self.bw
78
79    def get_nbits(self):
80
81        return 0 if self.sr == 0 else \
82               1 + np.log2(self.sr).astype(int)
83
84    def store(self, b):
85
86        b.write_uint(self.bw, self.get_nbits())
87
88    def get(self, b):
89
90        return b.read_uint(self.get_nbits())
91
92### ------------------------------------------------------------------------ ###
93
94def check_unit(rng, dt, sr):
95
96    ok = True
97
98    bwdet = BandwidthDetector(dt, sr)
99
100    for bw0 in range(sr+1):
101        for drop in range(10):
102
103            ### Generate random 'high' energy and
104            ### scale relevant bands to select 'bw0'
105
106            e = 20 + 100 * rng.random(64)
107
108            for i in range(sr):
109                if i+1 != bw0:
110                    i0 = BW_START[dt][sr][i]
111                    i1 = BW_STOP[dt][sr][i]
112                    e[i0:i1+1] /= (np.mean(e[i0:i1+1]) / TQ[i] + 1e-3)
113
114            ### Stage 2 Condition,
115            ### cut-off random coefficients at each steps
116
117            if bw0 < sr:
118                l  = L[dt][bw0]
119                i0 = BW_START[dt][sr][bw0] - l
120                i1 = BW_START[dt][sr][bw0]
121
122                e[i0-l+1:i1+2] /= np.power(10, np.arange(2*l+1) / (1 + drop))
123
124            ### Check with implementation
125
126            bw_c = lc3.bwdet_run(dt, sr, e)
127
128            ok = ok and bw_c == bwdet.run(e)
129
130    return ok
131
132def check_appendix_c(dt):
133
134    sr = T.SRATE_16K
135    ok = True
136
137    E_B  = C.E_B[dt]
138    P_BW = C.P_BW[dt]
139
140    bw = lc3.bwdet_run(dt, sr, E_B[0])
141    ok = ok and bw == P_BW[0]
142
143    bw = lc3.bwdet_run(dt, sr, E_B[1])
144    ok = ok and bw == P_BW[1]
145
146    return ok
147
148def check():
149
150    rng = np.random.default_rng(1234)
151
152    ok = True
153    for dt in range(T.NUM_DT):
154        for sr in range(T.NUM_SRATE):
155            ok = ok and check_unit(rng, dt, sr)
156
157    for dt in range(T.NUM_DT):
158        ok = ok and check_appendix_c(dt)
159
160    return ok
161
162### ------------------------------------------------------------------------ ###
163