1 /******************************************************************************
2  *
3  *  Copyright 2022 Google LLC
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at:
8  *
9  *  http://www.apache.org/licenses/LICENSE-2.0
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  ******************************************************************************/
18 
19 #include "sns.h"
20 #include "tables.h"
21 
22 
23 /* ----------------------------------------------------------------------------
24  *  DCT-16
25  * -------------------------------------------------------------------------- */
26 
27 /**
28  * Matrix of DCT-16 coefficients
29  *
30  * M[n][k] = 2f cos( Pi k (2n + 1) / 2N )
31  *
32  *   k = [0..N-1], n = [0..N-1], N = 16
33  *   f = sqrt(1/4N) for k=0, sqrt(1/2N) otherwise
34  */
35 static const float dct16_m[16][16] = {
36 
37     {  2.50000000e-01,  3.51850934e-01,  3.46759961e-01,  3.38329500e-01,
38        3.26640741e-01,  3.11806253e-01,  2.93968901e-01,  2.73300467e-01,
39        2.50000000e-01,  2.24291897e-01,  1.96423740e-01,  1.66663915e-01,
40        1.35299025e-01,  1.02631132e-01,  6.89748448e-02,  3.46542923e-02 },
41 
42     {  2.50000000e-01,  3.38329500e-01,  2.93968901e-01,  2.24291897e-01,
43        1.35299025e-01,  3.46542923e-02, -6.89748448e-02, -1.66663915e-01,
44       -2.50000000e-01, -3.11806253e-01, -3.46759961e-01, -3.51850934e-01,
45       -3.26640741e-01, -2.73300467e-01, -1.96423740e-01, -1.02631132e-01 },
46 
47      { 2.50000000e-01,  3.11806253e-01,  1.96423740e-01,  3.46542923e-02,
48       -1.35299025e-01, -2.73300467e-01, -3.46759961e-01, -3.38329500e-01,
49       -2.50000000e-01, -1.02631132e-01,  6.89748448e-02,  2.24291897e-01,
50        3.26640741e-01,  3.51850934e-01,  2.93968901e-01,  1.66663915e-01 },
51 
52      { 2.50000000e-01,  2.73300467e-01,  6.89748448e-02, -1.66663915e-01,
53       -3.26640741e-01, -3.38329500e-01, -1.96423740e-01,  3.46542923e-02,
54        2.50000000e-01,  3.51850934e-01,  2.93968901e-01,  1.02631132e-01,
55       -1.35299025e-01, -3.11806253e-01, -3.46759961e-01, -2.24291897e-01 },
56 
57     {  2.50000000e-01,  2.24291897e-01, -6.89748448e-02, -3.11806253e-01,
58       -3.26640741e-01, -1.02631132e-01,  1.96423740e-01,  3.51850934e-01,
59        2.50000000e-01, -3.46542923e-02, -2.93968901e-01, -3.38329500e-01,
60       -1.35299025e-01,  1.66663915e-01,  3.46759961e-01,  2.73300467e-01 },
61 
62     {  2.50000000e-01,  1.66663915e-01, -1.96423740e-01, -3.51850934e-01,
63       -1.35299025e-01,  2.24291897e-01,  3.46759961e-01,  1.02631132e-01,
64       -2.50000000e-01, -3.38329500e-01, -6.89748448e-02,  2.73300467e-01,
65        3.26640741e-01,  3.46542923e-02, -2.93968901e-01, -3.11806253e-01 },
66 
67     {  2.50000000e-01,  1.02631132e-01, -2.93968901e-01, -2.73300467e-01,
68        1.35299025e-01,  3.51850934e-01,  6.89748448e-02, -3.11806253e-01,
69       -2.50000000e-01,  1.66663915e-01,  3.46759961e-01,  3.46542923e-02,
70       -3.26640741e-01, -2.24291897e-01,  1.96423740e-01,  3.38329500e-01 },
71 
72     {  2.50000000e-01,  3.46542923e-02, -3.46759961e-01, -1.02631132e-01,
73        3.26640741e-01,  1.66663915e-01, -2.93968901e-01, -2.24291897e-01,
74        2.50000000e-01,  2.73300467e-01, -1.96423740e-01, -3.11806253e-01,
75        1.35299025e-01,  3.38329500e-01, -6.89748448e-02, -3.51850934e-01 },
76 
77     {  2.50000000e-01, -3.46542923e-02, -3.46759961e-01,  1.02631132e-01,
78        3.26640741e-01, -1.66663915e-01, -2.93968901e-01,  2.24291897e-01,
79        2.50000000e-01, -2.73300467e-01, -1.96423740e-01,  3.11806253e-01,
80        1.35299025e-01, -3.38329500e-01, -6.89748448e-02,  3.51850934e-01 },
81 
82     {  2.50000000e-01, -1.02631132e-01, -2.93968901e-01,  2.73300467e-01,
83        1.35299025e-01, -3.51850934e-01,  6.89748448e-02,  3.11806253e-01,
84       -2.50000000e-01, -1.66663915e-01,  3.46759961e-01, -3.46542923e-02,
85       -3.26640741e-01,  2.24291897e-01,  1.96423740e-01, -3.38329500e-01 },
86 
87     {  2.50000000e-01, -1.66663915e-01, -1.96423740e-01,  3.51850934e-01,
88       -1.35299025e-01, -2.24291897e-01,  3.46759961e-01, -1.02631132e-01,
89       -2.50000000e-01,  3.38329500e-01, -6.89748448e-02, -2.73300467e-01,
90        3.26640741e-01, -3.46542923e-02, -2.93968901e-01,  3.11806253e-01 },
91 
92     {  2.50000000e-01, -2.24291897e-01, -6.89748448e-02,  3.11806253e-01,
93       -3.26640741e-01,  1.02631132e-01,  1.96423740e-01, -3.51850934e-01,
94        2.50000000e-01,  3.46542923e-02, -2.93968901e-01,  3.38329500e-01,
95       -1.35299025e-01, -1.66663915e-01,  3.46759961e-01, -2.73300467e-01 },
96 
97     {  2.50000000e-01, -2.73300467e-01,  6.89748448e-02,  1.66663915e-01,
98       -3.26640741e-01,  3.38329500e-01, -1.96423740e-01, -3.46542923e-02,
99        2.50000000e-01, -3.51850934e-01,  2.93968901e-01, -1.02631132e-01,
100       -1.35299025e-01,  3.11806253e-01, -3.46759961e-01,  2.24291897e-01 },
101 
102     {  2.50000000e-01, -3.11806253e-01,  1.96423740e-01, -3.46542923e-02,
103       -1.35299025e-01,  2.73300467e-01, -3.46759961e-01,  3.38329500e-01,
104       -2.50000000e-01,  1.02631132e-01,  6.89748448e-02, -2.24291897e-01,
105        3.26640741e-01, -3.51850934e-01,  2.93968901e-01, -1.66663915e-01 },
106 
107     {  2.50000000e-01, -3.38329500e-01,  2.93968901e-01, -2.24291897e-01,
108        1.35299025e-01, -3.46542923e-02, -6.89748448e-02,  1.66663915e-01,
109       -2.50000000e-01,  3.11806253e-01, -3.46759961e-01,  3.51850934e-01,
110       -3.26640741e-01,  2.73300467e-01, -1.96423740e-01,  1.02631132e-01 },
111 
112     {  2.50000000e-01, -3.51850934e-01,  3.46759961e-01, -3.38329500e-01,
113        3.26640741e-01, -3.11806253e-01,  2.93968901e-01, -2.73300467e-01,
114        2.50000000e-01, -2.24291897e-01,  1.96423740e-01, -1.66663915e-01,
115        1.35299025e-01, -1.02631132e-01,  6.89748448e-02, -3.46542923e-02 },
116 
117 };
118 
119 /**
120  * Forward DCT-16 transformation
121  * x, y            Input and output 16 values
122  */
dct16_forward(const float * x,float * y)123 LC3_HOT static void dct16_forward(const float *x, float *y)
124 {
125     for (int i = 0, j; i < 16; i++)
126         for (y[i] = 0, j = 0; j < 16; j++)
127             y[i] += x[j] * dct16_m[j][i];
128 }
129 
130 /**
131  * Inverse DCT-16 transformation
132  * x, y            Input and output 16 values
133  */
dct16_inverse(const float * x,float * y)134 LC3_HOT static void dct16_inverse(const float *x, float *y)
135 {
136     for (int i = 0, j; i < 16; i++)
137         for (y[i] = 0, j = 0; j < 16; j++)
138             y[i] += x[j] * dct16_m[i][j];
139 }
140 
141 
142 /* ----------------------------------------------------------------------------
143  *  Scale factors
144  * -------------------------------------------------------------------------- */
145 
146 /**
147  * Scale factors
148  * dt, sr          Duration and samplerate of the frame
149  * eb              Energy estimation per bands
150  * att             1: Attack detected  0: Otherwise
151  * scf             Output 16 scale factors
152  */
compute_scale_factors(enum lc3_dt dt,enum lc3_srate sr,const float * eb,bool att,float * scf)153 LC3_HOT static void compute_scale_factors(
154     enum lc3_dt dt, enum lc3_srate sr,
155     const float *eb, bool att, float *scf)
156 {
157     /* Pre-emphasis gain table :
158      * Ge[b] = 10 ^ (b * g_tilt) / 630 , b = [0..63] */
159 
160     static const float ge_table[LC3_NUM_SRATE][LC3_NUM_BANDS] = {
161 
162         [LC3_SRATE_8K] = { /* g_tilt = 14 */
163             1.00000000e+00, 1.05250029e+00, 1.10775685e+00, 1.16591440e+00,
164             1.22712524e+00, 1.29154967e+00, 1.35935639e+00, 1.43072299e+00,
165             1.50583635e+00, 1.58489319e+00, 1.66810054e+00, 1.75567629e+00,
166             1.84784980e+00, 1.94486244e+00, 2.04696827e+00, 2.15443469e+00,
167             2.26754313e+00, 2.38658979e+00, 2.51188643e+00, 2.64376119e+00,
168             2.78255940e+00, 2.92864456e+00, 3.08239924e+00, 3.24422608e+00,
169             3.41454887e+00, 3.59381366e+00, 3.78248991e+00, 3.98107171e+00,
170             4.19007911e+00, 4.41005945e+00, 4.64158883e+00, 4.88527357e+00,
171             5.14175183e+00, 5.41169527e+00, 5.69581081e+00, 5.99484250e+00,
172             6.30957344e+00, 6.64082785e+00, 6.98947321e+00, 7.35642254e+00,
173             7.74263683e+00, 8.14912747e+00, 8.57695899e+00, 9.02725178e+00,
174             9.50118507e+00, 1.00000000e+01, 1.05250029e+01, 1.10775685e+01,
175             1.16591440e+01, 1.22712524e+01, 1.29154967e+01, 1.35935639e+01,
176             1.43072299e+01, 1.50583635e+01, 1.58489319e+01, 1.66810054e+01,
177             1.75567629e+01, 1.84784980e+01, 1.94486244e+01, 2.04696827e+01,
178             2.15443469e+01, 2.26754313e+01, 2.38658979e+01, 2.51188643e+01 },
179 
180         [LC3_SRATE_16K] = { /* g_tilt = 18 */
181             1.00000000e+00, 1.06800043e+00, 1.14062492e+00, 1.21818791e+00,
182             1.30102522e+00, 1.38949549e+00, 1.48398179e+00, 1.58489319e+00,
183             1.69266662e+00, 1.80776868e+00, 1.93069773e+00, 2.06198601e+00,
184             2.20220195e+00, 2.35195264e+00, 2.51188643e+00, 2.68269580e+00,
185             2.86512027e+00, 3.05994969e+00, 3.26802759e+00, 3.49025488e+00,
186             3.72759372e+00, 3.98107171e+00, 4.25178630e+00, 4.54090961e+00,
187             4.84969343e+00, 5.17947468e+00, 5.53168120e+00, 5.90783791e+00,
188             6.30957344e+00, 6.73862717e+00, 7.19685673e+00, 7.68624610e+00,
189             8.20891416e+00, 8.76712387e+00, 9.36329209e+00, 1.00000000e+01,
190             1.06800043e+01, 1.14062492e+01, 1.21818791e+01, 1.30102522e+01,
191             1.38949549e+01, 1.48398179e+01, 1.58489319e+01, 1.69266662e+01,
192             1.80776868e+01, 1.93069773e+01, 2.06198601e+01, 2.20220195e+01,
193             2.35195264e+01, 2.51188643e+01, 2.68269580e+01, 2.86512027e+01,
194             3.05994969e+01, 3.26802759e+01, 3.49025488e+01, 3.72759372e+01,
195             3.98107171e+01, 4.25178630e+01, 4.54090961e+01, 4.84969343e+01,
196             5.17947468e+01, 5.53168120e+01, 5.90783791e+01, 6.30957344e+01 },
197 
198         [LC3_SRATE_24K] = { /* g_tilt = 22 */
199             1.00000000e+00, 1.08372885e+00, 1.17446822e+00, 1.27280509e+00,
200             1.37937560e+00, 1.49486913e+00, 1.62003281e+00, 1.75567629e+00,
201             1.90267705e+00, 2.06198601e+00, 2.23463373e+00, 2.42173704e+00,
202             2.62450630e+00, 2.84425319e+00, 3.08239924e+00, 3.34048498e+00,
203             3.62017995e+00, 3.92329345e+00, 4.25178630e+00, 4.60778348e+00,
204             4.99358789e+00, 5.41169527e+00, 5.86481029e+00, 6.35586411e+00,
205             6.88803330e+00, 7.46476041e+00, 8.08977621e+00, 8.76712387e+00,
206             9.50118507e+00, 1.02967084e+01, 1.11588399e+01, 1.20931568e+01,
207             1.31057029e+01, 1.42030283e+01, 1.53922315e+01, 1.66810054e+01,
208             1.80776868e+01, 1.95913107e+01, 2.12316686e+01, 2.30093718e+01,
209             2.49359200e+01, 2.70237760e+01, 2.92864456e+01, 3.17385661e+01,
210             3.43959997e+01, 3.72759372e+01, 4.03970086e+01, 4.37794036e+01,
211             4.74450028e+01, 5.14175183e+01, 5.57226480e+01, 6.03882412e+01,
212             6.54444792e+01, 7.09240702e+01, 7.68624610e+01, 8.32980665e+01,
213             9.02725178e+01, 9.78309319e+01, 1.06022203e+02, 1.14899320e+02,
214             1.24519708e+02, 1.34945600e+02, 1.46244440e+02, 1.58489319e+02 },
215 
216         [LC3_SRATE_32K] = { /* g_tilt = 26 */
217             1.00000000e+00, 1.09968890e+00, 1.20931568e+00, 1.32987103e+00,
218             1.46244440e+00, 1.60823388e+00, 1.76855694e+00, 1.94486244e+00,
219             2.13874364e+00, 2.35195264e+00, 2.58641621e+00, 2.84425319e+00,
220             3.12779366e+00, 3.43959997e+00, 3.78248991e+00, 4.15956216e+00,
221             4.57422434e+00, 5.03022373e+00, 5.53168120e+00, 6.08312841e+00,
222             6.68954879e+00, 7.35642254e+00, 8.08977621e+00, 8.89623710e+00,
223             9.78309319e+00, 1.07583590e+01, 1.18308480e+01, 1.30102522e+01,
224             1.43072299e+01, 1.57335019e+01, 1.73019574e+01, 1.90267705e+01,
225             2.09235283e+01, 2.30093718e+01, 2.53031508e+01, 2.78255940e+01,
226             3.05994969e+01, 3.36499270e+01, 3.70044512e+01, 4.06933843e+01,
227             4.47500630e+01, 4.92111475e+01, 5.41169527e+01, 5.95118121e+01,
228             6.54444792e+01, 7.19685673e+01, 7.91430346e+01, 8.70327166e+01,
229             9.57089124e+01, 1.05250029e+02, 1.15742288e+02, 1.27280509e+02,
230             1.39968963e+02, 1.53922315e+02, 1.69266662e+02, 1.86140669e+02,
231             2.04696827e+02, 2.25102829e+02, 2.47543082e+02, 2.72220379e+02,
232             2.99357729e+02, 3.29200372e+02, 3.62017995e+02, 3.98107171e+02 },
233 
234         [LC3_SRATE_48K] = { /* g_tilt = 30 */
235             1.00000000e+00, 1.11588399e+00, 1.24519708e+00, 1.38949549e+00,
236             1.55051578e+00, 1.73019574e+00, 1.93069773e+00, 2.15443469e+00,
237             2.40409918e+00, 2.68269580e+00, 2.99357729e+00, 3.34048498e+00,
238             3.72759372e+00, 4.15956216e+00, 4.64158883e+00, 5.17947468e+00,
239             5.77969288e+00, 6.44946677e+00, 7.19685673e+00, 8.03085722e+00,
240             8.96150502e+00, 1.00000000e+01, 1.11588399e+01, 1.24519708e+01,
241             1.38949549e+01, 1.55051578e+01, 1.73019574e+01, 1.93069773e+01,
242             2.15443469e+01, 2.40409918e+01, 2.68269580e+01, 2.99357729e+01,
243             3.34048498e+01, 3.72759372e+01, 4.15956216e+01, 4.64158883e+01,
244             5.17947468e+01, 5.77969288e+01, 6.44946677e+01, 7.19685673e+01,
245             8.03085722e+01, 8.96150502e+01, 1.00000000e+02, 1.11588399e+02,
246             1.24519708e+02, 1.38949549e+02, 1.55051578e+02, 1.73019574e+02,
247             1.93069773e+02, 2.15443469e+02, 2.40409918e+02, 2.68269580e+02,
248             2.99357729e+02, 3.34048498e+02, 3.72759372e+02, 4.15956216e+02,
249             4.64158883e+02, 5.17947468e+02, 5.77969288e+02, 6.44946677e+02,
250             7.19685673e+02, 8.03085722e+02, 8.96150502e+02, 1.00000000e+03 },
251     };
252 
253     float e[LC3_NUM_BANDS];
254 
255     /* --- Copy and padding --- */
256 
257     int nb = LC3_MIN(lc3_band_lim[dt][sr][LC3_NUM_BANDS], LC3_NUM_BANDS);
258     int n2 = LC3_NUM_BANDS - nb;
259 
260     for (int i2 = 0; i2 < n2; i2++)
261         e[2*i2 + 0] = e[2*i2 + 1] = eb[i2];
262 
263     memcpy(e + 2*n2, eb + n2, (nb - n2) * sizeof(float));
264 
265     /* --- Smoothing, pre-emphasis and logarithm --- */
266 
267     const float *ge = ge_table[sr];
268 
269     float e0 = e[0], e1 = e[0], e2;
270     float e_sum = 0;
271 
272     for (int i = 0; i < LC3_NUM_BANDS-1; ) {
273         e[i] = (e0 * 0.25f + e1 * 0.5f + (e2 = e[i+1]) * 0.25f) * ge[i];
274         e_sum += e[i++];
275 
276         e[i] = (e1 * 0.25f + e2 * 0.5f + (e0 = e[i+1]) * 0.25f) * ge[i];
277         e_sum += e[i++];
278 
279         e[i] = (e2 * 0.25f + e0 * 0.5f + (e1 = e[i+1]) * 0.25f) * ge[i];
280         e_sum += e[i++];
281     }
282 
283     e[LC3_NUM_BANDS-1] = (e0 * 0.25f + e1 * 0.75f) * ge[LC3_NUM_BANDS-1];
284     e_sum += e[LC3_NUM_BANDS-1];
285 
286     float noise_floor = fmaxf(e_sum * (1e-4f / 64), 0x1p-32f);
287 
288     for (int i = 0; i < LC3_NUM_BANDS; i++)
289         e[i] = fast_log2f(fmaxf(e[i], noise_floor)) * 0.5f;
290 
291     /* --- Grouping & scaling --- */
292 
293     float scf_sum;
294 
295     scf[0] = (e[0] + e[4]) * 1.f/12 +
296              (e[0] + e[3]) * 2.f/12 +
297              (e[1] + e[2]) * 3.f/12  ;
298     scf_sum = scf[0];
299 
300     for (int i = 1; i < 15; i++) {
301         scf[i] = (e[4*i-1] + e[4*i+4]) * 1.f/12 +
302                  (e[4*i  ] + e[4*i+3]) * 2.f/12 +
303                  (e[4*i+1] + e[4*i+2]) * 3.f/12  ;
304         scf_sum += scf[i];
305     }
306 
307     scf[15] = (e[59] + e[63]) * 1.f/12 +
308               (e[60] + e[63]) * 2.f/12 +
309               (e[61] + e[62]) * 3.f/12  ;
310     scf_sum += scf[15];
311 
312     for (int i = 0; i < 16; i++)
313         scf[i] = 0.85f * (scf[i] - scf_sum * 1.f/16);
314 
315     /* --- Attack handling --- */
316 
317     if (!att)
318         return;
319 
320     float s0, s1 = scf[0], s2 = scf[1], s3 = scf[2], s4 = scf[3];
321     float sn = s1 + s2;
322 
323     scf[0] = (sn += s3) * 1.f/3;
324     scf[1] = (sn += s4) * 1.f/4;
325     scf_sum = scf[0] + scf[1];
326 
327     for (int i = 2; i < 14; i++, sn -= s0) {
328         s0 = s1, s1 = s2, s2 = s3, s3 = s4, s4 = scf[i+2];
329         scf[i] = (sn += s4) * 1.f/5;
330         scf_sum += scf[i];
331     }
332 
333     scf[14] = (sn      ) * 1.f/4;
334     scf[15] = (sn -= s1) * 1.f/3;
335     scf_sum += scf[14] + scf[15];
336 
337     for (int i = 0; i < 16; i++)
338         scf[i] = (dt == LC3_DT_7M5 ? 0.3f : 0.5f) *
339                  (scf[i] - scf_sum * 1.f/16);
340 }
341 
342 /**
343  * Codebooks
344  * scf             Input 16 scale factors
345  * lf/hfcb_idx     Output the low and high frequency codebooks index
346  */
resolve_codebooks(const float * scf,int * lfcb_idx,int * hfcb_idx)347 LC3_HOT static void resolve_codebooks(
348     const float *scf, int *lfcb_idx, int *hfcb_idx)
349 {
350     float dlfcb_max = 0, dhfcb_max = 0;
351     *lfcb_idx = *hfcb_idx = 0;
352 
353     for (int icb = 0; icb < 32; icb++) {
354         const float *lfcb = lc3_sns_lfcb[icb];
355         const float *hfcb = lc3_sns_hfcb[icb];
356         float dlfcb = 0, dhfcb = 0;
357 
358         for (int i = 0; i < 8; i++) {
359             dlfcb += (scf[  i] - lfcb[i]) * (scf[  i] - lfcb[i]);
360             dhfcb += (scf[8+i] - hfcb[i]) * (scf[8+i] - hfcb[i]);
361         }
362 
363         if (icb == 0 || dlfcb < dlfcb_max)
364             *lfcb_idx = icb, dlfcb_max = dlfcb;
365 
366         if (icb == 0 || dhfcb < dhfcb_max)
367             *hfcb_idx = icb, dhfcb_max = dhfcb;
368     }
369 }
370 
371 /**
372  * Unit energy normalize pulse configuration
373  * c               Pulse configuration
374  * cn              Normalized pulse configuration
375  */
normalize(const int * c,float * cn)376 LC3_HOT static void normalize(const int *c, float *cn)
377 {
378     int c2_sum = 0;
379     for (int i = 0; i < 16; i++)
380         c2_sum += c[i] * c[i];
381 
382     float c_norm = 1.f / sqrtf(c2_sum);
383 
384     for (int i = 0; i < 16; i++)
385         cn[i] = c[i] * c_norm;
386 }
387 
388 /**
389  * Sub-procedure of `quantize()`, add unit pulse
390  * x, y, n         Transformed residual, and vector of pulses with length
391  * start, end      Current number of pulses, limit to reach
392  * corr, energy    Correlation (x,y) and y energy, updated at output
393  */
add_pulse(const float * x,int * y,int n,int start,int end,float * corr,float * energy)394 LC3_HOT static void add_pulse(const float *x, int *y, int n,
395     int start, int end, float *corr, float *energy)
396 {
397     for (int k = start; k < end; k++) {
398         float best_c2 = (*corr + x[0]) * (*corr + x[0]);
399         float best_e = *energy + 2*y[0] + 1;
400         int nbest = 0;
401 
402         for (int i = 1; i < n; i++) {
403             float c2 = (*corr + x[i]) * (*corr + x[i]);
404             float e  = *energy + 2*y[i] + 1;
405 
406             if (c2 * best_e > e * best_c2)
407                 best_c2 = c2, best_e = e, nbest = i;
408         }
409 
410         *corr += x[nbest];
411         *energy += 2*y[nbest] + 1;
412         y[nbest]++;
413     }
414 }
415 
416 /**
417  * Quantization of codebooks residual
418  * scf             Input 16 scale factors, output quantized version
419  * lf/hfcb_idx     Codebooks index
420  * c, cn           Output 4 pulse configurations candidates, normalized
421  * shape/gain_idx  Output selected shape/gain indexes
422  */
quantize(const float * scf,int lfcb_idx,int hfcb_idx,int (* c)[16],float (* cn)[16],int * shape_idx,int * gain_idx)423 LC3_HOT static void quantize(const float *scf, int lfcb_idx, int hfcb_idx,
424     int (*c)[16], float (*cn)[16], int *shape_idx, int *gain_idx)
425 {
426     /* --- Residual --- */
427 
428     const float *lfcb = lc3_sns_lfcb[lfcb_idx];
429     const float *hfcb = lc3_sns_hfcb[hfcb_idx];
430     float r[16], x[16];
431 
432     for (int i = 0; i < 8; i++) {
433         r[  i] = scf[  i] - lfcb[i];
434         r[8+i] = scf[8+i] - hfcb[i];
435     }
436 
437     dct16_forward(r, x);
438 
439     /* --- Shape 3 candidate ---
440      * Project to or below pyramid N = 16, K = 6,
441      * then add unit pulses until you reach K = 6, over N = 16 */
442 
443     float xm[16];
444     float xm_sum = 0;
445 
446     for (int i = 0; i < 16; i++) {
447         xm[i] = fabsf(x[i]);
448         xm_sum += xm[i];
449     }
450 
451     float proj_factor = (6 - 1) / fmaxf(xm_sum, 1e-31f);
452     float corr = 0, energy = 0;
453     int npulses = 0;
454 
455     for (int i = 0; i < 16; i++) {
456         c[3][i] = floorf(xm[i] * proj_factor);
457         npulses += c[3][i];
458         corr    += c[3][i] * xm[i];
459         energy  += c[3][i] * c[3][i];
460     }
461 
462     add_pulse(xm, c[3], 16, npulses, 6, &corr, &energy);
463     npulses = 6;
464 
465     /* --- Shape 2 candidate ---
466      * Add unit pulses until you reach K = 8 on shape 3 */
467 
468     memcpy(c[2], c[3], sizeof(c[2]));
469 
470     add_pulse(xm, c[2], 16, npulses, 8, &corr, &energy);
471     npulses = 8;
472 
473     /* --- Shape 1 candidate ---
474      * Remove any unit pulses from shape 2 that are not part of 0 to 9
475      * Update energy and correlation terms accordingly
476      * Add unit pulses until you reach K = 10, over N = 10 */
477 
478     memcpy(c[1], c[2], sizeof(c[1]));
479 
480     for (int i = 10; i < 16; i++) {
481         c[1][i] = 0;
482         npulses -= c[2][i];
483         corr    -= c[2][i] * xm[i];
484         energy  -= c[2][i] * c[2][i];
485     }
486 
487     add_pulse(xm, c[1], 10, npulses, 10, &corr, &energy);
488     npulses = 10;
489 
490     /* --- Shape 0 candidate ---
491      * Add unit pulses until you reach K = 1, on shape 1 */
492 
493     memcpy(c[0], c[1], sizeof(c[0]));
494 
495     add_pulse(xm + 10, c[0] + 10, 6, 0, 1, &corr, &energy);
496 
497     /* --- Add sign and unit energy normalize --- */
498 
499     for (int j = 0; j < 16; j++)
500         for (int i = 0; i < 4; i++)
501             c[i][j] = x[j] < 0 ? -c[i][j] : c[i][j];
502 
503     for (int i = 0; i < 4; i++)
504         normalize(c[i], cn[i]);
505 
506     /* --- Determe shape & gain index ---
507      * Search the Mean Square Error, within (shape, gain) combinations */
508 
509     float mse_min = INFINITY;
510     *shape_idx = *gain_idx = 0;
511 
512     for (int ic = 0; ic < 4; ic++) {
513         const struct lc3_sns_vq_gains *cgains = lc3_sns_vq_gains + ic;
514         float cmse_min = INFINITY;
515         int cgain_idx = 0;
516 
517         for (int ig = 0; ig < cgains->count; ig++) {
518             float g = cgains->v[ig];
519 
520             float mse = 0;
521             for (int i = 0; i < 16; i++)
522                 mse += (x[i] - g * cn[ic][i]) * (x[i] - g * cn[ic][i]);
523 
524             if (mse < cmse_min) {
525                 cgain_idx = ig,
526                 cmse_min = mse;
527             }
528         }
529 
530         if (cmse_min < mse_min) {
531             *shape_idx = ic, *gain_idx = cgain_idx;
532             mse_min = cmse_min;
533         }
534     }
535 }
536 
537 /**
538  * Unquantization of codebooks residual
539  * lf/hfcb_idx     Low and high frequency codebooks index
540  * c               Table of normalized pulse configuration
541  * shape/gain      Selected shape/gain indexes
542  * scf             Return unquantized scale factors
543  */
unquantize(int lfcb_idx,int hfcb_idx,const float * c,int shape,int gain,float * scf)544 LC3_HOT static void unquantize(int lfcb_idx, int hfcb_idx,
545     const float *c, int shape, int gain, float *scf)
546 {
547     const float *lfcb = lc3_sns_lfcb[lfcb_idx];
548     const float *hfcb = lc3_sns_hfcb[hfcb_idx];
549     float g = lc3_sns_vq_gains[shape].v[gain];
550 
551     dct16_inverse(c, scf);
552 
553     for (int i = 0; i < 8; i++)
554         scf[i] = lfcb[i] + g * scf[i];
555 
556     for (int i = 8; i < 16; i++)
557         scf[i] = hfcb[i-8] + g * scf[i];
558 }
559 
560 /**
561  * Sub-procedure of `sns_enumerate()`, enumeration of a vector
562  * c, n            Table of pulse configuration, and length
563  * idx, ls         Return enumeration set
564  */
enum_mvpq(const int * c,int n,int * idx,bool * ls)565 static void enum_mvpq(const int *c, int n, int *idx, bool *ls)
566 {
567     int ci, i, j;
568 
569     /* --- Scan for 1st significant coeff --- */
570 
571     for (i = 0, c += n; (ci = *(--c)) == 0 ; i++);
572 
573     *idx = 0;
574     *ls = ci < 0;
575 
576     /* --- Scan remaining coefficients --- */
577 
578     for (i++, j = LC3_ABS(ci); i < n; i++, j += LC3_ABS(ci)) {
579 
580         if ((ci = *(--c)) != 0) {
581             *idx = (*idx << 1) | *ls;
582             *ls = ci < 0;
583         }
584 
585         *idx += lc3_sns_mpvq_offsets[i][j];
586     }
587 }
588 
589 /**
590  * Sub-procedure of `sns_deenumerate()`, deenumeration of a vector
591  * idx, ls         Enumeration set
592  * npulses         Number of pulses in the set
593  * c, n            Table of pulses configuration, and length
594  */
deenum_mvpq(int idx,bool ls,int npulses,int * c,int n)595 static void deenum_mvpq(int idx, bool ls, int npulses, int *c, int n)
596 {
597     int i;
598 
599     /* --- Scan for coefficients --- */
600 
601     for (i = n-1; i >= 0 && idx; i--) {
602 
603         int ci = 0;
604 
605         for (ci = 0; idx < lc3_sns_mpvq_offsets[i][npulses - ci]; ci++);
606         idx -= lc3_sns_mpvq_offsets[i][npulses - ci];
607 
608         *(c++) = ls ? -ci : ci;
609         npulses -= ci;
610         if (ci > 0) {
611             ls = idx & 1;
612             idx >>= 1;
613         }
614     }
615 
616     /* --- Set last significant --- */
617 
618     int ci = npulses;
619 
620     if (i-- >= 0)
621         *(c++) = ls ? -ci : ci;
622 
623     while (i-- >= 0)
624         *(c++) = 0;
625 }
626 
627 /**
628  * SNS Enumeration of PVQ configuration
629  * shape           Selected shape index
630  * c               Selected pulse configuration
631  * idx_a, ls_a     Return enumeration set A
632  * idx_b, ls_b     Return enumeration set B (shape = 0)
633  */
enumerate(int shape,const int * c,int * idx_a,bool * ls_a,int * idx_b,bool * ls_b)634 static void enumerate(int shape, const int *c,
635     int *idx_a, bool *ls_a, int *idx_b, bool *ls_b)
636 {
637     enum_mvpq(c, shape < 2 ? 10 : 16, idx_a, ls_a);
638 
639     if (shape == 0)
640         enum_mvpq(c + 10, 6, idx_b, ls_b);
641 }
642 
643 /**
644  * SNS Deenumeration of PVQ configuration
645  * shape           Selected shape index
646  * idx_a, ls_a     enumeration set A
647  * idx_b, ls_b     enumeration set B (shape = 0)
648  * c               Return pulse configuration
649  */
deenumerate(int shape,int idx_a,bool ls_a,int idx_b,bool ls_b,int * c)650 static void deenumerate(int shape,
651     int idx_a, bool ls_a, int idx_b, bool ls_b, int *c)
652 {
653     int npulses_a = (const int []){ 10, 10, 8, 6 }[shape];
654 
655     deenum_mvpq(idx_a, ls_a, npulses_a, c, shape < 2 ? 10 : 16);
656 
657     if (shape == 0)
658         deenum_mvpq(idx_b, ls_b, 1, c + 10, 6);
659     else if (shape == 1)
660         memset(c + 10, 0, 6 * sizeof(*c));
661 }
662 
663 
664 /* ----------------------------------------------------------------------------
665  *  Filtering
666  * -------------------------------------------------------------------------- */
667 
668 /**
669  * Spectral shaping
670  * dt, sr          Duration and samplerate of the frame
671  * scf_q           Quantized scale factors
672  * inv             True on inverse shaping, False otherwise
673  * x               Spectral coefficients
674  * y               Return shapped coefficients
675  *
676  * `x` and `y` can be the same buffer
677  */
spectral_shaping(enum lc3_dt dt,enum lc3_srate sr,const float * scf_q,bool inv,const float * x,float * y)678 LC3_HOT static void spectral_shaping(enum lc3_dt dt, enum lc3_srate sr,
679     const float *scf_q, bool inv, const float *x, float *y)
680 {
681     /* --- Interpolate scale factors --- */
682 
683     float scf[LC3_NUM_BANDS];
684     float s0, s1 = inv ? -scf_q[0] : scf_q[0];
685 
686     scf[0] = scf[1] = s1;
687     for (int i = 0; i < 15; i++) {
688         s0 = s1, s1 = inv ? -scf_q[i+1] : scf_q[i+1];
689         scf[4*i+2] = s0 + 0.125f * (s1 - s0);
690         scf[4*i+3] = s0 + 0.375f * (s1 - s0);
691         scf[4*i+4] = s0 + 0.625f * (s1 - s0);
692         scf[4*i+5] = s0 + 0.875f * (s1 - s0);
693     }
694     scf[62] = s1 + 0.125f * (s1 - s0);
695     scf[63] = s1 + 0.375f * (s1 - s0);
696 
697     int nb = LC3_MIN(lc3_band_lim[dt][sr][LC3_NUM_BANDS], LC3_NUM_BANDS);
698     int n2 = LC3_NUM_BANDS - nb;
699 
700     for (int i2 = 0; i2 < n2; i2++)
701         scf[i2] = 0.5f * (scf[2*i2] + scf[2*i2+1]);
702 
703     if (n2 > 0)
704         memmove(scf + n2, scf + 2*n2, (nb - n2) * sizeof(float));
705 
706     /* --- Spectral shaping --- */
707 
708     const int *lim = lc3_band_lim[dt][sr];
709 
710     for (int i = 0, ib = 0; ib < nb; ib++) {
711         float g_sns = fast_exp2f(-scf[ib]);
712 
713         for ( ; i < lim[ib+1]; i++)
714             y[i] = x[i] * g_sns;
715     }
716 }
717 
718 
719 /* ----------------------------------------------------------------------------
720  *  Interface
721  * -------------------------------------------------------------------------- */
722 
723 /**
724  * SNS analysis
725  */
lc3_sns_analyze(enum lc3_dt dt,enum lc3_srate sr,const float * eb,bool att,struct lc3_sns_data * data,const float * x,float * y)726 void lc3_sns_analyze(enum lc3_dt dt, enum lc3_srate sr,
727     const float *eb, bool att, struct lc3_sns_data *data,
728     const float *x, float *y)
729 {
730     /* Processing steps :
731      * - Determine 16 scale factors from bands energy estimation
732      * - Get codebooks indexes that match thoses scale factors
733      * - Quantize the residual with the selected codebook
734      * - The pulse configuration `c[]` is enumerated
735      * - Finally shape the spectrum coefficients accordingly */
736 
737     float scf[16], cn[4][16];
738     int c[4][16];
739 
740     compute_scale_factors(dt, sr, eb, att, scf);
741 
742     resolve_codebooks(scf, &data->lfcb, &data->hfcb);
743 
744     quantize(scf, data->lfcb, data->hfcb,
745         c, cn, &data->shape, &data->gain);
746 
747     unquantize(data->lfcb, data->hfcb,
748         cn[data->shape], data->shape, data->gain, scf);
749 
750     enumerate(data->shape, c[data->shape],
751         &data->idx_a, &data->ls_a, &data->idx_b, &data->ls_b);
752 
753     spectral_shaping(dt, sr, scf, false, x, y);
754 }
755 
756 /**
757  * SNS synthesis
758  */
lc3_sns_synthesize(enum lc3_dt dt,enum lc3_srate sr,const lc3_sns_data_t * data,const float * x,float * y)759 void lc3_sns_synthesize(enum lc3_dt dt, enum lc3_srate sr,
760     const lc3_sns_data_t *data, const float *x, float *y)
761 {
762     float scf[16], cn[16];
763     int c[16];
764 
765     deenumerate(data->shape,
766         data->idx_a, data->ls_a, data->idx_b, data->ls_b, c);
767 
768     normalize(c, cn);
769 
770     unquantize(data->lfcb, data->hfcb, cn, data->shape, data->gain, scf);
771 
772     spectral_shaping(dt, sr, scf, true, x, y);
773 }
774 
775 /**
776  * Return number of bits coding the bitstream data
777  */
lc3_sns_get_nbits(void)778 int lc3_sns_get_nbits(void)
779 {
780     return 38;
781 }
782 
783 /**
784  * Put bitstream data
785  */
lc3_sns_put_data(lc3_bits_t * bits,const struct lc3_sns_data * data)786 void lc3_sns_put_data(lc3_bits_t *bits, const struct lc3_sns_data *data)
787 {
788     /* --- Codebooks --- */
789 
790     lc3_put_bits(bits, data->lfcb, 5);
791     lc3_put_bits(bits, data->hfcb, 5);
792 
793     /* --- Shape, gain and vectors --- *
794      * Write MSB bit of shape index, next LSB bits of shape and gain,
795      * and MVPQ vectors indexes are muxed */
796 
797     int shape_msb = data->shape >> 1;
798     lc3_put_bit(bits, shape_msb);
799 
800     if (shape_msb == 0) {
801         const int size_a = 2390004;
802         int submode = data->shape & 1;
803 
804         int mux_high = submode == 0 ?
805             2 * (data->idx_b + 1) + data->ls_b : data->gain & 1;
806         int mux_code = mux_high * size_a + data->idx_a;
807 
808         lc3_put_bits(bits, data->gain >> submode, 1);
809         lc3_put_bits(bits, data->ls_a, 1);
810         lc3_put_bits(bits, mux_code, 25);
811 
812     } else {
813         const int size_a = 15158272;
814         int submode = data->shape & 1;
815 
816         int mux_code = submode == 0 ?
817             data->idx_a : size_a + 2 * data->idx_a + (data->gain & 1);
818 
819         lc3_put_bits(bits, data->gain >> submode, 2);
820         lc3_put_bits(bits, data->ls_a, 1);
821         lc3_put_bits(bits, mux_code, 24);
822     }
823 }
824 
825 /**
826  * Get bitstream data
827  */
lc3_sns_get_data(lc3_bits_t * bits,struct lc3_sns_data * data)828 int lc3_sns_get_data(lc3_bits_t *bits, struct lc3_sns_data *data)
829 {
830     /* --- Codebooks --- */
831 
832     *data = (struct lc3_sns_data){
833         .lfcb = lc3_get_bits(bits, 5),
834         .hfcb = lc3_get_bits(bits, 5)
835     };
836 
837     /* --- Shape, gain and vectors --- */
838 
839     int shape_msb = lc3_get_bit(bits);
840     data->gain = lc3_get_bits(bits, 1 + shape_msb);
841     data->ls_a = lc3_get_bit(bits);
842 
843     int mux_code = lc3_get_bits(bits, 25 - shape_msb);
844 
845     if (shape_msb == 0) {
846         const int size_a = 2390004;
847 
848         if (mux_code >= size_a * 14)
849             return -1;
850 
851         data->idx_a = mux_code % size_a;
852         mux_code = mux_code / size_a;
853 
854         data->shape = (mux_code < 2);
855 
856         if (data->shape == 0) {
857             data->idx_b = (mux_code - 2) / 2;
858             data->ls_b  = (mux_code - 2) % 2;
859         } else {
860             data->gain = (data->gain << 1) + (mux_code % 2);
861         }
862 
863     } else {
864         const int size_a = 15158272;
865 
866         if (mux_code >= size_a + 1549824)
867             return -1;
868 
869         data->shape = 2 + (mux_code >= size_a);
870         if (data->shape == 2) {
871             data->idx_a = mux_code;
872         } else {
873             mux_code -= size_a;
874             data->idx_a = mux_code / 2;
875             data->gain = (data->gain << 1) + (mux_code % 2);
876         }
877     }
878 
879     return 0;
880 }
881