add pyproject.toml for poetry transition

This commit is contained in:
Félix Voituret
2021-01-08 17:32:39 +01:00
parent ad0171f6dd
commit 2b479eb683
25 changed files with 3741 additions and 1624 deletions

View File

@@ -1,4 +1,4 @@
name: pytest
name: test
on:
pull_request:
branches:
@@ -15,13 +15,6 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- uses: actions/cache@v2
id: spleeter-pip-cache
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }}
restore-keys: |
${{ runner.os }}-pip-
- uses: actions/cache@v2
env:
model-release: 1
@@ -31,11 +24,29 @@ jobs:
key: models-${{ env.model-release }}
restore-keys: |
models-${{ env.model-release }}
- name: Install dependencies
- name: Install ffmpeg
run: |
sudo apt-get update && sudo apt-get install -y ffmpeg
pip install --upgrade pip setuptools
pip install pytest==5.4.3 pytest-xdist==1.32.0 pytest-forked==1.1.3 musdb museval
python setup.py install
- name: Install Poetry
uses: dschep/install-poetry-action@v1.2
- name: Cache Poetry virtualenv
uses: actions/cache@v1
id: cache
with:
path: ~/.virtualenvs
key: poetry-${{ hashFiles('**/poetry.lock') }}
restore-keys: |
poetry-${{ hashFiles('**/poetry.lock') }}
- name: Set Poetry config
run: |
poetry config settings.virtualenvs.in-project false
poetry config settings.virtualenvs.path ~/.virtualenvs
- name: Install Dependencies
run: poetry install
if: steps.cache.outputs.cache-hit != 'true'
- name: Code quality checks
run: |
poetry run black spleeter --check
poetry run isort spleeter --check
- name: Test with pytest
run: make test
run: poetry run pytest tests/

1931
poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

84
pyproject.toml Normal file
View File

@@ -0,0 +1,84 @@
[tool.poetry]
name = "spleeter"
version = "2.1.0"
description = "The Deezer source separation library with pretrained models based on tensorflow."
authors = ["Deezer Research <spleeter@deezer.com>"]
license = "MIT License"
readme = "README.md"
repository = "https://github.com/deezer/spleeter"
homepage = "https://github.com/deezer/spleeter"
classifiers = [
"Environment :: Console",
"Environment :: MacOS X",
"Intended Audience :: Developers",
"Intended Audience :: Information Technology",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Operating System :: MacOS",
"Operating System :: Microsoft :: Windows",
"Operating System :: POSIX :: Linux",
"Operating System :: Unix",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: Implementation :: CPython",
"Topic :: Artistic Software",
"Topic :: Multimedia",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Sound/Audio :: Analysis",
"Topic :: Multimedia :: Sound/Audio :: Conversion",
"Topic :: Multimedia :: Sound/Audio :: Sound Synthesis",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Information Analysis",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Utilities"
]
packages = [ { include = "spleeter" } ]
include = ["spleeter/resources/*.json"]
[tool.poetry.dependencies]
python = "^3.7"
ffmpeg-python = "0.2.0"
norbert = "0.2.1"
httpx = {extras = ["http2"], version = "^0.16.1"}
typer = "^0.3.2"
librosa = "0.8.0"
musdb = {version = "0.3.1", optional = true}
museval = {version = "0.3.0", optional = true}
tensorflow = "2.3.0"
pandas = "1.1.2"
numpy = "<1.19.0,>=1.16.0"
[tool.poetry.dev-dependencies]
pytest = "^6.2.1"
isort = "^5.7.0"
black = "^20.8b1"
mypy = "^0.790"
pytest-xdist = "^2.2.0"
pytest-forked = "^1.3.0"
musdb = "0.3.1"
museval = "0.3.0"
[tool.poetry.scripts]
spleeter = 'spleeter.__main__:entrypoint'
[tool.poetry.extras]
evaluation = ["musdb", "museval"]
[tool.isort]
profile = "black"
multi_line_output = 3
[tool.pytest.ini_options]
addopts = "-W ignore::FutureWarning -W ignore::DeprecationWarning -vv --forked"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

View File

@@ -13,9 +13,9 @@
by providing train, evaluation and source separation action.
"""
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class SpleeterError(Exception):

View File

@@ -13,21 +13,21 @@
"""
import json
from functools import partial
from itertools import product
from glob import glob
from itertools import product
from os.path import join
from pathlib import Path
from typing import Container, Dict, List, Optional
# pyright: reportMissingImports=false
# pylint: disable=import-error
from typer import Exit, Typer
from . import SpleeterError
from .options import *
from .utils.logging import configure_logger, logger
# pyright: reportMissingImports=false
# pylint: disable=import-error
from typer import Exit, Typer
# pylint: enable=import-error
spleeter: Typer = Typer(add_completion=False)
@@ -36,21 +36,22 @@ spleeter: Typer = Typer(add_completion=False)
@spleeter.command()
def train(
adapter: str = AudioAdapterOption,
data: Path = TrainingDataDirectoryOption,
params_filename: str = ModelParametersOption,
verbose: bool = VerboseOption) -> None:
adapter: str = AudioAdapterOption,
data: Path = TrainingDataDirectoryOption,
params_filename: str = ModelParametersOption,
verbose: bool = VerboseOption,
) -> None:
"""
Train a source separation model
Train a source separation model
"""
import tensorflow as tf
from .audio.adapter import AudioAdapter
from .dataset import get_training_dataset, get_validation_dataset
from .model import model_fn
from .model.provider import ModelProvider
from .utils.configuration import load_configuration
import tensorflow as tf
configure_logger(verbose)
audio_adapter = AudioAdapter.get(adapter)
audio_path = str(data)
@@ -59,51 +60,49 @@ def train(
session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params['model_dir'],
model_dir=params["model_dir"],
params=params,
config=tf.estimator.RunConfig(
save_checkpoints_steps=params['save_checkpoints_steps'],
tf_random_seed=params['random_seed'],
save_summary_steps=params['save_summary_steps'],
save_checkpoints_steps=params["save_checkpoints_steps"],
tf_random_seed=params["random_seed"],
save_summary_steps=params["save_summary_steps"],
session_config=session_config,
log_step_count_steps=10,
keep_checkpoint_max=2))
keep_checkpoint_max=2,
),
)
input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
train_spec = tf.estimator.TrainSpec(
input_fn=input_fn,
max_steps=params['train_max_steps'])
input_fn = partial(
get_validation_dataset,
params,
audio_adapter,
audio_path)
input_fn=input_fn, max_steps=params["train_max_steps"]
)
input_fn = partial(get_validation_dataset, params, audio_adapter, audio_path)
evaluation_spec = tf.estimator.EvalSpec(
input_fn=input_fn,
steps=None,
throttle_secs=params['throttle_secs'])
logger.info('Start model training')
input_fn=input_fn, steps=None, throttle_secs=params["throttle_secs"]
)
logger.info("Start model training")
tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec)
ModelProvider.writeProbe(params['model_dir'])
logger.info('Model training done')
ModelProvider.writeProbe(params["model_dir"])
logger.info("Model training done")
@spleeter.command()
def separate(
deprecated_files: Optional[str] = AudioInputOption,
files: List[Path] = AudioInputArgument,
adapter: str = AudioAdapterOption,
bitrate: str = AudioBitrateOption,
codec: Codec = AudioCodecOption,
duration: float = AudioDurationOption,
offset: float = AudioOffsetOption,
output_path: Path = AudioOutputOption,
stft_backend: STFTBackend = AudioSTFTBackendOption,
filename_format: str = FilenameFormatOption,
params_filename: str = ModelParametersOption,
mwf: bool = MWFOption,
verbose: bool = VerboseOption) -> None:
deprecated_files: Optional[str] = AudioInputOption,
files: List[Path] = AudioInputArgument,
adapter: str = AudioAdapterOption,
bitrate: str = AudioBitrateOption,
codec: Codec = AudioCodecOption,
duration: float = AudioDurationOption,
offset: float = AudioOffsetOption,
output_path: Path = AudioOutputOption,
stft_backend: STFTBackend = AudioSTFTBackendOption,
filename_format: str = FilenameFormatOption,
params_filename: str = ModelParametersOption,
mwf: bool = MWFOption,
verbose: bool = VerboseOption,
) -> None:
"""
Separate audio file(s)
Separate audio file(s)
"""
from .audio.adapter import AudioAdapter
from .separator import Separator
@@ -111,14 +110,14 @@ def separate(
configure_logger(verbose)
if deprecated_files is not None:
logger.error(
'⚠️ -i option is not supported anymore, audio files must be supplied '
'using input argument instead (see spleeter separate --help)')
"⚠️ -i option is not supported anymore, audio files must be supplied "
"using input argument instead (see spleeter separate --help)"
)
raise Exit(20)
audio_adapter: AudioAdapter = AudioAdapter.get(adapter)
separator: Separator = Separator(
params_filename,
MWF=mwf,
stft_backend=stft_backend)
params_filename, MWF=mwf, stft_backend=stft_backend
)
for filename in files:
separator.separate_to_file(
str(filename),
@@ -129,66 +128,73 @@ def separate(
codec=codec,
bitrate=bitrate,
filename_format=filename_format,
synchronous=False)
synchronous=False,
)
separator.join()
EVALUATION_SPLIT: str = 'test'
EVALUATION_METRICS_DIRECTORY: str = 'metrics'
EVALUATION_INSTRUMENTS: Container[str] = ('vocals', 'drums', 'bass', 'other')
EVALUATION_METRICS: Container[str] = ('SDR', 'SAR', 'SIR', 'ISR')
EVALUATION_MIXTURE: str = 'mixture.wav'
EVALUATION_AUDIO_DIRECTORY: str = 'audio'
EVALUATION_SPLIT: str = "test"
EVALUATION_METRICS_DIRECTORY: str = "metrics"
EVALUATION_INSTRUMENTS: Container[str] = ("vocals", "drums", "bass", "other")
EVALUATION_METRICS: Container[str] = ("SDR", "SAR", "SIR", "ISR")
EVALUATION_MIXTURE: str = "mixture.wav"
EVALUATION_AUDIO_DIRECTORY: str = "audio"
def _compile_metrics(metrics_output_directory) -> Dict:
"""
Compiles metrics from given directory and returns results as dict.
Compiles metrics from given directory and returns results as dict.
Parameters:
metrics_output_directory (str):
Directory to get metrics from.
Parameters:
metrics_output_directory (str):
Directory to get metrics from.
Returns:
Dict:
Compiled metrics as dict.
Returns:
Dict:
Compiled metrics as dict.
"""
import pandas as pd
import numpy as np
import pandas as pd
songs = glob(join(metrics_output_directory, 'test/*.json'))
songs = glob(join(metrics_output_directory, "test/*.json"))
index = pd.MultiIndex.from_tuples(
product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS),
names=['instrument', 'metric'])
pd.DataFrame([], index=['config1', 'config2'], columns=index)
names=["instrument", "metric"],
)
pd.DataFrame([], index=["config1", "config2"], columns=index)
metrics = {
instrument: {k: [] for k in EVALUATION_METRICS}
for instrument in EVALUATION_INSTRUMENTS}
for instrument in EVALUATION_INSTRUMENTS
}
for song in songs:
with open(song, 'r') as stream:
with open(song, "r") as stream:
data = json.load(stream)
for target in data['targets']:
instrument = target['name']
for target in data["targets"]:
instrument = target["name"]
for metric in EVALUATION_METRICS:
sdr_med = np.median([
frame['metrics'][metric]
for frame in target['frames']
if not np.isnan(frame['metrics'][metric])])
sdr_med = np.median(
[
frame["metrics"][metric]
for frame in target["frames"]
if not np.isnan(frame["metrics"][metric])
]
)
metrics[instrument][metric].append(sdr_med)
return metrics
@spleeter.command()
def evaluate(
adapter: str = AudioAdapterOption,
output_path: Path = AudioOutputOption,
stft_backend: STFTBackend = AudioSTFTBackendOption,
params_filename: str = ModelParametersOption,
mus_dir: Path = MUSDBDirectoryOption,
mwf: bool = MWFOption,
verbose: bool = VerboseOption) -> Dict:
adapter: str = AudioAdapterOption,
output_path: Path = AudioOutputOption,
stft_backend: STFTBackend = AudioSTFTBackendOption,
params_filename: str = ModelParametersOption,
mus_dir: Path = MUSDBDirectoryOption,
mwf: bool = MWFOption,
verbose: bool = VerboseOption,
) -> Dict:
"""
Evaluate a model on the musDB test dataset
Evaluate a model on the musDB test dataset
"""
import numpy as np
@@ -197,42 +203,44 @@ def evaluate(
import musdb
import museval
except ImportError:
logger.error('Extra dependencies musdb and museval not found')
logger.error('Please install musdb and museval first, abort')
logger.error("Extra dependencies musdb and museval not found")
logger.error("Please install musdb and museval first, abort")
raise Exit(10)
# Separate musdb sources.
songs = glob(join(mus_dir, EVALUATION_SPLIT, '*/'))
songs = glob(join(mus_dir, EVALUATION_SPLIT, "*/"))
mixtures = [join(song, EVALUATION_MIXTURE) for song in songs]
audio_output_directory = join(output_path, EVALUATION_AUDIO_DIRECTORY)
separate(
deprecated_files=None,
files=mixtures,
adapter=adapter,
bitrate='128k',
bitrate="128k",
codec=Codec.WAV,
duration=600.,
duration=600.0,
offset=0,
output_path=join(audio_output_directory, EVALUATION_SPLIT),
stft_backend=stft_backend,
filename_format='{foldername}/{instrument}.{codec}',
filename_format="{foldername}/{instrument}.{codec}",
params_filename=params_filename,
mwf=mwf,
verbose=verbose)
verbose=verbose,
)
# Compute metrics with musdb.
metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY)
logger.info('Starting musdb evaluation (this could be long) ...')
logger.info("Starting musdb evaluation (this could be long) ...")
dataset = musdb.DB(root=mus_dir, is_wav=True, subsets=[EVALUATION_SPLIT])
museval.eval_mus_dir(
dataset=dataset,
estimates_dir=audio_output_directory,
output_dir=metrics_output_directory)
logger.info('musdb evaluation done')
output_dir=metrics_output_directory,
)
logger.info("musdb evaluation done")
# Compute and pretty print median metrics.
metrics = _compile_metrics(metrics_output_directory)
for instrument, metric in metrics.items():
logger.info(f'{instrument}:')
logger.info(f"{instrument}:")
for metric, value in metric.items():
logger.info(f'{metric}: {np.median(value):.3f}')
logger.info(f"{metric}: {np.median(value):.3f}")
return metrics
@@ -244,5 +252,5 @@ def entrypoint():
logger.error(e)
if __name__ == '__main__':
if __name__ == "__main__":
entrypoint()

View File

@@ -12,28 +12,28 @@
from enum import Enum
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class Codec(str, Enum):
""" Enumeration of supported audio codec. """
WAV: str = 'wav'
MP3: str = 'mp3'
OGG: str = 'ogg'
M4A: str = 'm4a'
WMA: str = 'wma'
FLAC: str = 'flac'
WAV: str = "wav"
MP3: str = "mp3"
OGG: str = "ogg"
M4A: str = "m4a"
WMA: str = "wma"
FLAC: str = "flac"
class STFTBackend(str, Enum):
""" Enumeration of supported STFT backend. """
AUTO: str = 'auto'
TENSORFLOW: str = 'tensorflow'
LIBROSA: str = 'librosa'
AUTO: str = "auto"
TENSORFLOW: str = "tensorflow"
LIBROSA: str = "librosa"
@classmethod
def resolve(cls: type, backend: str) -> str:
@@ -44,9 +44,9 @@ class STFTBackend(str, Enum):
import tensorflow as tf
if backend not in cls.__members__.values():
raise ValueError(f'Unsupported backend {backend}')
raise ValueError(f"Unsupported backend {backend}")
if backend == cls.AUTO:
if len(tf.config.list_physical_devices('GPU')):
if len(tf.config.list_physical_devices("GPU")):
return cls.TENSORFLOW
return cls.LIBROSA
return backend

View File

@@ -6,94 +6,98 @@
from abc import ABC, abstractmethod
from importlib import import_module
from pathlib import Path
from spleeter.audio import Codec
from typing import Any, Dict, List, Optional, Union
from .. import SpleeterError
from ..types import AudioDescriptor, Signal
from ..utils.logging import logger
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
from spleeter.audio import Codec
from .. import SpleeterError
from ..types import AudioDescriptor, Signal
from ..utils.logging import logger
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class AudioAdapter(ABC):
""" An abstract class for manipulating audio signal. """
_DEFAULT: 'AudioAdapter' = None
_DEFAULT: "AudioAdapter" = None
""" Default audio adapter singleton instance. """
@abstractmethod
def load(
self,
audio_descriptor: AudioDescriptor,
offset: Optional[float] = None,
duration: Optional[float] = None,
sample_rate: Optional[float] = None,
dtype: np.dtype = np.float32) -> Signal:
self,
audio_descriptor: AudioDescriptor,
offset: Optional[float] = None,
duration: Optional[float] = None,
sample_rate: Optional[float] = None,
dtype: np.dtype = np.float32,
) -> Signal:
"""
Loads the audio file denoted by the given audio descriptor and
returns it data as a waveform. Aims to be implemented by client.
Loads the audio file denoted by the given audio descriptor and
returns it data as a waveform. Aims to be implemented by client.
Parameters:
audio_descriptor (AudioDescriptor):
Describe song to load, in case of file based audio adapter,
such descriptor would be a file path.
offset (Optional[float]):
Start offset to load from in seconds.
duration (Optional[float]):
Duration to load in seconds.
sample_rate (Optional[float]):
Sample rate to load audio with.
dtype (numpy.dtype):
(Optional) Numpy data type to use, default to `float32`.
Parameters:
audio_descriptor (AudioDescriptor):
Describe song to load, in case of file based audio adapter,
such descriptor would be a file path.
offset (Optional[float]):
Start offset to load from in seconds.
duration (Optional[float]):
Duration to load in seconds.
sample_rate (Optional[float]):
Sample rate to load audio with.
dtype (numpy.dtype):
(Optional) Numpy data type to use, default to `float32`.
Returns:
Signal:
Loaded data as (wf, sample_rate) tuple.
Returns:
Signal:
Loaded data as (wf, sample_rate) tuple.
"""
pass
def load_tf_waveform(
self,
audio_descriptor,
offset: float = 0.0,
duration: float = 1800.,
sample_rate: int = 44100,
dtype: bytes = b'float32',
waveform_name: str = 'waveform') -> Dict[str, Any]:
self,
audio_descriptor,
offset: float = 0.0,
duration: float = 1800.0,
sample_rate: int = 44100,
dtype: bytes = b"float32",
waveform_name: str = "waveform",
) -> Dict[str, Any]:
"""
Load the audio and convert it to a tensorflow waveform.
Load the audio and convert it to a tensorflow waveform.
Parameters:
audio_descriptor ():
Describe song to load, in case of file based audio adapter,
such descriptor would be a file path.
offset (float):
Start offset to load from in seconds.
duration (float):
Duration to load in seconds.
sample_rate (float):
Sample rate to load audio with.
dtype (bytes):
(Optional)data type to use, default to `b'float32'`.
waveform_name (str):
(Optional) Name of the key in output dict, default to
`'waveform'`.
Parameters:
audio_descriptor ():
Describe song to load, in case of file based audio adapter,
such descriptor would be a file path.
offset (float):
Start offset to load from in seconds.
duration (float):
Duration to load in seconds.
sample_rate (float):
Sample rate to load audio with.
dtype (bytes):
(Optional)data type to use, default to `b'float32'`.
waveform_name (str):
(Optional) Name of the key in output dict, default to
`'waveform'`.
Returns:
Dict[str, Any]:
TF output dict with waveform as `(T x chan numpy array)`
and a boolean that tells whether there were an error while
trying to load the waveform.
Returns:
Dict[str, Any]:
TF output dict with waveform as `(T x chan numpy array)`
and a boolean that tells whether there were an error while
trying to load the waveform.
"""
# Cast parameters to TF format.
offset = tf.cast(offset, tf.float64)
@@ -101,94 +105,96 @@ class AudioAdapter(ABC):
# Defined safe loading function.
def safe_load(path, offset, duration, sample_rate, dtype):
logger.info(
f'Loading audio {path} from {offset} to {offset + duration}')
logger.info(f"Loading audio {path} from {offset} to {offset + duration}")
try:
(data, _) = self.load(
path.numpy(),
offset.numpy(),
duration.numpy(),
sample_rate.numpy(),
dtype=dtype.numpy())
logger.info('Audio data loaded successfully')
dtype=dtype.numpy(),
)
logger.info("Audio data loaded successfully")
return (data, False)
except Exception as e:
logger.exception(
'An error occurs while loading audio',
exc_info=e)
logger.exception("An error occurs while loading audio", exc_info=e)
return (np.float32(-1.0), True)
# Execute function and format results.
results = tf.py_function(
safe_load,
[audio_descriptor, offset, duration, sample_rate, dtype],
(tf.float32, tf.bool)),
results = (
tf.py_function(
safe_load,
[audio_descriptor, offset, duration, sample_rate, dtype],
(tf.float32, tf.bool),
),
)
waveform, error = results[0]
return {
waveform_name: waveform,
f'{waveform_name}_error': error}
return {waveform_name: waveform, f"{waveform_name}_error": error}
@abstractmethod
def save(
self,
path: Union[Path, str],
data: np.ndarray,
sample_rate: float,
codec: Codec = None,
bitrate: str = None) -> None:
self,
path: Union[Path, str],
data: np.ndarray,
sample_rate: float,
codec: Codec = None,
bitrate: str = None,
) -> None:
"""
Save the given audio data to the file denoted by the given path.
Save the given audio data to the file denoted by the given path.
Parameters:
path (Union[Path, str]):
Path like of the audio file to save data in.
data (numpy.ndarray):
Waveform data to write.
sample_rate (float):
Sample rate to write file in.
codec ():
(Optional) Writing codec to use, default to `None`.
bitrate (str):
(Optional) Bitrate of the written audio file, default to
`None`.
Parameters:
path (Union[Path, str]):
Path like of the audio file to save data in.
data (numpy.ndarray):
Waveform data to write.
sample_rate (float):
Sample rate to write file in.
codec ():
(Optional) Writing codec to use, default to `None`.
bitrate (str):
(Optional) Bitrate of the written audio file, default to
`None`.
"""
pass
@classmethod
def default(cls: type) -> 'AudioAdapter':
def default(cls: type) -> "AudioAdapter":
"""
Builds and returns a default audio adapter instance.
Builds and returns a default audio adapter instance.
Returns:
AudioAdapter:
Default adapter instance to use.
Returns:
AudioAdapter:
Default adapter instance to use.
"""
if cls._DEFAULT is None:
from .ffmpeg import FFMPEGProcessAudioAdapter
cls._DEFAULT = FFMPEGProcessAudioAdapter()
return cls._DEFAULT
@classmethod
def get(cls: type, descriptor: str) -> 'AudioAdapter':
def get(cls: type, descriptor: str) -> "AudioAdapter":
"""
Load dynamically an AudioAdapter from given class descriptor.
Load dynamically an AudioAdapter from given class descriptor.
Parameters:
descriptor (str):
Adapter class descriptor (module.Class)
Parameters:
descriptor (str):
Adapter class descriptor (module.Class)
Returns:
AudioAdapter:
Created adapter instance.
Returns:
AudioAdapter:
Created adapter instance.
"""
if not descriptor:
return cls.default()
module_path: List[str] = descriptor.split('.')
module_path: List[str] = descriptor.split(".")
adapter_class_name: str = module_path[-1]
module_path: str = '.'.join(module_path[:-1])
module_path: str = ".".join(module_path[:-1])
adapter_module = import_module(module_path)
adapter_class = getattr(adapter_module, adapter_class_name)
if not issubclass(adapter_class, AudioAdapter):
raise SpleeterError(
f'{adapter_class_name} is not a valid AudioAdapter class')
f"{adapter_class_name} is not a valid AudioAdapter class"
)
return adapter_class()

View File

@@ -3,54 +3,54 @@
""" This module provides audio data convertion functions. """
from ..utils.tensor import from_float32_to_uint8, from_uint8_to_float32
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
from ..utils.tensor import from_float32_to_uint8, from_uint8_to_float32
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def to_n_channels(
waveform: tf.Tensor,
n_channels: int) -> tf.Tensor:
def to_n_channels(waveform: tf.Tensor, n_channels: int) -> tf.Tensor:
"""
Convert a waveform to n_channels by removing or duplicating channels if
needed (in tensorflow).
Convert a waveform to n_channels by removing or duplicating channels if
needed (in tensorflow).
Parameters:
waveform (tensorflow.Tensor):
Waveform to transform.
n_channels (int):
Number of channel to reshape waveform in.
Parameters:
waveform (tensorflow.Tensor):
Waveform to transform.
n_channels (int):
Number of channel to reshape waveform in.
Returns:
tensorflow.Tensor:
Reshaped waveform.
Returns:
tensorflow.Tensor:
Reshaped waveform.
"""
return tf.cond(
tf.shape(waveform)[1] >= n_channels,
true_fn=lambda: waveform[:, :n_channels],
false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels])
false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels],
)
def to_stereo(waveform: np.ndarray) -> np.ndarray:
"""
Convert a waveform to stereo by duplicating if mono, or truncating
if too many channels.
Convert a waveform to stereo by duplicating if mono, or truncating
if too many channels.
Parameters:
waveform (numpy.ndarray):
a `(N, d)` numpy array.
Parameters:
waveform (numpy.ndarray):
a `(N, d)` numpy array.
Returns:
numpy.ndarray:
A stereo waveform as a `(N, 1)` numpy array.
Returns:
numpy.ndarray:
A stereo waveform as a `(N, 1)` numpy array.
"""
if waveform.shape[1] == 1:
return np.repeat(waveform, 2, axis=-1)
@@ -61,82 +61,79 @@ def to_stereo(waveform: np.ndarray) -> np.ndarray:
def gain_to_db(tensor: tf.Tensor, espilon: float = 10e-10) -> tf.Tensor:
"""
Convert from gain to decibel in tensorflow.
Convert from gain to decibel in tensorflow.
Parameters:
tensor (tensorflow.Tensor):
Tensor to convert
epsilon (float):
Operation constant.
Parameters:
tensor (tensorflow.Tensor):
Tensor to convert
epsilon (float):
Operation constant.
Returns:
tensorflow.Tensor:
Converted tensor.
Returns:
tensorflow.Tensor:
Converted tensor.
"""
return 20. / np.log(10) * tf.math.log(tf.maximum(tensor, espilon))
return 20.0 / np.log(10) * tf.math.log(tf.maximum(tensor, espilon))
def db_to_gain(tensor: tf.Tensor) -> tf.Tensor:
"""
Convert from decibel to gain in tensorflow.
Convert from decibel to gain in tensorflow.
Parameters:
tensor (tensorflow.Tensor):
Tensor to convert
Parameters:
tensor (tensorflow.Tensor):
Tensor to convert
Returns:
tensorflow.Tensor:
Converted tensor.
Returns:
tensorflow.Tensor:
Converted tensor.
"""
return tf.pow(10., (tensor / 20.))
return tf.pow(10.0, (tensor / 20.0))
def spectrogram_to_db_uint(
spectrogram: tf.Tensor,
db_range: float = 100.,
**kwargs) -> tf.Tensor:
spectrogram: tf.Tensor, db_range: float = 100.0, **kwargs
) -> tf.Tensor:
"""
Encodes given spectrogram into uint8 using decibel scale.
Encodes given spectrogram into uint8 using decibel scale.
Parameters:
spectrogram (tensorflow.Tensor):
Spectrogram to be encoded as TF float tensor.
db_range (float):
Range in decibel for encoding.
Parameters:
spectrogram (tensorflow.Tensor):
Spectrogram to be encoded as TF float tensor.
db_range (float):
Range in decibel for encoding.
Returns:
tensorflow.Tensor:
Encoded decibel spectrogram as `uint8` tensor.
Returns:
tensorflow.Tensor:
Encoded decibel spectrogram as `uint8` tensor.
"""
db_spectrogram: tf.Tensor = gain_to_db(spectrogram)
max_db_spectrogram: tf.Tensor = tf.reduce_max(db_spectrogram)
db_spectrogram: tf.Tensor = tf.maximum(
db_spectrogram,
max_db_spectrogram - db_range)
db_spectrogram, max_db_spectrogram - db_range
)
return from_float32_to_uint8(db_spectrogram, **kwargs)
def db_uint_spectrogram_to_gain(
db_uint_spectrogram: tf.Tensor,
min_db: tf.Tensor,
max_db: tf.Tensor) -> tf.Tensor:
db_uint_spectrogram: tf.Tensor, min_db: tf.Tensor, max_db: tf.Tensor
) -> tf.Tensor:
"""
Decode spectrogram from uint8 decibel scale.
Decode spectrogram from uint8 decibel scale.
Paramters:
db_uint_spectrogram (tensorflow.Tensor):
Decibel spectrogram to decode.
min_db (tensorflow.Tensor):
Lower bound limit for decoding.
max_db (tensorflow.Tensor):
Upper bound limit for decoding.
Paramters:
db_uint_spectrogram (tensorflow.Tensor):
Decibel spectrogram to decode.
min_db (tensorflow.Tensor):
Lower bound limit for decoding.
max_db (tensorflow.Tensor):
Upper bound limit for decoding.
Returns:
tensorflow.Tensor:
Decoded spectrogram as `float32` tensor.
Returns:
tensorflow.Tensor:
Decoded spectrogram as `float32` tensor.
"""
db_spectrogram: tf.Tensor = from_uint8_to_float32(
db_uint_spectrogram,
min_db,
max_db)
db_uint_spectrogram, min_db, max_db
)
return db_to_gain(db_spectrogram)

View File

@@ -11,86 +11,87 @@
import datetime as dt
import os
import shutil
from pathlib import Path
from typing import Dict, Optional, Union
from . import Codec
from .adapter import AudioAdapter
from .. import SpleeterError
from ..types import Signal
from ..utils.logging import logger
# pyright: reportMissingImports=false
# pylint: disable=import-error
import ffmpeg
import numpy as np
from .. import SpleeterError
from ..types import Signal
from ..utils.logging import logger
from . import Codec
from .adapter import AudioAdapter
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class FFMPEGProcessAudioAdapter(AudioAdapter):
"""
An AudioAdapter implementation that use FFMPEG binary through
subprocess in order to perform I/O operation for audio processing.
An AudioAdapter implementation that use FFMPEG binary through
subprocess in order to perform I/O operation for audio processing.
When created, FFMPEG binary path will be checked and expended,
raising exception if not found. Such path could be infered using
`FFMPEG_PATH` environment variable.
When created, FFMPEG binary path will be checked and expended,
raising exception if not found. Such path could be infered using
`FFMPEG_PATH` environment variable.
"""
SUPPORTED_CODECS: Dict[Codec, str] = {
Codec.M4A: 'aac',
Codec.OGG: 'libvorbis',
Codec.WMA: 'wmav2'
Codec.M4A: "aac",
Codec.OGG: "libvorbis",
Codec.WMA: "wmav2",
}
""" FFMPEG codec name mapping. """
def __init__(_) -> None:
"""
Default constructor, ensure FFMPEG binaries are available.
Default constructor, ensure FFMPEG binaries are available.
Raises:
SpleeterError:
If ffmpeg or ffprobe is not found.
Raises:
SpleeterError:
If ffmpeg or ffprobe is not found.
"""
for binary in ('ffmpeg', 'ffprobe'):
for binary in ("ffmpeg", "ffprobe"):
if shutil.which(binary) is None:
raise SpleeterError('{} binary not found'.format(binary))
raise SpleeterError("{} binary not found".format(binary))
def load(
_,
path: Union[Path, str],
offset: Optional[float] = None,
duration: Optional[float] = None,
sample_rate: Optional[float] = None,
dtype: np.dtype = np.float32) -> Signal:
_,
path: Union[Path, str],
offset: Optional[float] = None,
duration: Optional[float] = None,
sample_rate: Optional[float] = None,
dtype: np.dtype = np.float32,
) -> Signal:
"""
Loads the audio file denoted by the given path
and returns it data as a waveform.
Loads the audio file denoted by the given path
and returns it data as a waveform.
Parameters:
path (Union[Path, str]:
Path of the audio file to load data from.
offset (Optional[float]):
Start offset to load from in seconds.
duration (Optional[float]):
Duration to load in seconds.
sample_rate (Optional[float]):
Sample rate to load audio with.
dtype (numpy.dtype):
(Optional) Numpy data type to use, default to `float32`.
Parameters:
path (Union[Path, str]:
Path of the audio file to load data from.
offset (Optional[float]):
Start offset to load from in seconds.
duration (Optional[float]):
Duration to load in seconds.
sample_rate (Optional[float]):
Sample rate to load audio with.
dtype (numpy.dtype):
(Optional) Numpy data type to use, default to `float32`.
Returns:
Signal:
Loaded data a (waveform, sample_rate) tuple.
Returns:
Signal:
Loaded data a (waveform, sample_rate) tuple.
Raises:
SpleeterError:
If any error occurs while loading audio.
Raises:
SpleeterError:
If any error occurs while loading audio.
"""
if isinstance(path, Path):
path = str(path)
@@ -100,84 +101,85 @@ class FFMPEGProcessAudioAdapter(AudioAdapter):
probe = ffmpeg.probe(path)
except ffmpeg._run.Error as e:
raise SpleeterError(
'An error occurs with ffprobe (see ffprobe output below)\n\n{}'
.format(e.stderr.decode()))
if 'streams' not in probe or len(probe['streams']) == 0:
raise SpleeterError('No stream was found with ffprobe')
"An error occurs with ffprobe (see ffprobe output below)\n\n{}".format(
e.stderr.decode()
)
)
if "streams" not in probe or len(probe["streams"]) == 0:
raise SpleeterError("No stream was found with ffprobe")
metadata = next(
stream
for stream in probe['streams']
if stream['codec_type'] == 'audio')
n_channels = metadata['channels']
stream for stream in probe["streams"] if stream["codec_type"] == "audio"
)
n_channels = metadata["channels"]
if sample_rate is None:
sample_rate = metadata['sample_rate']
output_kwargs = {'format': 'f32le', 'ar': sample_rate}
sample_rate = metadata["sample_rate"]
output_kwargs = {"format": "f32le", "ar": sample_rate}
if duration is not None:
output_kwargs['t'] = str(dt.timedelta(seconds=duration))
output_kwargs["t"] = str(dt.timedelta(seconds=duration))
if offset is not None:
output_kwargs['ss'] = str(dt.timedelta(seconds=offset))
output_kwargs["ss"] = str(dt.timedelta(seconds=offset))
process = (
ffmpeg
.input(path)
.output('pipe:', **output_kwargs)
.run_async(pipe_stdout=True, pipe_stderr=True))
ffmpeg.input(path)
.output("pipe:", **output_kwargs)
.run_async(pipe_stdout=True, pipe_stderr=True)
)
buffer, _ = process.communicate()
waveform = np.frombuffer(buffer, dtype='<f4').reshape(-1, n_channels)
waveform = np.frombuffer(buffer, dtype="<f4").reshape(-1, n_channels)
if not waveform.dtype == np.dtype(dtype):
waveform = waveform.astype(dtype)
return (waveform, sample_rate)
def save(
self,
path: Union[Path, str],
data: np.ndarray,
sample_rate: float,
codec: Codec = None,
bitrate: str = None) -> None:
self,
path: Union[Path, str],
data: np.ndarray,
sample_rate: float,
codec: Codec = None,
bitrate: str = None,
) -> None:
"""
Write waveform data to the file denoted by the given path using
FFMPEG process.
Write waveform data to the file denoted by the given path using
FFMPEG process.
Parameters:
path (Union[Path, str]):
Path like of the audio file to save data in.
data (numpy.ndarray):
Waveform data to write.
sample_rate (float):
Sample rate to write file in.
codec ():
(Optional) Writing codec to use, default to `None`.
bitrate (str):
(Optional) Bitrate of the written audio file, default to
`None`.
Parameters:
path (Union[Path, str]):
Path like of the audio file to save data in.
data (numpy.ndarray):
Waveform data to write.
sample_rate (float):
Sample rate to write file in.
codec ():
(Optional) Writing codec to use, default to `None`.
bitrate (str):
(Optional) Bitrate of the written audio file, default to
`None`.
Raises:
IOError:
If any error occurs while using FFMPEG to write data.
Raises:
IOError:
If any error occurs while using FFMPEG to write data.
"""
if isinstance(path, Path):
path = str(path)
directory = os.path.dirname(path)
if not os.path.exists(directory):
raise SpleeterError(
f'output directory does not exists: {directory}')
logger.debug(f'Writing file {path}')
input_kwargs = {'ar': sample_rate, 'ac': data.shape[1]}
output_kwargs = {'ar': sample_rate, 'strict': '-2'}
raise SpleeterError(f"output directory does not exists: {directory}")
logger.debug(f"Writing file {path}")
input_kwargs = {"ar": sample_rate, "ac": data.shape[1]}
output_kwargs = {"ar": sample_rate, "strict": "-2"}
if bitrate:
output_kwargs['audio_bitrate'] = bitrate
if codec is not None and codec != 'wav':
output_kwargs['codec'] = self.SUPPORTED_CODECS.get(codec, codec)
output_kwargs["audio_bitrate"] = bitrate
if codec is not None and codec != "wav":
output_kwargs["codec"] = self.SUPPORTED_CODECS.get(codec, codec)
process = (
ffmpeg
.input('pipe:', format='f32le', **input_kwargs)
ffmpeg.input("pipe:", format="f32le", **input_kwargs)
.output(path, **output_kwargs)
.overwrite_output()
.run_async(pipe_stdin=True, pipe_stderr=True, quiet=True))
.run_async(pipe_stdin=True, pipe_stderr=True, quiet=True)
)
try:
process.stdin.write(data.astype('<f4').tobytes())
process.stdin.write(data.astype("<f4").tobytes())
process.stdin.close()
process.wait()
except IOError:
raise SpleeterError(f'FFMPEG error: {process.stderr.read()}')
logger.info(f'File {path} written succesfully')
raise SpleeterError(f"FFMPEG error: {process.stderr.read()}")
logger.info(f"File {path} written succesfully")

View File

@@ -7,43 +7,44 @@
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
from tensorflow.signal import hann_window, stft
from tensorflow.signal import stft, hann_window
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def compute_spectrogram_tf(
waveform: tf.Tensor,
frame_length: int = 2048,
frame_step: int = 512,
spec_exponent: float = 1.,
window_exponent: float = 1.) -> tf.Tensor:
waveform: tf.Tensor,
frame_length: int = 2048,
frame_step: int = 512,
spec_exponent: float = 1.0,
window_exponent: float = 1.0,
) -> tf.Tensor:
"""
Compute magnitude / power spectrogram from waveform as a
`n_samples x n_channels` tensor.
Compute magnitude / power spectrogram from waveform as a
`n_samples x n_channels` tensor.
Parameters:
waveform (tensorflow.Tensor):
Input waveform as `(times x number of channels)` tensor.
frame_length (int):
Length of a STFT frame to use.
frame_step (int):
HOP between successive frames.
spec_exponent (float):
Exponent of the spectrogram (usually 1 for magnitude
spectrogram, or 2 for power spectrogram).
window_exponent (float):
Exponent applied to the Hann windowing function (may be
useful for making perfect STFT/iSTFT reconstruction).
Parameters:
waveform (tensorflow.Tensor):
Input waveform as `(times x number of channels)` tensor.
frame_length (int):
Length of a STFT frame to use.
frame_step (int):
HOP between successive frames.
spec_exponent (float):
Exponent of the spectrogram (usually 1 for magnitude
spectrogram, or 2 for power spectrogram).
window_exponent (float):
Exponent applied to the Hann windowing function (may be
useful for making perfect STFT/iSTFT reconstruction).
Returns:
tensorflow.Tensor:
Computed magnitude / power spectrogram as a
`(T x F x n_channels)` tensor.
Returns:
tensorflow.Tensor:
Computed magnitude / power spectrogram as a
`(T x F x n_channels)` tensor.
"""
stft_tensor: tf.Tensor = tf.transpose(
stft(
@@ -51,131 +52,125 @@ def compute_spectrogram_tf(
frame_length,
frame_step,
window_fn=lambda f, dtype: hann_window(
f,
periodic=True,
dtype=waveform.dtype) ** window_exponent),
perm=[1, 2, 0])
f, periodic=True, dtype=waveform.dtype
)
** window_exponent,
),
perm=[1, 2, 0],
)
return tf.abs(stft_tensor) ** spec_exponent
def time_stretch(
spectrogram: tf.Tensor,
factor: float = 1.0,
method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR
) -> tf.Tensor:
spectrogram: tf.Tensor,
factor: float = 1.0,
method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR,
) -> tf.Tensor:
"""
Time stretch a spectrogram preserving shape in tensorflow. Note that
this is an approximation in the frequency domain.
Time stretch a spectrogram preserving shape in tensorflow. Note that
this is an approximation in the frequency domain.
Parameters:
spectrogram (tensorflow.Tensor):
Input spectrogram to be time stretched as tensor.
factor (float):
(Optional) Time stretch factor, must be > 0, default to `1`.
method (tensorflow.image.ResizeMethod):
(Optional) Interpolation method, default to `BILINEAR`.
Parameters:
spectrogram (tensorflow.Tensor):
Input spectrogram to be time stretched as tensor.
factor (float):
(Optional) Time stretch factor, must be > 0, default to `1`.
method (tensorflow.image.ResizeMethod):
(Optional) Interpolation method, default to `BILINEAR`.
Returns:
tensorflow.Tensor:
Time stretched spectrogram as tensor with same shape.
Returns:
tensorflow.Tensor:
Time stretched spectrogram as tensor with same shape.
"""
T = tf.shape(spectrogram)[0]
T_ts = tf.cast(tf.cast(T, tf.float32) * factor, tf.int32)[0]
F = tf.shape(spectrogram)[1]
ts_spec = tf.image.resize_images(
spectrogram,
[T_ts, F],
method=method,
align_corners=True)
spectrogram, [T_ts, F], method=method, align_corners=True
)
return tf.image.resize_image_with_crop_or_pad(ts_spec, T, F)
def random_time_stretch(
spectrogram: tf.Tensor,
factor_min: float = 0.9,
factor_max: float = 1.1,
**kwargs) -> tf.Tensor:
spectrogram: tf.Tensor, factor_min: float = 0.9, factor_max: float = 1.1, **kwargs
) -> tf.Tensor:
"""
Time stretch a spectrogram preserving shape with random ratio in
tensorflow. Applies time_stretch to spectrogram with a random ratio
drawn uniformly in `[factor_min, factor_max]`.
Time stretch a spectrogram preserving shape with random ratio in
tensorflow. Applies time_stretch to spectrogram with a random ratio
drawn uniformly in `[factor_min, factor_max]`.
Parameters:
spectrogram (tensorflow.Tensor):
Input spectrogram to be time stretched as tensor.
factor_min (float):
(Optional) Min time stretch factor, default to `0.9`.
factor_max (float):
(Optional) Max time stretch factor, default to `1.1`.
Parameters:
spectrogram (tensorflow.Tensor):
Input spectrogram to be time stretched as tensor.
factor_min (float):
(Optional) Min time stretch factor, default to `0.9`.
factor_max (float):
(Optional) Max time stretch factor, default to `1.1`.
Returns:
tensorflow.Tensor:
Randomly time stretched spectrogram as tensor with same shape.
Returns:
tensorflow.Tensor:
Randomly time stretched spectrogram as tensor with same shape.
"""
factor = tf.random_uniform(
shape=(1,),
seed=0) * (factor_max - factor_min) + factor_min
factor = (
tf.random_uniform(shape=(1,), seed=0) * (factor_max - factor_min) + factor_min
)
return time_stretch(spectrogram, factor=factor, **kwargs)
def pitch_shift(
spectrogram: tf.Tensor,
semitone_shift: float = 0.0,
method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR
) -> tf.Tensor:
spectrogram: tf.Tensor,
semitone_shift: float = 0.0,
method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR,
) -> tf.Tensor:
"""
Pitch shift a spectrogram preserving shape in tensorflow. Note that
this is an approximation in the frequency domain.
Pitch shift a spectrogram preserving shape in tensorflow. Note that
this is an approximation in the frequency domain.
Parameters:
spectrogram (tensorflow.Tensor):
Input spectrogram to be pitch shifted as tensor.
semitone_shift (float):
(Optional) Pitch shift in semitone, default to `0.0`.
method (tensorflow.image.ResizeMethod):
(Optional) Interpolation method, default to `BILINEAR`.
Parameters:
spectrogram (tensorflow.Tensor):
Input spectrogram to be pitch shifted as tensor.
semitone_shift (float):
(Optional) Pitch shift in semitone, default to `0.0`.
method (tensorflow.image.ResizeMethod):
(Optional) Interpolation method, default to `BILINEAR`.
Returns:
tensorflow.Tensor:
Pitch shifted spectrogram (same shape as spectrogram).
Returns:
tensorflow.Tensor:
Pitch shifted spectrogram (same shape as spectrogram).
"""
factor = 2 ** (semitone_shift / 12.)
factor = 2 ** (semitone_shift / 12.0)
T = tf.shape(spectrogram)[0]
F = tf.shape(spectrogram)[1]
F_ps = tf.cast(tf.cast(F, tf.float32) * factor, tf.int32)[0]
ps_spec = tf.image.resize_images(
spectrogram,
[T, F_ps],
method=method,
align_corners=True)
spectrogram, [T, F_ps], method=method, align_corners=True
)
paddings = [[0, 0], [0, tf.maximum(0, F - F_ps)], [0, 0]]
return tf.pad(ps_spec[:, :F, :], paddings, 'CONSTANT')
return tf.pad(ps_spec[:, :F, :], paddings, "CONSTANT")
def random_pitch_shift(
spectrogram: tf.Tensor,
shift_min: float = -1.,
shift_max: float = 1.,
**kwargs) -> tf.Tensor:
spectrogram: tf.Tensor, shift_min: float = -1.0, shift_max: float = 1.0, **kwargs
) -> tf.Tensor:
"""
Pitch shift a spectrogram preserving shape with random ratio in
tensorflow. Applies pitch_shift to spectrogram with a random shift
amount (expressed in semitones) drawn uniformly in
`[shift_min, shift_max]`.
Pitch shift a spectrogram preserving shape with random ratio in
tensorflow. Applies pitch_shift to spectrogram with a random shift
amount (expressed in semitones) drawn uniformly in
`[shift_min, shift_max]`.
Parameters:
spectrogram (tensorflow.Tensor):
Input spectrogram to be pitch shifted as tensor.
shift_min (float):
(Optional) Min pitch shift in semitone, default to -1.
shift_max (float):
(Optional) Max pitch shift in semitone, default to 1.
Parameters:
spectrogram (tensorflow.Tensor):
Input spectrogram to be pitch shifted as tensor.
shift_min (float):
(Optional) Min pitch shift in semitone, default to -1.
shift_max (float):
(Optional) Max pitch shift in semitone, default to 1.
Returns:
tensorflow.Tensor:
Randomly pitch shifted spectrogram (same shape as spectrogram).
Returns:
tensorflow.Tensor:
Randomly pitch shifted spectrogram (same shape as spectrogram).
"""
semitone_shift = tf.random_uniform(
shape=(1,),
seed=0) * (shift_max - shift_min) + shift_min
semitone_shift = (
tf.random_uniform(shape=(1,), seed=0) * (shift_max - shift_min) + shift_min
)
return pitch_shift(spectrogram, semitone_shift=semitone_shift, **kwargs)

View File

@@ -14,104 +14,110 @@
(ground truth)
"""
import time
import os
from os.path import exists, sep as SEPARATOR
import time
from os.path import exists
from os.path import sep as SEPARATOR
from typing import Any, Dict, Optional
from .audio.adapter import AudioAdapter
from .audio.convertor import db_uint_spectrogram_to_gain
from .audio.convertor import spectrogram_to_db_uint
from .audio.spectrogram import compute_spectrogram_tf
from .audio.spectrogram import random_pitch_shift, random_time_stretch
from .utils.logging import logger
from .utils.tensor import check_tensor_shape, dataset_from_csv
from .utils.tensor import set_tensor_shape, sync_apply
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
from .audio.adapter import AudioAdapter
from .audio.convertor import db_uint_spectrogram_to_gain, spectrogram_to_db_uint
from .audio.spectrogram import (
compute_spectrogram_tf,
random_pitch_shift,
random_time_stretch,
)
from .utils.logging import logger
from .utils.tensor import (
check_tensor_shape,
dataset_from_csv,
set_tensor_shape,
sync_apply,
)
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
# Default audio parameters to use.
DEFAULT_AUDIO_PARAMS: Dict = {
'instrument_list': ('vocals', 'accompaniment'),
'mix_name': 'mix',
'sample_rate': 44100,
'frame_length': 4096,
'frame_step': 1024,
'T': 512,
'F': 1024}
"instrument_list": ("vocals", "accompaniment"),
"mix_name": "mix",
"sample_rate": 44100,
"frame_length": 4096,
"frame_step": 1024,
"T": 512,
"F": 1024,
}
def get_training_dataset(
audio_params: Dict,
audio_adapter: AudioAdapter,
audio_path: str) -> Any:
audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str
) -> Any:
"""
Builds training dataset.
Builds training dataset.
Parameters:
audio_params (Dict):
Audio parameters.
audio_adapter (AudioAdapter):
Adapter to load audio from.
audio_path (str):
Path of directory containing audio.
Parameters:
audio_params (Dict):
Audio parameters.
audio_adapter (AudioAdapter):
Adapter to load audio from.
audio_path (str):
Path of directory containing audio.
Returns:
Any:
Built dataset.
Returns:
Any:
Built dataset.
"""
builder = DatasetBuilder(
audio_params,
audio_adapter,
audio_path,
chunk_duration=audio_params.get('chunk_duration', 20.0),
random_seed=audio_params.get('random_seed', 0))
chunk_duration=audio_params.get("chunk_duration", 20.0),
random_seed=audio_params.get("random_seed", 0),
)
return builder.build(
audio_params.get('train_csv'),
cache_directory=audio_params.get('training_cache'),
batch_size=audio_params.get('batch_size'),
n_chunks_per_song=audio_params.get('n_chunks_per_song', 2),
audio_params.get("train_csv"),
cache_directory=audio_params.get("training_cache"),
batch_size=audio_params.get("batch_size"),
n_chunks_per_song=audio_params.get("n_chunks_per_song", 2),
random_data_augmentation=False,
convert_to_uint=True,
wait_for_cache=False)
wait_for_cache=False,
)
def get_validation_dataset(
audio_params: Dict,
audio_adapter: AudioAdapter,
audio_path: str) -> Any:
audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str
) -> Any:
"""
Builds validation dataset.
Builds validation dataset.
Parameters:
audio_params (Dict):
Audio parameters.
audio_adapter (AudioAdapter):
Adapter to load audio from.
audio_path (str):
Path of directory containing audio.
Parameters:
audio_params (Dict):
Audio parameters.
audio_adapter (AudioAdapter):
Adapter to load audio from.
audio_path (str):
Path of directory containing audio.
Returns:
Any:
Built dataset.
Returns:
Any:
Built dataset.
"""
builder = DatasetBuilder(
audio_params,
audio_adapter,
audio_path,
chunk_duration=12.0)
audio_params, audio_adapter, audio_path, chunk_duration=12.0
)
return builder.build(
audio_params.get('validation_csv'),
batch_size=audio_params.get('batch_size'),
cache_directory=audio_params.get('validation_cache'),
audio_params.get("validation_csv"),
batch_size=audio_params.get("batch_size"),
cache_directory=audio_params.get("validation_cache"),
convert_to_uint=True,
infinite_generator=False,
n_chunks_per_song=1,
@@ -127,97 +133,132 @@ class InstrumentDatasetBuilder(object):
def __init__(self, parent, instrument) -> None:
"""
Default constructor.
Default constructor.
Parameters:
parent:
Parent dataset builder.
instrument:
Target instrument.
Parameters:
parent:
Parent dataset builder.
instrument:
Target instrument.
"""
self._parent = parent
self._instrument = instrument
self._spectrogram_key = f'{instrument}_spectrogram'
self._min_spectrogram_key = f'min_{instrument}_spectrogram'
self._max_spectrogram_key = f'max_{instrument}_spectrogram'
self._spectrogram_key = f"{instrument}_spectrogram"
self._min_spectrogram_key = f"min_{instrument}_spectrogram"
self._max_spectrogram_key = f"max_{instrument}_spectrogram"
def load_waveform(self, sample):
""" Load waveform for given sample. """
return dict(sample, **self._parent._audio_adapter.load_tf_waveform(
sample[f'{self._instrument}_path'],
offset=sample['start'],
duration=self._parent._chunk_duration,
sample_rate=self._parent._sample_rate,
waveform_name='waveform'))
return dict(
sample,
**self._parent._audio_adapter.load_tf_waveform(
sample[f"{self._instrument}_path"],
offset=sample["start"],
duration=self._parent._chunk_duration,
sample_rate=self._parent._sample_rate,
waveform_name="waveform",
),
)
def compute_spectrogram(self, sample):
""" Compute spectrogram of the given sample. """
return dict(sample, **{
self._spectrogram_key: compute_spectrogram_tf(
sample['waveform'],
frame_length=self._parent._frame_length,
frame_step=self._parent._frame_step,
spec_exponent=1.,
window_exponent=1.)})
return dict(
sample,
**{
self._spectrogram_key: compute_spectrogram_tf(
sample["waveform"],
frame_length=self._parent._frame_length,
frame_step=self._parent._frame_step,
spec_exponent=1.0,
window_exponent=1.0,
)
},
)
def filter_frequencies(self, sample):
""" """
return dict(sample, **{
self._spectrogram_key:
sample[self._spectrogram_key][:, :self._parent._F, :]})
return dict(
sample,
**{
self._spectrogram_key: sample[self._spectrogram_key][
:, : self._parent._F, :
]
},
)
def convert_to_uint(self, sample):
""" Convert given sample from float to unit. """
return dict(sample, **spectrogram_to_db_uint(
sample[self._spectrogram_key],
tensor_key=self._spectrogram_key,
min_key=self._min_spectrogram_key,
max_key=self._max_spectrogram_key))
return dict(
sample,
**spectrogram_to_db_uint(
sample[self._spectrogram_key],
tensor_key=self._spectrogram_key,
min_key=self._min_spectrogram_key,
max_key=self._max_spectrogram_key,
),
)
def filter_infinity(self, sample):
""" Filter infinity sample. """
return tf.logical_not(
tf.math.is_inf(
sample[self._min_spectrogram_key]))
return tf.logical_not(tf.math.is_inf(sample[self._min_spectrogram_key]))
def convert_to_float32(self, sample):
""" Convert given sample from unit to float. """
return dict(sample, **{
self._spectrogram_key: db_uint_spectrogram_to_gain(
sample[self._spectrogram_key],
sample[self._min_spectrogram_key],
sample[self._max_spectrogram_key])})
return dict(
sample,
**{
self._spectrogram_key: db_uint_spectrogram_to_gain(
sample[self._spectrogram_key],
sample[self._min_spectrogram_key],
sample[self._max_spectrogram_key],
)
},
)
def time_crop(self, sample):
""" """
def start(sample):
""" mid_segment_start """
return tf.cast(
tf.maximum(
tf.shape(sample[self._spectrogram_key])[0]
/ 2 - self._parent._T / 2, 0),
tf.int32)
return dict(sample, **{
self._spectrogram_key: sample[self._spectrogram_key][
start(sample):start(sample) + self._parent._T, :, :]})
tf.shape(sample[self._spectrogram_key])[0] / 2
- self._parent._T / 2,
0,
),
tf.int32,
)
return dict(
sample,
**{
self._spectrogram_key: sample[self._spectrogram_key][
start(sample) : start(sample) + self._parent._T, :, :
]
},
)
def filter_shape(self, sample):
""" Filter badly shaped sample. """
return check_tensor_shape(
sample[self._spectrogram_key], (
self._parent._T, self._parent._F, 2))
sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2)
)
def reshape_spectrogram(self, sample):
""" Reshape given sample. """
return dict(sample, **{
self._spectrogram_key: set_tensor_shape(
sample[self._spectrogram_key],
(self._parent._T, self._parent._F, 2))})
return dict(
sample,
**{
self._spectrogram_key: set_tensor_shape(
sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2)
)
},
)
class DatasetBuilder(object):
"""
TO BE DOCUMENTED.
TO BE DOCUMENTED.
"""
MARGIN: float = 0.5
@@ -227,37 +268,38 @@ class DatasetBuilder(object):
""" Wait period for cache (in seconds). """
def __init__(
self,
audio_params: Dict,
audio_adapter: AudioAdapter,
audio_path: str,
random_seed: int = 0,
chunk_duration: float = 20.0) -> None:
self,
audio_params: Dict,
audio_adapter: AudioAdapter,
audio_path: str,
random_seed: int = 0,
chunk_duration: float = 20.0,
) -> None:
"""
Default constructor.
Default constructor.
NOTE: Probably need for AudioAdapter.
NOTE: Probably need for AudioAdapter.
Parameters:
audio_params (Dict):
Audio parameters to use.
audio_adapter (AudioAdapter):
Audio adapter to use.
audio_path (str):
random_seed (int):
chunk_duration (float):
Parameters:
audio_params (Dict):
Audio parameters to use.
audio_adapter (AudioAdapter):
Audio adapter to use.
audio_path (str):
random_seed (int):
chunk_duration (float):
"""
# Length of segment in frames (if fs=22050 and
# frame_step=512, then T=512 corresponds to 11.89s)
self._T = audio_params['T']
self._T = audio_params["T"]
# Number of frequency bins to be used (should
# be less than frame_length/2 + 1)
self._F = audio_params['F']
self._sample_rate = audio_params['sample_rate']
self._frame_length = audio_params['frame_length']
self._frame_step = audio_params['frame_step']
self._mix_name = audio_params['mix_name']
self._instruments = [self._mix_name] + audio_params['instrument_list']
self._F = audio_params["F"]
self._sample_rate = audio_params["sample_rate"]
self._frame_length = audio_params["frame_length"]
self._frame_step = audio_params["frame_step"]
self._mix_name = audio_params["mix_name"]
self._instruments = [self._mix_name] + audio_params["instrument_list"]
self._instrument_builders = None
self._chunk_duration = chunk_duration
self._audio_adapter = audio_adapter
@@ -267,105 +309,157 @@ class DatasetBuilder(object):
def expand_path(self, sample):
""" Expands audio paths for the given sample. """
return dict(sample, **{f'{instrument}_path': tf.strings.join(
(self._audio_path, sample[f'{instrument}_path']), SEPARATOR)
for instrument in self._instruments})
return dict(
sample,
**{
f"{instrument}_path": tf.strings.join(
(self._audio_path, sample[f"{instrument}_path"]), SEPARATOR
)
for instrument in self._instruments
},
)
def filter_error(self, sample):
""" Filter errored sample. """
return tf.logical_not(sample['waveform_error'])
return tf.logical_not(sample["waveform_error"])
def filter_waveform(self, sample):
""" Filter waveform from sample. """
return {k: v for k, v in sample.items() if not k == 'waveform'}
return {k: v for k, v in sample.items() if not k == "waveform"}
def harmonize_spectrogram(self, sample):
""" Ensure same size for vocals and mix spectrograms. """
def _reduce(sample):
return tf.reduce_min([
tf.shape(sample[f'{instrument}_spectrogram'])[0]
for instrument in self._instruments])
return dict(sample, **{
f'{instrument}_spectrogram':
sample[f'{instrument}_spectrogram'][:_reduce(sample), :, :]
for instrument in self._instruments})
return tf.reduce_min(
[
tf.shape(sample[f"{instrument}_spectrogram"])[0]
for instrument in self._instruments
]
)
return dict(
sample,
**{
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"][
: _reduce(sample), :, :
]
for instrument in self._instruments
},
)
def filter_short_segments(self, sample):
""" Filter out too short segment. """
return tf.reduce_any([
tf.shape(sample[f'{instrument}_spectrogram'])[0] >= self._T
for instrument in self._instruments])
return tf.reduce_any(
[
tf.shape(sample[f"{instrument}_spectrogram"])[0] >= self._T
for instrument in self._instruments
]
)
def random_time_crop(self, sample):
""" Random time crop of 11.88s. """
return dict(sample, **sync_apply({
f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram']
for instrument in self._instruments},
lambda x: tf.image.random_crop(
x, (self._T, len(self._instruments) * self._F, 2),
seed=self._random_seed)))
return dict(
sample,
**sync_apply(
{
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
for instrument in self._instruments
},
lambda x: tf.image.random_crop(
x,
(self._T, len(self._instruments) * self._F, 2),
seed=self._random_seed,
),
),
)
def random_time_stretch(self, sample):
""" Randomly time stretch the given sample. """
return dict(sample, **sync_apply({
f'{instrument}_spectrogram':
sample[f'{instrument}_spectrogram']
for instrument in self._instruments},
lambda x: random_time_stretch(
x, factor_min=0.9, factor_max=1.1)))
return dict(
sample,
**sync_apply(
{
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
for instrument in self._instruments
},
lambda x: random_time_stretch(x, factor_min=0.9, factor_max=1.1),
),
)
def random_pitch_shift(self, sample):
""" Randomly pitch shift the given sample. """
return dict(sample, **sync_apply({
f'{instrument}_spectrogram':
sample[f'{instrument}_spectrogram']
for instrument in self._instruments},
lambda x: random_pitch_shift(
x, shift_min=-1.0, shift_max=1.0), concat_axis=0))
return dict(
sample,
**sync_apply(
{
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
for instrument in self._instruments
},
lambda x: random_pitch_shift(x, shift_min=-1.0, shift_max=1.0),
concat_axis=0,
),
)
def map_features(self, sample):
""" Select features and annotation of the given sample. """
input_ = {
f'{self._mix_name}_spectrogram':
sample[f'{self._mix_name}_spectrogram']}
f"{self._mix_name}_spectrogram": sample[f"{self._mix_name}_spectrogram"]
}
output = {
f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram']
for instrument in self._audio_params['instrument_list']}
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
for instrument in self._audio_params["instrument_list"]
}
return (input_, output)
def compute_segments(
self,
dataset: Any,
n_chunks_per_song: int) -> Any:
def compute_segments(self, dataset: Any, n_chunks_per_song: int) -> Any:
"""
Computes segments for each song of the dataset.
Computes segments for each song of the dataset.
Parameters:
dataset (Any):
Dataset to compute segments for.
n_chunks_per_song (int):
Number of segment per song to compute.
Parameters:
dataset (Any):
Dataset to compute segments for.
n_chunks_per_song (int):
Number of segment per song to compute.
Returns:
Any:
Segmented dataset.
Returns:
Any:
Segmented dataset.
"""
if n_chunks_per_song <= 0:
raise ValueError('n_chunks_per_song must be positif')
raise ValueError("n_chunks_per_song must be positif")
datasets = []
for k in range(n_chunks_per_song):
if n_chunks_per_song > 1:
datasets.append(
dataset.map(lambda sample: dict(sample, start=tf.maximum(
k * (
sample['duration'] - self._chunk_duration - 2
* self.MARGIN) / (n_chunks_per_song - 1)
+ self.MARGIN, 0))))
dataset.map(
lambda sample: dict(
sample,
start=tf.maximum(
k
* (
sample["duration"]
- self._chunk_duration
- 2 * self.MARGIN
)
/ (n_chunks_per_song - 1)
+ self.MARGIN,
0,
),
)
)
)
elif n_chunks_per_song == 1: # Take central segment.
datasets.append(
dataset.map(lambda sample: dict(sample, start=tf.maximum(
sample['duration'] / 2 - self._chunk_duration / 2,
0))))
dataset.map(
lambda sample: dict(
sample,
start=tf.maximum(
sample["duration"] / 2 - self._chunk_duration / 2, 0
),
)
)
)
dataset = datasets[-1]
for d in datasets[:-1]:
dataset = dataset.concatenate(d)
@@ -374,47 +468,43 @@ class DatasetBuilder(object):
@property
def instruments(self) -> Any:
"""
Instrument dataset builder generator.
Instrument dataset builder generator.
Yields:
Any:
InstrumentBuilder instance.
Yields:
Any:
InstrumentBuilder instance.
"""
if self._instrument_builders is None:
self._instrument_builders = []
for instrument in self._instruments:
self._instrument_builders.append(
InstrumentDatasetBuilder(self, instrument))
InstrumentDatasetBuilder(self, instrument)
)
for builder in self._instrument_builders:
yield builder
def cache(
self,
dataset: Any,
cache: str,
wait: bool) -> Any:
def cache(self, dataset: Any, cache: str, wait: bool) -> Any:
"""
Cache the given dataset if cache is enabled. Eventually waits for
cache to be available (useful if another process is already
computing cache) if provided wait flag is `True`.
Cache the given dataset if cache is enabled. Eventually waits for
cache to be available (useful if another process is already
computing cache) if provided wait flag is `True`.
Parameters:
dataset (Any):
Dataset to be cached if cache is required.
cache (str):
Path of cache directory to be used, None if no cache.
wait (bool):
If caching is enabled, True is cache should be waited.
Parameters:
dataset (Any):
Dataset to be cached if cache is required.
cache (str):
Path of cache directory to be used, None if no cache.
wait (bool):
If caching is enabled, True is cache should be waited.
Returns:
Any:
Cached dataset if needed, original dataset otherwise.
Returns:
Any:
Cached dataset if needed, original dataset otherwise.
"""
if cache is not None:
if wait:
while not exists(f'{cache}.index'):
logger.info(
f'Cache not available, wait {self.WAIT_PERIOD}')
while not exists(f"{cache}.index"):
logger.info(f"Cache not available, wait {self.WAIT_PERIOD}")
time.sleep(self.WAIT_PERIOD)
cache_path = os.path.split(cache)[0]
os.makedirs(cache_path, exist_ok=True)
@@ -422,20 +512,21 @@ class DatasetBuilder(object):
return dataset
def build(
self,
csv_path: str,
batch_size: int = 8,
shuffle: bool = True,
convert_to_uint: bool = True,
random_data_augmentation: bool = False,
random_time_crop: bool = True,
infinite_generator: bool = True,
cache_directory: Optional[str] = None,
wait_for_cache: bool = False,
num_parallel_calls: int = 4,
n_chunks_per_song: float = 2,) -> Any:
self,
csv_path: str,
batch_size: int = 8,
shuffle: bool = True,
convert_to_uint: bool = True,
random_data_augmentation: bool = False,
random_time_crop: bool = True,
infinite_generator: bool = True,
cache_directory: Optional[str] = None,
wait_for_cache: bool = False,
num_parallel_calls: int = 4,
n_chunks_per_song: float = 2,
) -> Any:
"""
TO BE DOCUMENTED.
TO BE DOCUMENTED.
"""
dataset = dataset_from_csv(csv_path)
dataset = self.compute_segments(dataset, n_chunks_per_song)
@@ -445,7 +536,8 @@ class DatasetBuilder(object):
buffer_size=200000,
seed=self._random_seed,
# useless since it is cached :
reshuffle_each_iteration=True)
reshuffle_each_iteration=True,
)
# Expand audio path.
dataset = dataset.map(self.expand_path)
# Load waveform, compute spectrogram, and filtering error,
@@ -453,11 +545,11 @@ class DatasetBuilder(object):
N = num_parallel_calls
for instrument in self.instruments:
dataset = (
dataset
.map(instrument.load_waveform, num_parallel_calls=N)
dataset.map(instrument.load_waveform, num_parallel_calls=N)
.filter(self.filter_error)
.map(instrument.compute_spectrogram, num_parallel_calls=N)
.map(instrument.filter_frequencies))
.map(instrument.filter_frequencies)
)
dataset = dataset.map(self.filter_waveform)
# Convert to uint before caching in order to save space.
if convert_to_uint:
@@ -488,26 +580,25 @@ class DatasetBuilder(object):
# after croping but before converting back to float.
if shuffle:
dataset = dataset.shuffle(
buffer_size=256, seed=self._random_seed,
reshuffle_each_iteration=True)
buffer_size=256, seed=self._random_seed, reshuffle_each_iteration=True
)
# Convert back to float32
if convert_to_uint:
for instrument in self.instruments:
dataset = dataset.map(
instrument.convert_to_float32, num_parallel_calls=N)
instrument.convert_to_float32, num_parallel_calls=N
)
M = 8 # Parallel call post caching.
# Must be applied with the same factor on mix and vocals.
if random_data_augmentation:
dataset = (
dataset
.map(self.random_time_stretch, num_parallel_calls=M)
.map(self.random_pitch_shift, num_parallel_calls=M))
dataset = dataset.map(self.random_time_stretch, num_parallel_calls=M).map(
self.random_pitch_shift, num_parallel_calls=M
)
# Filter by shape (remove badly shaped tensors).
for instrument in self.instruments:
dataset = (
dataset
.filter(instrument.filter_shape)
.map(instrument.reshape_spectrogram))
dataset = dataset.filter(instrument.filter_shape).map(
instrument.reshape_spectrogram
)
# Select features and annotation.
dataset = dataset.map(self.map_features)
# Make batch (done after selection to avoid

View File

@@ -8,15 +8,16 @@ import importlib
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
from tensorflow.signal import stft, inverse_stft, hann_window
# pylint: enable=import-error
from tensorflow.signal import hann_window, inverse_stft, stft
from ..utils.tensor import pad_and_partition, pad_and_reshape
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# pylint: enable=import-error
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
placeholder = tf.compat.v1.placeholder
@@ -24,29 +25,28 @@ placeholder = tf.compat.v1.placeholder
def get_model_function(model_type):
"""
Get tensorflow function of the model to be applied to the input tensor.
For instance "unet.softmax_unet" will return the softmax_unet function
in the "unet.py" submodule of the current module (spleeter.model).
Get tensorflow function of the model to be applied to the input tensor.
For instance "unet.softmax_unet" will return the softmax_unet function
in the "unet.py" submodule of the current module (spleeter.model).
Params:
- model_type: str
the relative module path to the model function.
Params:
- model_type: str
the relative module path to the model function.
Returns:
A tensorflow function to be applied to the input tensor to get the
multitrack output.
Returns:
A tensorflow function to be applied to the input tensor to get the
multitrack output.
"""
relative_path_to_module = '.'.join(model_type.split('.')[:-1])
model_name = model_type.split('.')[-1]
main_module = '.'.join((__name__, 'functions'))
path_to_module = f'{main_module}.{relative_path_to_module}'
relative_path_to_module = ".".join(model_type.split(".")[:-1])
model_name = model_type.split(".")[-1]
main_module = ".".join((__name__, "functions"))
path_to_module = f"{main_module}.{relative_path_to_module}"
module = importlib.import_module(path_to_module)
model_function = getattr(module, model_name)
return model_function
class InputProvider(object):
def __init__(self, params):
self.params = params
@@ -62,16 +62,16 @@ class InputProvider(object):
class WaveformInputProvider(InputProvider):
@property
def input_names(self):
return ["audio_id", "waveform"]
def get_input_dict_placeholders(self):
shape = (None, self.params['n_channels'])
shape = (None, self.params["n_channels"])
features = {
'waveform': placeholder(tf.float32, shape=shape, name="waveform"),
'audio_id': placeholder(tf.string, name="audio_id")}
"waveform": placeholder(tf.float32, shape=shape, name="waveform"),
"audio_id": placeholder(tf.string, name="audio_id"),
}
return features
def get_feed_dict(self, features, waveform, audio_id):
@@ -79,7 +79,6 @@ class WaveformInputProvider(InputProvider):
class SpectralInputProvider(InputProvider):
def __init__(self, params):
super().__init__(params)
self.stft_input_name = "{}_stft".format(self.params["mix_name"])
@@ -90,11 +89,17 @@ class SpectralInputProvider(InputProvider):
def get_input_dict_placeholders(self):
features = {
self.stft_input_name: placeholder(tf.complex64,
shape=(None, self.params["frame_length"]//2+1,
self.params['n_channels']),
name=self.stft_input_name),
'audio_id': placeholder(tf.string, name="audio_id")}
self.stft_input_name: placeholder(
tf.complex64,
shape=(
None,
self.params["frame_length"] // 2 + 1,
self.params["n_channels"],
),
name=self.stft_input_name,
),
"audio_id": placeholder(tf.string, name="audio_id"),
}
return features
def get_feed_dict(self, features, stft, audio_id):
@@ -102,11 +107,13 @@ class SpectralInputProvider(InputProvider):
class InputProviderFactory(object):
@staticmethod
def get(params):
stft_backend = params["stft_backend"]
assert stft_backend in ("tensorflow", "librosa"), "Unexpected backend {}".format(stft_backend)
assert stft_backend in (
"tensorflow",
"librosa",
), "Unexpected backend {}".format(stft_backend)
if stft_backend == "tensorflow":
return WaveformInputProvider(params)
else:
@@ -114,7 +121,7 @@ class InputProviderFactory(object):
class EstimatorSpecBuilder(object):
""" A builder class that allows to builds a multitrack unet model
"""A builder class that allows to builds a multitrack unet model
estimator. The built model estimator has a different behaviour when
used in a train/eval mode and in predict mode.
@@ -138,22 +145,22 @@ class EstimatorSpecBuilder(object):
"""
# Supported model functions.
DEFAULT_MODEL = 'unet.unet'
DEFAULT_MODEL = "unet.unet"
# Supported loss functions.
L1_MASK = 'L1_mask'
WEIGHTED_L1_MASK = 'weighted_L1_mask'
L1_MASK = "L1_mask"
WEIGHTED_L1_MASK = "weighted_L1_mask"
# Supported optimizers.
ADADELTA = 'Adadelta'
SGD = 'SGD'
ADADELTA = "Adadelta"
SGD = "SGD"
# Math constants.
WINDOW_COMPENSATION_FACTOR = 2./3.
WINDOW_COMPENSATION_FACTOR = 2.0 / 3.0
EPSILON = 1e-10
def __init__(self, features, params):
""" Default constructor. Depending on built model
"""Default constructor. Depending on built model
usage, the provided features should be different:
* In train/eval mode: features is a dictionary with a
@@ -170,20 +177,20 @@ class EstimatorSpecBuilder(object):
self._features = features
self._params = params
# Get instrument name.
self._mix_name = params['mix_name']
self._instruments = params['instrument_list']
self._mix_name = params["mix_name"]
self._instruments = params["instrument_list"]
# Get STFT/signals parameters
self._n_channels = params['n_channels']
self._T = params['T']
self._F = params['F']
self._frame_length = params['frame_length']
self._frame_step = params['frame_step']
self._n_channels = params["n_channels"]
self._T = params["T"]
self._F = params["F"]
self._frame_length = params["frame_length"]
self._frame_step = params["frame_step"]
def include_stft_computations(self):
return self._params["stft_backend"] == "tensorflow"
def _build_model_outputs(self):
""" Created a batch_sizexTxFxn_channels input tensor containing
"""Created a batch_sizexTxFxn_channels input tensor containing
mix magnitude spectrogram, then an output dict from it according
to the selected model in internal parameters.
@@ -192,22 +199,21 @@ class EstimatorSpecBuilder(object):
"""
input_tensor = self.spectrogram_feature
model = self._params.get('model', None)
model = self._params.get("model", None)
if model is not None:
model_type = model.get('type', self.DEFAULT_MODEL)
model_type = model.get("type", self.DEFAULT_MODEL)
else:
model_type = self.DEFAULT_MODEL
try:
apply_model = get_model_function(model_type)
except ModuleNotFoundError:
raise ValueError(f'No model function {model_type} found')
raise ValueError(f"No model function {model_type} found")
self._model_outputs = apply_model(
input_tensor,
self._instruments,
self._params['model']['params'])
input_tensor, self._instruments, self._params["model"]["params"]
)
def _build_loss(self, labels):
""" Construct tensorflow loss and metrics
"""Construct tensorflow loss and metrics
:param output_dict: dictionary of network outputs (key: instrument
name, value: estimated spectrogram of the instrument)
@@ -216,7 +222,7 @@ class EstimatorSpecBuilder(object):
:returns: tensorflow (loss, metrics) tuple.
"""
output_dict = self.model_outputs
loss_type = self._params.get('loss_type', self.L1_MASK)
loss_type = self._params.get("loss_type", self.L1_MASK)
if loss_type == self.L1_MASK:
losses = {
name: tf.reduce_mean(tf.abs(output - labels[name]))
@@ -225,11 +231,9 @@ class EstimatorSpecBuilder(object):
elif loss_type == self.WEIGHTED_L1_MASK:
losses = {
name: tf.reduce_mean(
tf.reduce_mean(
labels[name],
axis=[1, 2, 3],
keep_dims=True) *
tf.abs(output - labels[name]))
tf.reduce_mean(labels[name], axis=[1, 2, 3], keep_dims=True)
* tf.abs(output - labels[name])
)
for name, output in output_dict.items()
}
else:
@@ -237,20 +241,20 @@ class EstimatorSpecBuilder(object):
loss = tf.reduce_sum(list(losses.values()))
# Add metrics for monitoring each instrument.
metrics = {k: tf.compat.v1.metrics.mean(v) for k, v in losses.items()}
metrics['absolute_difference'] = tf.compat.v1.metrics.mean(loss)
metrics["absolute_difference"] = tf.compat.v1.metrics.mean(loss)
return loss, metrics
def _build_optimizer(self):
""" Builds an optimizer instance from internal parameter values.
"""Builds an optimizer instance from internal parameter values.
Default to AdamOptimizer if not specified.
:returns: Optimizer instance from internal configuration.
"""
name = self._params.get('optimizer')
name = self._params.get("optimizer")
if name == self.ADADELTA:
return tf.compat.v1.train.AdadeltaOptimizer()
rate = self._params['learning_rate']
rate = self._params["learning_rate"]
if name == self.SGD:
return tf.compat.v1.train.GradientDescentOptimizer(rate)
return tf.compat.v1.train.AdamOptimizer(rate)
@@ -261,15 +265,15 @@ class EstimatorSpecBuilder(object):
@property
def stft_name(self):
return f'{self._mix_name}_stft'
return f"{self._mix_name}_stft"
@property
def spectrogram_name(self):
return f'{self._mix_name}_spectrogram'
return f"{self._mix_name}_spectrogram"
def _build_stft_feature(self):
""" Compute STFT of waveform and slice the STFT in segment
with the right length to feed the network.
"""Compute STFT of waveform and slice the STFT in segment
with the right length to feed the network.
"""
stft_name = self.stft_name
@@ -277,25 +281,30 @@ class EstimatorSpecBuilder(object):
if stft_name not in self._features:
# pad input with a frame of zeros
waveform = tf.concat([
tf.zeros((self._frame_length, self._n_channels)),
self._features['waveform']
],
0
)
waveform = tf.concat(
[
tf.zeros((self._frame_length, self._n_channels)),
self._features["waveform"],
],
0,
)
stft_feature = tf.transpose(
stft(
tf.transpose(waveform),
self._frame_length,
self._frame_step,
window_fn=lambda frame_length, dtype: (
hann_window(frame_length, periodic=True, dtype=dtype)),
pad_end=True),
perm=[1, 2, 0])
self._features[f'{self._mix_name}_stft'] = stft_feature
hann_window(frame_length, periodic=True, dtype=dtype)
),
pad_end=True,
),
perm=[1, 2, 0],
)
self._features[f"{self._mix_name}_stft"] = stft_feature
if spec_name not in self._features:
self._features[spec_name] = tf.abs(
pad_and_partition(self._features[stft_name], self._T))[:, :, :self._F, :]
pad_and_partition(self._features[stft_name], self._T)
)[:, :, : self._F, :]
@property
def model_outputs(self):
@@ -334,25 +343,29 @@ class EstimatorSpecBuilder(object):
return self._masked_stfts
def _inverse_stft(self, stft_t, time_crop=None):
""" Inverse and reshape the given STFT
"""Inverse and reshape the given STFT
:param stft_t: input STFT
:returns: inverse STFT (waveform)
"""
inversed = inverse_stft(
tf.transpose(stft_t, perm=[2, 0, 1]),
self._frame_length,
self._frame_step,
window_fn=lambda frame_length, dtype: (
hann_window(frame_length, periodic=True, dtype=dtype))
) * self.WINDOW_COMPENSATION_FACTOR
inversed = (
inverse_stft(
tf.transpose(stft_t, perm=[2, 0, 1]),
self._frame_length,
self._frame_step,
window_fn=lambda frame_length, dtype: (
hann_window(frame_length, periodic=True, dtype=dtype)
),
)
* self.WINDOW_COMPENSATION_FACTOR
)
reshaped = tf.transpose(inversed)
if time_crop is None:
time_crop = tf.shape(self._features['waveform'])[0]
return reshaped[self._frame_length:self._frame_length+time_crop, :]
time_crop = tf.shape(self._features["waveform"])[0]
return reshaped[self._frame_length : self._frame_length + time_crop, :]
def _build_mwf_output_waveform(self):
""" Perform separation with multichannel Wiener Filtering using Norbert.
"""Perform separation with multichannel Wiener Filtering using Norbert.
Note: multichannel Wiener Filtering is not coded in Tensorflow and thus
may be quite slow.
@@ -360,36 +373,42 @@ class EstimatorSpecBuilder(object):
value: estimated waveform of the instrument)
"""
import norbert # pylint: disable=import-error
output_dict = self.model_outputs
x = self.stft_feature
v = tf.stack(
[
pad_and_reshape(
output_dict[f'{instrument}_spectrogram'],
output_dict[f"{instrument}_spectrogram"],
self._frame_length,
self._F)[:tf.shape(x)[0], ...]
self._F,
)[: tf.shape(x)[0], ...]
for instrument in self._instruments
],
axis=3)
axis=3,
)
input_args = [v, x]
stft_function = tf.py_function(
lambda v, x: norbert.wiener(v.numpy(), x.numpy()),
input_args,
tf.complex64),
stft_function = (
tf.py_function(
lambda v, x: norbert.wiener(v.numpy(), x.numpy()),
input_args,
tf.complex64,
),
)
return {
instrument: self._inverse_stft(stft_function[0][:, :, :, k])
for k, instrument in enumerate(self._instruments)
}
def _extend_mask(self, mask):
""" Extend mask, from reduced number of frequency bin to the number of
"""Extend mask, from reduced number of frequency bin to the number of
frequency bin in the STFT.
:param mask: restricted mask
:returns: extended mask
:raise ValueError: If invalid mask_extension parameter is set.
"""
extension = self._params['mask_extension']
extension = self._params["mask_extension"]
# Extend with average
# (dispatch according to energy in the processed band)
if extension == "average":
@@ -398,13 +417,9 @@ class EstimatorSpecBuilder(object):
# (avoid extension artifacts but not conservative separation)
elif extension == "zeros":
mask_shape = tf.shape(mask)
extension_row = tf.zeros((
mask_shape[0],
mask_shape[1],
1,
mask_shape[-1]))
extension_row = tf.zeros((mask_shape[0], mask_shape[1], 1, mask_shape[-1]))
else:
raise ValueError(f'Invalid mask_extension parameter {extension}')
raise ValueError(f"Invalid mask_extension parameter {extension}")
n_extra_row = self._frame_length // 2 + 1 - self._F
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
return tf.concat([mask, extension], axis=2)
@@ -416,29 +431,31 @@ class EstimatorSpecBuilder(object):
"""
output_dict = self.model_outputs
stft_feature = self.stft_feature
separation_exponent = self._params['separation_exponent']
output_sum = tf.reduce_sum(
[e ** separation_exponent for e in output_dict.values()],
axis=0
) + self.EPSILON
separation_exponent = self._params["separation_exponent"]
output_sum = (
tf.reduce_sum(
[e ** separation_exponent for e in output_dict.values()], axis=0
)
+ self.EPSILON
)
out = {}
for instrument in self._instruments:
output = output_dict[f'{instrument}_spectrogram']
output = output_dict[f"{instrument}_spectrogram"]
# Compute mask with the model.
instrument_mask = (output ** separation_exponent
+ (self.EPSILON / len(output_dict))) / output_sum
instrument_mask = (
output ** separation_exponent + (self.EPSILON / len(output_dict))
) / output_sum
# Extend mask;
instrument_mask = self._extend_mask(instrument_mask)
# Stack back mask.
old_shape = tf.shape(instrument_mask)
new_shape = tf.concat(
[[old_shape[0] * old_shape[1]], old_shape[2:]],
axis=0)
[[old_shape[0] * old_shape[1]], old_shape[2:]], axis=0
)
instrument_mask = tf.reshape(instrument_mask, new_shape)
# Remove padded part (for mask having the same size as STFT);
instrument_mask = instrument_mask[
:tf.shape(stft_feature)[0], ...]
instrument_mask = instrument_mask[: tf.shape(stft_feature)[0], ...]
out[instrument] = instrument_mask
self._masks = out
@@ -450,7 +467,7 @@ class EstimatorSpecBuilder(object):
self._masked_stfts = out
def _build_manual_output_waveform(self, masked_stft):
""" Perform ratio mask separation
"""Perform ratio mask separation
:param output_dict: dictionary of estimated spectrogram (key: instrument
name, value: estimated spectrogram of the instrument)
@@ -464,14 +481,14 @@ class EstimatorSpecBuilder(object):
return output_waveform
def _build_output_waveform(self, masked_stft):
""" Build output waveform from given output dict in order to be used in
"""Build output waveform from given output dict in order to be used in
prediction context. Regarding of the configuration building method will
be using MWF.
:returns: Built output waveform.
"""
if self._params.get('MWF', False):
if self._params.get("MWF", False):
output_waveform = self._build_mwf_output_waveform()
else:
output_waveform = self._build_manual_output_waveform(masked_stft)
@@ -483,11 +500,11 @@ class EstimatorSpecBuilder(object):
else:
self._outputs = self.masked_stfts
if 'audio_id' in self._features:
self._outputs['audio_id'] = self._features['audio_id']
if "audio_id" in self._features:
self._outputs["audio_id"] = self._features["audio_id"]
def build_predict_model(self):
""" Builder interface for creating model instance that aims to perform
"""Builder interface for creating model instance that aims to perform
prediction / inference over given track. The output of such estimator
will be a dictionary with a "<instrument>" key per separated instrument
, associated to the estimated separated waveform of the instrument.
@@ -496,11 +513,11 @@ class EstimatorSpecBuilder(object):
"""
return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.PREDICT,
predictions=self.outputs)
tf.estimator.ModeKeys.PREDICT, predictions=self.outputs
)
def build_evaluation_model(self, labels):
""" Builder interface for creating model instance that aims to perform
"""Builder interface for creating model instance that aims to perform
model evaluation. The output of such estimator will be a dictionary
with a key "<instrument>_spectrogram" per separated instrument,
associated to the estimated separated instrument magnitude spectrogram.
@@ -510,12 +527,11 @@ class EstimatorSpecBuilder(object):
"""
loss, metrics = self._build_loss(labels)
return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.EVAL,
loss=loss,
eval_metric_ops=metrics)
tf.estimator.ModeKeys.EVAL, loss=loss, eval_metric_ops=metrics
)
def build_train_model(self, labels):
""" Builder interface for creating model instance that aims to perform
"""Builder interface for creating model instance that aims to perform
model training. The output of such estimator will be a dictionary
with a key "<instrument>_spectrogram" per separated instrument,
associated to the estimated separated instrument magnitude spectrogram.
@@ -526,8 +542,8 @@ class EstimatorSpecBuilder(object):
loss, metrics = self._build_loss(labels)
optimizer = self._build_optimizer()
train_operation = optimizer.minimize(
loss=loss,
global_step=tf.compat.v1.train.get_global_step())
loss=loss, global_step=tf.compat.v1.train.get_global_step()
)
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN,
loss=loss,
@@ -540,9 +556,9 @@ def model_fn(features, labels, mode, params, config):
"""
:param features:
:param labels:
:param labels:
:param mode: Estimator mode.
:param params:
:param params:
:param config: TF configuration (not used).
:returns: Built EstimatorSpec.
:raise ValueError: If estimator mode is not supported.
@@ -554,4 +570,4 @@ def model_fn(features, labels, mode, params, config):
return builder.build_evaluation_model(labels)
elif mode == tf.estimator.ModeKeys.TRAIN:
return builder.build_train_model(labels)
raise ValueError(f'Unknown mode {mode}')
raise ValueError(f"Unknown mode {mode}")

View File

@@ -8,39 +8,40 @@ from typing import Callable, Dict, Iterable, Optional
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def apply(
function: Callable,
input_tensor: tf.Tensor,
instruments: Iterable[str],
params: Optional[Dict] = None) -> Dict:
function: Callable,
input_tensor: tf.Tensor,
instruments: Iterable[str],
params: Optional[Dict] = None,
) -> Dict:
"""
Apply given function to the input tensor.
Apply given function to the input tensor.
Parameters:
function:
Function to be applied to tensor.
input_tensor (tensorflow.Tensor):
Tensor to apply blstm to.
instruments (Iterable[str]):
Iterable that provides a collection of instruments.
params:
(Optional) dict of BLSTM parameters.
Parameters:
function:
Function to be applied to tensor.
input_tensor (tensorflow.Tensor):
Tensor to apply blstm to.
instruments (Iterable[str]):
Iterable that provides a collection of instruments.
params:
(Optional) dict of BLSTM parameters.
Returns:
Created output tensor dict.
Returns:
Created output tensor dict.
"""
output_dict: Dict = {}
for instrument in instruments:
out_name = f'{instrument}_spectrogram'
out_name = f"{instrument}_spectrogram"
output_dict[out_name] = function(
input_tensor,
output_name=out_name,
params=params or {})
input_tensor, output_name=out_name, params=params or {}
)
return output_dict

View File

@@ -22,12 +22,9 @@
from typing import Dict, Optional
from . import apply
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
from tensorflow.compat.v1.keras.initializers import he_uniform
from tensorflow.compat.v1.keras.layers import CuDNNLSTM
from tensorflow.keras.layers import (
@@ -35,45 +32,48 @@ from tensorflow.keras.layers import (
Dense,
Flatten,
Reshape,
TimeDistributed)
TimeDistributed,
)
from . import apply
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def apply_blstm(
input_tensor: tf.Tensor,
output_name: str = 'output',
params: Optional[Dict] = None) -> tf.Tensor:
input_tensor: tf.Tensor, output_name: str = "output", params: Optional[Dict] = None
) -> tf.Tensor:
"""
Apply BLSTM to the given input_tensor.
Apply BLSTM to the given input_tensor.
Parameters:
input_tensor (tensorflow.Tensor):
Input of the model.
output_name (str):
(Optional) name of the output, default to 'output'.
params (Optional[Dict]):
(Optional) dict of BLSTM parameters.
Parameters:
input_tensor (tensorflow.Tensor):
Input of the model.
output_name (str):
(Optional) name of the output, default to 'output'.
params (Optional[Dict]):
(Optional) dict of BLSTM parameters.
Returns:
tensorflow.Tensor:
Output tensor.
Returns:
tensorflow.Tensor:
Output tensor.
"""
if params is None:
params = {}
units: int = params.get('lstm_units', 250)
units: int = params.get("lstm_units", 250)
kernel_initializer = he_uniform(seed=50)
flatten_input = TimeDistributed(Flatten())((input_tensor))
def create_bidirectional():
return Bidirectional(
CuDNNLSTM(
units,
kernel_initializer=kernel_initializer,
return_sequences=True))
units, kernel_initializer=kernel_initializer, return_sequences=True
)
)
l1 = create_bidirectional()((flatten_input))
l2 = create_bidirectional()((l1))
@@ -81,17 +81,18 @@ def apply_blstm(
dense = TimeDistributed(
Dense(
int(flatten_input.shape[2]),
activation='relu',
kernel_initializer=kernel_initializer))((l3))
activation="relu",
kernel_initializer=kernel_initializer,
)
)((l3))
output: tf.Tensor = TimeDistributed(
Reshape(input_tensor.shape[2:]),
name=output_name)(dense)
Reshape(input_tensor.shape[2:]), name=output_name
)(dense)
return output
def blstm(
input_tensor: tf.Tensor,
output_name: str = 'output',
params: Optional[Dict] = None) -> tf.Tensor:
input_tensor: tf.Tensor, output_name: str = "output", params: Optional[Dict] = None
) -> tf.Tensor:
""" Model function applier. """
return apply(apply_blstm, input_tensor, output_name, params)

View File

@@ -16,95 +16,95 @@
from functools import partial
from typing import Any, Dict, Iterable, Optional
from . import apply
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
from tensorflow.compat.v1 import logging
from tensorflow.compat.v1.keras.initializers import he_uniform
from tensorflow.keras.layers import (
ELU,
BatchNormalization,
Concatenate,
Conv2D,
Conv2DTranspose,
Dropout,
ELU,
LeakyReLU,
Multiply,
ReLU,
Softmax)
from tensorflow.compat.v1 import logging
from tensorflow.compat.v1.keras.initializers import he_uniform
Softmax,
)
from . import apply
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def _get_conv_activation_layer(params: Dict) -> Any:
"""
> To be documented.
> To be documented.
Parameters:
params (Dict):
Parameters:
params (Dict):
Returns:
Any:
Required Activation function.
Returns:
Any:
Required Activation function.
"""
conv_activation: str = params.get('conv_activation')
if conv_activation == 'ReLU':
conv_activation: str = params.get("conv_activation")
if conv_activation == "ReLU":
return ReLU()
elif conv_activation == 'ELU':
elif conv_activation == "ELU":
return ELU()
return LeakyReLU(0.2)
def _get_deconv_activation_layer(params: Dict) -> Any:
"""
> To be documented.
> To be documented.
Parameters:
params (Dict):
Parameters:
params (Dict):
Returns:
Any:
Required Activation function.
Returns:
Any:
Required Activation function.
"""
deconv_activation: str = params.get('deconv_activation')
if deconv_activation == 'LeakyReLU':
deconv_activation: str = params.get("deconv_activation")
if deconv_activation == "LeakyReLU":
return LeakyReLU(0.2)
elif deconv_activation == 'ELU':
elif deconv_activation == "ELU":
return ELU()
return ReLU()
def apply_unet(
input_tensor: tf.Tensor,
output_name: str = 'output',
params: Optional[Dict] = None,
output_mask_logit: bool = False) -> Any:
input_tensor: tf.Tensor,
output_name: str = "output",
params: Optional[Dict] = None,
output_mask_logit: bool = False,
) -> Any:
"""
Apply a convolutionnal U-net to model a single instrument (one U-net
is used for each instrument).
Apply a convolutionnal U-net to model a single instrument (one U-net
is used for each instrument).
Parameters:
input_tensor (tensorflow.Tensor):
output_name (str):
params (Optional[Dict]):
output_mask_logit (bool):
Parameters:
input_tensor (tensorflow.Tensor):
output_name (str):
params (Optional[Dict]):
output_mask_logit (bool):
"""
logging.info(f'Apply unet for {output_name}')
conv_n_filters = params.get('conv_n_filters', [16, 32, 64, 128, 256, 512])
logging.info(f"Apply unet for {output_name}")
conv_n_filters = params.get("conv_n_filters", [16, 32, 64, 128, 256, 512])
conv_activation_layer = _get_conv_activation_layer(params)
deconv_activation_layer = _get_deconv_activation_layer(params)
kernel_initializer = he_uniform(seed=50)
conv2d_factory = partial(
Conv2D,
strides=(2, 2),
padding='same',
kernel_initializer=kernel_initializer)
Conv2D, strides=(2, 2), padding="same", kernel_initializer=kernel_initializer
)
# First layer.
conv1 = conv2d_factory(conv_n_filters[0], (5, 5))(input_tensor)
batch1 = BatchNormalization(axis=-1)(conv1)
@@ -134,8 +134,9 @@ def apply_unet(
conv2d_transpose_factory = partial(
Conv2DTranspose,
strides=(2, 2),
padding='same',
kernel_initializer=kernel_initializer)
padding="same",
kernel_initializer=kernel_initializer,
)
#
up1 = conv2d_transpose_factory(conv_n_filters[4], (5, 5))((conv6))
up1 = deconv_activation_layer(up1)
@@ -174,60 +175,60 @@ def apply_unet(
2,
(4, 4),
dilation_rate=(2, 2),
activation='sigmoid',
padding='same',
kernel_initializer=kernel_initializer)((batch12))
activation="sigmoid",
padding="same",
kernel_initializer=kernel_initializer,
)((batch12))
output = Multiply(name=output_name)([up7, input_tensor])
return output
return Conv2D(
2,
(4, 4),
dilation_rate=(2, 2),
padding='same',
kernel_initializer=kernel_initializer)((batch12))
padding="same",
kernel_initializer=kernel_initializer,
)((batch12))
def unet(
input_tensor: tf.Tensor,
instruments: Iterable[str],
params: Optional[Dict] = None) -> Dict:
input_tensor: tf.Tensor, instruments: Iterable[str], params: Optional[Dict] = None
) -> Dict:
""" Model function applier. """
return apply(apply_unet, input_tensor, instruments, params)
def softmax_unet(
input_tensor: tf.Tensor,
instruments: Iterable[str],
params: Optional[Dict] = None) -> Dict:
input_tensor: tf.Tensor, instruments: Iterable[str], params: Optional[Dict] = None
) -> Dict:
"""
Apply softmax to multitrack unet in order to have mask suming to one.
Apply softmax to multitrack unet in order to have mask suming to one.
Parameters:
input_tensor (tensorflow.Tensor):
Tensor to apply blstm to.
instruments (Iterable[str]):
Iterable that provides a collection of instruments.
params (Optional[Dict]):
(Optional) dict of BLSTM parameters.
Parameters:
input_tensor (tensorflow.Tensor):
Tensor to apply blstm to.
instruments (Iterable[str]):
Iterable that provides a collection of instruments.
params (Optional[Dict]):
(Optional) dict of BLSTM parameters.
Returns:
Dict:
Created output tensor dict.
Returns:
Dict:
Created output tensor dict.
"""
logit_mask_list = []
for instrument in instruments:
out_name = f'{instrument}_spectrogram'
out_name = f"{instrument}_spectrogram"
logit_mask_list.append(
apply_unet(
input_tensor,
output_name=out_name,
params=params,
output_mask_logit=True))
output_mask_logit=True,
)
)
masks = Softmax(axis=4)(tf.stack(logit_mask_list, axis=4))
output_dict = {}
for i, instrument in enumerate(instruments):
out_name = f'{instrument}_spectrogram'
output_dict[out_name] = Multiply(name=out_name)([
masks[..., i],
input_tensor])
out_name = f"{instrument}_spectrogram"
output_dict[out_name] = Multiply(name=out_name)([masks[..., i], input_tensor])
return output_dict

View File

@@ -17,57 +17,57 @@ from abc import ABC, abstractmethod
from os import environ, makedirs
from os.path import exists, isabs, join, sep
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class ModelProvider(ABC):
"""
A ModelProvider manages model files on disk and
file download is not available.
A ModelProvider manages model files on disk and
file download is not available.
"""
DEFAULT_MODEL_PATH: str = environ.get('MODEL_PATH', 'pretrained_models')
MODEL_PROBE_PATH: str = '.probe'
DEFAULT_MODEL_PATH: str = environ.get("MODEL_PATH", "pretrained_models")
MODEL_PROBE_PATH: str = ".probe"
@abstractmethod
def download(_, name: str, path: str) -> None:
"""
Download model denoted by the given name to disk.
Download model denoted by the given name to disk.
Parameters:
name (str):
Name of the model to download.
path (str):
Path of the directory to save model into.
Parameters:
name (str):
Name of the model to download.
path (str):
Path of the directory to save model into.
"""
pass
@staticmethod
def writeProbe(directory: str) -> None:
"""
Write a model probe file into the given directory.
Write a model probe file into the given directory.
Parameters:
directory (str):
Directory to write probe into.
Parameters:
directory (str):
Directory to write probe into.
"""
probe: str = join(directory, ModelProvider.MODEL_PROBE_PATH)
with open(probe, 'w') as stream:
stream.write('OK')
with open(probe, "w") as stream:
stream.write("OK")
def get(self, model_directory: str) -> str:
"""
Ensures required model is available at given location.
Ensures required model is available at given location.
Parameters:
model_directory (str):
Expected model_directory to be available.
Parameters:
model_directory (str):
Expected model_directory to be available.
Raises:
IOError:
If model can not be retrieved.
Raises:
IOError:
If model can not be retrieved.
"""
# Expend model directory if needed.
if not isabs(model_directory):
@@ -77,20 +77,19 @@ class ModelProvider(ABC):
if not exists(model_probe):
if not exists(model_directory):
makedirs(model_directory)
self.download(
model_directory.split(sep)[-1],
model_directory)
self.download(model_directory.split(sep)[-1], model_directory)
self.writeProbe(model_directory)
return model_directory
@classmethod
def default(_: type) -> 'ModelProvider':
def default(_: type) -> "ModelProvider":
"""
Builds and returns a default model provider.
Builds and returns a default model provider.
Returns:
ModelProvider:
A default model provider instance to use.
Returns:
ModelProvider:
A default model provider instance to use.
"""
from .github import GithubModelProvider
return GithubModelProvider.from_environ()

View File

@@ -17,35 +17,35 @@
"""
import hashlib
import tarfile
import os
import tarfile
from os import environ
from tempfile import NamedTemporaryFile
from typing import Dict
from . import ModelProvider
from ...utils.logging import logger
# pyright: reportMissingImports=false
# pylint: disable=import-error
import httpx
from ...utils.logging import logger
from . import ModelProvider
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def compute_file_checksum(path):
""" Computes given path file sha256.
"""Computes given path file sha256.
:param path: Path of the file to compute checksum for.
:returns: File checksum.
"""
sha256 = hashlib.sha256()
with open(path, 'rb') as stream:
for chunk in iter(lambda: stream.read(4096), b''):
with open(path, "rb") as stream:
for chunk in iter(lambda: stream.read(4096), b""):
sha256.update(chunk)
return sha256.hexdigest()
@@ -53,19 +53,15 @@ def compute_file_checksum(path):
class GithubModelProvider(ModelProvider):
""" A ModelProvider implementation backed on Github for remote storage. """
DEFAULT_HOST: str = 'https://github.com'
DEFAULT_REPOSITORY: str = 'deezer/spleeter'
DEFAULT_HOST: str = "https://github.com"
DEFAULT_REPOSITORY: str = "deezer/spleeter"
CHECKSUM_INDEX: str = 'checksum.json'
LATEST_RELEASE: str = 'v1.4.0'
RELEASE_PATH: str = 'releases/download'
CHECKSUM_INDEX: str = "checksum.json"
LATEST_RELEASE: str = "v1.4.0"
RELEASE_PATH: str = "releases/download"
def __init__(
self,
host: str,
repository: str,
release: str) -> None:
""" Default constructor.
def __init__(self, host: str, repository: str, release: str) -> None:
"""Default constructor.
Parameters:
host (str):
@@ -80,81 +76,81 @@ class GithubModelProvider(ModelProvider):
self._release: str = release
@classmethod
def from_environ(cls: type) -> 'GithubModelProvider':
def from_environ(cls: type) -> "GithubModelProvider":
"""
Factory method that creates provider from envvars.
Factory method that creates provider from envvars.
Returns:
GithubModelProvider:
Created instance.
Returns:
GithubModelProvider:
Created instance.
"""
return cls(
environ.get('GITHUB_HOST', cls.DEFAULT_HOST),
environ.get('GITHUB_REPOSITORY', cls.DEFAULT_REPOSITORY),
environ.get('GITHUB_RELEASE', cls.LATEST_RELEASE))
environ.get("GITHUB_HOST", cls.DEFAULT_HOST),
environ.get("GITHUB_REPOSITORY", cls.DEFAULT_REPOSITORY),
environ.get("GITHUB_RELEASE", cls.LATEST_RELEASE),
)
def checksum(self, name: str) -> str:
"""
Downloads and returns reference checksum for the given model name.
Downloads and returns reference checksum for the given model name.
Parameters:
name (str):
Name of the model to get checksum for.
Returns:
str:
Checksum of the required model.
Parameters:
name (str):
Name of the model to get checksum for.
Returns:
str:
Checksum of the required model.
Raises:
ValueError:
If the given model name is not indexed.
Raises:
ValueError:
If the given model name is not indexed.
"""
url: str = '/'.join((
self._host,
self._repository,
self.RELEASE_PATH,
self._release,
self.CHECKSUM_INDEX))
url: str = "/".join(
(
self._host,
self._repository,
self.RELEASE_PATH,
self._release,
self.CHECKSUM_INDEX,
)
)
response: httpx.Response = httpx.get(url)
response.raise_for_status()
index: Dict = response.json()
if name not in index:
raise ValueError(f'No checksum for model {name}')
raise ValueError(f"No checksum for model {name}")
return index[name]
def download(self, name: str, path: str) -> None:
"""
Download model denoted by the given name to disk.
Download model denoted by the given name to disk.
Parameters:
name (str):
Name of the model to download.
path (str):
Path of the directory to save model into.
Parameters:
name (str):
Name of the model to download.
path (str):
Path of the directory to save model into.
"""
url: str = '/'.join((
self._host,
self._repository,
self.RELEASE_PATH,
self._release,
name))
url = f'{url}.tar.gz'
logger.info(f'Downloading model archive {url}')
url: str = "/".join(
(self._host, self._repository, self.RELEASE_PATH, self._release, name)
)
url = f"{url}.tar.gz"
logger.info(f"Downloading model archive {url}")
with httpx.Client(http2=True) as client:
with client.stream('GET', url) as response:
with client.stream("GET", url) as response:
response.raise_for_status()
archive = NamedTemporaryFile(delete=False)
try:
with archive as stream:
for chunk in response.iter_raw():
stream.write(chunk)
logger.info('Validating archive checksum')
logger.info("Validating archive checksum")
checksum: str = compute_file_checksum(archive.name)
if checksum != self.checksum(name):
raise IOError(
'Downloaded file is corrupted, please retry')
logger.info(f'Extracting downloaded {name} archive')
raise IOError("Downloaded file is corrupted, please retry")
logger.info(f"Extracting downloaded {name} archive")
with tarfile.open(name=archive.name) as tar:
tar.extractall(path=path)
finally:
os.unlink(archive.name)
logger.info(f'{name} model file(s) extracted')
logger.info(f"{name} model file(s) extracted")

View File

@@ -3,126 +3,126 @@
""" This modules provides spleeter command as well as CLI parsing methods. """
from tempfile import gettempdir
from os.path import join
from .audio import Codec, STFTBackend
from tempfile import gettempdir
from typer import Argument, Option
from typer.models import ArgumentInfo, OptionInfo
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
from .audio import Codec, STFTBackend
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
AudioInputArgument: ArgumentInfo = Argument(
...,
help='List of input audio file path',
help="List of input audio file path",
exists=True,
file_okay=True,
dir_okay=False,
readable=True,
resolve_path=True)
resolve_path=True,
)
AudioInputOption: OptionInfo = Option(
None,
'--inputs',
'-i',
help='(DEPRECATED) placeholder for deprecated input option')
None, "--inputs", "-i", help="(DEPRECATED) placeholder for deprecated input option"
)
AudioAdapterOption: OptionInfo = Option(
'spleeter.audio.ffmpeg.FFMPEGProcessAudioAdapter',
'--adapter',
'-a',
help='Name of the audio adapter to use for audio I/O')
"spleeter.audio.ffmpeg.FFMPEGProcessAudioAdapter",
"--adapter",
"-a",
help="Name of the audio adapter to use for audio I/O",
)
AudioOutputOption: OptionInfo = Option(
join(gettempdir(), 'separated_audio'),
'--output_path',
'-o',
help='Path of the output directory to write audio files in')
join(gettempdir(), "separated_audio"),
"--output_path",
"-o",
help="Path of the output directory to write audio files in",
)
AudioOffsetOption: OptionInfo = Option(
0.,
'--offset',
'-s',
help='Set the starting offset to separate audio from')
0.0, "--offset", "-s", help="Set the starting offset to separate audio from"
)
AudioDurationOption: OptionInfo = Option(
600.,
'--duration',
'-d',
600.0,
"--duration",
"-d",
help=(
'Set a maximum duration for processing audio '
'(only separate offset + duration first seconds of '
'the input file)'))
"Set a maximum duration for processing audio "
"(only separate offset + duration first seconds of "
"the input file)"
),
)
AudioSTFTBackendOption: OptionInfo = Option(
STFTBackend.AUTO,
'--stft-backend',
'-B',
"--stft-backend",
"-B",
case_sensitive=False,
help=(
'Who should be in charge of computing the stfts. Librosa is faster '
"Who should be in charge of computing the stfts. Librosa is faster "
'than tensorflow on CPU and uses less memory. "auto" will use '
'tensorflow when GPU acceleration is available and librosa when not'))
"tensorflow when GPU acceleration is available and librosa when not"
),
)
AudioCodecOption: OptionInfo = Option(
Codec.WAV,
'--codec',
'-c',
help='Audio codec to be used for the separated output')
Codec.WAV, "--codec", "-c", help="Audio codec to be used for the separated output"
)
AudioBitrateOption: OptionInfo = Option(
'128k',
'--bitrate',
'-b',
help='Audio bitrate to be used for the separated output')
"128k", "--bitrate", "-b", help="Audio bitrate to be used for the separated output"
)
FilenameFormatOption: OptionInfo = Option(
'{filename}/{instrument}.{codec}',
'--filename_format',
'-f',
"{filename}/{instrument}.{codec}",
"--filename_format",
"-f",
help=(
'Template string that will be formatted to generated'
'output filename. Such template should be Python formattable'
'string, and could use {filename}, {instrument}, and {codec}'
'variables'))
"Template string that will be formatted to generated"
"output filename. Such template should be Python formattable"
"string, and could use {filename}, {instrument}, and {codec}"
"variables"
),
)
ModelParametersOption: OptionInfo = Option(
'spleeter:2stems',
'--params_filename',
'-p',
help='JSON filename that contains params')
"spleeter:2stems",
"--params_filename",
"-p",
help="JSON filename that contains params",
)
MWFOption: OptionInfo = Option(
False,
'--mwf',
help='Whether to use multichannel Wiener filtering for separation')
False, "--mwf", help="Whether to use multichannel Wiener filtering for separation"
)
MUSDBDirectoryOption: OptionInfo = Option(
...,
'--mus_dir',
"--mus_dir",
exists=True,
dir_okay=True,
file_okay=False,
readable=True,
resolve_path=True,
help='Path to musDB dataset directory')
help="Path to musDB dataset directory",
)
TrainingDataDirectoryOption: OptionInfo = Option(
...,
'--data',
'-d',
"--data",
"-d",
exists=True,
dir_okay=True,
file_okay=False,
readable=True,
resolve_path=True,
help='Path of the folder containing audio data for training')
help="Path of the folder containing audio data for training",
)
VerboseOption: OptionInfo = Option(
False,
'--verbose',
help='Enable verbose logs')
VerboseOption: OptionInfo = Option(False, "--verbose", help="Enable verbose logs")

View File

@@ -3,6 +3,6 @@
""" Packages that provides static resources file for the library. """
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"

View File

@@ -16,41 +16,40 @@
import atexit
import os
from multiprocessing import Pool
from os.path import basename, join, splitext, dirname
from spleeter.model.provider import ModelProvider
from os.path import basename, dirname, join, splitext
from typing import Dict, Generator, Optional
from . import SpleeterError
from .audio import Codec, STFTBackend
from .audio.adapter import AudioAdapter
from .audio.convertor import to_stereo
from .model import model_fn
from .model import EstimatorSpecBuilder, InputProviderFactory
from .model.provider import ModelProvider
from .types import AudioDescriptor
from .utils.configuration import load_configuration
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
from librosa.core import stft, istft
from librosa.core import istft, stft
from scipy.signal.windows import hann
from spleeter.model.provider import ModelProvider
from . import SpleeterError
from .audio import Codec, STFTBackend
from .audio.adapter import AudioAdapter
from .audio.convertor import to_stereo
from .model import EstimatorSpecBuilder, InputProviderFactory, model_fn
from .model.provider import ModelProvider
from .types import AudioDescriptor
from .utils.configuration import load_configuration
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class DataGenerator(object):
"""
Generator object that store a sample and generate it once while called.
Used to feed a tensorflow estimator without knowing the whole data at
build time.
Generator object that store a sample and generate it once while called.
Used to feed a tensorflow estimator without knowing the whole data at
build time.
"""
def __init__(self) -> None:
@@ -71,28 +70,26 @@ class DataGenerator(object):
def create_estimator(params, MWF):
"""
Initialize tensorflow estimator that will perform separation
Initialize tensorflow estimator that will perform separation
Params:
- params: a dictionary of parameters for building the model
Params:
- params: a dictionary of parameters for building the model
Returns:
a tensorflow estimator
Returns:
a tensorflow estimator
"""
# Load model.
provider: ModelProvider = ModelProvider.default()
params['model_dir'] = provider.get(params['model_dir'])
params['MWF'] = MWF
params["model_dir"] = provider.get(params["model_dir"])
params["MWF"] = MWF
# Setup config
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
config = tf.estimator.RunConfig(session_config=session_config)
# Setup estimator
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params['model_dir'],
params=params,
config=config)
model_fn=model_fn, model_dir=params["model_dir"], params=params, config=config
)
return estimator
@@ -100,22 +97,23 @@ class Separator(object):
""" A wrapper class for performing separation. """
def __init__(
self,
params_descriptor: str,
MWF: bool = False,
stft_backend: STFTBackend = STFTBackend.AUTO,
multiprocess: bool = True) -> None:
self,
params_descriptor: str,
MWF: bool = False,
stft_backend: STFTBackend = STFTBackend.AUTO,
multiprocess: bool = True,
) -> None:
"""
Default constructor.
Default constructor.
Parameters:
params_descriptor (str):
Descriptor for TF params to be used.
MWF (bool):
(Optional) `True` if MWF should be used, `False` otherwise.
Parameters:
params_descriptor (str):
Descriptor for TF params to be used.
MWF (bool):
(Optional) `True` if MWF should be used, `False` otherwise.
"""
self._params = load_configuration(params_descriptor)
self._sample_rate = self._params['sample_rate']
self._sample_rate = self._params["sample_rate"]
self._MWF = MWF
self._tf_graph = tf.Graph()
self._prediction_generator = None
@@ -129,7 +127,7 @@ class Separator(object):
else:
self._pool = None
self._tasks = []
self._params['stft_backend'] = STFTBackend.resolve(stft_backend)
self._params["stft_backend"] = STFTBackend.resolve(stft_backend)
self._data_generator = DataGenerator()
def __del__(self) -> None:
@@ -138,12 +136,12 @@ class Separator(object):
def _get_prediction_generator(self) -> Generator:
"""
Lazy loading access method for internal prediction generator
returned by the predict method of a tensorflow estimator.
Lazy loading access method for internal prediction generator
returned by the predict method of a tensorflow estimator.
Returns:
Generator:
Generator of prediction.
Returns:
Generator:
Generator of prediction.
"""
if self._prediction_generator is None:
estimator = create_estimator(self._params, self._MWF)
@@ -151,25 +149,22 @@ class Separator(object):
def get_dataset():
return tf.data.Dataset.from_generator(
self._data_generator,
output_types={
'waveform': tf.float32,
'audio_id': tf.string},
output_shapes={
'waveform': (None, 2),
'audio_id': ()})
output_types={"waveform": tf.float32, "audio_id": tf.string},
output_shapes={"waveform": (None, 2), "audio_id": ()},
)
self._prediction_generator = estimator.predict(
get_dataset,
yield_single_examples=False)
get_dataset, yield_single_examples=False
)
return self._prediction_generator
def join(self, timeout: int = 200) -> None:
"""
Wait for all pending tasks to be finished.
Wait for all pending tasks to be finished.
Parameters:
timeout (int):
(Optional) task waiting timeout.
Parameters:
timeout (int):
(Optional) task waiting timeout.
"""
while len(self._tasks) > 0:
task = self._tasks.pop()
@@ -177,53 +172,51 @@ class Separator(object):
task.wait(timeout=timeout)
def _stft(
self,
data: np.ndarray,
inverse: bool = False,
length: Optional[int] = None) -> np.ndarray:
self, data: np.ndarray, inverse: bool = False, length: Optional[int] = None
) -> np.ndarray:
"""
Single entrypoint for both stft and istft. This computes stft and
istft with librosa on stereo data. The two channels are processed
separately and are concatenated together in the result. The
expected input formats are: (n_samples, 2) for stft and (T, F, 2)
for istft.
Single entrypoint for both stft and istft. This computes stft and
istft with librosa on stereo data. The two channels are processed
separately and are concatenated together in the result. The
expected input formats are: (n_samples, 2) for stft and (T, F, 2)
for istft.
Parameters:
data (numpy.array):
Array with either the waveform or the complex spectrogram
depending on the parameter inverse
inverse (bool):
(Optional) Should a stft or an istft be computed.
length (Optional[int]):
Parameters:
data (numpy.array):
Array with either the waveform or the complex spectrogram
depending on the parameter inverse
inverse (bool):
(Optional) Should a stft or an istft be computed.
length (Optional[int]):
Returns:
numpy.ndarray:
Stereo data as numpy array for the transform. The channels
are stored in the last dimension.
Returns:
numpy.ndarray:
Stereo data as numpy array for the transform. The channels
are stored in the last dimension.
"""
assert not (inverse and length is None)
data = np.asfortranarray(data)
N = self._params['frame_length']
H = self._params['frame_step']
N = self._params["frame_length"]
H = self._params["frame_step"]
win = hann(N, sym=False)
fstft = istft if inverse else stft
win_len_arg = {
'win_length': None,
'length': None} if inverse else {'n_fft': N}
win_len_arg = {"win_length": None, "length": None} if inverse else {"n_fft": N}
n_channels = data.shape[-1]
out = []
for c in range(n_channels):
d = np.concatenate(
(np.zeros((N, )), data[:, c], np.zeros((N, )))
) if not inverse else data[:, :, c].T
d = (
np.concatenate((np.zeros((N,)), data[:, c], np.zeros((N,))))
if not inverse
else data[:, :, c].T
)
s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg)
if inverse:
s = s[N:N+length]
s = np.expand_dims(s.T, 2-inverse)
s = s[N : N + length]
s = np.expand_dims(s.T, 2 - inverse)
out.append(s)
if len(out) == 1:
return out[0]
return np.concatenate(out, axis=2-inverse)
return np.concatenate(out, axis=2 - inverse)
def _get_input_provider(self):
if self._input_provider is None:
@@ -238,32 +231,29 @@ class Separator(object):
def _get_builder(self):
if self._builder is None:
self._builder = EstimatorSpecBuilder(
self._get_features(),
self._params)
self._builder = EstimatorSpecBuilder(self._get_features(), self._params)
return self._builder
def _get_session(self):
if self._session is None:
saver = tf.compat.v1.train.Saver()
provider = ModelProvider.default()
model_directory: str = provider.get(self._params['model_dir'])
model_directory: str = provider.get(self._params["model_dir"])
latest_checkpoint = tf.train.latest_checkpoint(model_directory)
self._session = tf.compat.v1.Session()
saver.restore(self._session, latest_checkpoint)
return self._session
def _separate_librosa(
self,
waveform: np.ndarray,
audio_descriptor: AudioDescriptor) -> Dict:
self, waveform: np.ndarray, audio_descriptor: AudioDescriptor
) -> Dict:
"""
Performs separation with librosa backend for STFT.
Performs separation with librosa backend for STFT.
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
audio_descriptor (AudioDescriptor):
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
audio_descriptor (AudioDescriptor):
"""
with self._tf_graph.as_default():
out = {}
@@ -280,108 +270,106 @@ class Separator(object):
outputs = sess.run(
outputs,
feed_dict=self._get_input_provider().get_feed_dict(
features,
stft,
audio_descriptor))
features, stft, audio_descriptor
),
)
for inst in self._get_builder().instruments:
out[inst] = self._stft(
outputs[inst],
inverse=True,
length=waveform.shape[0])
outputs[inst], inverse=True, length=waveform.shape[0]
)
return out
def _separate_tensorflow(
self,
waveform: np.ndarray,
audio_descriptor: AudioDescriptor) -> Dict:
self, waveform: np.ndarray, audio_descriptor: AudioDescriptor
) -> Dict:
"""
Performs source separation over the given waveform with tensorflow
backend.
Performs source separation over the given waveform with tensorflow
backend.
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
audio_descriptor (AudioDescriptor):
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
audio_descriptor (AudioDescriptor):
Returns:
Separated waveforms.
Returns:
Separated waveforms.
"""
if not waveform.shape[-1] == 2:
waveform = to_stereo(waveform)
prediction_generator = self._get_prediction_generator()
# NOTE: update data in generator before performing separation.
self._data_generator.update_data({
'waveform': waveform,
'audio_id': np.array(audio_descriptor)})
self._data_generator.update_data(
{"waveform": waveform, "audio_id": np.array(audio_descriptor)}
)
# NOTE: perform separation.
prediction = next(prediction_generator)
prediction.pop('audio_id')
prediction.pop("audio_id")
return prediction
def separate(
self,
waveform: np.ndarray,
audio_descriptor: Optional[str] = None) -> None:
self, waveform: np.ndarray, audio_descriptor: Optional[str] = None
) -> None:
"""
Performs separation on a waveform.
Performs separation on a waveform.
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
audio_descriptor (str):
(Optional) string describing the waveform (e.g. filename).
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
audio_descriptor (str):
(Optional) string describing the waveform (e.g. filename).
"""
backend: str = self._params['stft_backend']
backend: str = self._params["stft_backend"]
if backend == STFTBackend.TENSORFLOW:
return self._separate_tensorflow(waveform, audio_descriptor)
elif backend == STFTBackend.LIBROSA:
return self._separate_librosa(waveform, audio_descriptor)
raise ValueError(f'Unsupported STFT backend {backend}')
raise ValueError(f"Unsupported STFT backend {backend}")
def separate_to_file(
self,
audio_descriptor: AudioDescriptor,
destination: str,
audio_adapter: Optional[AudioAdapter] = None,
offset: int = 0,
duration: float = 600.,
codec: Codec = Codec.WAV,
bitrate: str = '128k',
filename_format: str = '{filename}/{instrument}.{codec}',
synchronous: bool = True) -> None:
self,
audio_descriptor: AudioDescriptor,
destination: str,
audio_adapter: Optional[AudioAdapter] = None,
offset: int = 0,
duration: float = 600.0,
codec: Codec = Codec.WAV,
bitrate: str = "128k",
filename_format: str = "{filename}/{instrument}.{codec}",
synchronous: bool = True,
) -> None:
"""
Performs source separation and export result to file using
given audio adapter.
Performs source separation and export result to file using
given audio adapter.
Filename format should be a Python formattable string that could
use following parameters :
Filename format should be a Python formattable string that could
use following parameters :
- {instrument}
- {filename}
- {foldername}
- {codec}.
- {instrument}
- {filename}
- {foldername}
- {codec}.
Parameters:
audio_descriptor (AudioDescriptor):
Describe song to separate, used by audio adapter to
retrieve and load audio data, in case of file based
audio adapter, such descriptor would be a file path.
destination (str):
Target directory to write output to.
audio_adapter (Optional[AudioAdapter]):
(Optional) Audio adapter to use for I/O.
offset (int):
(Optional) Offset of loaded song.
duration (float):
(Optional) Duration of loaded song (default: 600s).
codec (Codec):
(Optional) Export codec.
bitrate (str):
(Optional) Export bitrate.
filename_format (str):
(Optional) Filename format.
synchronous (bool):
(Optional) True is should by synchronous.
Parameters:
audio_descriptor (AudioDescriptor):
Describe song to separate, used by audio adapter to
retrieve and load audio data, in case of file based
audio adapter, such descriptor would be a file path.
destination (str):
Target directory to write output to.
audio_adapter (Optional[AudioAdapter]):
(Optional) Audio adapter to use for I/O.
offset (int):
(Optional) Offset of loaded song.
duration (float):
(Optional) Duration of loaded song (default: 600s).
codec (Codec):
(Optional) Export codec.
bitrate (str):
(Optional) Export bitrate.
filename_format (str):
(Optional) Filename format.
synchronous (bool):
(Optional) True is should by synchronous.
"""
if audio_adapter is None:
audio_adapter = AudioAdapter.default()
@@ -389,7 +377,8 @@ class Separator(object):
audio_descriptor,
offset=offset,
duration=duration,
sample_rate=self._sample_rate)
sample_rate=self._sample_rate,
)
sources = self.separate(waveform, audio_descriptor)
self.save_to_file(
sources,
@@ -399,43 +388,45 @@ class Separator(object):
codec,
audio_adapter,
bitrate,
synchronous)
synchronous,
)
def save_to_file(
self,
sources: Dict,
audio_descriptor: AudioDescriptor,
destination: str,
filename_format: str = '{filename}/{instrument}.{codec}',
codec: Codec = Codec.WAV,
audio_adapter: Optional[AudioAdapter] = None,
bitrate: str = '128k',
synchronous: bool = True) -> None:
self,
sources: Dict,
audio_descriptor: AudioDescriptor,
destination: str,
filename_format: str = "{filename}/{instrument}.{codec}",
codec: Codec = Codec.WAV,
audio_adapter: Optional[AudioAdapter] = None,
bitrate: str = "128k",
synchronous: bool = True,
) -> None:
"""
Export dictionary of sources to files.
Export dictionary of sources to files.
Parameters:
sources (Dict):
Dictionary of sources to be exported. The keys are the name
of the instruments, and the values are `N x 2` numpy arrays
containing the corresponding intrument waveform, as
returned by the separate method
audio_descriptor (AudioDescriptor):
Describe song to separate, used by audio adapter to
retrieve and load audio data, in case of file based audio
adapter, such descriptor would be a file path.
destination (str):
Target directory to write output to.
filename_format (str):
(Optional) Filename format.
codec (Codec):
(Optional) Export codec.
audio_adapter (Optional[AudioAdapter]):
(Optional) Audio adapter to use for I/O.
bitrate (str):
(Optional) Export bitrate.
synchronous (bool):
(Optional) True is should by synchronous.
Parameters:
sources (Dict):
Dictionary of sources to be exported. The keys are the name
of the instruments, and the values are `N x 2` numpy arrays
containing the corresponding intrument waveform, as
returned by the separate method
audio_descriptor (AudioDescriptor):
Describe song to separate, used by audio adapter to
retrieve and load audio data, in case of file based audio
adapter, such descriptor would be a file path.
destination (str):
Target directory to write output to.
filename_format (str):
(Optional) Filename format.
codec (Codec):
(Optional) Export codec.
audio_adapter (Optional[AudioAdapter]):
(Optional) Audio adapter to use for I/O.
bitrate (str):
(Optional) Export bitrate.
synchronous (bool):
(Optional) True is should by synchronous.
"""
if audio_adapter is None:
audio_adapter = AudioAdapter.default()
@@ -443,34 +434,32 @@ class Separator(object):
filename = splitext(basename(audio_descriptor))[0]
generated = []
for instrument, data in sources.items():
path = join(destination, filename_format.format(
filename=filename,
instrument=instrument,
foldername=foldername,
codec=codec,
))
path = join(
destination,
filename_format.format(
filename=filename,
instrument=instrument,
foldername=foldername,
codec=codec,
),
)
directory = os.path.dirname(path)
if not os.path.exists(directory):
os.makedirs(directory)
if path in generated:
raise SpleeterError((
f'Separated source path conflict : {path},'
'please check your filename format'))
raise SpleeterError(
(
f"Separated source path conflict : {path},"
"please check your filename format"
)
)
generated.append(path)
if self._pool:
task = self._pool.apply_async(audio_adapter.save, (
path,
data,
self._sample_rate,
codec,
bitrate))
task = self._pool.apply_async(
audio_adapter.save, (path, data, self._sample_rate, codec, bitrate)
)
self._tasks.append(task)
else:
audio_adapter.save(
path,
data,
self._sample_rate,
codec,
bitrate)
audio_adapter.save(path, data, self._sample_rate, codec, bitrate)
if synchronous and self._pool:
self.join()

View File

@@ -8,6 +8,7 @@ from typing import Any, Tuple
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
# pylint: enable=import-error
AudioDescriptor: type = Any

View File

@@ -3,6 +3,6 @@
""" This package provides utility function and classes. """
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"

View File

@@ -3,51 +3,49 @@
""" Module that provides configuration loading function. """
import json
import importlib.resources as loader
import json
from os.path import exists
from typing import Dict
from .. import resources, SpleeterError
from .. import SpleeterError, resources
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
_EMBEDDED_CONFIGURATION_PREFIX: str = 'spleeter:'
_EMBEDDED_CONFIGURATION_PREFIX: str = "spleeter:"
def load_configuration(descriptor: str) -> Dict:
"""
Load configuration from the given descriptor. Could be either a
`spleeter:` prefixed embedded configuration name or a file system path
to read configuration from.
Load configuration from the given descriptor. Could be either a
`spleeter:` prefixed embedded configuration name or a file system path
to read configuration from.
Parameters:
descriptor (str):
Configuration descriptor to use for lookup.
Parameters:
descriptor (str):
Configuration descriptor to use for lookup.
Returns:
Dict:
Loaded description as dict.
Returns:
Dict:
Loaded description as dict.
Raises:
ValueError:
If required embedded configuration does not exists.
SpleeterError:
If required configuration file does not exists.
Raises:
ValueError:
If required embedded configuration does not exists.
SpleeterError:
If required configuration file does not exists.
"""
# Embedded configuration reading.
if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX):
name = descriptor[len(_EMBEDDED_CONFIGURATION_PREFIX):]
if not loader.is_resource(resources, f'{name}.json'):
raise SpleeterError(f'No embedded configuration {name} found')
with loader.open_text(resources, f'{name}.json') as stream:
name = descriptor[len(_EMBEDDED_CONFIGURATION_PREFIX) :]
if not loader.is_resource(resources, f"{name}.json"):
raise SpleeterError(f"No embedded configuration {name} found")
with loader.open_text(resources, f"{name}.json") as stream:
return json.load(stream)
# Standard file reading.
if not exists(descriptor):
raise SpleeterError(f'Configuration file {descriptor} not found')
with open(descriptor, 'r') as stream:
raise SpleeterError(f"Configuration file {descriptor} not found")
with open(descriptor, "r") as stream:
return json.load(stream)

View File

@@ -5,19 +5,19 @@
import logging
import warnings
from os import environ
# pyright: reportMissingImports=false
# pylint: disable=import-error
from typer import echo
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
class TyperLoggerHandler(logging.Handler):
@@ -27,29 +27,30 @@ class TyperLoggerHandler(logging.Handler):
echo(self.format(record))
formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')
formatter = logging.Formatter("%(levelname)s:%(name)s:%(message)s")
handler = TyperLoggerHandler()
handler.setFormatter(formatter)
logger: logging.Logger = logging.getLogger('spleeter')
logger: logging.Logger = logging.getLogger("spleeter")
logger.addHandler(handler)
logger.setLevel(logging.INFO)
def configure_logger(verbose: bool) -> None:
"""
Configure application logger.
Configure application logger.
Parameters:
verbose (bool):
`True` to use verbose logger, `False` otherwise.
Parameters:
verbose (bool):
`True` to use verbose logger, `False` otherwise.
"""
from tensorflow import get_logger
from tensorflow.compat.v1 import logging as tf_logging
tf_logger = get_logger()
tf_logger.handlers = [handler]
if verbose:
tf_logging.set_verbosity(tf_logging.INFO)
logger.setLevel(logging.DEBUG)
else:
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore")
tf_logging.set_verbosity(tf_logging.ERROR)

View File

@@ -5,50 +5,52 @@
from typing import Any, Callable, Dict
import pandas as pd
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
import pandas as pd
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def sync_apply(
tensor_dict: tf.Tensor,
func: Callable,
concat_axis: int = 1) -> Dict[str, tf.Tensor]:
tensor_dict: tf.Tensor, func: Callable, concat_axis: int = 1
) -> Dict[str, tf.Tensor]:
"""
Return a function that applies synchronously the provided func on the
provided dictionnary of tensor. This means that func is applied to the
concatenation of the tensors in tensor_dict. This is useful for
performing random operation that needs the same drawn value on multiple
tensor, such as a random time-crop on both input data and label (the
same crop should be applied to both input data and label, so random
crop cannot be applied separately on each of them).
Return a function that applies synchronously the provided func on the
provided dictionnary of tensor. This means that func is applied to the
concatenation of the tensors in tensor_dict. This is useful for
performing random operation that needs the same drawn value on multiple
tensor, such as a random time-crop on both input data and label (the
same crop should be applied to both input data and label, so random
crop cannot be applied separately on each of them).
Notes:
All tensor are assumed to be the same shape.
Notes:
All tensor are assumed to be the same shape.
Parameters:
tensor_dict (Dict[str, tensorflow.Tensor]):
A dictionary of tensor.
func (Callable):
Function to be applied to the concatenation of the tensors in
`tensor_dict`.
concat_axis (int):
The axis on which to perform the concatenation.
Parameters:
tensor_dict (Dict[str, tensorflow.Tensor]):
A dictionary of tensor.
func (Callable):
Function to be applied to the concatenation of the tensors in
`tensor_dict`.
concat_axis (int):
The axis on which to perform the concatenation.
Returns:
Dict[str, tensorflow.Tensor]:
Processed tensors dictionary with the same name (keys) as input
tensor_dict.
Returns:
Dict[str, tensorflow.Tensor]:
Processed tensors dictionary with the same name (keys) as input
tensor_dict.
"""
if concat_axis not in {0, 1}:
raise NotImplementedError(
'Function only implemented for concat_axis equal to 0 or 1')
"Function only implemented for concat_axis equal to 0 or 1"
)
tensor_list = list(tensor_dict.values())
concat_tensor = tf.concat(tensor_list, concat_axis)
processed_concat_tensor = func(concat_tensor)
@@ -56,107 +58,104 @@ def sync_apply(
D = tensor_shape[concat_axis]
if concat_axis == 0:
return {
name: processed_concat_tensor[index * D:(index + 1) * D, :, :]
for index, name in enumerate(tensor_dict)}
name: processed_concat_tensor[index * D : (index + 1) * D, :, :]
for index, name in enumerate(tensor_dict)
}
return {
name: processed_concat_tensor[:, index * D:(index + 1) * D, :]
for index, name in enumerate(tensor_dict)}
name: processed_concat_tensor[:, index * D : (index + 1) * D, :]
for index, name in enumerate(tensor_dict)
}
def from_float32_to_uint8(
tensor: tf.Tensor,
tensor_key: str = 'tensor',
min_key: str = 'min',
max_key: str = 'max') -> tf.Tensor:
tensor: tf.Tensor,
tensor_key: str = "tensor",
min_key: str = "min",
max_key: str = "max",
) -> tf.Tensor:
"""
Parameters:
tensor (tensorflow.Tensor):
tensor_key (str):
min_key (str):
max_key (str):
Parameters:
tensor (tensorflow.Tensor):
tensor_key (str):
min_key (str):
max_key (str):
Returns:
tensorflow.Tensor:
Returns:
tensorflow.Tensor:
"""
tensor_min = tf.reduce_min(tensor)
tensor_max = tf.reduce_max(tensor)
return {
tensor_key: tf.cast(
(tensor - tensor_min) / (tensor_max - tensor_min + 1e-16)
* 255.9999, dtype=tf.uint8),
(tensor - tensor_min) / (tensor_max - tensor_min + 1e-16) * 255.9999,
dtype=tf.uint8,
),
min_key: tensor_min,
max_key: tensor_max}
max_key: tensor_max,
}
def from_uint8_to_float32(
tensor: tf.Tensor,
tensor_min: tf.Tensor,
tensor_max: tf.Tensor) -> tf.Tensor:
tensor: tf.Tensor, tensor_min: tf.Tensor, tensor_max: tf.Tensor
) -> tf.Tensor:
"""
Parameters:
tensor (tensorflow.Tensor):
tensor_min (tensorflow.Tensor):
tensor_max (tensorflow.Tensor):
Parameters:
tensor (tensorflow.Tensor):
tensor_min (tensorflow.Tensor):
tensor_max (tensorflow.Tensor):
Returns:
tensorflow.Tensor:
Returns:
tensorflow.Tensor:
"""
return (
tf.cast(tensor, tf.float32)
* (tensor_max - tensor_min)
/ 255.9999 + tensor_min)
tf.cast(tensor, tf.float32) * (tensor_max - tensor_min) / 255.9999 + tensor_min
)
def pad_and_partition(
tensor: tf.Tensor,
segment_len: int) -> tf.Tensor:
def pad_and_partition(tensor: tf.Tensor, segment_len: int) -> tf.Tensor:
"""
Pad and partition a tensor into segment of len `segment_len`
along the first dimension. The tensor is padded with 0 in order
to ensure that the first dimension is a multiple of `segment_len`.
Pad and partition a tensor into segment of len `segment_len`
along the first dimension. The tensor is padded with 0 in order
to ensure that the first dimension is a multiple of `segment_len`.
Tensor must be of known fixed rank
Tensor must be of known fixed rank
Examples:
Examples:
```python
>>> tensor = [[1, 2, 3], [4, 5, 6]]
>>> segment_len = 2
>>> pad_and_partition(tensor, segment_len)
[[[1, 2], [4, 5]], [[3, 0], [6, 0]]]
````
```python
>>> tensor = [[1, 2, 3], [4, 5, 6]]
>>> segment_len = 2
>>> pad_and_partition(tensor, segment_len)
[[[1, 2], [4, 5]], [[3, 0], [6, 0]]]
````
Parameters:
tensor (tensorflow.Tensor):
segment_len (int):
Parameters:
tensor (tensorflow.Tensor):
segment_len (int):
Returns:
tensorflow.Tensor:
Returns:
tensorflow.Tensor:
"""
tensor_size = tf.math.floormod(tf.shape(tensor)[0], segment_len)
pad_size = tf.math.floormod(segment_len - tensor_size, segment_len)
padded = tf.pad(
tensor,
[[0, pad_size]] + [[0, 0]] * (len(tensor.shape)-1))
padded = tf.pad(tensor, [[0, pad_size]] + [[0, 0]] * (len(tensor.shape) - 1))
split = (tf.shape(padded)[0] + segment_len - 1) // segment_len
return tf.reshape(
padded,
tf.concat(
[[split, segment_len], tf.shape(padded)[1:]],
axis=0))
padded, tf.concat([[split, segment_len], tf.shape(padded)[1:]], axis=0)
)
def pad_and_reshape(instr_spec, frame_length, F) -> Any:
"""
Parameters:
instr_spec:
frame_length:
F:
Parameters:
instr_spec:
frame_length:
F:
Returns:
Any:
Returns:
Any:
"""
spec_shape = tf.shape(instr_spec)
extension_row = tf.zeros((spec_shape[0], spec_shape[1], 1, spec_shape[-1]))
@@ -164,77 +163,67 @@ def pad_and_reshape(instr_spec, frame_length, F) -> Any:
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
extended_spec = tf.concat([instr_spec, extension], axis=2)
old_shape = tf.shape(extended_spec)
new_shape = tf.concat([
[old_shape[0] * old_shape[1]],
old_shape[2:]],
axis=0)
new_shape = tf.concat([[old_shape[0] * old_shape[1]], old_shape[2:]], axis=0)
processed_instr_spec = tf.reshape(extended_spec, new_shape)
return processed_instr_spec
def dataset_from_csv(csv_path: str, **kwargs) -> Any:
"""
Load dataset from a CSV file using Pandas. kwargs if any are
forwarded to the `pandas.read_csv` function.
Load dataset from a CSV file using Pandas. kwargs if any are
forwarded to the `pandas.read_csv` function.
Parameters:
csv_path (str):
Path of the CSV file to load dataset from.
Parameters:
csv_path (str):
Path of the CSV file to load dataset from.
Returns:
Any:
Loaded dataset.
Returns:
Any:
Loaded dataset.
"""
df = pd.read_csv(csv_path, **kwargs)
dataset = (
tf.data.Dataset.from_tensor_slices(
{key: df[key].values for key in df})
)
dataset = tf.data.Dataset.from_tensor_slices({key: df[key].values for key in df})
return dataset
def check_tensor_shape(
tensor_tf: tf.Tensor,
target_shape: Any) -> bool:
def check_tensor_shape(tensor_tf: tf.Tensor, target_shape: Any) -> bool:
"""
Return a Tensorflow boolean graph that indicates whether
sample[features_key] has the specified target shape. Only check
not None entries of target_shape.
Return a Tensorflow boolean graph that indicates whether
sample[features_key] has the specified target shape. Only check
not None entries of target_shape.
Parameters:
tensor_tf (tensorflow.Tensor):
Tensor to check shape for.
target_shape (Any):
Target shape to compare tensor to.
Parameters:
tensor_tf (tensorflow.Tensor):
Tensor to check shape for.
target_shape (Any):
Target shape to compare tensor to.
Returns:
bool:
`True` if shape is valid, `False` otherwise (as TF boolean).
Returns:
bool:
`True` if shape is valid, `False` otherwise (as TF boolean).
"""
result = tf.constant(True)
for i, target_length in enumerate(target_shape):
if target_length:
result = tf.logical_and(
result,
tf.equal(tf.constant(target_length), tf.shape(tensor_tf)[i]))
result, tf.equal(tf.constant(target_length), tf.shape(tensor_tf)[i])
)
return result
def set_tensor_shape(
tensor: tf.Tensor,
tensor_shape: Any) -> tf.Tensor:
def set_tensor_shape(tensor: tf.Tensor, tensor_shape: Any) -> tf.Tensor:
"""
Set shape for a tensor (not in place, as opposed to tf.set_shape)
Set shape for a tensor (not in place, as opposed to tf.set_shape)
Parameters:
tensor (tensorflow.Tensor):
Tensor to reshape.
tensor_shape (Any):
Shape to apply to the tensor.
Parameters:
tensor (tensorflow.Tensor):
Tensor to reshape.
tensor_shape (Any):
Shape to apply to the tensor.
Returns:
tensorflow.Tensor:
A reshaped tensor.
Returns:
tensorflow.Tensor:
A reshaped tensor.
"""
# NOTE: That SOUND LIKE IN PLACE HERE ?
tensor.set_shape(tensor_shape)