1#
2# Copyright 2024 Google LLC
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16from __future__ import annotations
17
18import array
19import ctypes
20import enum
21import glob
22import os
23import typing
24
25from ctypes import c_bool, c_byte, c_int, c_uint, c_size_t, c_void_p
26from ctypes.util import find_library
27from collections.abc import Iterable
28
29
30class BaseError(Exception):
31    """Base error raised by liblc3."""
32
33
34class InitializationError(RuntimeError, BaseError):
35    """Error raised when liblc3 cannot be initialized."""
36
37
38class InvalidArgumentError(ValueError, BaseError):
39    """Error raised when a bad argument is given."""
40
41
42class _PcmFormat(enum.IntEnum):
43    S16 = 0
44    S24 = 1
45    S24_3LE = 2
46    FLOAT = 3
47
48
49class _Base:
50
51    def __init__(
52        self,
53        frame_duration_us: int,
54        sample_rate_hz: int,
55        num_channels: int,
56        hrmode: bool = False,
57        pcm_sample_rate_hz: int | None = None,
58        libpath: str | None = None,
59    ) -> None:
60
61        self.hrmode = hrmode
62        self.frame_duration_us = frame_duration_us
63        self.sample_rate_hz = sample_rate_hz
64        self.pcm_sample_rate_hz = pcm_sample_rate_hz or self.sample_rate_hz
65        self.num_channels = num_channels
66
67        if self.frame_duration_us not in [2500, 5000, 7500, 10000]:
68            raise InvalidArgumentError(
69                f"Invalid frame duration: {self.frame_duration_us} us ({self.frame_duration_us / 1000:.1f} ms)"
70            )
71
72        allowed_samplerate = (
73            [8000, 16000, 24000, 32000, 48000] if not self.hrmode else [48000, 96000]
74        )
75
76        if self.sample_rate_hz not in allowed_samplerate:
77            raise InvalidArgumentError(f"Invalid sample rate: {sample_rate_hz} Hz")
78
79        if libpath is None:
80            mesonpy_lib = glob.glob(
81                os.path.join(os.path.dirname(__file__), ".lc3py.mesonpy.libs", "*lc3*")
82            )
83
84            if mesonpy_lib:
85                libpath = mesonpy_lib[0]
86            else:
87                libpath = find_library("lc3")
88            if not libpath:
89                raise InitializationError("LC3 library not found")
90
91        lib = ctypes.cdll.LoadLibrary(libpath)
92
93        if not all(
94            hasattr(lib, func)
95            for func in (
96                "lc3_hr_frame_samples",
97                "lc3_hr_frame_block_bytes",
98                "lc3_hr_resolve_bitrate",
99                "lc3_hr_delay_samples",
100            )
101        ):
102            if self.hrmode:
103                raise InitializationError("High-Resolution interface not available")
104
105            lc3_hr_frame_samples = lambda hrmode, dt_us, sr_hz: lib.lc3_frame_samples(
106                dt_us, sr_hz
107            )
108            lc3_hr_frame_block_bytes = (
109                lambda hrmode, dt_us, sr_hz, num_channels, bitrate: num_channels
110                * lib.lc3_frame_bytes(dt_us, bitrate // 2)
111            )
112            lc3_hr_resolve_bitrate = (
113                lambda hrmode, dt_us, sr_hz, nbytes: lib.lc3_resolve_bitrate(
114                    dt_us, nbytes
115                )
116            )
117            lc3_hr_delay_samples = lambda hrmode, dt_us, sr_hz: lib.lc3_delay_samples(
118                dt_us, sr_hz
119            )
120            setattr(lib, "lc3_hr_frame_samples", lc3_hr_frame_samples)
121            setattr(lib, "lc3_hr_frame_block_bytes", lc3_hr_frame_block_bytes)
122            setattr(lib, "lc3_hr_resolve_bitrate", lc3_hr_resolve_bitrate)
123            setattr(lib, "lc3_hr_delay_samples", lc3_hr_delay_samples)
124
125        lib.lc3_hr_frame_samples.argtypes = [c_bool, c_int, c_int]
126        lib.lc3_hr_frame_block_bytes.argtypes = [c_bool, c_int, c_int, c_int, c_int]
127        lib.lc3_hr_resolve_bitrate.argtypes = [c_bool, c_int, c_int, c_int]
128        lib.lc3_hr_delay_samples.argtypes = [c_bool, c_int, c_int]
129        self.lib = lib
130
131        if not (libc_path := find_library("c")):
132            raise InitializationError("Unable to find libc")
133        libc = ctypes.cdll.LoadLibrary(libc_path)
134
135        self.malloc = libc.malloc
136        self.malloc.argtypes = [c_size_t]
137        self.malloc.restype = c_void_p
138
139        self.free = libc.free
140        self.free.argtypes = [c_void_p]
141
142    def get_frame_samples(self) -> int:
143        """
144        Returns the number of PCM samples in an LC3 frame.
145        """
146        ret = self.lib.lc3_hr_frame_samples(
147            self.hrmode, self.frame_duration_us, self.pcm_sample_rate_hz
148        )
149        if ret < 0:
150            raise InvalidArgumentError("Bad parameters")
151        return ret
152
153    def get_frame_bytes(self, bitrate: int) -> int:
154        """
155        Returns the size of LC3 frame blocks, from bitrate in bit per seconds.
156        A target `bitrate` equals 0 or `INT32_MAX` returns respectively
157        the minimum and maximum allowed size.
158        """
159        ret = self.lib.lc3_hr_frame_block_bytes(
160            self.hrmode,
161            self.frame_duration_us,
162            self.sample_rate_hz,
163            self.num_channels,
164            bitrate,
165        )
166        if ret < 0:
167            raise InvalidArgumentError("Bad parameters")
168        return ret
169
170    def resolve_bitrate(self, num_bytes: int) -> int:
171        """
172        Returns the bitrate in bits per seconds, from the size of LC3 frames.
173        """
174        ret = self.lib.lc3_hr_resolve_bitrate(
175            self.hrmode, self.frame_duration_us, self.sample_rate_hz, num_bytes
176        )
177        if ret < 0:
178            raise InvalidArgumentError("Bad parameters")
179        return ret
180
181    def get_delay_samples(self) -> int:
182        """
183        Returns the algorithmic delay, as a number of samples.
184        """
185        ret = self.lib.lc3_hr_delay_samples(
186            self.hrmode, self.frame_duration_us, self.pcm_sample_rate_hz
187        )
188        if ret < 0:
189            raise InvalidArgumentError("Bad parameters")
190        return ret
191
192    @classmethod
193    def _resolve_pcm_format(cls, bit_depth: int | None) -> tuple[
194        _PcmFormat,
195        type[ctypes.c_int16] | type[ctypes.Array[ctypes.c_byte]] | type[ctypes.c_float],
196    ]:
197        match bit_depth:
198            case 16:
199                return (_PcmFormat.S16, ctypes.c_int16)
200            case 24:
201                return (_PcmFormat.S24_3LE, 3 * ctypes.c_byte)
202            case None:
203                return (_PcmFormat.FLOAT, ctypes.c_float)
204            case _:
205                raise InvalidArgumentError("Could not interpret PCM bit_depth")
206
207
208class Encoder(_Base):
209    """
210    LC3 Encoder wrapper.
211
212    The `frame_duration_us`, in microsecond, is any of 2500, 5000, 7500, or 10000.
213    The `sample_rate_hz`, in Hertz, is any of 8000, 16000, 24000, 32000
214    or 48000, unless High-Resolution mode is enabled. In High-Resolution mode,
215    the `sample_rate_hz` is 48000 or 96000.
216
217    By default, one channel is processed. When `num_channels` is greater than one,
218    the PCM input stream is read interleaved and consecutives LC3 frames are
219    output, for each channel.
220
221    Optional arguments:
222        hrmode               : Enable High-Resolution mode, default is `False`.
223        input_sample_rate_hz : Input PCM samplerate, enable downsampling of input.
224        libpath              : LC3 library path and name
225    """
226
227    class c_encoder_t(c_void_p):
228        pass
229
230    def __init__(
231        self,
232        frame_duration_us: int,
233        sample_rate_hz: int,
234        num_channels: int = 1,
235        hrmode: bool = False,
236        input_sample_rate_hz: int | None = None,
237        libpath: str | None = None,
238    ) -> None:
239
240        super().__init__(
241            frame_duration_us,
242            sample_rate_hz,
243            num_channels,
244            hrmode,
245            input_sample_rate_hz,
246            libpath,
247        )
248
249        lib = self.lib
250
251        if not all(
252            hasattr(lib, func)
253            for func in ("lc3_hr_encoder_size", "lc3_hr_setup_encoder")
254        ):
255            if self.hrmode:
256                raise InitializationError("High-Resolution interface not available")
257
258            lc3_hr_encoder_size = lambda hrmode, dt_us, sr_hz: lib.lc3_encoder_size(
259                dt_us, sr_hz
260            )
261
262            lc3_hr_setup_encoder = (
263                lambda hrmode, dt_us, sr_hz, sr_pcm_hz, mem: lib.lc3_setup_encoder(
264                    dt_us, sr_hz, sr_pcm_hz, mem
265                )
266            )
267            setattr(lib, "lc3_hr_encoder_size", lc3_hr_encoder_size)
268            setattr(lib, "lc3_hr_setup_encoder", lc3_hr_setup_encoder)
269
270        lib.lc3_hr_encoder_size.argtypes = [c_bool, c_int, c_int]
271        lib.lc3_hr_encoder_size.restype = c_uint
272
273        lib.lc3_hr_setup_encoder.argtypes = [c_bool, c_int, c_int, c_int, c_void_p]
274        lib.lc3_hr_setup_encoder.restype = self.c_encoder_t
275
276        lib.lc3_encode.argtypes = [
277            self.c_encoder_t,
278            c_int,
279            c_void_p,
280            c_int,
281            c_int,
282            c_void_p,
283        ]
284
285        def new_encoder():
286            return lib.lc3_hr_setup_encoder(
287                self.hrmode,
288                self.frame_duration_us,
289                self.sample_rate_hz,
290                self.pcm_sample_rate_hz,
291                self.malloc(
292                    lib.lc3_hr_encoder_size(
293                        self.hrmode, self.frame_duration_us, self.pcm_sample_rate_hz
294                    )
295                ),
296            )
297
298        self.__encoders = [new_encoder() for _ in range(num_channels)]
299
300    def __del__(self) -> None:
301
302        try:
303            (self.free(encoder) for encoder in self.__encoders)
304        finally:
305            return
306
307    @typing.overload
308    def encode(
309        self,
310        pcm: bytes | bytearray | memoryview | Iterable[float],
311        num_bytes: int,
312        bit_depth: None = None,
313    ) -> bytes: ...
314
315    @typing.overload
316    def encode(
317        self, pcm: bytes | bytearray | memoryview, num_bytes: int, bit_depth: int
318    ) -> bytes: ...
319
320    def encode(self, pcm, num_bytes: int, bit_depth: int | None = None) -> bytes:
321        """
322        Encodes LC3 frame(s), for each channel.
323
324        The `pcm` input is given in two ways. When no `bit_depth` is defined,
325        it's a vector of floating point values from -1 to 1, coding the sample
326        levels. When `bit_depth` is defined, `pcm` is interpreted as a byte-like
327        object, each sample coded on `bit_depth` bits (16 or 24).
328        The machine endianness, or little endian, is used for 16 or 24 bits
329        width, respectively.
330        In both cases, the `pcm` vector data is padded with zeros when
331        its length is less than the required input samples for the encoder.
332        Channels concatenation of encoded LC3 frames, of `nbytes`, is returned.
333        """
334
335        nchannels = self.num_channels
336        frame_samples = self.get_frame_samples()
337
338        (pcm_fmt, pcm_t) = self._resolve_pcm_format(bit_depth)
339        pcm_len = nchannels * frame_samples
340
341        if bit_depth is None:
342            pcm_buffer = array.array("f", pcm)
343
344            # Invert test to catch NaN
345            if not abs(sum(pcm)) / frame_samples < 2:
346                raise InvalidArgumentError("Out of range PCM input")
347
348            padding = max(pcm_len - frame_samples, 0)
349            pcm_buffer.extend(array.array("f", [0] * padding))
350
351        else:
352            padding = max(pcm_len * ctypes.sizeof(pcm_t) - len(pcm), 0)
353            pcm_buffer = bytearray(pcm) + bytearray(padding)  # type: ignore
354
355        data_buffer = (c_byte * num_bytes)()
356        data_offset = 0
357
358        for ich, encoder in enumerate(self.__encoders):
359
360            pcm_offset = ich * ctypes.sizeof(pcm_t)
361            pcm = (pcm_t * (pcm_len - ich)).from_buffer(pcm_buffer, pcm_offset)
362
363            data_size = num_bytes // nchannels + int(ich < num_bytes % nchannels)
364            data = (c_byte * data_size).from_buffer(data_buffer, data_offset)
365            data_offset += data_size
366
367            ret = self.lib.lc3_encode(encoder, pcm_fmt, pcm, nchannels, len(data), data)
368            if ret < 0:
369                raise InvalidArgumentError("Bad parameters")
370
371        return bytes(data_buffer)
372
373
374class Decoder(_Base):
375    """
376    LC3 Decoder wrapper.
377
378    The `frame_duration_us`, in microsecond, is any of 2500, 5000, 7500, or 10000.
379    The `sample_rate_hz`, in Hertz, is any of 8000, 16000, 24000, 32000
380    or 48000, unless High-Resolution mode is enabled. In High-Resolution mode,
381    the `sample_rate_hz` is 48000 or 96000.
382
383    By default, one channel is processed. When `num_chanels` is greater than one,
384    the PCM input stream is read interleaved and consecutives LC3 frames are
385    output, for each channel.
386
387    Optional arguments:
388        hrmode                : Enable High-Resolution mode, default is `False`.
389        output_sample_rate_hz : Output PCM sample_rate_hz, enable upsampling of output.
390        libpath               : LC3 library path and name
391    """
392
393    class c_decoder_t(c_void_p):
394        pass
395
396    def __init__(
397        self,
398        frame_duration_us: int,
399        sample_rate_hz: int,
400        num_channels: int = 1,
401        hrmode: bool = False,
402        output_sample_rate_hz: int | None = None,
403        libpath: str | None = None,
404    ) -> None:
405
406        super().__init__(
407            frame_duration_us,
408            sample_rate_hz,
409            num_channels,
410            hrmode,
411            output_sample_rate_hz,
412            libpath,
413        )
414
415        lib = self.lib
416
417        if not all(
418            hasattr(lib, func)
419            for func in ("lc3_hr_decoder_size", "lc3_hr_setup_decoder")
420        ):
421            if self.hrmode:
422                raise InitializationError("High-Resolution interface not available")
423
424            lc3_hr_decoder_size = lambda hrmode, dt_us, sr_hz: lib.lc3_decoder_size(
425                dt_us, sr_hz
426            )
427
428            lc3_hr_setup_decoder = (
429                lambda hrmode, dt_us, sr_hz, sr_pcm_hz, mem: lib.lc3_setup_decoder(
430                    dt_us, sr_hz, sr_pcm_hz, mem
431                )
432            )
433            setattr(lib, "lc3_hr_decoder_size", lc3_hr_decoder_size)
434            setattr(lib, "lc3_hr_setup_decoder", lc3_hr_setup_decoder)
435
436        lib.lc3_hr_decoder_size.argtypes = [c_bool, c_int, c_int]
437        lib.lc3_hr_decoder_size.restype = c_uint
438
439        lib.lc3_hr_setup_decoder.argtypes = [c_bool, c_int, c_int, c_int, c_void_p]
440        lib.lc3_hr_setup_decoder.restype = self.c_decoder_t
441
442        lib.lc3_decode.argtypes = [
443            self.c_decoder_t,
444            c_void_p,
445            c_int,
446            c_int,
447            c_void_p,
448            c_int,
449        ]
450
451        def new_decoder():
452            return lib.lc3_hr_setup_decoder(
453                self.hrmode,
454                self.frame_duration_us,
455                self.sample_rate_hz,
456                self.pcm_sample_rate_hz,
457                self.malloc(
458                    lib.lc3_hr_decoder_size(
459                        self.hrmode, self.frame_duration_us, self.pcm_sample_rate_hz
460                    )
461                ),
462            )
463
464        self.__decoders = [new_decoder() for i in range(num_channels)]
465
466    def __del__(self) -> None:
467
468        try:
469            (self.free(decoder) for decoder in self.__decoders)
470        finally:
471            return
472
473    @typing.overload
474    def decode(
475        self, data: bytes | bytearray | memoryview, bit_depth: None = None
476    ) -> array.array[float]: ...
477
478    @typing.overload
479    def decode(self, data: bytes | bytearray | memoryview, bit_depth: int) -> bytes: ...
480
481    def decode(
482        self, data: bytes | bytearray | memoryview, bit_depth: int | None = None
483    ) -> bytes | array.array[float]:
484        """
485        Decodes an LC3 frame.
486
487        The input `data` is the channels concatenation of LC3 frames in a
488        byte-like object. Interleaved PCM samples are returned according to
489        the `bit_depth` indication.
490        When no `bit_depth` is defined, it's a vector of floating point values
491        from -1 to 1, coding the sample levels. When `bit_depth` is defined,
492        it returns a byte array, each sample coded on `bit_depth` bits.
493        The machine endianness, or little endian, is used for 16 or 24 bits
494        width, respectively.
495        """
496
497        num_channels = self.num_channels
498
499        (pcm_fmt, pcm_t) = self._resolve_pcm_format(bit_depth)
500        pcm_len = num_channels * self.get_frame_samples()
501        pcm_buffer = (pcm_t * pcm_len)()
502
503        data_buffer = bytearray(data)
504        data_offset = 0
505
506        for ich, decoder in enumerate(self.__decoders):
507            pcm_offset = ich * ctypes.sizeof(pcm_t)
508            pcm = (pcm_t * (pcm_len - ich)).from_buffer(pcm_buffer, pcm_offset)
509
510            data_size = len(data_buffer) // num_channels + int(
511                ich < len(data_buffer) % num_channels
512            )
513            buf = (c_byte * data_size).from_buffer(data_buffer, data_offset)
514            data_offset += data_size
515
516            ret = self.lib.lc3_decode(
517                decoder, buf, len(buf), pcm_fmt, pcm, self.num_channels
518            )
519            if ret < 0:
520                raise InvalidArgumentError("Bad parameters")
521
522        return array.array("f", pcm_buffer) if bit_depth is None else bytes(pcm_buffer)
523