mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-30 12:22:58 +00:00
🎉 start typing integration
This commit is contained in:
@@ -10,6 +10,19 @@
|
||||
- Waveform convertion and transforming functions.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
__email__ = 'spleeter@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
|
||||
class Codec(str, Enum):
|
||||
""" Enumeration of supported audio codec. """
|
||||
|
||||
WAV: str = 'wav'
|
||||
MP3: str = 'mp3'
|
||||
OGG: str = 'ogg'
|
||||
M4A: str = 'm4a'
|
||||
WMA: str = 'wma'
|
||||
FLAC: str = 'flac'
|
||||
|
||||
@@ -3,21 +3,22 @@
|
||||
|
||||
""" AudioAdapter class defintion. """
|
||||
|
||||
import subprocess
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from importlib import import_module
|
||||
from os.path import exists
|
||||
from pathlib import Path
|
||||
from spleeter.audio import Codec
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from .. import SpleeterError
|
||||
from ..types import AudioDescriptor, Signal
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
# pyright: reportMissingImports=false
|
||||
# pylint: disable=import-error
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.signal import stft, hann_window
|
||||
# pylint: enable=import-error
|
||||
|
||||
from .. import SpleeterError
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
__email__ = 'spleeter@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
@@ -27,46 +28,72 @@ __license__ = 'MIT License'
|
||||
class AudioAdapter(ABC):
|
||||
""" An abstract class for manipulating audio signal. """
|
||||
|
||||
# Default audio adapter singleton instance.
|
||||
DEFAULT = None
|
||||
_DEFAULT: 'AudioAdapter' = None
|
||||
""" Default audio adapter singleton instance. """
|
||||
|
||||
@abstractmethod
|
||||
def load(
|
||||
self, audio_descriptor, offset, duration,
|
||||
sample_rate, dtype=np.float32):
|
||||
""" Loads the audio file denoted by the given audio descriptor
|
||||
and returns it data as a waveform. Aims to be implemented
|
||||
by client.
|
||||
self,
|
||||
audio_descriptor: AudioDescriptor,
|
||||
offset: float,
|
||||
duration: float,
|
||||
sample_rate: float,
|
||||
dtype: np.dtype = np.float32) -> Signal:
|
||||
"""
|
||||
Loads the audio file denoted by the given audio descriptor and
|
||||
returns it data as a waveform. Aims to be implemented by client.
|
||||
|
||||
:param audio_descriptor: Describe song to load, in case of file
|
||||
based audio adapter, such descriptor would
|
||||
be a file path.
|
||||
:param offset: Start offset to load from in seconds.
|
||||
:param duration: Duration to load in seconds.
|
||||
:param sample_rate: Sample rate to load audio with.
|
||||
:param dtype: Numpy data type to use, default to float32.
|
||||
:returns: Loaded data as (wf, sample_rate) tuple.
|
||||
Parameters:
|
||||
audio_descriptor (AudioDescriptor):
|
||||
Describe song to load, in case of file based audio adapter,
|
||||
such descriptor would be a file path.
|
||||
offset (float):
|
||||
Start offset to load from in seconds.
|
||||
duration (float):
|
||||
Duration to load in seconds.
|
||||
sample_rate (float):
|
||||
Sample rate to load audio with.
|
||||
dtype (numpy.dtype):
|
||||
(Optional) Numpy data type to use, default to `float32`.
|
||||
|
||||
Returns:
|
||||
Signal:
|
||||
Loaded data as (wf, sample_rate) tuple.
|
||||
"""
|
||||
pass
|
||||
|
||||
def load_tf_waveform(
|
||||
self, audio_descriptor,
|
||||
offset=0.0, duration=1800., sample_rate=44100,
|
||||
dtype=b'float32', waveform_name='waveform'):
|
||||
""" Load the audio and convert it to a tensorflow waveform.
|
||||
self,
|
||||
audio_descriptor,
|
||||
offset: float = 0.0,
|
||||
duration: float = 1800.,
|
||||
sample_rate: int = 44100,
|
||||
dtype: bytes = b'float32',
|
||||
waveform_name: str = 'waveform') -> Dict[str, Any]:
|
||||
"""
|
||||
Load the audio and convert it to a tensorflow waveform.
|
||||
|
||||
:param audio_descriptor: Describe song to load, in case of file
|
||||
based audio adapter, such descriptor would
|
||||
be a file path.
|
||||
:param offset: Start offset to load from in seconds.
|
||||
:param duration: Duration to load in seconds.
|
||||
:param sample_rate: Sample rate to load audio with.
|
||||
:param dtype: Numpy data type to use, default to float32.
|
||||
:param waveform_name: (Optional) Name of the key in output dict.
|
||||
:returns: TF output dict with waveform as
|
||||
(T x chan numpy array) and a boolean that
|
||||
tells whether there were an error while
|
||||
trying to load the waveform.
|
||||
Parameters:
|
||||
audio_descriptor ():
|
||||
Describe song to load, in case of file based audio adapter,
|
||||
such descriptor would be a file path.
|
||||
offset (float):
|
||||
Start offset to load from in seconds.
|
||||
duration (float):
|
||||
Duration to load in seconds.
|
||||
sample_rate (float):
|
||||
Sample rate to load audio with.
|
||||
dtype (bytes):
|
||||
(Optional)data type to use, default to `b'float32'`.
|
||||
waveform_name (str):
|
||||
(Optional) Name of the key in output dict, default to
|
||||
`'waveform'`.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]:
|
||||
TF output dict with waveform as `(T x chan numpy array)`
|
||||
and a boolean that tells whether there were an error while
|
||||
trying to load the waveform.
|
||||
"""
|
||||
# Cast parameters to TF format.
|
||||
offset = tf.cast(offset, tf.float64)
|
||||
@@ -100,50 +127,69 @@ class AudioAdapter(ABC):
|
||||
waveform, error = results[0]
|
||||
return {
|
||||
waveform_name: waveform,
|
||||
f'{waveform_name}_error': error
|
||||
}
|
||||
f'{waveform_name}_error': error}
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self, path, data, sample_rate,
|
||||
codec=None, bitrate=None):
|
||||
""" Save the given audio data to the file denoted by
|
||||
the given path.
|
||||
self,
|
||||
path: Union[Path, str],
|
||||
data: np.ndarray,
|
||||
sample_rate: float,
|
||||
codec: Codec = None,
|
||||
bitrate: str = None):
|
||||
"""
|
||||
Save the given audio data to the file denoted by the given path.
|
||||
|
||||
:param path: Path of the audio file to save data in.
|
||||
:param data: Waveform data to write.
|
||||
:param sample_rate: Sample rate to write file in.
|
||||
:param codec: (Optional) Writing codec to use.
|
||||
:param bitrate: (Optional) Bitrate of the written audio file.
|
||||
Parameters:
|
||||
path (Union[Path, str]):
|
||||
Path like of the audio file to save data in.
|
||||
data (numpy.ndarray):
|
||||
Waveform data to write.
|
||||
sample_rate (float):
|
||||
Sample rate to write file in.
|
||||
codec ():
|
||||
(Optional) Writing codec to use, default to `None`.
|
||||
bitrate (str):
|
||||
(Optional) Bitrate of the written audio file, default to
|
||||
`None`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def default(cls: type) -> 'AudioAdapter':
|
||||
"""
|
||||
Builds and returns a default audio adapter instance.
|
||||
|
||||
def get_default_audio_adapter():
|
||||
""" Builds and returns a default audio adapter instance.
|
||||
Returns:
|
||||
AudioAdapter:
|
||||
Default adapter instance to use.
|
||||
"""
|
||||
if cls._DEFAULT is None:
|
||||
from .ffmpeg import FFMPEGProcessAudioAdapter
|
||||
cls._DEFAULT = FFMPEGProcessAudioAdapter()
|
||||
return cls._DEFAULT
|
||||
|
||||
:returns: An audio adapter instance.
|
||||
"""
|
||||
if AudioAdapter.DEFAULT is None:
|
||||
from .ffmpeg import FFMPEGProcessAudioAdapter
|
||||
AudioAdapter.DEFAULT = FFMPEGProcessAudioAdapter()
|
||||
return AudioAdapter.DEFAULT
|
||||
@classmethod
|
||||
def get(cls: type, descriptor: str) -> 'AudioAdapter':
|
||||
"""
|
||||
Load dynamically an AudioAdapter from given class descriptor.
|
||||
|
||||
Parameters:
|
||||
descriptor (str):
|
||||
Adapter class descriptor (module.Class)
|
||||
|
||||
def get_audio_adapter(descriptor):
|
||||
""" Load dynamically an AudioAdapter from given class descriptor.
|
||||
|
||||
:param descriptor: Adapter class descriptor (module.Class)
|
||||
:returns: Created adapter instance.
|
||||
"""
|
||||
if descriptor is None:
|
||||
return get_default_audio_adapter()
|
||||
module_path = descriptor.split('.')
|
||||
adapter_class_name = module_path[-1]
|
||||
module_path = '.'.join(module_path[:-1])
|
||||
adapter_module = import_module(module_path)
|
||||
adapter_class = getattr(adapter_module, adapter_class_name)
|
||||
if not isinstance(adapter_class, AudioAdapter):
|
||||
raise SpleeterError(
|
||||
f'{adapter_class_name} is not a valid AudioAdapter class')
|
||||
return adapter_class()
|
||||
Returns:
|
||||
AudioAdapter:
|
||||
Created adapter instance.
|
||||
"""
|
||||
if not descriptor:
|
||||
return cls.default()
|
||||
module_path: List[str] = descriptor.split('.')
|
||||
adapter_class_name: str = module_path[-1]
|
||||
module_path: str = '.'.join(module_path[:-1])
|
||||
adapter_module = import_module(module_path)
|
||||
adapter_class = getattr(adapter_module, adapter_class_name)
|
||||
if not isinstance(adapter_class, AudioAdapter):
|
||||
raise SpleeterError(
|
||||
f'{adapter_class_name} is not a valid AudioAdapter class')
|
||||
return adapter_class()
|
||||
|
||||
@@ -3,39 +3,54 @@
|
||||
|
||||
""" This module provides audio data convertion functions. """
|
||||
|
||||
from ..utils.tensor import from_float32_to_uint8, from_uint8_to_float32
|
||||
|
||||
# pyright: reportMissingImports=false
|
||||
# pylint: disable=import-error
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
# pylint: enable=import-error
|
||||
|
||||
from ..utils.tensor import from_float32_to_uint8, from_uint8_to_float32
|
||||
|
||||
__email__ = 'spleeter@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
|
||||
def to_n_channels(waveform, n_channels):
|
||||
""" Convert a waveform to n_channels by removing or
|
||||
duplicating channels if needed (in tensorflow).
|
||||
def to_n_channels(
|
||||
waveform: tf.Tensor,
|
||||
n_channels: int) -> tf.Tensor:
|
||||
"""
|
||||
Convert a waveform to n_channels by removing or duplicating channels if
|
||||
needed (in tensorflow).
|
||||
|
||||
:param waveform: Waveform to transform.
|
||||
:param n_channels: Number of channel to reshape waveform in.
|
||||
:returns: Reshaped waveform.
|
||||
Parameters:
|
||||
waveform (tensorflow.Tensor):
|
||||
Waveform to transform.
|
||||
n_channels (int):
|
||||
Number of channel to reshape waveform in.
|
||||
|
||||
Returns:
|
||||
tensorflow.Tensor:
|
||||
Reshaped waveform.
|
||||
"""
|
||||
return tf.cond(
|
||||
tf.shape(waveform)[1] >= n_channels,
|
||||
true_fn=lambda: waveform[:, :n_channels],
|
||||
false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels]
|
||||
)
|
||||
false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels])
|
||||
|
||||
|
||||
def to_stereo(waveform):
|
||||
""" Convert a waveform to stereo by duplicating if mono,
|
||||
or truncating if too many channels.
|
||||
def to_stereo(waveform: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert a waveform to stereo by duplicating if mono, or truncating
|
||||
if too many channels.
|
||||
|
||||
:param waveform: a (N, d) numpy array.
|
||||
:returns: A stereo waveform as a (N, 1) numpy array.
|
||||
Parameters:
|
||||
waveform (numpy.ndarray):
|
||||
a `(N, d)` numpy array.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray:
|
||||
A stereo waveform as a `(N, 1)` numpy array.
|
||||
"""
|
||||
if waveform.shape[1] == 1:
|
||||
return np.repeat(waveform, 2, axis=-1)
|
||||
@@ -44,45 +59,84 @@ def to_stereo(waveform):
|
||||
return waveform
|
||||
|
||||
|
||||
def gain_to_db(tensor, espilon=10e-10):
|
||||
""" Convert from gain to decibel in tensorflow.
|
||||
def gain_to_db(tensor: tf.Tensor, espilon: float = 10e-10) -> tf.Tensor:
|
||||
"""
|
||||
Convert from gain to decibel in tensorflow.
|
||||
|
||||
:param tensor: Tensor to convert.
|
||||
:param epsilon: Operation constant.
|
||||
:returns: Converted tensor.
|
||||
Parameters:
|
||||
tensor (tensorflow.Tensor):
|
||||
Tensor to convert
|
||||
epsilon (float):
|
||||
Operation constant.
|
||||
|
||||
Returns:
|
||||
tensorflow.Tensor:
|
||||
Converted tensor.
|
||||
"""
|
||||
return 20. / np.log(10) * tf.math.log(tf.maximum(tensor, espilon))
|
||||
|
||||
|
||||
def db_to_gain(tensor):
|
||||
""" Convert from decibel to gain in tensorflow.
|
||||
def db_to_gain(tensor: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Convert from decibel to gain in tensorflow.
|
||||
|
||||
:param tensor_db: Tensor to convert.
|
||||
:returns: Converted tensor.
|
||||
Parameters:
|
||||
tensor (tensorflow.Tensor):
|
||||
Tensor to convert
|
||||
|
||||
Returns:
|
||||
tensorflow.Tensor:
|
||||
Converted tensor.
|
||||
"""
|
||||
return tf.pow(10., (tensor / 20.))
|
||||
|
||||
|
||||
def spectrogram_to_db_uint(spectrogram, db_range=100., **kwargs):
|
||||
""" Encodes given spectrogram into uint8 using decibel scale.
|
||||
|
||||
:param spectrogram: Spectrogram to be encoded as TF float tensor.
|
||||
:param db_range: Range in decibel for encoding.
|
||||
:returns: Encoded decibel spectrogram as uint8 tensor.
|
||||
def spectrogram_to_db_uint(
|
||||
spectrogram: tf.Tensor,
|
||||
db_range: float = 100.,
|
||||
**kwargs) -> tf.Tensor:
|
||||
"""
|
||||
db_spectrogram = gain_to_db(spectrogram)
|
||||
max_db_spectrogram = tf.reduce_max(db_spectrogram)
|
||||
db_spectrogram = tf.maximum(db_spectrogram, max_db_spectrogram - db_range)
|
||||
Encodes given spectrogram into uint8 using decibel scale.
|
||||
|
||||
Parameters:
|
||||
spectrogram (tensorflow.Tensor):
|
||||
Spectrogram to be encoded as TF float tensor.
|
||||
db_range (float):
|
||||
Range in decibel for encoding.
|
||||
|
||||
Returns:
|
||||
tensorflow.Tensor:
|
||||
Encoded decibel spectrogram as `uint8` tensor.
|
||||
"""
|
||||
db_spectrogram: tf.Tensor = gain_to_db(spectrogram)
|
||||
max_db_spectrogram: tf.Tensor = tf.reduce_max(db_spectrogram)
|
||||
db_spectrogram: tf.Tensor = tf.maximum(
|
||||
db_spectrogram,
|
||||
max_db_spectrogram - db_range)
|
||||
return from_float32_to_uint8(db_spectrogram, **kwargs)
|
||||
|
||||
|
||||
def db_uint_spectrogram_to_gain(db_uint_spectrogram, min_db, max_db):
|
||||
""" Decode spectrogram from uint8 decibel scale.
|
||||
|
||||
:param db_uint_spectrogram: Decibel pectrogram to decode.
|
||||
:param min_db: Lower bound limit for decoding.
|
||||
:param max_db: Upper bound limit for decoding.
|
||||
:returns: Decoded spectrogram as float2 tensor.
|
||||
def db_uint_spectrogram_to_gain(
|
||||
db_uint_spectrogram: tf.Tensor,
|
||||
min_db: tf.Tensor,
|
||||
max_db: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
db_spectrogram = from_uint8_to_float32(db_uint_spectrogram, min_db, max_db)
|
||||
Decode spectrogram from uint8 decibel scale.
|
||||
|
||||
Paramters:
|
||||
db_uint_spectrogram (tensorflow.Tensor):
|
||||
Decibel spectrogram to decode.
|
||||
min_db (tensorflow.Tensor):
|
||||
Lower bound limit for decoding.
|
||||
max_db (tensorflow.Tensor):
|
||||
Upper bound limit for decoding.
|
||||
|
||||
Returns:
|
||||
tensorflow.Tensor:
|
||||
Decoded spectrogram as `float32` tensor.
|
||||
"""
|
||||
db_spectrogram: tf.Tensor = from_uint8_to_float32(
|
||||
db_uint_spectrogram,
|
||||
min_db,
|
||||
max_db)
|
||||
return db_to_gain(db_spectrogram)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
""" Spectrogram specific data augmentation """
|
||||
|
||||
# pyright: reportMissingImports=false
|
||||
# pylint: disable=import-error
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -51,34 +51,38 @@ OPT_PARAMS = {
|
||||
'help': 'JSON filename that contains params'
|
||||
}
|
||||
|
||||
# -s opt specification (separate).
|
||||
OPT_OFFSET = {
|
||||
'dest': 'offset',
|
||||
'type': float,
|
||||
'default': 0.,
|
||||
'help': 'Set the starting offset to separate audio from.'
|
||||
}
|
||||
Offset: OptionInfo = Option(
|
||||
0.,
|
||||
'--offset',
|
||||
'-s',
|
||||
help='Set the starting offset to separate audio from')
|
||||
|
||||
# -d opt specification (separate).
|
||||
OPT_DURATION = {
|
||||
'dest': 'duration',
|
||||
'type': float,
|
||||
'default': 600.,
|
||||
'help': (
|
||||
Duration: OptionInfo = Option(
|
||||
600.,
|
||||
'--duration',
|
||||
'-d',
|
||||
help=(
|
||||
'Set a maximum duration for processing audio '
|
||||
'(only separate offset + duration first seconds of '
|
||||
'the input file)')
|
||||
}
|
||||
'the input file)'))
|
||||
|
||||
# -w opt specification (separate)
|
||||
OPT_STFT_BACKEND = {
|
||||
'dest': 'stft_backend',
|
||||
'type': str,
|
||||
'choices' : ["tensorflow", "librosa", "auto"],
|
||||
'default': "auto",
|
||||
'help': 'Who should be in charge of computing the stfts. Librosa is faster than tensorflow on CPU and uses'
|
||||
' less memory. "auto" will use tensorflow when GPU acceleration is available and librosa when not.'
|
||||
}
|
||||
|
||||
class STFTBackendEnum(Enum, str):
|
||||
|
||||
AUTO: str
|
||||
TENSORFLOW: str
|
||||
LIBROSA: str
|
||||
|
||||
|
||||
STFTBackend: OptionInfo = Option(
|
||||
STFTBackendEnum.AUTO,
|
||||
'--stft-backend',
|
||||
'-B',
|
||||
case_sensitive=False,
|
||||
help=(
|
||||
'Who should be in charge of computing the stfts. Librosa is faster '
|
||||
'than tensorflow on CPU and uses less memory. "auto" will use '
|
||||
'tensorflow when GPU acceleration is available and librosa when not'))
|
||||
|
||||
|
||||
# -c opt specification (separate).
|
||||
@@ -128,6 +132,14 @@ OPT_ADAPTER = {
|
||||
'help': 'Name of the audio adapter to use for audio I/O'
|
||||
}
|
||||
|
||||
|
||||
|
||||
AudioAdapter: OptionInfo = Option(
|
||||
'spleeter.audio.ffmpeg.FFMPEGProcessAudioAdapter',
|
||||
'--adapter',
|
||||
help='Name of the audio adapter to use for audio I/O')
|
||||
|
||||
|
||||
# -a opt specification (train, evaluate and separate).
|
||||
OPT_VERBOSE = {
|
||||
'action': 'store_true',
|
||||
|
||||
@@ -19,6 +19,10 @@ __author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
|
||||
from typer import Option
|
||||
|
||||
AudioAdapter = Option()
|
||||
|
||||
|
||||
def entrypoint(arguments, params):
|
||||
""" Command entrypoint.
|
||||
|
||||
15
spleeter/types.py
Normal file
15
spleeter/types.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf8
|
||||
|
||||
""" TO DOCUMENT """
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
# pyright: reportMissingImports=false
|
||||
# pylint: disable=import-error
|
||||
import numpy as np
|
||||
# pylint: enable=import-error
|
||||
|
||||
|
||||
AudioDescriptor: type = Any
|
||||
Signal: type = Tuple[np.ndarray, float]
|
||||
Reference in New Issue
Block a user