1 /******************************************************************************
2  *
3  *  Copyright 2022 Google LLC
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at:
8  *
9  *  http://www.apache.org/licenses/LICENSE-2.0
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  ******************************************************************************/
18 
19 #include "mdct.h"
20 #include "tables.h"
21 
22 #include "mdct_neon.h"
23 
24 
25 /* ----------------------------------------------------------------------------
26  *  FFT processing
27  * -------------------------------------------------------------------------- */
28 
29 /**
30  * FFT 5 Points
31  * x, y            Input and output coefficients, of size 5xn
32  * n               Number of interleaved transform to perform (n % 2 = 0)
33  */
34 #ifndef fft_5
fft_5(const struct lc3_complex * x,struct lc3_complex * y,int n)35 LC3_HOT static inline void fft_5(
36     const struct lc3_complex *x, struct lc3_complex *y, int n)
37 {
38     static const float cos1 =  0.3090169944;  /* cos(-2Pi 1/5) */
39     static const float cos2 = -0.8090169944;  /* cos(-2Pi 2/5) */
40 
41     static const float sin1 = -0.9510565163;  /* sin(-2Pi 1/5) */
42     static const float sin2 = -0.5877852523;  /* sin(-2Pi 2/5) */
43 
44     for (int i = 0; i < n; i++, x++, y+= 5) {
45 
46         struct lc3_complex s14 =
47             { x[1*n].re + x[4*n].re, x[1*n].im + x[4*n].im };
48         struct lc3_complex d14 =
49             { x[1*n].re - x[4*n].re, x[1*n].im - x[4*n].im };
50 
51         struct lc3_complex s23 =
52             { x[2*n].re + x[3*n].re, x[2*n].im + x[3*n].im };
53         struct lc3_complex d23 =
54             { x[2*n].re - x[3*n].re, x[2*n].im - x[3*n].im };
55 
56         y[0].re = x[0].re + s14.re + s23.re;
57 
58         y[0].im = x[0].im + s14.im + s23.im;
59 
60         y[1].re = x[0].re + s14.re * cos1 - d14.im * sin1
61                           + s23.re * cos2 - d23.im * sin2;
62 
63         y[1].im = x[0].im + s14.im * cos1 + d14.re * sin1
64                           + s23.im * cos2 + d23.re * sin2;
65 
66         y[2].re = x[0].re + s14.re * cos2 - d14.im * sin2
67                           + s23.re * cos1 + d23.im * sin1;
68 
69         y[2].im = x[0].im + s14.im * cos2 + d14.re * sin2
70                           + s23.im * cos1 - d23.re * sin1;
71 
72         y[3].re = x[0].re + s14.re * cos2 + d14.im * sin2
73                           + s23.re * cos1 - d23.im * sin1;
74 
75         y[3].im = x[0].im + s14.im * cos2 - d14.re * sin2
76                           + s23.im * cos1 + d23.re * sin1;
77 
78         y[4].re = x[0].re + s14.re * cos1 + d14.im * sin1
79                           + s23.re * cos2 + d23.im * sin2;
80 
81         y[4].im = x[0].im + s14.im * cos1 - d14.re * sin1
82                           + s23.im * cos2 - d23.re * sin2;
83     }
84 }
85 #endif /* fft_5 */
86 
87 /**
88  * FFT Butterfly 3 Points
89  * x, y            Input and output coefficients
90  * twiddles        Twiddles factors, determine size of transform
91  * n               Number of interleaved transforms
92  */
93 #ifndef fft_bf3
fft_bf3(const struct lc3_fft_bf3_twiddles * twiddles,const struct lc3_complex * x,struct lc3_complex * y,int n)94 LC3_HOT static inline void fft_bf3(
95     const struct lc3_fft_bf3_twiddles *twiddles,
96     const struct lc3_complex *x, struct lc3_complex *y, int n)
97 {
98     int n3 = twiddles->n3;
99     const struct lc3_complex (*w0)[2] = twiddles->t;
100     const struct lc3_complex (*w1)[2] = w0 + n3, (*w2)[2] = w1 + n3;
101 
102     const struct lc3_complex *x0 = x, *x1 = x0 + n*n3, *x2 = x1 + n*n3;
103     struct lc3_complex *y0 = y, *y1 = y0 + n3, *y2 = y1 + n3;
104 
105     for (int i = 0; i < n; i++, y0 += 3*n3, y1 += 3*n3, y2 += 3*n3)
106         for (int j = 0; j < n3; j++, x0++, x1++, x2++) {
107 
108             y0[j].re = x0->re + x1->re * w0[j][0].re - x1->im * w0[j][0].im
109                               + x2->re * w0[j][1].re - x2->im * w0[j][1].im;
110 
111             y0[j].im = x0->im + x1->im * w0[j][0].re + x1->re * w0[j][0].im
112                               + x2->im * w0[j][1].re + x2->re * w0[j][1].im;
113 
114             y1[j].re = x0->re + x1->re * w1[j][0].re - x1->im * w1[j][0].im
115                               + x2->re * w1[j][1].re - x2->im * w1[j][1].im;
116 
117             y1[j].im = x0->im + x1->im * w1[j][0].re + x1->re * w1[j][0].im
118                               + x2->im * w1[j][1].re + x2->re * w1[j][1].im;
119 
120             y2[j].re = x0->re + x1->re * w2[j][0].re - x1->im * w2[j][0].im
121                               + x2->re * w2[j][1].re - x2->im * w2[j][1].im;
122 
123             y2[j].im = x0->im + x1->im * w2[j][0].re + x1->re * w2[j][0].im
124                               + x2->im * w2[j][1].re + x2->re * w2[j][1].im;
125         }
126 }
127 #endif /* fft_bf3 */
128 
129 /**
130  * FFT Butterfly 2 Points
131  * twiddles        Twiddles factors, determine size of transform
132  * x, y            Input and output coefficients
133  * n               Number of interleaved transforms
134  */
135 #ifndef fft_bf2
fft_bf2(const struct lc3_fft_bf2_twiddles * twiddles,const struct lc3_complex * x,struct lc3_complex * y,int n)136 LC3_HOT static inline void fft_bf2(
137     const struct lc3_fft_bf2_twiddles *twiddles,
138     const struct lc3_complex *x, struct lc3_complex *y, int n)
139 {
140     int n2 = twiddles->n2;
141     const struct lc3_complex *w = twiddles->t;
142 
143     const struct lc3_complex *x0 = x, *x1 = x0 + n*n2;
144     struct lc3_complex *y0 = y, *y1 = y0 + n2;
145 
146     for (int i = 0; i < n; i++, y0 += 2*n2, y1 += 2*n2) {
147 
148         for (int j = 0; j < n2; j++, x0++, x1++) {
149 
150             y0[j].re = x0->re + x1->re * w[j].re - x1->im * w[j].im;
151             y0[j].im = x0->im + x1->im * w[j].re + x1->re * w[j].im;
152 
153             y1[j].re = x0->re - x1->re * w[j].re + x1->im * w[j].im;
154             y1[j].im = x0->im - x1->im * w[j].re - x1->re * w[j].im;
155         }
156     }
157 }
158 #endif /* fft_bf2 */
159 
160 /**
161  * Perform FFT
162  * x, y0, y1       Input, and 2 scratch buffers of size `n`
163  * n               Number of points 30, 40, 60, 80, 90, 120, 160, 180, 240
164  * return          The buffer `y0` or `y1` that hold the result
165  *
166  * Input `x` can be the same as the `y0` second scratch buffer
167  */
fft(const struct lc3_complex * x,int n,struct lc3_complex * y0,struct lc3_complex * y1)168 static struct lc3_complex *fft(const struct lc3_complex *x, int n,
169     struct lc3_complex *y0, struct lc3_complex *y1)
170 {
171     struct lc3_complex *y[2] = { y1, y0 };
172     int i2, i3, is = 0;
173 
174     /* The number of points `n` can be decomposed as :
175      *
176      *   n = 5^1 * 3^n3 * 2^n2
177      *
178      *   for n = 40, 80, 160        n3 = 0, n2 = [3..5]
179      *       n = 30, 60, 120, 240   n3 = 1, n2 = [1..4]
180      *       n = 90, 180            n3 = 2, n2 = [1..2]
181      *
182      * Note that the expression `n & (n-1) == 0` is equivalent
183      * to the check that `n` is a power of 2. */
184 
185     fft_5(x, y[is], n /= 5);
186 
187     for (i3 = 0; n & (n-1); i3++, is ^= 1)
188         fft_bf3(lc3_fft_twiddles_bf3[i3], y[is], y[is ^ 1], n /= 3);
189 
190     for (i2 = 0; n > 1; i2++, is ^= 1)
191         fft_bf2(lc3_fft_twiddles_bf2[i2][i3], y[is], y[is ^ 1], n >>= 1);
192 
193     return y[is];
194 }
195 
196 
197 /* ----------------------------------------------------------------------------
198  *  MDCT processing
199  * -------------------------------------------------------------------------- */
200 
201 /**
202  * Windowing of samples before MDCT
203  * dt, sr          Duration and samplerate (size of the transform)
204  * x, y            Input current and delayed samples
205  * y, d            Output windowed samples, and delayed ones
206  */
mdct_window(enum lc3_dt dt,enum lc3_srate sr,const float * x,float * d,float * y)207 LC3_HOT static void mdct_window(enum lc3_dt dt, enum lc3_srate sr,
208     const float *x, float *d, float *y)
209 {
210     int ns = LC3_NS(dt, sr), nd = LC3_ND(dt, sr);
211 
212     const float *w0 = lc3_mdct_win[dt][sr], *w1 = w0 + ns;
213     const float *w2 = w1, *w3 = w2 + nd;
214 
215     const float *x0 = x + ns-nd, *x1 = x0;
216     float *y0 = y + ns/2, *y1 = y0;
217     float *d0 = d, *d1 = d + nd;
218 
219     while (x1 > x) {
220         *(--y0) = *d0 * *(w0++) - *(--x1) * *(--w1);
221         *(y1++) = (*(d0++) = *(x0++)) * *(w2++);
222 
223         *(--y0) = *d0 * *(w0++) - *(--x1) * *(--w1);
224         *(y1++) = (*(d0++) = *(x0++)) * *(w2++);
225     }
226 
227     for (x1 += ns; x0 < x1; ) {
228         *(--y0) = *d0 * *(w0++) - *(--d1) * *(--w1);
229         *(y1++) = (*(d0++) = *(x0++)) * *(w2++) + (*d1 = *(--x1)) * *(--w3);
230 
231         *(--y0) = *d0 * *(w0++) - *(--d1) * *(--w1);
232         *(y1++) = (*(d0++) = *(x0++)) * *(w2++) + (*d1 = *(--x1)) * *(--w3);
233     }
234 }
235 
236 /**
237  * Pre-rotate MDCT coefficients of N/2 points, before FFT N/4 points FFT
238  * def             Size and twiddles factors
239  * x, y            Input and output coefficients
240  *
241  * `x` and y` can be the same buffer
242  */
mdct_pre_fft(const struct lc3_mdct_rot_def * def,const float * x,struct lc3_complex * y)243 LC3_HOT static void mdct_pre_fft(const struct lc3_mdct_rot_def *def,
244     const float *x, struct lc3_complex *y)
245 {
246     int n4 = def->n4;
247 
248     const float *x0 = x, *x1 = x0 + 2*n4;
249     const struct lc3_complex *w0 = def->w, *w1 = w0 + n4;
250     struct lc3_complex *y0 = y, *y1 = y0 + n4;
251 
252     while (x0 < x1) {
253         struct lc3_complex u, uw = *(w0++);
254         u.re = - *(--x1) * uw.re + *x0 * uw.im;
255         u.im =   *(x0++) * uw.re + *x1 * uw.im;
256 
257         struct lc3_complex v, vw = *(--w1);
258         v.re = - *(--x1) * vw.im + *x0 * vw.re;
259         v.im = - *(x0++) * vw.im - *x1 * vw.re;
260 
261         *(y0++) = u;
262         *(--y1) = v;
263     }
264 }
265 
266 /**
267  * Post-rotate FFT N/4 points coefficients, resulting MDCT N points
268  * def             Size and twiddles factors
269  * x, y            Input and output coefficients
270  *
271  * `x` and y` can be the same buffer
272  */
mdct_post_fft(const struct lc3_mdct_rot_def * def,const struct lc3_complex * x,float * y)273 LC3_HOT static void mdct_post_fft(const struct lc3_mdct_rot_def *def,
274     const struct lc3_complex *x, float *y)
275 {
276     int n4 = def->n4, n8 = n4 >> 1;
277 
278     const struct lc3_complex *w0 = def->w + n8, *w1 = w0 - 1;
279     const struct lc3_complex *x0 = x + n8, *x1 = x0 - 1;
280 
281     float *y0 = y + n4, *y1 = y0;
282 
283     for ( ; y1 > y; x0++, x1--, w0++, w1--) {
284 
285         float u0 = x0->im * w0->im + x0->re * w0->re;
286         float u1 = x1->re * w1->im - x1->im * w1->re;
287 
288         float v0 = x0->re * w0->im - x0->im * w0->re;
289         float v1 = x1->im * w1->im + x1->re * w1->re;
290 
291         *(y0++) = u0;  *(y0++) = u1;
292         *(--y1) = v0;  *(--y1) = v1;
293     }
294 }
295 
296 /**
297  * Pre-rotate IMDCT coefficients of N points, before FFT N/4 points FFT
298  * def             Size and twiddles factors
299  * x, y            Input and output coefficients
300  *
301  * `x` and `y` can be the same buffer
302  * The real and imaginary parts of `y` are swapped,
303  * to operate on FFT instead of IFFT
304  */
imdct_pre_fft(const struct lc3_mdct_rot_def * def,const float * x,struct lc3_complex * y)305 LC3_HOT static void imdct_pre_fft(const struct lc3_mdct_rot_def *def,
306     const float *x, struct lc3_complex *y)
307 {
308     int n4 = def->n4;
309 
310     const float *x0 = x, *x1 = x0 + 2*n4;
311 
312     const struct lc3_complex *w0 = def->w, *w1 = w0 + n4;
313     struct lc3_complex *y0 = y, *y1 = y0 + n4;
314 
315     while (x0 < x1) {
316         float u0 = *(x0++), u1 = *(--x1);
317         float v0 = *(x0++), v1 = *(--x1);
318         struct lc3_complex uw = *(w0++), vw = *(--w1);
319 
320         (y0  )->re = - u0 * uw.re - u1 * uw.im;
321         (y0++)->im = - u1 * uw.re + u0 * uw.im;
322 
323         (--y1)->re = - v1 * vw.re - v0 * vw.im;
324         (  y1)->im = - v0 * vw.re + v1 * vw.im;
325     }
326 }
327 
328 /**
329  * Post-rotate FFT N/4 points coefficients, resulting IMDCT N points
330  * def             Size and twiddles factors
331  * x, y            Input and output coefficients
332  *
333  * `x` and y` can be the same buffer
334  * The real and imaginary parts of `x` are swapped,
335  * to operate on FFT instead of IFFT
336  */
imdct_post_fft(const struct lc3_mdct_rot_def * def,const struct lc3_complex * x,float * y)337 LC3_HOT static void imdct_post_fft(const struct lc3_mdct_rot_def *def,
338     const struct lc3_complex *x, float *y)
339 {
340     int n4 = def->n4;
341 
342     const struct lc3_complex *w0 = def->w, *w1 = w0 + n4;
343     const struct lc3_complex *x0 = x, *x1 = x0 + n4;
344 
345     float *y0 = y, *y1 = y0 + 2*n4;
346 
347     while (x0 < x1) {
348         struct lc3_complex uz = *(x0++), vz = *(--x1);
349         struct lc3_complex uw = *(w0++), vw = *(--w1);
350 
351         *(y0++) = uz.re * uw.im - uz.im * uw.re;
352         *(--y1) = uz.re * uw.re + uz.im * uw.im;
353 
354         *(--y1) = vz.re * vw.im - vz.im * vw.re;
355         *(y0++) = vz.re * vw.re + vz.im * vw.im;
356     }
357 }
358 
359 /**
360  * Apply windowing of samples
361  * dt, sr          Duration and samplerate
362  * x, d            Middle half of IMDCT coefficients and delayed samples
363  * y, d            Output samples and delayed ones
364  */
imdct_window(enum lc3_dt dt,enum lc3_srate sr,const float * x,float * d,float * y)365 LC3_HOT static void imdct_window(enum lc3_dt dt, enum lc3_srate sr,
366     const float *x, float *d, float *y)
367 {
368     /* The full MDCT coefficients is given by symmetry :
369      *   T[   0 ..  n/4-1] = -half[n/4-1 .. 0    ]
370      *   T[ n/4 ..  n/2-1] =  half[0     .. n/4-1]
371      *   T[ n/2 .. 3n/4-1] =  half[n/4   .. n/2-1]
372      *   T[3n/4 ..    n-1] =  half[n/2-1 .. n/4  ]  */
373 
374     int n4 = LC3_NS(dt, sr) >> 1, nd = LC3_ND(dt, sr);
375     const float *w2 = lc3_mdct_win[dt][sr], *w0 = w2 + 3*n4, *w1 = w0;
376 
377     const float *x0 = d + nd-n4, *x1 = x0;
378     float *y0 = y + nd-n4, *y1 = y0, *y2 = d + nd, *y3 = d;
379 
380     while (y0 > y) {
381         *(--y0) = *(--x0) - *(x  ) * *(w1++);
382         *(y1++) = *(x1++) + *(x++) * *(--w0);
383 
384         *(--y0) = *(--x0) - *(x  ) * *(w1++);
385         *(y1++) = *(x1++) + *(x++) * *(--w0);
386     }
387 
388     while (y1 < y + nd) {
389         *(y1++) = *(x1++) + *(x++) * *(--w0);
390         *(y1++) = *(x1++) + *(x++) * *(--w0);
391     }
392 
393     while (y1 < y + 2*n4) {
394         *(y1++) = *(x  ) * *(--w0);
395         *(--y2) = *(x++) * *(w2++);
396 
397         *(y1++) = *(x  ) * *(--w0);
398         *(--y2) = *(x++) * *(w2++);
399     }
400 
401     while (y2 > y3) {
402         *(y3++) = *(x  ) * *(--w0);
403         *(--y2) = *(x++) * *(w2++);
404 
405         *(y3++) = *(x  ) * *(--w0);
406         *(--y2) = *(x++) * *(w2++);
407     }
408 }
409 
410 /**
411  * Rescale samples
412  * x, n            Input and count of samples, scaled as output
413  * scale           Scale factor
414  */
rescale(float * x,int n,float f)415 LC3_HOT static void rescale(float *x, int n, float f)
416 {
417     for (int i = 0; i < (n >> 2); i++) {
418         *(x++) *= f; *(x++) *= f;
419         *(x++) *= f; *(x++) *= f;
420     }
421 }
422 
423 /**
424  * Forward MDCT transformation
425  */
lc3_mdct_forward(enum lc3_dt dt,enum lc3_srate sr,enum lc3_srate sr_dst,const float * x,float * d,float * y)426 void lc3_mdct_forward(enum lc3_dt dt, enum lc3_srate sr,
427     enum lc3_srate sr_dst, const float *x, float *d, float *y)
428 {
429     const struct lc3_mdct_rot_def *rot = lc3_mdct_rot[dt][sr];
430     int ns_dst = LC3_NS(dt, sr_dst);
431     int ns = LC3_NS(dt, sr);
432 
433     struct lc3_complex buffer[LC3_MAX_NS / 2];
434     struct lc3_complex *z = (struct lc3_complex *)y;
435     union { float *f; struct lc3_complex *z; } u = { .z = buffer };
436 
437     mdct_window(dt, sr, x, d, u.f);
438 
439     mdct_pre_fft(rot, u.f, u.z);
440     u.z = fft(u.z, ns/2, u.z, z);
441     mdct_post_fft(rot, u.z, y);
442 
443     if (ns != ns_dst)
444         rescale(y, ns_dst, sqrtf((float)ns_dst / ns));
445 }
446 
447 /**
448  * Inverse MDCT transformation
449  */
lc3_mdct_inverse(enum lc3_dt dt,enum lc3_srate sr,enum lc3_srate sr_src,const float * x,float * d,float * y)450 void lc3_mdct_inverse(enum lc3_dt dt, enum lc3_srate sr,
451     enum lc3_srate sr_src, const float *x, float *d, float *y)
452 {
453     const struct lc3_mdct_rot_def *rot = lc3_mdct_rot[dt][sr];
454     int ns_src = LC3_NS(dt, sr_src);
455     int ns = LC3_NS(dt, sr);
456 
457     struct lc3_complex buffer[LC3_MAX_NS / 2];
458     struct lc3_complex *z = (struct lc3_complex *)y;
459     union { float *f; struct lc3_complex *z; } u = { .z = buffer };
460 
461     imdct_pre_fft(rot, x, z);
462     z = fft(z, ns/2, z, u.z);
463     imdct_post_fft(rot, z, u.f);
464 
465     if (ns != ns_src)
466         rescale(u.f, ns, sqrtf((float)ns / ns_src));
467 
468     imdct_window(dt, sr, u.f, d, y);
469 }
470