1#!/usr/bin/env python3 2# 3# Copyright 2022 Google LLC 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18import numpy as np 19import scipy.signal as signal 20import scipy.io.wavfile as wavfile 21import struct 22import argparse 23 24import lc3 25import tables as T, appendix_c as C 26 27import mdct, energy, bwdet, sns, tns, spec, ltpf 28import bitstream 29 30### ------------------------------------------------------------------------ ### 31 32class Decoder: 33 34 def __init__(self, dt_ms, sr_hz): 35 36 dt = { 7.5: T.DT_7M5, 10: T.DT_10M }[dt_ms] 37 38 sr = { 8000: T.SRATE_8K , 16000: T.SRATE_16K, 24000: T.SRATE_24K, 39 32000: T.SRATE_32K, 48000: T.SRATE_48K }[sr_hz] 40 41 self.sr = sr 42 self.ne = T.NE[dt][sr] 43 self.ns = T.NS[dt][sr] 44 45 self.mdct = mdct.MdctInverse(dt, sr) 46 47 self.bwdet = bwdet.BandwidthDetector(dt, sr) 48 self.spec = spec.SpectrumSynthesis(dt, sr) 49 self.tns = tns.TnsSynthesis(dt) 50 self.sns = sns.SnsSynthesis(dt, sr) 51 self.ltpf = ltpf.LtpfSynthesis(dt, sr) 52 53 def decode(self, data): 54 55 b = bitstream.BitstreamReader(data) 56 57 bw = self.bwdet.get(b) 58 if bw > self.sr: 59 raise ValueError('Invalid bandwidth indication') 60 61 self.spec.load(b) 62 63 self.tns.load(b, bw, len(data)) 64 65 pitch = b.read_bit() 66 67 self.sns.load(b) 68 69 if pitch: 70 self.ltpf.load(b) 71 else: 72 self.ltpf.disable() 73 74 x = self.spec.decode(b, bw, len(data)) 75 76 return (x, bw, pitch) 77 78 def synthesize(self, x, bw, pitch, nbytes): 79 80 x = self.tns.run(x, bw) 81 82 x = self.sns.run(x) 83 84 x = np.append(x, np.zeros(self.ns - self.ne)) 85 x = self.mdct.run(x) 86 87 x = self.ltpf.run(x, len(data)) 88 89 return x 90 91 def run(self, data): 92 93 (x, bw, pitch) = self.decode(data) 94 95 x = self.synthesize(x, bw, pitch, len(data)) 96 97 return x 98 99### ------------------------------------------------------------------------ ### 100 101def check_appendix_c(dt): 102 103 ok = True 104 105 dec_c = lc3.setup_decoder(int(T.DT_MS[dt] * 1000), 16000) 106 107 for i in range(len(C.BYTES_AC[dt])): 108 109 pcm = lc3.decode(dec_c, bytes(C.BYTES_AC[dt][i])) 110 ok = ok and np.max(np.abs(pcm - C.X_HAT_CLIP[dt][i])) < 1 111 112 return ok 113 114def check(): 115 116 ok = True 117 118 for dt in range(T.NUM_DT): 119 ok = ok and check_appendix_c(dt) 120 121 return ok 122 123### ------------------------------------------------------------------------ ### 124 125if __name__ == "__main__": 126 127 parser = argparse.ArgumentParser(description='LC3 Decoder Test Framework') 128 parser.add_argument('lc3_file', 129 help='Input bitstream file', type=argparse.FileType('r')) 130 parser.add_argument('--pyout', 131 help='Python output file', type=argparse.FileType('w')) 132 parser.add_argument('--cout', 133 help='C output file', type=argparse.FileType('w')) 134 args = parser.parse_args() 135 136 ### File Header ### 137 138 f_lc3 = open(args.lc3_file.name, 'rb') 139 140 header = struct.unpack('=HHHHHHHI', f_lc3.read(18)) 141 142 if header[0] != 0xcc1c: 143 raise ValueError('Invalid bitstream file') 144 145 if header[4] != 1: 146 raise ValueError('Unsupported number of channels') 147 148 sr_hz = header[2] * 100 149 bitrate = header[3] * 100 150 nchannels = header[4] 151 dt_ms = header[5] / 100 152 153 f_lc3.seek(header[1]) 154 155 ### Setup ### 156 157 dec = Decoder(dt_ms, sr_hz) 158 dec_c = lc3.setup_decoder(int(dt_ms * 1000), sr_hz) 159 160 pcm_c = np.empty(0).astype(np.int16) 161 pcm_py = np.empty(0).astype(np.int16) 162 163 ### Decoding loop ### 164 165 nframes = 0 166 167 while True: 168 169 data = f_lc3.read(2) 170 if len(data) != 2: 171 break 172 173 (frame_nbytes,) = struct.unpack('=H', data) 174 175 print('Decoding frame %d' % nframes, end='\r') 176 177 data = f_lc3.read(frame_nbytes) 178 179 x = dec.run(data) 180 pcm_py = np.append(pcm_py, 181 np.clip(np.round(x), -32768, 32767).astype(np.int16)) 182 183 x_c = lc3.decode(dec_c, data) 184 pcm_c = np.append(pcm_c, x_c) 185 186 nframes += 1 187 188 print('done ! %16s' % '') 189 190 ### Terminate ### 191 192 if args.pyout: 193 wavfile.write(args.pyout.name, sr_hz, pcm_py) 194 if args.cout: 195 wavfile.write(args.cout.name, sr_hz, pcm_c) 196 197### ------------------------------------------------------------------------ ### 198