1#!/usr/bin/env python3
2#
3# Copyright 2022 Google LLC
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18import numpy as np
19import scipy.signal as signal
20import scipy.io.wavfile as wavfile
21import struct
22import argparse
23
24import lc3
25import tables as T, appendix_c as C
26
27import mdct, energy, bwdet, sns, tns, spec, ltpf
28import bitstream
29
30### ------------------------------------------------------------------------ ###
31
32class Decoder:
33
34    def __init__(self, dt_ms, sr_hz):
35
36        dt = { 7.5: T.DT_7M5, 10: T.DT_10M }[dt_ms]
37
38        sr = {  8000: T.SRATE_8K , 16000: T.SRATE_16K, 24000: T.SRATE_24K,
39               32000: T.SRATE_32K, 48000: T.SRATE_48K }[sr_hz]
40
41        self.sr = sr
42        self.ne = T.NE[dt][sr]
43        self.ns = T.NS[dt][sr]
44
45        self.mdct = mdct.MdctInverse(dt, sr)
46
47        self.bwdet = bwdet.BandwidthDetector(dt, sr)
48        self.spec = spec.SpectrumSynthesis(dt, sr)
49        self.tns = tns.TnsSynthesis(dt)
50        self.sns = sns.SnsSynthesis(dt, sr)
51        self.ltpf = ltpf.LtpfSynthesis(dt, sr)
52
53    def decode(self, data):
54
55        b = bitstream.BitstreamReader(data)
56
57        bw = self.bwdet.get(b)
58        if bw > self.sr:
59            raise ValueError('Invalid bandwidth indication')
60
61        self.spec.load(b)
62
63        self.tns.load(b, bw, len(data))
64
65        pitch = b.read_bit()
66
67        self.sns.load(b)
68
69        if pitch:
70            self.ltpf.load(b)
71        else:
72            self.ltpf.disable()
73
74        x = self.spec.decode(b, bw, len(data))
75
76        return (x, bw, pitch)
77
78    def synthesize(self, x, bw, pitch, nbytes):
79
80        x = self.tns.run(x, bw)
81
82        x = self.sns.run(x)
83
84        x = np.append(x, np.zeros(self.ns - self.ne))
85        x = self.mdct.run(x)
86
87        x = self.ltpf.run(x, len(data))
88
89        return x
90
91    def run(self, data):
92
93        (x, bw, pitch) = self.decode(data)
94
95        x = self.synthesize(x, bw, pitch, len(data))
96
97        return x
98
99### ------------------------------------------------------------------------ ###
100
101def check_appendix_c(dt):
102
103    ok = True
104
105    dec_c = lc3.setup_decoder(int(T.DT_MS[dt] * 1000), 16000)
106
107    for i in range(len(C.BYTES_AC[dt])):
108
109        pcm = lc3.decode(dec_c, bytes(C.BYTES_AC[dt][i]))
110        ok = ok and np.max(np.abs(pcm - C.X_HAT_CLIP[dt][i])) < 1
111
112    return ok
113
114def check():
115
116    ok = True
117
118    for dt in range(T.NUM_DT):
119        ok = ok and check_appendix_c(dt)
120
121    return ok
122
123### ------------------------------------------------------------------------ ###
124
125if __name__ == "__main__":
126
127    parser = argparse.ArgumentParser(description='LC3 Decoder Test Framework')
128    parser.add_argument('lc3_file',
129        help='Input bitstream file', type=argparse.FileType('r'))
130    parser.add_argument('--pyout',
131        help='Python output file', type=argparse.FileType('w'))
132    parser.add_argument('--cout',
133        help='C output file', type=argparse.FileType('w'))
134    args = parser.parse_args()
135
136    ### File Header ###
137
138    f_lc3 = open(args.lc3_file.name, 'rb')
139
140    header = struct.unpack('=HHHHHHHI', f_lc3.read(18))
141
142    if header[0] != 0xcc1c:
143        raise ValueError('Invalid bitstream file')
144
145    if header[4] != 1:
146        raise ValueError('Unsupported number of channels')
147
148    sr_hz = header[2] * 100
149    bitrate = header[3] * 100
150    nchannels = header[4]
151    dt_ms = header[5] / 100
152
153    f_lc3.seek(header[1])
154
155    ### Setup ###
156
157    dec = Decoder(dt_ms, sr_hz)
158    dec_c = lc3.setup_decoder(int(dt_ms * 1000), sr_hz)
159
160    pcm_c  = np.empty(0).astype(np.int16)
161    pcm_py = np.empty(0).astype(np.int16)
162
163    ### Decoding loop ###
164
165    nframes = 0
166
167    while True:
168
169        data = f_lc3.read(2)
170        if len(data) != 2:
171            break
172
173        (frame_nbytes,) = struct.unpack('=H', data)
174
175        print('Decoding frame %d' % nframes, end='\r')
176
177        data = f_lc3.read(frame_nbytes)
178
179        x = dec.run(data)
180        pcm_py = np.append(pcm_py,
181            np.clip(np.round(x), -32768, 32767).astype(np.int16))
182
183        x_c = lc3.decode(dec_c, data)
184        pcm_c = np.append(pcm_c, x_c)
185
186        nframes += 1
187
188    print('done ! %16s' % '')
189
190    ### Terminate ###
191
192    if args.pyout:
193        wavfile.write(args.pyout.name, sr_hz, pcm_py)
194    if args.cout:
195        wavfile.write(args.cout.name, sr_hz, pcm_c)
196
197### ------------------------------------------------------------------------ ###
198