add pyproject.toml for poetry transition

This commit is contained in:
Félix Voituret
2021-01-08 17:32:39 +01:00
parent ad0171f6dd
commit 2b479eb683
25 changed files with 3741 additions and 1624 deletions

View File

@@ -1,4 +1,4 @@
name: pytest name: test
on: on:
pull_request: pull_request:
branches: branches:
@@ -15,13 +15,6 @@ jobs:
uses: actions/setup-python@v2 uses: actions/setup-python@v2
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- uses: actions/cache@v2
id: spleeter-pip-cache
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }}
restore-keys: |
${{ runner.os }}-pip-
- uses: actions/cache@v2 - uses: actions/cache@v2
env: env:
model-release: 1 model-release: 1
@@ -31,11 +24,29 @@ jobs:
key: models-${{ env.model-release }} key: models-${{ env.model-release }}
restore-keys: | restore-keys: |
models-${{ env.model-release }} models-${{ env.model-release }}
- name: Install dependencies - name: Install ffmpeg
run: | run: |
sudo apt-get update && sudo apt-get install -y ffmpeg sudo apt-get update && sudo apt-get install -y ffmpeg
pip install --upgrade pip setuptools - name: Install Poetry
pip install pytest==5.4.3 pytest-xdist==1.32.0 pytest-forked==1.1.3 musdb museval uses: dschep/install-poetry-action@v1.2
python setup.py install - name: Cache Poetry virtualenv
uses: actions/cache@v1
id: cache
with:
path: ~/.virtualenvs
key: poetry-${{ hashFiles('**/poetry.lock') }}
restore-keys: |
poetry-${{ hashFiles('**/poetry.lock') }}
- name: Set Poetry config
run: |
poetry config settings.virtualenvs.in-project false
poetry config settings.virtualenvs.path ~/.virtualenvs
- name: Install Dependencies
run: poetry install
if: steps.cache.outputs.cache-hit != 'true'
- name: Code quality checks
run: |
poetry run black spleeter --check
poetry run isort spleeter --check
- name: Test with pytest - name: Test with pytest
run: make test run: poetry run pytest tests/

1931
poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

84
pyproject.toml Normal file
View File

@@ -0,0 +1,84 @@
[tool.poetry]
name = "spleeter"
version = "2.1.0"
description = "The Deezer source separation library with pretrained models based on tensorflow."
authors = ["Deezer Research <spleeter@deezer.com>"]
license = "MIT License"
readme = "README.md"
repository = "https://github.com/deezer/spleeter"
homepage = "https://github.com/deezer/spleeter"
classifiers = [
"Environment :: Console",
"Environment :: MacOS X",
"Intended Audience :: Developers",
"Intended Audience :: Information Technology",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Operating System :: MacOS",
"Operating System :: Microsoft :: Windows",
"Operating System :: POSIX :: Linux",
"Operating System :: Unix",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: Implementation :: CPython",
"Topic :: Artistic Software",
"Topic :: Multimedia",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Sound/Audio :: Analysis",
"Topic :: Multimedia :: Sound/Audio :: Conversion",
"Topic :: Multimedia :: Sound/Audio :: Sound Synthesis",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Information Analysis",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Utilities"
]
packages = [ { include = "spleeter" } ]
include = ["spleeter/resources/*.json"]
[tool.poetry.dependencies]
python = "^3.7"
ffmpeg-python = "0.2.0"
norbert = "0.2.1"
httpx = {extras = ["http2"], version = "^0.16.1"}
typer = "^0.3.2"
librosa = "0.8.0"
musdb = {version = "0.3.1", optional = true}
museval = {version = "0.3.0", optional = true}
tensorflow = "2.3.0"
pandas = "1.1.2"
numpy = "<1.19.0,>=1.16.0"
[tool.poetry.dev-dependencies]
pytest = "^6.2.1"
isort = "^5.7.0"
black = "^20.8b1"
mypy = "^0.790"
pytest-xdist = "^2.2.0"
pytest-forked = "^1.3.0"
musdb = "0.3.1"
museval = "0.3.0"
[tool.poetry.scripts]
spleeter = 'spleeter.__main__:entrypoint'
[tool.poetry.extras]
evaluation = ["musdb", "museval"]
[tool.isort]
profile = "black"
multi_line_output = 3
[tool.pytest.ini_options]
addopts = "-W ignore::FutureWarning -W ignore::DeprecationWarning -vv --forked"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

View File

@@ -13,9 +13,9 @@
by providing train, evaluation and source separation action. by providing train, evaluation and source separation action.
""" """
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"
class SpleeterError(Exception): class SpleeterError(Exception):

View File

@@ -13,21 +13,21 @@
""" """
import json import json
from functools import partial from functools import partial
from itertools import product
from glob import glob from glob import glob
from itertools import product
from os.path import join from os.path import join
from pathlib import Path from pathlib import Path
from typing import Container, Dict, List, Optional from typing import Container, Dict, List, Optional
# pyright: reportMissingImports=false
# pylint: disable=import-error
from typer import Exit, Typer
from . import SpleeterError from . import SpleeterError
from .options import * from .options import *
from .utils.logging import configure_logger, logger from .utils.logging import configure_logger, logger
# pyright: reportMissingImports=false
# pylint: disable=import-error
from typer import Exit, Typer
# pylint: enable=import-error # pylint: enable=import-error
spleeter: Typer = Typer(add_completion=False) spleeter: Typer = Typer(add_completion=False)
@@ -39,18 +39,19 @@ def train(
adapter: str = AudioAdapterOption, adapter: str = AudioAdapterOption,
data: Path = TrainingDataDirectoryOption, data: Path = TrainingDataDirectoryOption,
params_filename: str = ModelParametersOption, params_filename: str = ModelParametersOption,
verbose: bool = VerboseOption) -> None: verbose: bool = VerboseOption,
) -> None:
""" """
Train a source separation model Train a source separation model
""" """
import tensorflow as tf
from .audio.adapter import AudioAdapter from .audio.adapter import AudioAdapter
from .dataset import get_training_dataset, get_validation_dataset from .dataset import get_training_dataset, get_validation_dataset
from .model import model_fn from .model import model_fn
from .model.provider import ModelProvider from .model.provider import ModelProvider
from .utils.configuration import load_configuration from .utils.configuration import load_configuration
import tensorflow as tf
configure_logger(verbose) configure_logger(verbose)
audio_adapter = AudioAdapter.get(adapter) audio_adapter = AudioAdapter.get(adapter)
audio_path = str(data) audio_path = str(data)
@@ -59,32 +60,29 @@ def train(
session_config.gpu_options.per_process_gpu_memory_fraction = 0.45 session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
estimator = tf.estimator.Estimator( estimator = tf.estimator.Estimator(
model_fn=model_fn, model_fn=model_fn,
model_dir=params['model_dir'], model_dir=params["model_dir"],
params=params, params=params,
config=tf.estimator.RunConfig( config=tf.estimator.RunConfig(
save_checkpoints_steps=params['save_checkpoints_steps'], save_checkpoints_steps=params["save_checkpoints_steps"],
tf_random_seed=params['random_seed'], tf_random_seed=params["random_seed"],
save_summary_steps=params['save_summary_steps'], save_summary_steps=params["save_summary_steps"],
session_config=session_config, session_config=session_config,
log_step_count_steps=10, log_step_count_steps=10,
keep_checkpoint_max=2)) keep_checkpoint_max=2,
),
)
input_fn = partial(get_training_dataset, params, audio_adapter, audio_path) input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
train_spec = tf.estimator.TrainSpec( train_spec = tf.estimator.TrainSpec(
input_fn=input_fn, input_fn=input_fn, max_steps=params["train_max_steps"]
max_steps=params['train_max_steps']) )
input_fn = partial( input_fn = partial(get_validation_dataset, params, audio_adapter, audio_path)
get_validation_dataset,
params,
audio_adapter,
audio_path)
evaluation_spec = tf.estimator.EvalSpec( evaluation_spec = tf.estimator.EvalSpec(
input_fn=input_fn, input_fn=input_fn, steps=None, throttle_secs=params["throttle_secs"]
steps=None, )
throttle_secs=params['throttle_secs']) logger.info("Start model training")
logger.info('Start model training')
tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec) tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec)
ModelProvider.writeProbe(params['model_dir']) ModelProvider.writeProbe(params["model_dir"])
logger.info('Model training done') logger.info("Model training done")
@spleeter.command() @spleeter.command()
@@ -101,7 +99,8 @@ def separate(
filename_format: str = FilenameFormatOption, filename_format: str = FilenameFormatOption,
params_filename: str = ModelParametersOption, params_filename: str = ModelParametersOption,
mwf: bool = MWFOption, mwf: bool = MWFOption,
verbose: bool = VerboseOption) -> None: verbose: bool = VerboseOption,
) -> None:
""" """
Separate audio file(s) Separate audio file(s)
""" """
@@ -111,14 +110,14 @@ def separate(
configure_logger(verbose) configure_logger(verbose)
if deprecated_files is not None: if deprecated_files is not None:
logger.error( logger.error(
'⚠️ -i option is not supported anymore, audio files must be supplied ' "⚠️ -i option is not supported anymore, audio files must be supplied "
'using input argument instead (see spleeter separate --help)') "using input argument instead (see spleeter separate --help)"
)
raise Exit(20) raise Exit(20)
audio_adapter: AudioAdapter = AudioAdapter.get(adapter) audio_adapter: AudioAdapter = AudioAdapter.get(adapter)
separator: Separator = Separator( separator: Separator = Separator(
params_filename, params_filename, MWF=mwf, stft_backend=stft_backend
MWF=mwf, )
stft_backend=stft_backend)
for filename in files: for filename in files:
separator.separate_to_file( separator.separate_to_file(
str(filename), str(filename),
@@ -129,16 +128,17 @@ def separate(
codec=codec, codec=codec,
bitrate=bitrate, bitrate=bitrate,
filename_format=filename_format, filename_format=filename_format,
synchronous=False) synchronous=False,
)
separator.join() separator.join()
EVALUATION_SPLIT: str = 'test' EVALUATION_SPLIT: str = "test"
EVALUATION_METRICS_DIRECTORY: str = 'metrics' EVALUATION_METRICS_DIRECTORY: str = "metrics"
EVALUATION_INSTRUMENTS: Container[str] = ('vocals', 'drums', 'bass', 'other') EVALUATION_INSTRUMENTS: Container[str] = ("vocals", "drums", "bass", "other")
EVALUATION_METRICS: Container[str] = ('SDR', 'SAR', 'SIR', 'ISR') EVALUATION_METRICS: Container[str] = ("SDR", "SAR", "SIR", "ISR")
EVALUATION_MIXTURE: str = 'mixture.wav' EVALUATION_MIXTURE: str = "mixture.wav"
EVALUATION_AUDIO_DIRECTORY: str = 'audio' EVALUATION_AUDIO_DIRECTORY: str = "audio"
def _compile_metrics(metrics_output_directory) -> Dict: def _compile_metrics(metrics_output_directory) -> Dict:
@@ -153,27 +153,32 @@ def _compile_metrics(metrics_output_directory) -> Dict:
Dict: Dict:
Compiled metrics as dict. Compiled metrics as dict.
""" """
import pandas as pd
import numpy as np import numpy as np
import pandas as pd
songs = glob(join(metrics_output_directory, 'test/*.json')) songs = glob(join(metrics_output_directory, "test/*.json"))
index = pd.MultiIndex.from_tuples( index = pd.MultiIndex.from_tuples(
product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS), product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS),
names=['instrument', 'metric']) names=["instrument", "metric"],
pd.DataFrame([], index=['config1', 'config2'], columns=index) )
pd.DataFrame([], index=["config1", "config2"], columns=index)
metrics = { metrics = {
instrument: {k: [] for k in EVALUATION_METRICS} instrument: {k: [] for k in EVALUATION_METRICS}
for instrument in EVALUATION_INSTRUMENTS} for instrument in EVALUATION_INSTRUMENTS
}
for song in songs: for song in songs:
with open(song, 'r') as stream: with open(song, "r") as stream:
data = json.load(stream) data = json.load(stream)
for target in data['targets']: for target in data["targets"]:
instrument = target['name'] instrument = target["name"]
for metric in EVALUATION_METRICS: for metric in EVALUATION_METRICS:
sdr_med = np.median([ sdr_med = np.median(
frame['metrics'][metric] [
for frame in target['frames'] frame["metrics"][metric]
if not np.isnan(frame['metrics'][metric])]) for frame in target["frames"]
if not np.isnan(frame["metrics"][metric])
]
)
metrics[instrument][metric].append(sdr_med) metrics[instrument][metric].append(sdr_med)
return metrics return metrics
@@ -186,7 +191,8 @@ def evaluate(
params_filename: str = ModelParametersOption, params_filename: str = ModelParametersOption,
mus_dir: Path = MUSDBDirectoryOption, mus_dir: Path = MUSDBDirectoryOption,
mwf: bool = MWFOption, mwf: bool = MWFOption,
verbose: bool = VerboseOption) -> Dict: verbose: bool = VerboseOption,
) -> Dict:
""" """
Evaluate a model on the musDB test dataset Evaluate a model on the musDB test dataset
""" """
@@ -197,42 +203,44 @@ def evaluate(
import musdb import musdb
import museval import museval
except ImportError: except ImportError:
logger.error('Extra dependencies musdb and museval not found') logger.error("Extra dependencies musdb and museval not found")
logger.error('Please install musdb and museval first, abort') logger.error("Please install musdb and museval first, abort")
raise Exit(10) raise Exit(10)
# Separate musdb sources. # Separate musdb sources.
songs = glob(join(mus_dir, EVALUATION_SPLIT, '*/')) songs = glob(join(mus_dir, EVALUATION_SPLIT, "*/"))
mixtures = [join(song, EVALUATION_MIXTURE) for song in songs] mixtures = [join(song, EVALUATION_MIXTURE) for song in songs]
audio_output_directory = join(output_path, EVALUATION_AUDIO_DIRECTORY) audio_output_directory = join(output_path, EVALUATION_AUDIO_DIRECTORY)
separate( separate(
deprecated_files=None, deprecated_files=None,
files=mixtures, files=mixtures,
adapter=adapter, adapter=adapter,
bitrate='128k', bitrate="128k",
codec=Codec.WAV, codec=Codec.WAV,
duration=600., duration=600.0,
offset=0, offset=0,
output_path=join(audio_output_directory, EVALUATION_SPLIT), output_path=join(audio_output_directory, EVALUATION_SPLIT),
stft_backend=stft_backend, stft_backend=stft_backend,
filename_format='{foldername}/{instrument}.{codec}', filename_format="{foldername}/{instrument}.{codec}",
params_filename=params_filename, params_filename=params_filename,
mwf=mwf, mwf=mwf,
verbose=verbose) verbose=verbose,
)
# Compute metrics with musdb. # Compute metrics with musdb.
metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY) metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY)
logger.info('Starting musdb evaluation (this could be long) ...') logger.info("Starting musdb evaluation (this could be long) ...")
dataset = musdb.DB(root=mus_dir, is_wav=True, subsets=[EVALUATION_SPLIT]) dataset = musdb.DB(root=mus_dir, is_wav=True, subsets=[EVALUATION_SPLIT])
museval.eval_mus_dir( museval.eval_mus_dir(
dataset=dataset, dataset=dataset,
estimates_dir=audio_output_directory, estimates_dir=audio_output_directory,
output_dir=metrics_output_directory) output_dir=metrics_output_directory,
logger.info('musdb evaluation done') )
logger.info("musdb evaluation done")
# Compute and pretty print median metrics. # Compute and pretty print median metrics.
metrics = _compile_metrics(metrics_output_directory) metrics = _compile_metrics(metrics_output_directory)
for instrument, metric in metrics.items(): for instrument, metric in metrics.items():
logger.info(f'{instrument}:') logger.info(f"{instrument}:")
for metric, value in metric.items(): for metric, value in metric.items():
logger.info(f'{metric}: {np.median(value):.3f}') logger.info(f"{metric}: {np.median(value):.3f}")
return metrics return metrics
@@ -244,5 +252,5 @@ def entrypoint():
logger.error(e) logger.error(e)
if __name__ == '__main__': if __name__ == "__main__":
entrypoint() entrypoint()

View File

@@ -12,28 +12,28 @@
from enum import Enum from enum import Enum
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"
class Codec(str, Enum): class Codec(str, Enum):
""" Enumeration of supported audio codec. """ """ Enumeration of supported audio codec. """
WAV: str = 'wav' WAV: str = "wav"
MP3: str = 'mp3' MP3: str = "mp3"
OGG: str = 'ogg' OGG: str = "ogg"
M4A: str = 'm4a' M4A: str = "m4a"
WMA: str = 'wma' WMA: str = "wma"
FLAC: str = 'flac' FLAC: str = "flac"
class STFTBackend(str, Enum): class STFTBackend(str, Enum):
""" Enumeration of supported STFT backend. """ """ Enumeration of supported STFT backend. """
AUTO: str = 'auto' AUTO: str = "auto"
TENSORFLOW: str = 'tensorflow' TENSORFLOW: str = "tensorflow"
LIBROSA: str = 'librosa' LIBROSA: str = "librosa"
@classmethod @classmethod
def resolve(cls: type, backend: str) -> str: def resolve(cls: type, backend: str) -> str:
@@ -44,9 +44,9 @@ class STFTBackend(str, Enum):
import tensorflow as tf import tensorflow as tf
if backend not in cls.__members__.values(): if backend not in cls.__members__.values():
raise ValueError(f'Unsupported backend {backend}') raise ValueError(f"Unsupported backend {backend}")
if backend == cls.AUTO: if backend == cls.AUTO:
if len(tf.config.list_physical_devices('GPU')): if len(tf.config.list_physical_devices("GPU")):
return cls.TENSORFLOW return cls.TENSORFLOW
return cls.LIBROSA return cls.LIBROSA
return backend return backend

View File

@@ -6,29 +6,31 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from importlib import import_module from importlib import import_module
from pathlib import Path from pathlib import Path
from spleeter.audio import Codec
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from .. import SpleeterError
from ..types import AudioDescriptor, Signal
from ..utils.logging import logger
# pyright: reportMissingImports=false # pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from spleeter.audio import Codec
from .. import SpleeterError
from ..types import AudioDescriptor, Signal
from ..utils.logging import logger
# pylint: enable=import-error # pylint: enable=import-error
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"
class AudioAdapter(ABC): class AudioAdapter(ABC):
""" An abstract class for manipulating audio signal. """ """ An abstract class for manipulating audio signal. """
_DEFAULT: 'AudioAdapter' = None _DEFAULT: "AudioAdapter" = None
""" Default audio adapter singleton instance. """ """ Default audio adapter singleton instance. """
@abstractmethod @abstractmethod
@@ -38,7 +40,8 @@ class AudioAdapter(ABC):
offset: Optional[float] = None, offset: Optional[float] = None,
duration: Optional[float] = None, duration: Optional[float] = None,
sample_rate: Optional[float] = None, sample_rate: Optional[float] = None,
dtype: np.dtype = np.float32) -> Signal: dtype: np.dtype = np.float32,
) -> Signal:
""" """
Loads the audio file denoted by the given audio descriptor and Loads the audio file denoted by the given audio descriptor and
returns it data as a waveform. Aims to be implemented by client. returns it data as a waveform. Aims to be implemented by client.
@@ -66,10 +69,11 @@ class AudioAdapter(ABC):
self, self,
audio_descriptor, audio_descriptor,
offset: float = 0.0, offset: float = 0.0,
duration: float = 1800., duration: float = 1800.0,
sample_rate: int = 44100, sample_rate: int = 44100,
dtype: bytes = b'float32', dtype: bytes = b"float32",
waveform_name: str = 'waveform') -> Dict[str, Any]: waveform_name: str = "waveform",
) -> Dict[str, Any]:
""" """
Load the audio and convert it to a tensorflow waveform. Load the audio and convert it to a tensorflow waveform.
@@ -101,32 +105,31 @@ class AudioAdapter(ABC):
# Defined safe loading function. # Defined safe loading function.
def safe_load(path, offset, duration, sample_rate, dtype): def safe_load(path, offset, duration, sample_rate, dtype):
logger.info( logger.info(f"Loading audio {path} from {offset} to {offset + duration}")
f'Loading audio {path} from {offset} to {offset + duration}')
try: try:
(data, _) = self.load( (data, _) = self.load(
path.numpy(), path.numpy(),
offset.numpy(), offset.numpy(),
duration.numpy(), duration.numpy(),
sample_rate.numpy(), sample_rate.numpy(),
dtype=dtype.numpy()) dtype=dtype.numpy(),
logger.info('Audio data loaded successfully') )
logger.info("Audio data loaded successfully")
return (data, False) return (data, False)
except Exception as e: except Exception as e:
logger.exception( logger.exception("An error occurs while loading audio", exc_info=e)
'An error occurs while loading audio',
exc_info=e)
return (np.float32(-1.0), True) return (np.float32(-1.0), True)
# Execute function and format results. # Execute function and format results.
results = tf.py_function( results = (
tf.py_function(
safe_load, safe_load,
[audio_descriptor, offset, duration, sample_rate, dtype], [audio_descriptor, offset, duration, sample_rate, dtype],
(tf.float32, tf.bool)), (tf.float32, tf.bool),
),
)
waveform, error = results[0] waveform, error = results[0]
return { return {waveform_name: waveform, f"{waveform_name}_error": error}
waveform_name: waveform,
f'{waveform_name}_error': error}
@abstractmethod @abstractmethod
def save( def save(
@@ -135,7 +138,8 @@ class AudioAdapter(ABC):
data: np.ndarray, data: np.ndarray,
sample_rate: float, sample_rate: float,
codec: Codec = None, codec: Codec = None,
bitrate: str = None) -> None: bitrate: str = None,
) -> None:
""" """
Save the given audio data to the file denoted by the given path. Save the given audio data to the file denoted by the given path.
@@ -155,7 +159,7 @@ class AudioAdapter(ABC):
pass pass
@classmethod @classmethod
def default(cls: type) -> 'AudioAdapter': def default(cls: type) -> "AudioAdapter":
""" """
Builds and returns a default audio adapter instance. Builds and returns a default audio adapter instance.
@@ -165,11 +169,12 @@ class AudioAdapter(ABC):
""" """
if cls._DEFAULT is None: if cls._DEFAULT is None:
from .ffmpeg import FFMPEGProcessAudioAdapter from .ffmpeg import FFMPEGProcessAudioAdapter
cls._DEFAULT = FFMPEGProcessAudioAdapter() cls._DEFAULT = FFMPEGProcessAudioAdapter()
return cls._DEFAULT return cls._DEFAULT
@classmethod @classmethod
def get(cls: type, descriptor: str) -> 'AudioAdapter': def get(cls: type, descriptor: str) -> "AudioAdapter":
""" """
Load dynamically an AudioAdapter from given class descriptor. Load dynamically an AudioAdapter from given class descriptor.
@@ -183,12 +188,13 @@ class AudioAdapter(ABC):
""" """
if not descriptor: if not descriptor:
return cls.default() return cls.default()
module_path: List[str] = descriptor.split('.') module_path: List[str] = descriptor.split(".")
adapter_class_name: str = module_path[-1] adapter_class_name: str = module_path[-1]
module_path: str = '.'.join(module_path[:-1]) module_path: str = ".".join(module_path[:-1])
adapter_module = import_module(module_path) adapter_module = import_module(module_path)
adapter_class = getattr(adapter_module, adapter_class_name) adapter_class = getattr(adapter_module, adapter_class_name)
if not issubclass(adapter_class, AudioAdapter): if not issubclass(adapter_class, AudioAdapter):
raise SpleeterError( raise SpleeterError(
f'{adapter_class_name} is not a valid AudioAdapter class') f"{adapter_class_name} is not a valid AudioAdapter class"
)
return adapter_class() return adapter_class()

View File

@@ -3,22 +3,21 @@
""" This module provides audio data convertion functions. """ """ This module provides audio data convertion functions. """
from ..utils.tensor import from_float32_to_uint8, from_uint8_to_float32
# pyright: reportMissingImports=false # pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from ..utils.tensor import from_float32_to_uint8, from_uint8_to_float32
# pylint: enable=import-error # pylint: enable=import-error
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"
def to_n_channels( def to_n_channels(waveform: tf.Tensor, n_channels: int) -> tf.Tensor:
waveform: tf.Tensor,
n_channels: int) -> tf.Tensor:
""" """
Convert a waveform to n_channels by removing or duplicating channels if Convert a waveform to n_channels by removing or duplicating channels if
needed (in tensorflow). needed (in tensorflow).
@@ -36,7 +35,8 @@ def to_n_channels(
return tf.cond( return tf.cond(
tf.shape(waveform)[1] >= n_channels, tf.shape(waveform)[1] >= n_channels,
true_fn=lambda: waveform[:, :n_channels], true_fn=lambda: waveform[:, :n_channels],
false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels]) false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels],
)
def to_stereo(waveform: np.ndarray) -> np.ndarray: def to_stereo(waveform: np.ndarray) -> np.ndarray:
@@ -73,7 +73,7 @@ def gain_to_db(tensor: tf.Tensor, espilon: float = 10e-10) -> tf.Tensor:
tensorflow.Tensor: tensorflow.Tensor:
Converted tensor. Converted tensor.
""" """
return 20. / np.log(10) * tf.math.log(tf.maximum(tensor, espilon)) return 20.0 / np.log(10) * tf.math.log(tf.maximum(tensor, espilon))
def db_to_gain(tensor: tf.Tensor) -> tf.Tensor: def db_to_gain(tensor: tf.Tensor) -> tf.Tensor:
@@ -88,13 +88,12 @@ def db_to_gain(tensor: tf.Tensor) -> tf.Tensor:
tensorflow.Tensor: tensorflow.Tensor:
Converted tensor. Converted tensor.
""" """
return tf.pow(10., (tensor / 20.)) return tf.pow(10.0, (tensor / 20.0))
def spectrogram_to_db_uint( def spectrogram_to_db_uint(
spectrogram: tf.Tensor, spectrogram: tf.Tensor, db_range: float = 100.0, **kwargs
db_range: float = 100., ) -> tf.Tensor:
**kwargs) -> tf.Tensor:
""" """
Encodes given spectrogram into uint8 using decibel scale. Encodes given spectrogram into uint8 using decibel scale.
@@ -111,15 +110,14 @@ def spectrogram_to_db_uint(
db_spectrogram: tf.Tensor = gain_to_db(spectrogram) db_spectrogram: tf.Tensor = gain_to_db(spectrogram)
max_db_spectrogram: tf.Tensor = tf.reduce_max(db_spectrogram) max_db_spectrogram: tf.Tensor = tf.reduce_max(db_spectrogram)
db_spectrogram: tf.Tensor = tf.maximum( db_spectrogram: tf.Tensor = tf.maximum(
db_spectrogram, db_spectrogram, max_db_spectrogram - db_range
max_db_spectrogram - db_range) )
return from_float32_to_uint8(db_spectrogram, **kwargs) return from_float32_to_uint8(db_spectrogram, **kwargs)
def db_uint_spectrogram_to_gain( def db_uint_spectrogram_to_gain(
db_uint_spectrogram: tf.Tensor, db_uint_spectrogram: tf.Tensor, min_db: tf.Tensor, max_db: tf.Tensor
min_db: tf.Tensor, ) -> tf.Tensor:
max_db: tf.Tensor) -> tf.Tensor:
""" """
Decode spectrogram from uint8 decibel scale. Decode spectrogram from uint8 decibel scale.
@@ -136,7 +134,6 @@ def db_uint_spectrogram_to_gain(
Decoded spectrogram as `float32` tensor. Decoded spectrogram as `float32` tensor.
""" """
db_spectrogram: tf.Tensor = from_uint8_to_float32( db_spectrogram: tf.Tensor = from_uint8_to_float32(
db_uint_spectrogram, db_uint_spectrogram, min_db, max_db
min_db, )
max_db)
return db_to_gain(db_spectrogram) return db_to_gain(db_spectrogram)

View File

@@ -11,25 +11,25 @@
import datetime as dt import datetime as dt
import os import os
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from . import Codec
from .adapter import AudioAdapter
from .. import SpleeterError
from ..types import Signal
from ..utils.logging import logger
# pyright: reportMissingImports=false # pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import ffmpeg import ffmpeg
import numpy as np import numpy as np
from .. import SpleeterError
from ..types import Signal
from ..utils.logging import logger
from . import Codec
from .adapter import AudioAdapter
# pylint: enable=import-error # pylint: enable=import-error
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"
class FFMPEGProcessAudioAdapter(AudioAdapter): class FFMPEGProcessAudioAdapter(AudioAdapter):
@@ -43,9 +43,9 @@ class FFMPEGProcessAudioAdapter(AudioAdapter):
""" """
SUPPORTED_CODECS: Dict[Codec, str] = { SUPPORTED_CODECS: Dict[Codec, str] = {
Codec.M4A: 'aac', Codec.M4A: "aac",
Codec.OGG: 'libvorbis', Codec.OGG: "libvorbis",
Codec.WMA: 'wmav2' Codec.WMA: "wmav2",
} }
""" FFMPEG codec name mapping. """ """ FFMPEG codec name mapping. """
@@ -57,9 +57,9 @@ class FFMPEGProcessAudioAdapter(AudioAdapter):
SpleeterError: SpleeterError:
If ffmpeg or ffprobe is not found. If ffmpeg or ffprobe is not found.
""" """
for binary in ('ffmpeg', 'ffprobe'): for binary in ("ffmpeg", "ffprobe"):
if shutil.which(binary) is None: if shutil.which(binary) is None:
raise SpleeterError('{} binary not found'.format(binary)) raise SpleeterError("{} binary not found".format(binary))
def load( def load(
_, _,
@@ -67,7 +67,8 @@ class FFMPEGProcessAudioAdapter(AudioAdapter):
offset: Optional[float] = None, offset: Optional[float] = None,
duration: Optional[float] = None, duration: Optional[float] = None,
sample_rate: Optional[float] = None, sample_rate: Optional[float] = None,
dtype: np.dtype = np.float32) -> Signal: dtype: np.dtype = np.float32,
) -> Signal:
""" """
Loads the audio file denoted by the given path Loads the audio file denoted by the given path
and returns it data as a waveform. and returns it data as a waveform.
@@ -100,29 +101,30 @@ class FFMPEGProcessAudioAdapter(AudioAdapter):
probe = ffmpeg.probe(path) probe = ffmpeg.probe(path)
except ffmpeg._run.Error as e: except ffmpeg._run.Error as e:
raise SpleeterError( raise SpleeterError(
'An error occurs with ffprobe (see ffprobe output below)\n\n{}' "An error occurs with ffprobe (see ffprobe output below)\n\n{}".format(
.format(e.stderr.decode())) e.stderr.decode()
if 'streams' not in probe or len(probe['streams']) == 0: )
raise SpleeterError('No stream was found with ffprobe') )
if "streams" not in probe or len(probe["streams"]) == 0:
raise SpleeterError("No stream was found with ffprobe")
metadata = next( metadata = next(
stream stream for stream in probe["streams"] if stream["codec_type"] == "audio"
for stream in probe['streams'] )
if stream['codec_type'] == 'audio') n_channels = metadata["channels"]
n_channels = metadata['channels']
if sample_rate is None: if sample_rate is None:
sample_rate = metadata['sample_rate'] sample_rate = metadata["sample_rate"]
output_kwargs = {'format': 'f32le', 'ar': sample_rate} output_kwargs = {"format": "f32le", "ar": sample_rate}
if duration is not None: if duration is not None:
output_kwargs['t'] = str(dt.timedelta(seconds=duration)) output_kwargs["t"] = str(dt.timedelta(seconds=duration))
if offset is not None: if offset is not None:
output_kwargs['ss'] = str(dt.timedelta(seconds=offset)) output_kwargs["ss"] = str(dt.timedelta(seconds=offset))
process = ( process = (
ffmpeg ffmpeg.input(path)
.input(path) .output("pipe:", **output_kwargs)
.output('pipe:', **output_kwargs) .run_async(pipe_stdout=True, pipe_stderr=True)
.run_async(pipe_stdout=True, pipe_stderr=True)) )
buffer, _ = process.communicate() buffer, _ = process.communicate()
waveform = np.frombuffer(buffer, dtype='<f4').reshape(-1, n_channels) waveform = np.frombuffer(buffer, dtype="<f4").reshape(-1, n_channels)
if not waveform.dtype == np.dtype(dtype): if not waveform.dtype == np.dtype(dtype):
waveform = waveform.astype(dtype) waveform = waveform.astype(dtype)
return (waveform, sample_rate) return (waveform, sample_rate)
@@ -133,7 +135,8 @@ class FFMPEGProcessAudioAdapter(AudioAdapter):
data: np.ndarray, data: np.ndarray,
sample_rate: float, sample_rate: float,
codec: Codec = None, codec: Codec = None,
bitrate: str = None) -> None: bitrate: str = None,
) -> None:
""" """
Write waveform data to the file denoted by the given path using Write waveform data to the file denoted by the given path using
FFMPEG process. FFMPEG process.
@@ -159,25 +162,24 @@ class FFMPEGProcessAudioAdapter(AudioAdapter):
path = str(path) path = str(path)
directory = os.path.dirname(path) directory = os.path.dirname(path)
if not os.path.exists(directory): if not os.path.exists(directory):
raise SpleeterError( raise SpleeterError(f"output directory does not exists: {directory}")
f'output directory does not exists: {directory}') logger.debug(f"Writing file {path}")
logger.debug(f'Writing file {path}') input_kwargs = {"ar": sample_rate, "ac": data.shape[1]}
input_kwargs = {'ar': sample_rate, 'ac': data.shape[1]} output_kwargs = {"ar": sample_rate, "strict": "-2"}
output_kwargs = {'ar': sample_rate, 'strict': '-2'}
if bitrate: if bitrate:
output_kwargs['audio_bitrate'] = bitrate output_kwargs["audio_bitrate"] = bitrate
if codec is not None and codec != 'wav': if codec is not None and codec != "wav":
output_kwargs['codec'] = self.SUPPORTED_CODECS.get(codec, codec) output_kwargs["codec"] = self.SUPPORTED_CODECS.get(codec, codec)
process = ( process = (
ffmpeg ffmpeg.input("pipe:", format="f32le", **input_kwargs)
.input('pipe:', format='f32le', **input_kwargs)
.output(path, **output_kwargs) .output(path, **output_kwargs)
.overwrite_output() .overwrite_output()
.run_async(pipe_stdin=True, pipe_stderr=True, quiet=True)) .run_async(pipe_stdin=True, pipe_stderr=True, quiet=True)
)
try: try:
process.stdin.write(data.astype('<f4').tobytes()) process.stdin.write(data.astype("<f4").tobytes())
process.stdin.close() process.stdin.close()
process.wait() process.wait()
except IOError: except IOError:
raise SpleeterError(f'FFMPEG error: {process.stderr.read()}') raise SpleeterError(f"FFMPEG error: {process.stderr.read()}")
logger.info(f'File {path} written succesfully') logger.info(f"File {path} written succesfully")

View File

@@ -7,21 +7,22 @@
# pylint: disable=import-error # pylint: disable=import-error
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.signal import hann_window, stft
from tensorflow.signal import stft, hann_window
# pylint: enable=import-error # pylint: enable=import-error
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"
def compute_spectrogram_tf( def compute_spectrogram_tf(
waveform: tf.Tensor, waveform: tf.Tensor,
frame_length: int = 2048, frame_length: int = 2048,
frame_step: int = 512, frame_step: int = 512,
spec_exponent: float = 1., spec_exponent: float = 1.0,
window_exponent: float = 1.) -> tf.Tensor: window_exponent: float = 1.0,
) -> tf.Tensor:
""" """
Compute magnitude / power spectrogram from waveform as a Compute magnitude / power spectrogram from waveform as a
`n_samples x n_channels` tensor. `n_samples x n_channels` tensor.
@@ -51,18 +52,20 @@ def compute_spectrogram_tf(
frame_length, frame_length,
frame_step, frame_step,
window_fn=lambda f, dtype: hann_window( window_fn=lambda f, dtype: hann_window(
f, f, periodic=True, dtype=waveform.dtype
periodic=True, )
dtype=waveform.dtype) ** window_exponent), ** window_exponent,
perm=[1, 2, 0]) ),
perm=[1, 2, 0],
)
return tf.abs(stft_tensor) ** spec_exponent return tf.abs(stft_tensor) ** spec_exponent
def time_stretch( def time_stretch(
spectrogram: tf.Tensor, spectrogram: tf.Tensor,
factor: float = 1.0, factor: float = 1.0,
method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR,
) -> tf.Tensor: ) -> tf.Tensor:
""" """
Time stretch a spectrogram preserving shape in tensorflow. Note that Time stretch a spectrogram preserving shape in tensorflow. Note that
this is an approximation in the frequency domain. this is an approximation in the frequency domain.
@@ -83,18 +86,14 @@ def time_stretch(
T_ts = tf.cast(tf.cast(T, tf.float32) * factor, tf.int32)[0] T_ts = tf.cast(tf.cast(T, tf.float32) * factor, tf.int32)[0]
F = tf.shape(spectrogram)[1] F = tf.shape(spectrogram)[1]
ts_spec = tf.image.resize_images( ts_spec = tf.image.resize_images(
spectrogram, spectrogram, [T_ts, F], method=method, align_corners=True
[T_ts, F], )
method=method,
align_corners=True)
return tf.image.resize_image_with_crop_or_pad(ts_spec, T, F) return tf.image.resize_image_with_crop_or_pad(ts_spec, T, F)
def random_time_stretch( def random_time_stretch(
spectrogram: tf.Tensor, spectrogram: tf.Tensor, factor_min: float = 0.9, factor_max: float = 1.1, **kwargs
factor_min: float = 0.9, ) -> tf.Tensor:
factor_max: float = 1.1,
**kwargs) -> tf.Tensor:
""" """
Time stretch a spectrogram preserving shape with random ratio in Time stretch a spectrogram preserving shape with random ratio in
tensorflow. Applies time_stretch to spectrogram with a random ratio tensorflow. Applies time_stretch to spectrogram with a random ratio
@@ -112,17 +111,17 @@ def random_time_stretch(
tensorflow.Tensor: tensorflow.Tensor:
Randomly time stretched spectrogram as tensor with same shape. Randomly time stretched spectrogram as tensor with same shape.
""" """
factor = tf.random_uniform( factor = (
shape=(1,), tf.random_uniform(shape=(1,), seed=0) * (factor_max - factor_min) + factor_min
seed=0) * (factor_max - factor_min) + factor_min )
return time_stretch(spectrogram, factor=factor, **kwargs) return time_stretch(spectrogram, factor=factor, **kwargs)
def pitch_shift( def pitch_shift(
spectrogram: tf.Tensor, spectrogram: tf.Tensor,
semitone_shift: float = 0.0, semitone_shift: float = 0.0,
method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR,
) -> tf.Tensor: ) -> tf.Tensor:
""" """
Pitch shift a spectrogram preserving shape in tensorflow. Note that Pitch shift a spectrogram preserving shape in tensorflow. Note that
this is an approximation in the frequency domain. this is an approximation in the frequency domain.
@@ -139,24 +138,20 @@ def pitch_shift(
tensorflow.Tensor: tensorflow.Tensor:
Pitch shifted spectrogram (same shape as spectrogram). Pitch shifted spectrogram (same shape as spectrogram).
""" """
factor = 2 ** (semitone_shift / 12.) factor = 2 ** (semitone_shift / 12.0)
T = tf.shape(spectrogram)[0] T = tf.shape(spectrogram)[0]
F = tf.shape(spectrogram)[1] F = tf.shape(spectrogram)[1]
F_ps = tf.cast(tf.cast(F, tf.float32) * factor, tf.int32)[0] F_ps = tf.cast(tf.cast(F, tf.float32) * factor, tf.int32)[0]
ps_spec = tf.image.resize_images( ps_spec = tf.image.resize_images(
spectrogram, spectrogram, [T, F_ps], method=method, align_corners=True
[T, F_ps], )
method=method,
align_corners=True)
paddings = [[0, 0], [0, tf.maximum(0, F - F_ps)], [0, 0]] paddings = [[0, 0], [0, tf.maximum(0, F - F_ps)], [0, 0]]
return tf.pad(ps_spec[:, :F, :], paddings, 'CONSTANT') return tf.pad(ps_spec[:, :F, :], paddings, "CONSTANT")
def random_pitch_shift( def random_pitch_shift(
spectrogram: tf.Tensor, spectrogram: tf.Tensor, shift_min: float = -1.0, shift_max: float = 1.0, **kwargs
shift_min: float = -1., ) -> tf.Tensor:
shift_max: float = 1.,
**kwargs) -> tf.Tensor:
""" """
Pitch shift a spectrogram preserving shape with random ratio in Pitch shift a spectrogram preserving shape with random ratio in
tensorflow. Applies pitch_shift to spectrogram with a random shift tensorflow. Applies pitch_shift to spectrogram with a random shift
@@ -175,7 +170,7 @@ def random_pitch_shift(
tensorflow.Tensor: tensorflow.Tensor:
Randomly pitch shifted spectrogram (same shape as spectrogram). Randomly pitch shifted spectrogram (same shape as spectrogram).
""" """
semitone_shift = tf.random_uniform( semitone_shift = (
shape=(1,), tf.random_uniform(shape=(1,), seed=0) * (shift_max - shift_min) + shift_min
seed=0) * (shift_max - shift_min) + shift_min )
return pitch_shift(spectrogram, semitone_shift=semitone_shift, **kwargs) return pitch_shift(spectrogram, semitone_shift=semitone_shift, **kwargs)

View File

@@ -14,45 +14,52 @@
(ground truth) (ground truth)
""" """
import time
import os import os
import time
from os.path import exists, sep as SEPARATOR from os.path import exists
from os.path import sep as SEPARATOR
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from .audio.adapter import AudioAdapter
from .audio.convertor import db_uint_spectrogram_to_gain
from .audio.convertor import spectrogram_to_db_uint
from .audio.spectrogram import compute_spectrogram_tf
from .audio.spectrogram import random_pitch_shift, random_time_stretch
from .utils.logging import logger
from .utils.tensor import check_tensor_shape, dataset_from_csv
from .utils.tensor import set_tensor_shape, sync_apply
# pyright: reportMissingImports=false # pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import tensorflow as tf import tensorflow as tf
from .audio.adapter import AudioAdapter
from .audio.convertor import db_uint_spectrogram_to_gain, spectrogram_to_db_uint
from .audio.spectrogram import (
compute_spectrogram_tf,
random_pitch_shift,
random_time_stretch,
)
from .utils.logging import logger
from .utils.tensor import (
check_tensor_shape,
dataset_from_csv,
set_tensor_shape,
sync_apply,
)
# pylint: enable=import-error # pylint: enable=import-error
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"
# Default audio parameters to use. # Default audio parameters to use.
DEFAULT_AUDIO_PARAMS: Dict = { DEFAULT_AUDIO_PARAMS: Dict = {
'instrument_list': ('vocals', 'accompaniment'), "instrument_list": ("vocals", "accompaniment"),
'mix_name': 'mix', "mix_name": "mix",
'sample_rate': 44100, "sample_rate": 44100,
'frame_length': 4096, "frame_length": 4096,
'frame_step': 1024, "frame_step": 1024,
'T': 512, "T": 512,
'F': 1024} "F": 1024,
}
def get_training_dataset( def get_training_dataset(
audio_params: Dict, audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str
audio_adapter: AudioAdapter, ) -> Any:
audio_path: str) -> Any:
""" """
Builds training dataset. Builds training dataset.
@@ -72,22 +79,23 @@ def get_training_dataset(
audio_params, audio_params,
audio_adapter, audio_adapter,
audio_path, audio_path,
chunk_duration=audio_params.get('chunk_duration', 20.0), chunk_duration=audio_params.get("chunk_duration", 20.0),
random_seed=audio_params.get('random_seed', 0)) random_seed=audio_params.get("random_seed", 0),
)
return builder.build( return builder.build(
audio_params.get('train_csv'), audio_params.get("train_csv"),
cache_directory=audio_params.get('training_cache'), cache_directory=audio_params.get("training_cache"),
batch_size=audio_params.get('batch_size'), batch_size=audio_params.get("batch_size"),
n_chunks_per_song=audio_params.get('n_chunks_per_song', 2), n_chunks_per_song=audio_params.get("n_chunks_per_song", 2),
random_data_augmentation=False, random_data_augmentation=False,
convert_to_uint=True, convert_to_uint=True,
wait_for_cache=False) wait_for_cache=False,
)
def get_validation_dataset( def get_validation_dataset(
audio_params: Dict, audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str
audio_adapter: AudioAdapter, ) -> Any:
audio_path: str) -> Any:
""" """
Builds validation dataset. Builds validation dataset.
@@ -104,14 +112,12 @@ def get_validation_dataset(
Built dataset. Built dataset.
""" """
builder = DatasetBuilder( builder = DatasetBuilder(
audio_params, audio_params, audio_adapter, audio_path, chunk_duration=12.0
audio_adapter, )
audio_path,
chunk_duration=12.0)
return builder.build( return builder.build(
audio_params.get('validation_csv'), audio_params.get("validation_csv"),
batch_size=audio_params.get('batch_size'), batch_size=audio_params.get("batch_size"),
cache_directory=audio_params.get('validation_cache'), cache_directory=audio_params.get("validation_cache"),
convert_to_uint=True, convert_to_uint=True,
infinite_generator=False, infinite_generator=False,
n_chunks_per_song=1, n_chunks_per_song=1,
@@ -137,82 +143,117 @@ class InstrumentDatasetBuilder(object):
""" """
self._parent = parent self._parent = parent
self._instrument = instrument self._instrument = instrument
self._spectrogram_key = f'{instrument}_spectrogram' self._spectrogram_key = f"{instrument}_spectrogram"
self._min_spectrogram_key = f'min_{instrument}_spectrogram' self._min_spectrogram_key = f"min_{instrument}_spectrogram"
self._max_spectrogram_key = f'max_{instrument}_spectrogram' self._max_spectrogram_key = f"max_{instrument}_spectrogram"
def load_waveform(self, sample): def load_waveform(self, sample):
""" Load waveform for given sample. """ """ Load waveform for given sample. """
return dict(sample, **self._parent._audio_adapter.load_tf_waveform( return dict(
sample[f'{self._instrument}_path'], sample,
offset=sample['start'], **self._parent._audio_adapter.load_tf_waveform(
sample[f"{self._instrument}_path"],
offset=sample["start"],
duration=self._parent._chunk_duration, duration=self._parent._chunk_duration,
sample_rate=self._parent._sample_rate, sample_rate=self._parent._sample_rate,
waveform_name='waveform')) waveform_name="waveform",
),
)
def compute_spectrogram(self, sample): def compute_spectrogram(self, sample):
""" Compute spectrogram of the given sample. """ """ Compute spectrogram of the given sample. """
return dict(sample, **{ return dict(
sample,
**{
self._spectrogram_key: compute_spectrogram_tf( self._spectrogram_key: compute_spectrogram_tf(
sample['waveform'], sample["waveform"],
frame_length=self._parent._frame_length, frame_length=self._parent._frame_length,
frame_step=self._parent._frame_step, frame_step=self._parent._frame_step,
spec_exponent=1., spec_exponent=1.0,
window_exponent=1.)}) window_exponent=1.0,
)
},
)
def filter_frequencies(self, sample): def filter_frequencies(self, sample):
""" """ """ """
return dict(sample, **{ return dict(
self._spectrogram_key: sample,
sample[self._spectrogram_key][:, :self._parent._F, :]}) **{
self._spectrogram_key: sample[self._spectrogram_key][
:, : self._parent._F, :
]
},
)
def convert_to_uint(self, sample): def convert_to_uint(self, sample):
""" Convert given sample from float to unit. """ """ Convert given sample from float to unit. """
return dict(sample, **spectrogram_to_db_uint( return dict(
sample,
**spectrogram_to_db_uint(
sample[self._spectrogram_key], sample[self._spectrogram_key],
tensor_key=self._spectrogram_key, tensor_key=self._spectrogram_key,
min_key=self._min_spectrogram_key, min_key=self._min_spectrogram_key,
max_key=self._max_spectrogram_key)) max_key=self._max_spectrogram_key,
),
)
def filter_infinity(self, sample): def filter_infinity(self, sample):
""" Filter infinity sample. """ """ Filter infinity sample. """
return tf.logical_not( return tf.logical_not(tf.math.is_inf(sample[self._min_spectrogram_key]))
tf.math.is_inf(
sample[self._min_spectrogram_key]))
def convert_to_float32(self, sample): def convert_to_float32(self, sample):
""" Convert given sample from unit to float. """ """ Convert given sample from unit to float. """
return dict(sample, **{ return dict(
sample,
**{
self._spectrogram_key: db_uint_spectrogram_to_gain( self._spectrogram_key: db_uint_spectrogram_to_gain(
sample[self._spectrogram_key], sample[self._spectrogram_key],
sample[self._min_spectrogram_key], sample[self._min_spectrogram_key],
sample[self._max_spectrogram_key])}) sample[self._max_spectrogram_key],
)
},
)
def time_crop(self, sample): def time_crop(self, sample):
""" """ """ """
def start(sample): def start(sample):
""" mid_segment_start """ """ mid_segment_start """
return tf.cast( return tf.cast(
tf.maximum( tf.maximum(
tf.shape(sample[self._spectrogram_key])[0] tf.shape(sample[self._spectrogram_key])[0] / 2
/ 2 - self._parent._T / 2, 0), - self._parent._T / 2,
tf.int32) 0,
return dict(sample, **{ ),
tf.int32,
)
return dict(
sample,
**{
self._spectrogram_key: sample[self._spectrogram_key][ self._spectrogram_key: sample[self._spectrogram_key][
start(sample):start(sample) + self._parent._T, :, :]}) start(sample) : start(sample) + self._parent._T, :, :
]
},
)
def filter_shape(self, sample): def filter_shape(self, sample):
""" Filter badly shaped sample. """ """ Filter badly shaped sample. """
return check_tensor_shape( return check_tensor_shape(
sample[self._spectrogram_key], ( sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2)
self._parent._T, self._parent._F, 2)) )
def reshape_spectrogram(self, sample): def reshape_spectrogram(self, sample):
""" Reshape given sample. """ """ Reshape given sample. """
return dict(sample, **{ return dict(
sample,
**{
self._spectrogram_key: set_tensor_shape( self._spectrogram_key: set_tensor_shape(
sample[self._spectrogram_key], sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2)
(self._parent._T, self._parent._F, 2))}) )
},
)
class DatasetBuilder(object): class DatasetBuilder(object):
@@ -232,7 +273,8 @@ class DatasetBuilder(object):
audio_adapter: AudioAdapter, audio_adapter: AudioAdapter,
audio_path: str, audio_path: str,
random_seed: int = 0, random_seed: int = 0,
chunk_duration: float = 20.0) -> None: chunk_duration: float = 20.0,
) -> None:
""" """
Default constructor. Default constructor.
@@ -249,15 +291,15 @@ class DatasetBuilder(object):
""" """
# Length of segment in frames (if fs=22050 and # Length of segment in frames (if fs=22050 and
# frame_step=512, then T=512 corresponds to 11.89s) # frame_step=512, then T=512 corresponds to 11.89s)
self._T = audio_params['T'] self._T = audio_params["T"]
# Number of frequency bins to be used (should # Number of frequency bins to be used (should
# be less than frame_length/2 + 1) # be less than frame_length/2 + 1)
self._F = audio_params['F'] self._F = audio_params["F"]
self._sample_rate = audio_params['sample_rate'] self._sample_rate = audio_params["sample_rate"]
self._frame_length = audio_params['frame_length'] self._frame_length = audio_params["frame_length"]
self._frame_step = audio_params['frame_step'] self._frame_step = audio_params["frame_step"]
self._mix_name = audio_params['mix_name'] self._mix_name = audio_params["mix_name"]
self._instruments = [self._mix_name] + audio_params['instrument_list'] self._instruments = [self._mix_name] + audio_params["instrument_list"]
self._instrument_builders = None self._instrument_builders = None
self._chunk_duration = chunk_duration self._chunk_duration = chunk_duration
self._audio_adapter = audio_adapter self._audio_adapter = audio_adapter
@@ -267,76 +309,110 @@ class DatasetBuilder(object):
def expand_path(self, sample): def expand_path(self, sample):
""" Expands audio paths for the given sample. """ """ Expands audio paths for the given sample. """
return dict(sample, **{f'{instrument}_path': tf.strings.join( return dict(
(self._audio_path, sample[f'{instrument}_path']), SEPARATOR) sample,
for instrument in self._instruments}) **{
f"{instrument}_path": tf.strings.join(
(self._audio_path, sample[f"{instrument}_path"]), SEPARATOR
)
for instrument in self._instruments
},
)
def filter_error(self, sample): def filter_error(self, sample):
""" Filter errored sample. """ """ Filter errored sample. """
return tf.logical_not(sample['waveform_error']) return tf.logical_not(sample["waveform_error"])
def filter_waveform(self, sample): def filter_waveform(self, sample):
""" Filter waveform from sample. """ """ Filter waveform from sample. """
return {k: v for k, v in sample.items() if not k == 'waveform'} return {k: v for k, v in sample.items() if not k == "waveform"}
def harmonize_spectrogram(self, sample): def harmonize_spectrogram(self, sample):
""" Ensure same size for vocals and mix spectrograms. """ """ Ensure same size for vocals and mix spectrograms. """
def _reduce(sample): def _reduce(sample):
return tf.reduce_min([ return tf.reduce_min(
tf.shape(sample[f'{instrument}_spectrogram'])[0] [
for instrument in self._instruments]) tf.shape(sample[f"{instrument}_spectrogram"])[0]
return dict(sample, **{ for instrument in self._instruments
f'{instrument}_spectrogram': ]
sample[f'{instrument}_spectrogram'][:_reduce(sample), :, :] )
for instrument in self._instruments})
return dict(
sample,
**{
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"][
: _reduce(sample), :, :
]
for instrument in self._instruments
},
)
def filter_short_segments(self, sample): def filter_short_segments(self, sample):
""" Filter out too short segment. """ """ Filter out too short segment. """
return tf.reduce_any([ return tf.reduce_any(
tf.shape(sample[f'{instrument}_spectrogram'])[0] >= self._T [
for instrument in self._instruments]) tf.shape(sample[f"{instrument}_spectrogram"])[0] >= self._T
for instrument in self._instruments
]
)
def random_time_crop(self, sample): def random_time_crop(self, sample):
""" Random time crop of 11.88s. """ """ Random time crop of 11.88s. """
return dict(sample, **sync_apply({ return dict(
f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram'] sample,
for instrument in self._instruments}, **sync_apply(
{
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
for instrument in self._instruments
},
lambda x: tf.image.random_crop( lambda x: tf.image.random_crop(
x, (self._T, len(self._instruments) * self._F, 2), x,
seed=self._random_seed))) (self._T, len(self._instruments) * self._F, 2),
seed=self._random_seed,
),
),
)
def random_time_stretch(self, sample): def random_time_stretch(self, sample):
""" Randomly time stretch the given sample. """ """ Randomly time stretch the given sample. """
return dict(sample, **sync_apply({ return dict(
f'{instrument}_spectrogram': sample,
sample[f'{instrument}_spectrogram'] **sync_apply(
for instrument in self._instruments}, {
lambda x: random_time_stretch( f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
x, factor_min=0.9, factor_max=1.1))) for instrument in self._instruments
},
lambda x: random_time_stretch(x, factor_min=0.9, factor_max=1.1),
),
)
def random_pitch_shift(self, sample): def random_pitch_shift(self, sample):
""" Randomly pitch shift the given sample. """ """ Randomly pitch shift the given sample. """
return dict(sample, **sync_apply({ return dict(
f'{instrument}_spectrogram': sample,
sample[f'{instrument}_spectrogram'] **sync_apply(
for instrument in self._instruments}, {
lambda x: random_pitch_shift( f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
x, shift_min=-1.0, shift_max=1.0), concat_axis=0)) for instrument in self._instruments
},
lambda x: random_pitch_shift(x, shift_min=-1.0, shift_max=1.0),
concat_axis=0,
),
)
def map_features(self, sample): def map_features(self, sample):
""" Select features and annotation of the given sample. """ """ Select features and annotation of the given sample. """
input_ = { input_ = {
f'{self._mix_name}_spectrogram': f"{self._mix_name}_spectrogram": sample[f"{self._mix_name}_spectrogram"]
sample[f'{self._mix_name}_spectrogram']} }
output = { output = {
f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram'] f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
for instrument in self._audio_params['instrument_list']} for instrument in self._audio_params["instrument_list"]
}
return (input_, output) return (input_, output)
def compute_segments( def compute_segments(self, dataset: Any, n_chunks_per_song: int) -> Any:
self,
dataset: Any,
n_chunks_per_song: int) -> Any:
""" """
Computes segments for each song of the dataset. Computes segments for each song of the dataset.
@@ -351,21 +427,39 @@ class DatasetBuilder(object):
Segmented dataset. Segmented dataset.
""" """
if n_chunks_per_song <= 0: if n_chunks_per_song <= 0:
raise ValueError('n_chunks_per_song must be positif') raise ValueError("n_chunks_per_song must be positif")
datasets = [] datasets = []
for k in range(n_chunks_per_song): for k in range(n_chunks_per_song):
if n_chunks_per_song > 1: if n_chunks_per_song > 1:
datasets.append( datasets.append(
dataset.map(lambda sample: dict(sample, start=tf.maximum( dataset.map(
k * ( lambda sample: dict(
sample['duration'] - self._chunk_duration - 2 sample,
* self.MARGIN) / (n_chunks_per_song - 1) start=tf.maximum(
+ self.MARGIN, 0)))) k
* (
sample["duration"]
- self._chunk_duration
- 2 * self.MARGIN
)
/ (n_chunks_per_song - 1)
+ self.MARGIN,
0,
),
)
)
)
elif n_chunks_per_song == 1: # Take central segment. elif n_chunks_per_song == 1: # Take central segment.
datasets.append( datasets.append(
dataset.map(lambda sample: dict(sample, start=tf.maximum( dataset.map(
sample['duration'] / 2 - self._chunk_duration / 2, lambda sample: dict(
0)))) sample,
start=tf.maximum(
sample["duration"] / 2 - self._chunk_duration / 2, 0
),
)
)
)
dataset = datasets[-1] dataset = datasets[-1]
for d in datasets[:-1]: for d in datasets[:-1]:
dataset = dataset.concatenate(d) dataset = dataset.concatenate(d)
@@ -384,15 +478,12 @@ class DatasetBuilder(object):
self._instrument_builders = [] self._instrument_builders = []
for instrument in self._instruments: for instrument in self._instruments:
self._instrument_builders.append( self._instrument_builders.append(
InstrumentDatasetBuilder(self, instrument)) InstrumentDatasetBuilder(self, instrument)
)
for builder in self._instrument_builders: for builder in self._instrument_builders:
yield builder yield builder
def cache( def cache(self, dataset: Any, cache: str, wait: bool) -> Any:
self,
dataset: Any,
cache: str,
wait: bool) -> Any:
""" """
Cache the given dataset if cache is enabled. Eventually waits for Cache the given dataset if cache is enabled. Eventually waits for
cache to be available (useful if another process is already cache to be available (useful if another process is already
@@ -412,9 +503,8 @@ class DatasetBuilder(object):
""" """
if cache is not None: if cache is not None:
if wait: if wait:
while not exists(f'{cache}.index'): while not exists(f"{cache}.index"):
logger.info( logger.info(f"Cache not available, wait {self.WAIT_PERIOD}")
f'Cache not available, wait {self.WAIT_PERIOD}')
time.sleep(self.WAIT_PERIOD) time.sleep(self.WAIT_PERIOD)
cache_path = os.path.split(cache)[0] cache_path = os.path.split(cache)[0]
os.makedirs(cache_path, exist_ok=True) os.makedirs(cache_path, exist_ok=True)
@@ -433,7 +523,8 @@ class DatasetBuilder(object):
cache_directory: Optional[str] = None, cache_directory: Optional[str] = None,
wait_for_cache: bool = False, wait_for_cache: bool = False,
num_parallel_calls: int = 4, num_parallel_calls: int = 4,
n_chunks_per_song: float = 2,) -> Any: n_chunks_per_song: float = 2,
) -> Any:
""" """
TO BE DOCUMENTED. TO BE DOCUMENTED.
""" """
@@ -445,7 +536,8 @@ class DatasetBuilder(object):
buffer_size=200000, buffer_size=200000,
seed=self._random_seed, seed=self._random_seed,
# useless since it is cached : # useless since it is cached :
reshuffle_each_iteration=True) reshuffle_each_iteration=True,
)
# Expand audio path. # Expand audio path.
dataset = dataset.map(self.expand_path) dataset = dataset.map(self.expand_path)
# Load waveform, compute spectrogram, and filtering error, # Load waveform, compute spectrogram, and filtering error,
@@ -453,11 +545,11 @@ class DatasetBuilder(object):
N = num_parallel_calls N = num_parallel_calls
for instrument in self.instruments: for instrument in self.instruments:
dataset = ( dataset = (
dataset dataset.map(instrument.load_waveform, num_parallel_calls=N)
.map(instrument.load_waveform, num_parallel_calls=N)
.filter(self.filter_error) .filter(self.filter_error)
.map(instrument.compute_spectrogram, num_parallel_calls=N) .map(instrument.compute_spectrogram, num_parallel_calls=N)
.map(instrument.filter_frequencies)) .map(instrument.filter_frequencies)
)
dataset = dataset.map(self.filter_waveform) dataset = dataset.map(self.filter_waveform)
# Convert to uint before caching in order to save space. # Convert to uint before caching in order to save space.
if convert_to_uint: if convert_to_uint:
@@ -488,26 +580,25 @@ class DatasetBuilder(object):
# after croping but before converting back to float. # after croping but before converting back to float.
if shuffle: if shuffle:
dataset = dataset.shuffle( dataset = dataset.shuffle(
buffer_size=256, seed=self._random_seed, buffer_size=256, seed=self._random_seed, reshuffle_each_iteration=True
reshuffle_each_iteration=True) )
# Convert back to float32 # Convert back to float32
if convert_to_uint: if convert_to_uint:
for instrument in self.instruments: for instrument in self.instruments:
dataset = dataset.map( dataset = dataset.map(
instrument.convert_to_float32, num_parallel_calls=N) instrument.convert_to_float32, num_parallel_calls=N
)
M = 8 # Parallel call post caching. M = 8 # Parallel call post caching.
# Must be applied with the same factor on mix and vocals. # Must be applied with the same factor on mix and vocals.
if random_data_augmentation: if random_data_augmentation:
dataset = ( dataset = dataset.map(self.random_time_stretch, num_parallel_calls=M).map(
dataset self.random_pitch_shift, num_parallel_calls=M
.map(self.random_time_stretch, num_parallel_calls=M) )
.map(self.random_pitch_shift, num_parallel_calls=M))
# Filter by shape (remove badly shaped tensors). # Filter by shape (remove badly shaped tensors).
for instrument in self.instruments: for instrument in self.instruments:
dataset = ( dataset = dataset.filter(instrument.filter_shape).map(
dataset instrument.reshape_spectrogram
.filter(instrument.filter_shape) )
.map(instrument.reshape_spectrogram))
# Select features and annotation. # Select features and annotation.
dataset = dataset.map(self.map_features) dataset = dataset.map(self.map_features)
# Make batch (done after selection to avoid # Make batch (done after selection to avoid

View File

@@ -8,15 +8,16 @@ import importlib
# pyright: reportMissingImports=false # pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import tensorflow as tf import tensorflow as tf
from tensorflow.signal import hann_window, inverse_stft, stft
from tensorflow.signal import stft, inverse_stft, hann_window
# pylint: enable=import-error
from ..utils.tensor import pad_and_partition, pad_and_reshape from ..utils.tensor import pad_and_partition, pad_and_reshape
__email__ = 'spleeter@deezer.com' # pylint: enable=import-error
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
placeholder = tf.compat.v1.placeholder placeholder = tf.compat.v1.placeholder
@@ -36,17 +37,16 @@ def get_model_function(model_type):
A tensorflow function to be applied to the input tensor to get the A tensorflow function to be applied to the input tensor to get the
multitrack output. multitrack output.
""" """
relative_path_to_module = '.'.join(model_type.split('.')[:-1]) relative_path_to_module = ".".join(model_type.split(".")[:-1])
model_name = model_type.split('.')[-1] model_name = model_type.split(".")[-1]
main_module = '.'.join((__name__, 'functions')) main_module = ".".join((__name__, "functions"))
path_to_module = f'{main_module}.{relative_path_to_module}' path_to_module = f"{main_module}.{relative_path_to_module}"
module = importlib.import_module(path_to_module) module = importlib.import_module(path_to_module)
model_function = getattr(module, model_name) model_function = getattr(module, model_name)
return model_function return model_function
class InputProvider(object): class InputProvider(object):
def __init__(self, params): def __init__(self, params):
self.params = params self.params = params
@@ -62,16 +62,16 @@ class InputProvider(object):
class WaveformInputProvider(InputProvider): class WaveformInputProvider(InputProvider):
@property @property
def input_names(self): def input_names(self):
return ["audio_id", "waveform"] return ["audio_id", "waveform"]
def get_input_dict_placeholders(self): def get_input_dict_placeholders(self):
shape = (None, self.params['n_channels']) shape = (None, self.params["n_channels"])
features = { features = {
'waveform': placeholder(tf.float32, shape=shape, name="waveform"), "waveform": placeholder(tf.float32, shape=shape, name="waveform"),
'audio_id': placeholder(tf.string, name="audio_id")} "audio_id": placeholder(tf.string, name="audio_id"),
}
return features return features
def get_feed_dict(self, features, waveform, audio_id): def get_feed_dict(self, features, waveform, audio_id):
@@ -79,7 +79,6 @@ class WaveformInputProvider(InputProvider):
class SpectralInputProvider(InputProvider): class SpectralInputProvider(InputProvider):
def __init__(self, params): def __init__(self, params):
super().__init__(params) super().__init__(params)
self.stft_input_name = "{}_stft".format(self.params["mix_name"]) self.stft_input_name = "{}_stft".format(self.params["mix_name"])
@@ -90,11 +89,17 @@ class SpectralInputProvider(InputProvider):
def get_input_dict_placeholders(self): def get_input_dict_placeholders(self):
features = { features = {
self.stft_input_name: placeholder(tf.complex64, self.stft_input_name: placeholder(
shape=(None, self.params["frame_length"]//2+1, tf.complex64,
self.params['n_channels']), shape=(
name=self.stft_input_name), None,
'audio_id': placeholder(tf.string, name="audio_id")} self.params["frame_length"] // 2 + 1,
self.params["n_channels"],
),
name=self.stft_input_name,
),
"audio_id": placeholder(tf.string, name="audio_id"),
}
return features return features
def get_feed_dict(self, features, stft, audio_id): def get_feed_dict(self, features, stft, audio_id):
@@ -102,11 +107,13 @@ class SpectralInputProvider(InputProvider):
class InputProviderFactory(object): class InputProviderFactory(object):
@staticmethod @staticmethod
def get(params): def get(params):
stft_backend = params["stft_backend"] stft_backend = params["stft_backend"]
assert stft_backend in ("tensorflow", "librosa"), "Unexpected backend {}".format(stft_backend) assert stft_backend in (
"tensorflow",
"librosa",
), "Unexpected backend {}".format(stft_backend)
if stft_backend == "tensorflow": if stft_backend == "tensorflow":
return WaveformInputProvider(params) return WaveformInputProvider(params)
else: else:
@@ -114,7 +121,7 @@ class InputProviderFactory(object):
class EstimatorSpecBuilder(object): class EstimatorSpecBuilder(object):
""" A builder class that allows to builds a multitrack unet model """A builder class that allows to builds a multitrack unet model
estimator. The built model estimator has a different behaviour when estimator. The built model estimator has a different behaviour when
used in a train/eval mode and in predict mode. used in a train/eval mode and in predict mode.
@@ -138,22 +145,22 @@ class EstimatorSpecBuilder(object):
""" """
# Supported model functions. # Supported model functions.
DEFAULT_MODEL = 'unet.unet' DEFAULT_MODEL = "unet.unet"
# Supported loss functions. # Supported loss functions.
L1_MASK = 'L1_mask' L1_MASK = "L1_mask"
WEIGHTED_L1_MASK = 'weighted_L1_mask' WEIGHTED_L1_MASK = "weighted_L1_mask"
# Supported optimizers. # Supported optimizers.
ADADELTA = 'Adadelta' ADADELTA = "Adadelta"
SGD = 'SGD' SGD = "SGD"
# Math constants. # Math constants.
WINDOW_COMPENSATION_FACTOR = 2./3. WINDOW_COMPENSATION_FACTOR = 2.0 / 3.0
EPSILON = 1e-10 EPSILON = 1e-10
def __init__(self, features, params): def __init__(self, features, params):
""" Default constructor. Depending on built model """Default constructor. Depending on built model
usage, the provided features should be different: usage, the provided features should be different:
* In train/eval mode: features is a dictionary with a * In train/eval mode: features is a dictionary with a
@@ -170,20 +177,20 @@ class EstimatorSpecBuilder(object):
self._features = features self._features = features
self._params = params self._params = params
# Get instrument name. # Get instrument name.
self._mix_name = params['mix_name'] self._mix_name = params["mix_name"]
self._instruments = params['instrument_list'] self._instruments = params["instrument_list"]
# Get STFT/signals parameters # Get STFT/signals parameters
self._n_channels = params['n_channels'] self._n_channels = params["n_channels"]
self._T = params['T'] self._T = params["T"]
self._F = params['F'] self._F = params["F"]
self._frame_length = params['frame_length'] self._frame_length = params["frame_length"]
self._frame_step = params['frame_step'] self._frame_step = params["frame_step"]
def include_stft_computations(self): def include_stft_computations(self):
return self._params["stft_backend"] == "tensorflow" return self._params["stft_backend"] == "tensorflow"
def _build_model_outputs(self): def _build_model_outputs(self):
""" Created a batch_sizexTxFxn_channels input tensor containing """Created a batch_sizexTxFxn_channels input tensor containing
mix magnitude spectrogram, then an output dict from it according mix magnitude spectrogram, then an output dict from it according
to the selected model in internal parameters. to the selected model in internal parameters.
@@ -192,22 +199,21 @@ class EstimatorSpecBuilder(object):
""" """
input_tensor = self.spectrogram_feature input_tensor = self.spectrogram_feature
model = self._params.get('model', None) model = self._params.get("model", None)
if model is not None: if model is not None:
model_type = model.get('type', self.DEFAULT_MODEL) model_type = model.get("type", self.DEFAULT_MODEL)
else: else:
model_type = self.DEFAULT_MODEL model_type = self.DEFAULT_MODEL
try: try:
apply_model = get_model_function(model_type) apply_model = get_model_function(model_type)
except ModuleNotFoundError: except ModuleNotFoundError:
raise ValueError(f'No model function {model_type} found') raise ValueError(f"No model function {model_type} found")
self._model_outputs = apply_model( self._model_outputs = apply_model(
input_tensor, input_tensor, self._instruments, self._params["model"]["params"]
self._instruments, )
self._params['model']['params'])
def _build_loss(self, labels): def _build_loss(self, labels):
""" Construct tensorflow loss and metrics """Construct tensorflow loss and metrics
:param output_dict: dictionary of network outputs (key: instrument :param output_dict: dictionary of network outputs (key: instrument
name, value: estimated spectrogram of the instrument) name, value: estimated spectrogram of the instrument)
@@ -216,7 +222,7 @@ class EstimatorSpecBuilder(object):
:returns: tensorflow (loss, metrics) tuple. :returns: tensorflow (loss, metrics) tuple.
""" """
output_dict = self.model_outputs output_dict = self.model_outputs
loss_type = self._params.get('loss_type', self.L1_MASK) loss_type = self._params.get("loss_type", self.L1_MASK)
if loss_type == self.L1_MASK: if loss_type == self.L1_MASK:
losses = { losses = {
name: tf.reduce_mean(tf.abs(output - labels[name])) name: tf.reduce_mean(tf.abs(output - labels[name]))
@@ -225,11 +231,9 @@ class EstimatorSpecBuilder(object):
elif loss_type == self.WEIGHTED_L1_MASK: elif loss_type == self.WEIGHTED_L1_MASK:
losses = { losses = {
name: tf.reduce_mean( name: tf.reduce_mean(
tf.reduce_mean( tf.reduce_mean(labels[name], axis=[1, 2, 3], keep_dims=True)
labels[name], * tf.abs(output - labels[name])
axis=[1, 2, 3], )
keep_dims=True) *
tf.abs(output - labels[name]))
for name, output in output_dict.items() for name, output in output_dict.items()
} }
else: else:
@@ -237,20 +241,20 @@ class EstimatorSpecBuilder(object):
loss = tf.reduce_sum(list(losses.values())) loss = tf.reduce_sum(list(losses.values()))
# Add metrics for monitoring each instrument. # Add metrics for monitoring each instrument.
metrics = {k: tf.compat.v1.metrics.mean(v) for k, v in losses.items()} metrics = {k: tf.compat.v1.metrics.mean(v) for k, v in losses.items()}
metrics['absolute_difference'] = tf.compat.v1.metrics.mean(loss) metrics["absolute_difference"] = tf.compat.v1.metrics.mean(loss)
return loss, metrics return loss, metrics
def _build_optimizer(self): def _build_optimizer(self):
""" Builds an optimizer instance from internal parameter values. """Builds an optimizer instance from internal parameter values.
Default to AdamOptimizer if not specified. Default to AdamOptimizer if not specified.
:returns: Optimizer instance from internal configuration. :returns: Optimizer instance from internal configuration.
""" """
name = self._params.get('optimizer') name = self._params.get("optimizer")
if name == self.ADADELTA: if name == self.ADADELTA:
return tf.compat.v1.train.AdadeltaOptimizer() return tf.compat.v1.train.AdadeltaOptimizer()
rate = self._params['learning_rate'] rate = self._params["learning_rate"]
if name == self.SGD: if name == self.SGD:
return tf.compat.v1.train.GradientDescentOptimizer(rate) return tf.compat.v1.train.GradientDescentOptimizer(rate)
return tf.compat.v1.train.AdamOptimizer(rate) return tf.compat.v1.train.AdamOptimizer(rate)
@@ -261,14 +265,14 @@ class EstimatorSpecBuilder(object):
@property @property
def stft_name(self): def stft_name(self):
return f'{self._mix_name}_stft' return f"{self._mix_name}_stft"
@property @property
def spectrogram_name(self): def spectrogram_name(self):
return f'{self._mix_name}_spectrogram' return f"{self._mix_name}_spectrogram"
def _build_stft_feature(self): def _build_stft_feature(self):
""" Compute STFT of waveform and slice the STFT in segment """Compute STFT of waveform and slice the STFT in segment
with the right length to feed the network. with the right length to feed the network.
""" """
@@ -277,11 +281,12 @@ class EstimatorSpecBuilder(object):
if stft_name not in self._features: if stft_name not in self._features:
# pad input with a frame of zeros # pad input with a frame of zeros
waveform = tf.concat([ waveform = tf.concat(
[
tf.zeros((self._frame_length, self._n_channels)), tf.zeros((self._frame_length, self._n_channels)),
self._features['waveform'] self._features["waveform"],
], ],
0 0,
) )
stft_feature = tf.transpose( stft_feature = tf.transpose(
stft( stft(
@@ -289,13 +294,17 @@ class EstimatorSpecBuilder(object):
self._frame_length, self._frame_length,
self._frame_step, self._frame_step,
window_fn=lambda frame_length, dtype: ( window_fn=lambda frame_length, dtype: (
hann_window(frame_length, periodic=True, dtype=dtype)), hann_window(frame_length, periodic=True, dtype=dtype)
pad_end=True), ),
perm=[1, 2, 0]) pad_end=True,
self._features[f'{self._mix_name}_stft'] = stft_feature ),
perm=[1, 2, 0],
)
self._features[f"{self._mix_name}_stft"] = stft_feature
if spec_name not in self._features: if spec_name not in self._features:
self._features[spec_name] = tf.abs( self._features[spec_name] = tf.abs(
pad_and_partition(self._features[stft_name], self._T))[:, :, :self._F, :] pad_and_partition(self._features[stft_name], self._T)
)[:, :, : self._F, :]
@property @property
def model_outputs(self): def model_outputs(self):
@@ -334,25 +343,29 @@ class EstimatorSpecBuilder(object):
return self._masked_stfts return self._masked_stfts
def _inverse_stft(self, stft_t, time_crop=None): def _inverse_stft(self, stft_t, time_crop=None):
""" Inverse and reshape the given STFT """Inverse and reshape the given STFT
:param stft_t: input STFT :param stft_t: input STFT
:returns: inverse STFT (waveform) :returns: inverse STFT (waveform)
""" """
inversed = inverse_stft( inversed = (
inverse_stft(
tf.transpose(stft_t, perm=[2, 0, 1]), tf.transpose(stft_t, perm=[2, 0, 1]),
self._frame_length, self._frame_length,
self._frame_step, self._frame_step,
window_fn=lambda frame_length, dtype: ( window_fn=lambda frame_length, dtype: (
hann_window(frame_length, periodic=True, dtype=dtype)) hann_window(frame_length, periodic=True, dtype=dtype)
) * self.WINDOW_COMPENSATION_FACTOR ),
)
* self.WINDOW_COMPENSATION_FACTOR
)
reshaped = tf.transpose(inversed) reshaped = tf.transpose(inversed)
if time_crop is None: if time_crop is None:
time_crop = tf.shape(self._features['waveform'])[0] time_crop = tf.shape(self._features["waveform"])[0]
return reshaped[self._frame_length:self._frame_length+time_crop, :] return reshaped[self._frame_length : self._frame_length + time_crop, :]
def _build_mwf_output_waveform(self): def _build_mwf_output_waveform(self):
""" Perform separation with multichannel Wiener Filtering using Norbert. """Perform separation with multichannel Wiener Filtering using Norbert.
Note: multichannel Wiener Filtering is not coded in Tensorflow and thus Note: multichannel Wiener Filtering is not coded in Tensorflow and thus
may be quite slow. may be quite slow.
@@ -360,36 +373,42 @@ class EstimatorSpecBuilder(object):
value: estimated waveform of the instrument) value: estimated waveform of the instrument)
""" """
import norbert # pylint: disable=import-error import norbert # pylint: disable=import-error
output_dict = self.model_outputs output_dict = self.model_outputs
x = self.stft_feature x = self.stft_feature
v = tf.stack( v = tf.stack(
[ [
pad_and_reshape( pad_and_reshape(
output_dict[f'{instrument}_spectrogram'], output_dict[f"{instrument}_spectrogram"],
self._frame_length, self._frame_length,
self._F)[:tf.shape(x)[0], ...] self._F,
)[: tf.shape(x)[0], ...]
for instrument in self._instruments for instrument in self._instruments
], ],
axis=3) axis=3,
)
input_args = [v, x] input_args = [v, x]
stft_function = tf.py_function( stft_function = (
tf.py_function(
lambda v, x: norbert.wiener(v.numpy(), x.numpy()), lambda v, x: norbert.wiener(v.numpy(), x.numpy()),
input_args, input_args,
tf.complex64), tf.complex64,
),
)
return { return {
instrument: self._inverse_stft(stft_function[0][:, :, :, k]) instrument: self._inverse_stft(stft_function[0][:, :, :, k])
for k, instrument in enumerate(self._instruments) for k, instrument in enumerate(self._instruments)
} }
def _extend_mask(self, mask): def _extend_mask(self, mask):
""" Extend mask, from reduced number of frequency bin to the number of """Extend mask, from reduced number of frequency bin to the number of
frequency bin in the STFT. frequency bin in the STFT.
:param mask: restricted mask :param mask: restricted mask
:returns: extended mask :returns: extended mask
:raise ValueError: If invalid mask_extension parameter is set. :raise ValueError: If invalid mask_extension parameter is set.
""" """
extension = self._params['mask_extension'] extension = self._params["mask_extension"]
# Extend with average # Extend with average
# (dispatch according to energy in the processed band) # (dispatch according to energy in the processed band)
if extension == "average": if extension == "average":
@@ -398,13 +417,9 @@ class EstimatorSpecBuilder(object):
# (avoid extension artifacts but not conservative separation) # (avoid extension artifacts but not conservative separation)
elif extension == "zeros": elif extension == "zeros":
mask_shape = tf.shape(mask) mask_shape = tf.shape(mask)
extension_row = tf.zeros(( extension_row = tf.zeros((mask_shape[0], mask_shape[1], 1, mask_shape[-1]))
mask_shape[0],
mask_shape[1],
1,
mask_shape[-1]))
else: else:
raise ValueError(f'Invalid mask_extension parameter {extension}') raise ValueError(f"Invalid mask_extension parameter {extension}")
n_extra_row = self._frame_length // 2 + 1 - self._F n_extra_row = self._frame_length // 2 + 1 - self._F
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1]) extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
return tf.concat([mask, extension], axis=2) return tf.concat([mask, extension], axis=2)
@@ -416,29 +431,31 @@ class EstimatorSpecBuilder(object):
""" """
output_dict = self.model_outputs output_dict = self.model_outputs
stft_feature = self.stft_feature stft_feature = self.stft_feature
separation_exponent = self._params['separation_exponent'] separation_exponent = self._params["separation_exponent"]
output_sum = tf.reduce_sum( output_sum = (
[e ** separation_exponent for e in output_dict.values()], tf.reduce_sum(
axis=0 [e ** separation_exponent for e in output_dict.values()], axis=0
) + self.EPSILON )
+ self.EPSILON
)
out = {} out = {}
for instrument in self._instruments: for instrument in self._instruments:
output = output_dict[f'{instrument}_spectrogram'] output = output_dict[f"{instrument}_spectrogram"]
# Compute mask with the model. # Compute mask with the model.
instrument_mask = (output ** separation_exponent instrument_mask = (
+ (self.EPSILON / len(output_dict))) / output_sum output ** separation_exponent + (self.EPSILON / len(output_dict))
) / output_sum
# Extend mask; # Extend mask;
instrument_mask = self._extend_mask(instrument_mask) instrument_mask = self._extend_mask(instrument_mask)
# Stack back mask. # Stack back mask.
old_shape = tf.shape(instrument_mask) old_shape = tf.shape(instrument_mask)
new_shape = tf.concat( new_shape = tf.concat(
[[old_shape[0] * old_shape[1]], old_shape[2:]], [[old_shape[0] * old_shape[1]], old_shape[2:]], axis=0
axis=0) )
instrument_mask = tf.reshape(instrument_mask, new_shape) instrument_mask = tf.reshape(instrument_mask, new_shape)
# Remove padded part (for mask having the same size as STFT); # Remove padded part (for mask having the same size as STFT);
instrument_mask = instrument_mask[ instrument_mask = instrument_mask[: tf.shape(stft_feature)[0], ...]
:tf.shape(stft_feature)[0], ...]
out[instrument] = instrument_mask out[instrument] = instrument_mask
self._masks = out self._masks = out
@@ -450,7 +467,7 @@ class EstimatorSpecBuilder(object):
self._masked_stfts = out self._masked_stfts = out
def _build_manual_output_waveform(self, masked_stft): def _build_manual_output_waveform(self, masked_stft):
""" Perform ratio mask separation """Perform ratio mask separation
:param output_dict: dictionary of estimated spectrogram (key: instrument :param output_dict: dictionary of estimated spectrogram (key: instrument
name, value: estimated spectrogram of the instrument) name, value: estimated spectrogram of the instrument)
@@ -464,14 +481,14 @@ class EstimatorSpecBuilder(object):
return output_waveform return output_waveform
def _build_output_waveform(self, masked_stft): def _build_output_waveform(self, masked_stft):
""" Build output waveform from given output dict in order to be used in """Build output waveform from given output dict in order to be used in
prediction context. Regarding of the configuration building method will prediction context. Regarding of the configuration building method will
be using MWF. be using MWF.
:returns: Built output waveform. :returns: Built output waveform.
""" """
if self._params.get('MWF', False): if self._params.get("MWF", False):
output_waveform = self._build_mwf_output_waveform() output_waveform = self._build_mwf_output_waveform()
else: else:
output_waveform = self._build_manual_output_waveform(masked_stft) output_waveform = self._build_manual_output_waveform(masked_stft)
@@ -483,11 +500,11 @@ class EstimatorSpecBuilder(object):
else: else:
self._outputs = self.masked_stfts self._outputs = self.masked_stfts
if 'audio_id' in self._features: if "audio_id" in self._features:
self._outputs['audio_id'] = self._features['audio_id'] self._outputs["audio_id"] = self._features["audio_id"]
def build_predict_model(self): def build_predict_model(self):
""" Builder interface for creating model instance that aims to perform """Builder interface for creating model instance that aims to perform
prediction / inference over given track. The output of such estimator prediction / inference over given track. The output of such estimator
will be a dictionary with a "<instrument>" key per separated instrument will be a dictionary with a "<instrument>" key per separated instrument
, associated to the estimated separated waveform of the instrument. , associated to the estimated separated waveform of the instrument.
@@ -496,11 +513,11 @@ class EstimatorSpecBuilder(object):
""" """
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.PREDICT, tf.estimator.ModeKeys.PREDICT, predictions=self.outputs
predictions=self.outputs) )
def build_evaluation_model(self, labels): def build_evaluation_model(self, labels):
""" Builder interface for creating model instance that aims to perform """Builder interface for creating model instance that aims to perform
model evaluation. The output of such estimator will be a dictionary model evaluation. The output of such estimator will be a dictionary
with a key "<instrument>_spectrogram" per separated instrument, with a key "<instrument>_spectrogram" per separated instrument,
associated to the estimated separated instrument magnitude spectrogram. associated to the estimated separated instrument magnitude spectrogram.
@@ -510,12 +527,11 @@ class EstimatorSpecBuilder(object):
""" """
loss, metrics = self._build_loss(labels) loss, metrics = self._build_loss(labels)
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.EVAL, loss=loss, eval_metric_ops=metrics
loss=loss, )
eval_metric_ops=metrics)
def build_train_model(self, labels): def build_train_model(self, labels):
""" Builder interface for creating model instance that aims to perform """Builder interface for creating model instance that aims to perform
model training. The output of such estimator will be a dictionary model training. The output of such estimator will be a dictionary
with a key "<instrument>_spectrogram" per separated instrument, with a key "<instrument>_spectrogram" per separated instrument,
associated to the estimated separated instrument magnitude spectrogram. associated to the estimated separated instrument magnitude spectrogram.
@@ -526,8 +542,8 @@ class EstimatorSpecBuilder(object):
loss, metrics = self._build_loss(labels) loss, metrics = self._build_loss(labels)
optimizer = self._build_optimizer() optimizer = self._build_optimizer()
train_operation = optimizer.minimize( train_operation = optimizer.minimize(
loss=loss, loss=loss, global_step=tf.compat.v1.train.get_global_step()
global_step=tf.compat.v1.train.get_global_step()) )
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN, mode=tf.estimator.ModeKeys.TRAIN,
loss=loss, loss=loss,
@@ -554,4 +570,4 @@ def model_fn(features, labels, mode, params, config):
return builder.build_evaluation_model(labels) return builder.build_evaluation_model(labels)
elif mode == tf.estimator.ModeKeys.TRAIN: elif mode == tf.estimator.ModeKeys.TRAIN:
return builder.build_train_model(labels) return builder.build_train_model(labels)
raise ValueError(f'Unknown mode {mode}') raise ValueError(f"Unknown mode {mode}")

View File

@@ -8,18 +8,20 @@ from typing import Callable, Dict, Iterable, Optional
# pyright: reportMissingImports=false # pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import tensorflow as tf import tensorflow as tf
# pylint: enable=import-error # pylint: enable=import-error
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"
def apply( def apply(
function: Callable, function: Callable,
input_tensor: tf.Tensor, input_tensor: tf.Tensor,
instruments: Iterable[str], instruments: Iterable[str],
params: Optional[Dict] = None) -> Dict: params: Optional[Dict] = None,
) -> Dict:
""" """
Apply given function to the input tensor. Apply given function to the input tensor.
@@ -38,9 +40,8 @@ def apply(
""" """
output_dict: Dict = {} output_dict: Dict = {}
for instrument in instruments: for instrument in instruments:
out_name = f'{instrument}_spectrogram' out_name = f"{instrument}_spectrogram"
output_dict[out_name] = function( output_dict[out_name] = function(
input_tensor, input_tensor, output_name=out_name, params=params or {}
output_name=out_name, )
params=params or {})
return output_dict return output_dict

View File

@@ -22,12 +22,9 @@
from typing import Dict, Optional from typing import Dict, Optional
from . import apply
# pyright: reportMissingImports=false # pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import tensorflow as tf import tensorflow as tf
from tensorflow.compat.v1.keras.initializers import he_uniform from tensorflow.compat.v1.keras.initializers import he_uniform
from tensorflow.compat.v1.keras.layers import CuDNNLSTM from tensorflow.compat.v1.keras.layers import CuDNNLSTM
from tensorflow.keras.layers import ( from tensorflow.keras.layers import (
@@ -35,18 +32,21 @@ from tensorflow.keras.layers import (
Dense, Dense,
Flatten, Flatten,
Reshape, Reshape,
TimeDistributed) TimeDistributed,
)
from . import apply
# pylint: enable=import-error # pylint: enable=import-error
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"
def apply_blstm( def apply_blstm(
input_tensor: tf.Tensor, input_tensor: tf.Tensor, output_name: str = "output", params: Optional[Dict] = None
output_name: str = 'output', ) -> tf.Tensor:
params: Optional[Dict] = None) -> tf.Tensor:
""" """
Apply BLSTM to the given input_tensor. Apply BLSTM to the given input_tensor.
@@ -64,16 +64,16 @@ def apply_blstm(
""" """
if params is None: if params is None:
params = {} params = {}
units: int = params.get('lstm_units', 250) units: int = params.get("lstm_units", 250)
kernel_initializer = he_uniform(seed=50) kernel_initializer = he_uniform(seed=50)
flatten_input = TimeDistributed(Flatten())((input_tensor)) flatten_input = TimeDistributed(Flatten())((input_tensor))
def create_bidirectional(): def create_bidirectional():
return Bidirectional( return Bidirectional(
CuDNNLSTM( CuDNNLSTM(
units, units, kernel_initializer=kernel_initializer, return_sequences=True
kernel_initializer=kernel_initializer, )
return_sequences=True)) )
l1 = create_bidirectional()((flatten_input)) l1 = create_bidirectional()((flatten_input))
l2 = create_bidirectional()((l1)) l2 = create_bidirectional()((l1))
@@ -81,17 +81,18 @@ def apply_blstm(
dense = TimeDistributed( dense = TimeDistributed(
Dense( Dense(
int(flatten_input.shape[2]), int(flatten_input.shape[2]),
activation='relu', activation="relu",
kernel_initializer=kernel_initializer))((l3)) kernel_initializer=kernel_initializer,
)
)((l3))
output: tf.Tensor = TimeDistributed( output: tf.Tensor = TimeDistributed(
Reshape(input_tensor.shape[2:]), Reshape(input_tensor.shape[2:]), name=output_name
name=output_name)(dense) )(dense)
return output return output
def blstm( def blstm(
input_tensor: tf.Tensor, input_tensor: tf.Tensor, output_name: str = "output", params: Optional[Dict] = None
output_name: str = 'output', ) -> tf.Tensor:
params: Optional[Dict] = None) -> tf.Tensor:
""" Model function applier. """ """ Model function applier. """
return apply(apply_blstm, input_tensor, output_name, params) return apply(apply_blstm, input_tensor, output_name, params)

View File

@@ -16,30 +16,31 @@
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, Optional from typing import Any, Dict, Iterable, Optional
from . import apply
# pyright: reportMissingImports=false # pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import tensorflow as tf import tensorflow as tf
from tensorflow.compat.v1 import logging
from tensorflow.compat.v1.keras.initializers import he_uniform
from tensorflow.keras.layers import ( from tensorflow.keras.layers import (
ELU,
BatchNormalization, BatchNormalization,
Concatenate, Concatenate,
Conv2D, Conv2D,
Conv2DTranspose, Conv2DTranspose,
Dropout, Dropout,
ELU,
LeakyReLU, LeakyReLU,
Multiply, Multiply,
ReLU, ReLU,
Softmax) Softmax,
from tensorflow.compat.v1 import logging )
from tensorflow.compat.v1.keras.initializers import he_uniform
from . import apply
# pylint: enable=import-error # pylint: enable=import-error
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"
def _get_conv_activation_layer(params: Dict) -> Any: def _get_conv_activation_layer(params: Dict) -> Any:
@@ -53,10 +54,10 @@ def _get_conv_activation_layer(params: Dict) -> Any:
Any: Any:
Required Activation function. Required Activation function.
""" """
conv_activation: str = params.get('conv_activation') conv_activation: str = params.get("conv_activation")
if conv_activation == 'ReLU': if conv_activation == "ReLU":
return ReLU() return ReLU()
elif conv_activation == 'ELU': elif conv_activation == "ELU":
return ELU() return ELU()
return LeakyReLU(0.2) return LeakyReLU(0.2)
@@ -72,19 +73,20 @@ def _get_deconv_activation_layer(params: Dict) -> Any:
Any: Any:
Required Activation function. Required Activation function.
""" """
deconv_activation: str = params.get('deconv_activation') deconv_activation: str = params.get("deconv_activation")
if deconv_activation == 'LeakyReLU': if deconv_activation == "LeakyReLU":
return LeakyReLU(0.2) return LeakyReLU(0.2)
elif deconv_activation == 'ELU': elif deconv_activation == "ELU":
return ELU() return ELU()
return ReLU() return ReLU()
def apply_unet( def apply_unet(
input_tensor: tf.Tensor, input_tensor: tf.Tensor,
output_name: str = 'output', output_name: str = "output",
params: Optional[Dict] = None, params: Optional[Dict] = None,
output_mask_logit: bool = False) -> Any: output_mask_logit: bool = False,
) -> Any:
""" """
Apply a convolutionnal U-net to model a single instrument (one U-net Apply a convolutionnal U-net to model a single instrument (one U-net
is used for each instrument). is used for each instrument).
@@ -95,16 +97,14 @@ def apply_unet(
params (Optional[Dict]): params (Optional[Dict]):
output_mask_logit (bool): output_mask_logit (bool):
""" """
logging.info(f'Apply unet for {output_name}') logging.info(f"Apply unet for {output_name}")
conv_n_filters = params.get('conv_n_filters', [16, 32, 64, 128, 256, 512]) conv_n_filters = params.get("conv_n_filters", [16, 32, 64, 128, 256, 512])
conv_activation_layer = _get_conv_activation_layer(params) conv_activation_layer = _get_conv_activation_layer(params)
deconv_activation_layer = _get_deconv_activation_layer(params) deconv_activation_layer = _get_deconv_activation_layer(params)
kernel_initializer = he_uniform(seed=50) kernel_initializer = he_uniform(seed=50)
conv2d_factory = partial( conv2d_factory = partial(
Conv2D, Conv2D, strides=(2, 2), padding="same", kernel_initializer=kernel_initializer
strides=(2, 2), )
padding='same',
kernel_initializer=kernel_initializer)
# First layer. # First layer.
conv1 = conv2d_factory(conv_n_filters[0], (5, 5))(input_tensor) conv1 = conv2d_factory(conv_n_filters[0], (5, 5))(input_tensor)
batch1 = BatchNormalization(axis=-1)(conv1) batch1 = BatchNormalization(axis=-1)(conv1)
@@ -134,8 +134,9 @@ def apply_unet(
conv2d_transpose_factory = partial( conv2d_transpose_factory = partial(
Conv2DTranspose, Conv2DTranspose,
strides=(2, 2), strides=(2, 2),
padding='same', padding="same",
kernel_initializer=kernel_initializer) kernel_initializer=kernel_initializer,
)
# #
up1 = conv2d_transpose_factory(conv_n_filters[4], (5, 5))((conv6)) up1 = conv2d_transpose_factory(conv_n_filters[4], (5, 5))((conv6))
up1 = deconv_activation_layer(up1) up1 = deconv_activation_layer(up1)
@@ -174,31 +175,31 @@ def apply_unet(
2, 2,
(4, 4), (4, 4),
dilation_rate=(2, 2), dilation_rate=(2, 2),
activation='sigmoid', activation="sigmoid",
padding='same', padding="same",
kernel_initializer=kernel_initializer)((batch12)) kernel_initializer=kernel_initializer,
)((batch12))
output = Multiply(name=output_name)([up7, input_tensor]) output = Multiply(name=output_name)([up7, input_tensor])
return output return output
return Conv2D( return Conv2D(
2, 2,
(4, 4), (4, 4),
dilation_rate=(2, 2), dilation_rate=(2, 2),
padding='same', padding="same",
kernel_initializer=kernel_initializer)((batch12)) kernel_initializer=kernel_initializer,
)((batch12))
def unet( def unet(
input_tensor: tf.Tensor, input_tensor: tf.Tensor, instruments: Iterable[str], params: Optional[Dict] = None
instruments: Iterable[str], ) -> Dict:
params: Optional[Dict] = None) -> Dict:
""" Model function applier. """ """ Model function applier. """
return apply(apply_unet, input_tensor, instruments, params) return apply(apply_unet, input_tensor, instruments, params)
def softmax_unet( def softmax_unet(
input_tensor: tf.Tensor, input_tensor: tf.Tensor, instruments: Iterable[str], params: Optional[Dict] = None
instruments: Iterable[str], ) -> Dict:
params: Optional[Dict] = None) -> Dict:
""" """
Apply softmax to multitrack unet in order to have mask suming to one. Apply softmax to multitrack unet in order to have mask suming to one.
@@ -216,18 +217,18 @@ def softmax_unet(
""" """
logit_mask_list = [] logit_mask_list = []
for instrument in instruments: for instrument in instruments:
out_name = f'{instrument}_spectrogram' out_name = f"{instrument}_spectrogram"
logit_mask_list.append( logit_mask_list.append(
apply_unet( apply_unet(
input_tensor, input_tensor,
output_name=out_name, output_name=out_name,
params=params, params=params,
output_mask_logit=True)) output_mask_logit=True,
)
)
masks = Softmax(axis=4)(tf.stack(logit_mask_list, axis=4)) masks = Softmax(axis=4)(tf.stack(logit_mask_list, axis=4))
output_dict = {} output_dict = {}
for i, instrument in enumerate(instruments): for i, instrument in enumerate(instruments):
out_name = f'{instrument}_spectrogram' out_name = f"{instrument}_spectrogram"
output_dict[out_name] = Multiply(name=out_name)([ output_dict[out_name] = Multiply(name=out_name)([masks[..., i], input_tensor])
masks[..., i],
input_tensor])
return output_dict return output_dict

View File

@@ -17,9 +17,9 @@ from abc import ABC, abstractmethod
from os import environ, makedirs from os import environ, makedirs
from os.path import exists, isabs, join, sep from os.path import exists, isabs, join, sep
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"
class ModelProvider(ABC): class ModelProvider(ABC):
@@ -28,8 +28,8 @@ class ModelProvider(ABC):
file download is not available. file download is not available.
""" """
DEFAULT_MODEL_PATH: str = environ.get('MODEL_PATH', 'pretrained_models') DEFAULT_MODEL_PATH: str = environ.get("MODEL_PATH", "pretrained_models")
MODEL_PROBE_PATH: str = '.probe' MODEL_PROBE_PATH: str = ".probe"
@abstractmethod @abstractmethod
def download(_, name: str, path: str) -> None: def download(_, name: str, path: str) -> None:
@@ -54,8 +54,8 @@ class ModelProvider(ABC):
Directory to write probe into. Directory to write probe into.
""" """
probe: str = join(directory, ModelProvider.MODEL_PROBE_PATH) probe: str = join(directory, ModelProvider.MODEL_PROBE_PATH)
with open(probe, 'w') as stream: with open(probe, "w") as stream:
stream.write('OK') stream.write("OK")
def get(self, model_directory: str) -> str: def get(self, model_directory: str) -> str:
""" """
@@ -77,14 +77,12 @@ class ModelProvider(ABC):
if not exists(model_probe): if not exists(model_probe):
if not exists(model_directory): if not exists(model_directory):
makedirs(model_directory) makedirs(model_directory)
self.download( self.download(model_directory.split(sep)[-1], model_directory)
model_directory.split(sep)[-1],
model_directory)
self.writeProbe(model_directory) self.writeProbe(model_directory)
return model_directory return model_directory
@classmethod @classmethod
def default(_: type) -> 'ModelProvider': def default(_: type) -> "ModelProvider":
""" """
Builds and returns a default model provider. Builds and returns a default model provider.
@@ -93,4 +91,5 @@ class ModelProvider(ABC):
A default model provider instance to use. A default model provider instance to use.
""" """
from .github import GithubModelProvider from .github import GithubModelProvider
return GithubModelProvider.from_environ() return GithubModelProvider.from_environ()

View File

@@ -17,35 +17,35 @@
""" """
import hashlib import hashlib
import tarfile
import os import os
import tarfile
from os import environ from os import environ
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Dict from typing import Dict
from . import ModelProvider
from ...utils.logging import logger
# pyright: reportMissingImports=false # pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import httpx import httpx
from ...utils.logging import logger
from . import ModelProvider
# pylint: enable=import-error # pylint: enable=import-error
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"
def compute_file_checksum(path): def compute_file_checksum(path):
""" Computes given path file sha256. """Computes given path file sha256.
:param path: Path of the file to compute checksum for. :param path: Path of the file to compute checksum for.
:returns: File checksum. :returns: File checksum.
""" """
sha256 = hashlib.sha256() sha256 = hashlib.sha256()
with open(path, 'rb') as stream: with open(path, "rb") as stream:
for chunk in iter(lambda: stream.read(4096), b''): for chunk in iter(lambda: stream.read(4096), b""):
sha256.update(chunk) sha256.update(chunk)
return sha256.hexdigest() return sha256.hexdigest()
@@ -53,19 +53,15 @@ def compute_file_checksum(path):
class GithubModelProvider(ModelProvider): class GithubModelProvider(ModelProvider):
""" A ModelProvider implementation backed on Github for remote storage. """ """ A ModelProvider implementation backed on Github for remote storage. """
DEFAULT_HOST: str = 'https://github.com' DEFAULT_HOST: str = "https://github.com"
DEFAULT_REPOSITORY: str = 'deezer/spleeter' DEFAULT_REPOSITORY: str = "deezer/spleeter"
CHECKSUM_INDEX: str = 'checksum.json' CHECKSUM_INDEX: str = "checksum.json"
LATEST_RELEASE: str = 'v1.4.0' LATEST_RELEASE: str = "v1.4.0"
RELEASE_PATH: str = 'releases/download' RELEASE_PATH: str = "releases/download"
def __init__( def __init__(self, host: str, repository: str, release: str) -> None:
self, """Default constructor.
host: str,
repository: str,
release: str) -> None:
""" Default constructor.
Parameters: Parameters:
host (str): host (str):
@@ -80,7 +76,7 @@ class GithubModelProvider(ModelProvider):
self._release: str = release self._release: str = release
@classmethod @classmethod
def from_environ(cls: type) -> 'GithubModelProvider': def from_environ(cls: type) -> "GithubModelProvider":
""" """
Factory method that creates provider from envvars. Factory method that creates provider from envvars.
@@ -89,9 +85,10 @@ class GithubModelProvider(ModelProvider):
Created instance. Created instance.
""" """
return cls( return cls(
environ.get('GITHUB_HOST', cls.DEFAULT_HOST), environ.get("GITHUB_HOST", cls.DEFAULT_HOST),
environ.get('GITHUB_REPOSITORY', cls.DEFAULT_REPOSITORY), environ.get("GITHUB_REPOSITORY", cls.DEFAULT_REPOSITORY),
environ.get('GITHUB_RELEASE', cls.LATEST_RELEASE)) environ.get("GITHUB_RELEASE", cls.LATEST_RELEASE),
)
def checksum(self, name: str) -> str: def checksum(self, name: str) -> str:
""" """
@@ -108,17 +105,20 @@ class GithubModelProvider(ModelProvider):
ValueError: ValueError:
If the given model name is not indexed. If the given model name is not indexed.
""" """
url: str = '/'.join(( url: str = "/".join(
(
self._host, self._host,
self._repository, self._repository,
self.RELEASE_PATH, self.RELEASE_PATH,
self._release, self._release,
self.CHECKSUM_INDEX)) self.CHECKSUM_INDEX,
)
)
response: httpx.Response = httpx.get(url) response: httpx.Response = httpx.get(url)
response.raise_for_status() response.raise_for_status()
index: Dict = response.json() index: Dict = response.json()
if name not in index: if name not in index:
raise ValueError(f'No checksum for model {name}') raise ValueError(f"No checksum for model {name}")
return index[name] return index[name]
def download(self, name: str, path: str) -> None: def download(self, name: str, path: str) -> None:
@@ -131,30 +131,26 @@ class GithubModelProvider(ModelProvider):
path (str): path (str):
Path of the directory to save model into. Path of the directory to save model into.
""" """
url: str = '/'.join(( url: str = "/".join(
self._host, (self._host, self._repository, self.RELEASE_PATH, self._release, name)
self._repository, )
self.RELEASE_PATH, url = f"{url}.tar.gz"
self._release, logger.info(f"Downloading model archive {url}")
name))
url = f'{url}.tar.gz'
logger.info(f'Downloading model archive {url}')
with httpx.Client(http2=True) as client: with httpx.Client(http2=True) as client:
with client.stream('GET', url) as response: with client.stream("GET", url) as response:
response.raise_for_status() response.raise_for_status()
archive = NamedTemporaryFile(delete=False) archive = NamedTemporaryFile(delete=False)
try: try:
with archive as stream: with archive as stream:
for chunk in response.iter_raw(): for chunk in response.iter_raw():
stream.write(chunk) stream.write(chunk)
logger.info('Validating archive checksum') logger.info("Validating archive checksum")
checksum: str = compute_file_checksum(archive.name) checksum: str = compute_file_checksum(archive.name)
if checksum != self.checksum(name): if checksum != self.checksum(name):
raise IOError( raise IOError("Downloaded file is corrupted, please retry")
'Downloaded file is corrupted, please retry') logger.info(f"Extracting downloaded {name} archive")
logger.info(f'Extracting downloaded {name} archive')
with tarfile.open(name=archive.name) as tar: with tarfile.open(name=archive.name) as tar:
tar.extractall(path=path) tar.extractall(path=path)
finally: finally:
os.unlink(archive.name) os.unlink(archive.name)
logger.info(f'{name} model file(s) extracted') logger.info(f"{name} model file(s) extracted")

View File

@@ -3,126 +3,126 @@
""" This modules provides spleeter command as well as CLI parsing methods. """ """ This modules provides spleeter command as well as CLI parsing methods. """
from tempfile import gettempdir
from os.path import join from os.path import join
from tempfile import gettempdir
from .audio import Codec, STFTBackend
from typer import Argument, Option from typer import Argument, Option
from typer.models import ArgumentInfo, OptionInfo from typer.models import ArgumentInfo, OptionInfo
__email__ = 'spleeter@deezer.com' from .audio import Codec, STFTBackend
__author__ = 'Deezer Research'
__license__ = 'MIT License' __email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
AudioInputArgument: ArgumentInfo = Argument( AudioInputArgument: ArgumentInfo = Argument(
..., ...,
help='List of input audio file path', help="List of input audio file path",
exists=True, exists=True,
file_okay=True, file_okay=True,
dir_okay=False, dir_okay=False,
readable=True, readable=True,
resolve_path=True) resolve_path=True,
)
AudioInputOption: OptionInfo = Option( AudioInputOption: OptionInfo = Option(
None, None, "--inputs", "-i", help="(DEPRECATED) placeholder for deprecated input option"
'--inputs', )
'-i',
help='(DEPRECATED) placeholder for deprecated input option')
AudioAdapterOption: OptionInfo = Option( AudioAdapterOption: OptionInfo = Option(
'spleeter.audio.ffmpeg.FFMPEGProcessAudioAdapter', "spleeter.audio.ffmpeg.FFMPEGProcessAudioAdapter",
'--adapter', "--adapter",
'-a', "-a",
help='Name of the audio adapter to use for audio I/O') help="Name of the audio adapter to use for audio I/O",
)
AudioOutputOption: OptionInfo = Option( AudioOutputOption: OptionInfo = Option(
join(gettempdir(), 'separated_audio'), join(gettempdir(), "separated_audio"),
'--output_path', "--output_path",
'-o', "-o",
help='Path of the output directory to write audio files in') help="Path of the output directory to write audio files in",
)
AudioOffsetOption: OptionInfo = Option( AudioOffsetOption: OptionInfo = Option(
0., 0.0, "--offset", "-s", help="Set the starting offset to separate audio from"
'--offset', )
'-s',
help='Set the starting offset to separate audio from')
AudioDurationOption: OptionInfo = Option( AudioDurationOption: OptionInfo = Option(
600., 600.0,
'--duration', "--duration",
'-d', "-d",
help=( help=(
'Set a maximum duration for processing audio ' "Set a maximum duration for processing audio "
'(only separate offset + duration first seconds of ' "(only separate offset + duration first seconds of "
'the input file)')) "the input file)"
),
)
AudioSTFTBackendOption: OptionInfo = Option( AudioSTFTBackendOption: OptionInfo = Option(
STFTBackend.AUTO, STFTBackend.AUTO,
'--stft-backend', "--stft-backend",
'-B', "-B",
case_sensitive=False, case_sensitive=False,
help=( help=(
'Who should be in charge of computing the stfts. Librosa is faster ' "Who should be in charge of computing the stfts. Librosa is faster "
'than tensorflow on CPU and uses less memory. "auto" will use ' 'than tensorflow on CPU and uses less memory. "auto" will use '
'tensorflow when GPU acceleration is available and librosa when not')) "tensorflow when GPU acceleration is available and librosa when not"
),
)
AudioCodecOption: OptionInfo = Option( AudioCodecOption: OptionInfo = Option(
Codec.WAV, Codec.WAV, "--codec", "-c", help="Audio codec to be used for the separated output"
'--codec', )
'-c',
help='Audio codec to be used for the separated output')
AudioBitrateOption: OptionInfo = Option( AudioBitrateOption: OptionInfo = Option(
'128k', "128k", "--bitrate", "-b", help="Audio bitrate to be used for the separated output"
'--bitrate', )
'-b',
help='Audio bitrate to be used for the separated output')
FilenameFormatOption: OptionInfo = Option( FilenameFormatOption: OptionInfo = Option(
'{filename}/{instrument}.{codec}', "{filename}/{instrument}.{codec}",
'--filename_format', "--filename_format",
'-f', "-f",
help=( help=(
'Template string that will be formatted to generated' "Template string that will be formatted to generated"
'output filename. Such template should be Python formattable' "output filename. Such template should be Python formattable"
'string, and could use {filename}, {instrument}, and {codec}' "string, and could use {filename}, {instrument}, and {codec}"
'variables')) "variables"
),
)
ModelParametersOption: OptionInfo = Option( ModelParametersOption: OptionInfo = Option(
'spleeter:2stems', "spleeter:2stems",
'--params_filename', "--params_filename",
'-p', "-p",
help='JSON filename that contains params') help="JSON filename that contains params",
)
MWFOption: OptionInfo = Option( MWFOption: OptionInfo = Option(
False, False, "--mwf", help="Whether to use multichannel Wiener filtering for separation"
'--mwf', )
help='Whether to use multichannel Wiener filtering for separation')
MUSDBDirectoryOption: OptionInfo = Option( MUSDBDirectoryOption: OptionInfo = Option(
..., ...,
'--mus_dir', "--mus_dir",
exists=True, exists=True,
dir_okay=True, dir_okay=True,
file_okay=False, file_okay=False,
readable=True, readable=True,
resolve_path=True, resolve_path=True,
help='Path to musDB dataset directory') help="Path to musDB dataset directory",
)
TrainingDataDirectoryOption: OptionInfo = Option( TrainingDataDirectoryOption: OptionInfo = Option(
..., ...,
'--data', "--data",
'-d', "-d",
exists=True, exists=True,
dir_okay=True, dir_okay=True,
file_okay=False, file_okay=False,
readable=True, readable=True,
resolve_path=True, resolve_path=True,
help='Path of the folder containing audio data for training') help="Path of the folder containing audio data for training",
)
VerboseOption: OptionInfo = Option( VerboseOption: OptionInfo = Option(False, "--verbose", help="Enable verbose logs")
False,
'--verbose',
help='Enable verbose logs')

View File

@@ -3,6 +3,6 @@
""" Packages that provides static resources file for the library. """ """ Packages that provides static resources file for the library. """
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"

View File

@@ -16,34 +16,33 @@
import atexit import atexit
import os import os
from multiprocessing import Pool from multiprocessing import Pool
from os.path import basename, join, splitext, dirname from os.path import basename, dirname, join, splitext
from spleeter.model.provider import ModelProvider
from typing import Dict, Generator, Optional from typing import Dict, Generator, Optional
from . import SpleeterError
from .audio import Codec, STFTBackend
from .audio.adapter import AudioAdapter
from .audio.convertor import to_stereo
from .model import model_fn
from .model import EstimatorSpecBuilder, InputProviderFactory
from .model.provider import ModelProvider
from .types import AudioDescriptor
from .utils.configuration import load_configuration
# pyright: reportMissingImports=false # pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from librosa.core import istft, stft
from librosa.core import stft, istft
from scipy.signal.windows import hann from scipy.signal.windows import hann
from spleeter.model.provider import ModelProvider
from . import SpleeterError
from .audio import Codec, STFTBackend
from .audio.adapter import AudioAdapter
from .audio.convertor import to_stereo
from .model import EstimatorSpecBuilder, InputProviderFactory, model_fn
from .model.provider import ModelProvider
from .types import AudioDescriptor
from .utils.configuration import load_configuration
# pylint: enable=import-error # pylint: enable=import-error
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"
class DataGenerator(object): class DataGenerator(object):
@@ -81,18 +80,16 @@ def create_estimator(params, MWF):
""" """
# Load model. # Load model.
provider: ModelProvider = ModelProvider.default() provider: ModelProvider = ModelProvider.default()
params['model_dir'] = provider.get(params['model_dir']) params["model_dir"] = provider.get(params["model_dir"])
params['MWF'] = MWF params["MWF"] = MWF
# Setup config # Setup config
session_config = tf.compat.v1.ConfigProto() session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7 session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
config = tf.estimator.RunConfig(session_config=session_config) config = tf.estimator.RunConfig(session_config=session_config)
# Setup estimator # Setup estimator
estimator = tf.estimator.Estimator( estimator = tf.estimator.Estimator(
model_fn=model_fn, model_fn=model_fn, model_dir=params["model_dir"], params=params, config=config
model_dir=params['model_dir'], )
params=params,
config=config)
return estimator return estimator
@@ -104,7 +101,8 @@ class Separator(object):
params_descriptor: str, params_descriptor: str,
MWF: bool = False, MWF: bool = False,
stft_backend: STFTBackend = STFTBackend.AUTO, stft_backend: STFTBackend = STFTBackend.AUTO,
multiprocess: bool = True) -> None: multiprocess: bool = True,
) -> None:
""" """
Default constructor. Default constructor.
@@ -115,7 +113,7 @@ class Separator(object):
(Optional) `True` if MWF should be used, `False` otherwise. (Optional) `True` if MWF should be used, `False` otherwise.
""" """
self._params = load_configuration(params_descriptor) self._params = load_configuration(params_descriptor)
self._sample_rate = self._params['sample_rate'] self._sample_rate = self._params["sample_rate"]
self._MWF = MWF self._MWF = MWF
self._tf_graph = tf.Graph() self._tf_graph = tf.Graph()
self._prediction_generator = None self._prediction_generator = None
@@ -129,7 +127,7 @@ class Separator(object):
else: else:
self._pool = None self._pool = None
self._tasks = [] self._tasks = []
self._params['stft_backend'] = STFTBackend.resolve(stft_backend) self._params["stft_backend"] = STFTBackend.resolve(stft_backend)
self._data_generator = DataGenerator() self._data_generator = DataGenerator()
def __del__(self) -> None: def __del__(self) -> None:
@@ -151,16 +149,13 @@ class Separator(object):
def get_dataset(): def get_dataset():
return tf.data.Dataset.from_generator( return tf.data.Dataset.from_generator(
self._data_generator, self._data_generator,
output_types={ output_types={"waveform": tf.float32, "audio_id": tf.string},
'waveform': tf.float32, output_shapes={"waveform": (None, 2), "audio_id": ()},
'audio_id': tf.string}, )
output_shapes={
'waveform': (None, 2),
'audio_id': ()})
self._prediction_generator = estimator.predict( self._prediction_generator = estimator.predict(
get_dataset, get_dataset, yield_single_examples=False
yield_single_examples=False) )
return self._prediction_generator return self._prediction_generator
def join(self, timeout: int = 200) -> None: def join(self, timeout: int = 200) -> None:
@@ -177,10 +172,8 @@ class Separator(object):
task.wait(timeout=timeout) task.wait(timeout=timeout)
def _stft( def _stft(
self, self, data: np.ndarray, inverse: bool = False, length: Optional[int] = None
data: np.ndarray, ) -> np.ndarray:
inverse: bool = False,
length: Optional[int] = None) -> np.ndarray:
""" """
Single entrypoint for both stft and istft. This computes stft and Single entrypoint for both stft and istft. This computes stft and
istft with librosa on stereo data. The two channels are processed istft with librosa on stereo data. The two channels are processed
@@ -203,27 +196,27 @@ class Separator(object):
""" """
assert not (inverse and length is None) assert not (inverse and length is None)
data = np.asfortranarray(data) data = np.asfortranarray(data)
N = self._params['frame_length'] N = self._params["frame_length"]
H = self._params['frame_step'] H = self._params["frame_step"]
win = hann(N, sym=False) win = hann(N, sym=False)
fstft = istft if inverse else stft fstft = istft if inverse else stft
win_len_arg = { win_len_arg = {"win_length": None, "length": None} if inverse else {"n_fft": N}
'win_length': None,
'length': None} if inverse else {'n_fft': N}
n_channels = data.shape[-1] n_channels = data.shape[-1]
out = [] out = []
for c in range(n_channels): for c in range(n_channels):
d = np.concatenate( d = (
(np.zeros((N, )), data[:, c], np.zeros((N, ))) np.concatenate((np.zeros((N,)), data[:, c], np.zeros((N,))))
) if not inverse else data[:, :, c].T if not inverse
else data[:, :, c].T
)
s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg) s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg)
if inverse: if inverse:
s = s[N:N+length] s = s[N : N + length]
s = np.expand_dims(s.T, 2-inverse) s = np.expand_dims(s.T, 2 - inverse)
out.append(s) out.append(s)
if len(out) == 1: if len(out) == 1:
return out[0] return out[0]
return np.concatenate(out, axis=2-inverse) return np.concatenate(out, axis=2 - inverse)
def _get_input_provider(self): def _get_input_provider(self):
if self._input_provider is None: if self._input_provider is None:
@@ -238,25 +231,22 @@ class Separator(object):
def _get_builder(self): def _get_builder(self):
if self._builder is None: if self._builder is None:
self._builder = EstimatorSpecBuilder( self._builder = EstimatorSpecBuilder(self._get_features(), self._params)
self._get_features(),
self._params)
return self._builder return self._builder
def _get_session(self): def _get_session(self):
if self._session is None: if self._session is None:
saver = tf.compat.v1.train.Saver() saver = tf.compat.v1.train.Saver()
provider = ModelProvider.default() provider = ModelProvider.default()
model_directory: str = provider.get(self._params['model_dir']) model_directory: str = provider.get(self._params["model_dir"])
latest_checkpoint = tf.train.latest_checkpoint(model_directory) latest_checkpoint = tf.train.latest_checkpoint(model_directory)
self._session = tf.compat.v1.Session() self._session = tf.compat.v1.Session()
saver.restore(self._session, latest_checkpoint) saver.restore(self._session, latest_checkpoint)
return self._session return self._session
def _separate_librosa( def _separate_librosa(
self, self, waveform: np.ndarray, audio_descriptor: AudioDescriptor
waveform: np.ndarray, ) -> Dict:
audio_descriptor: AudioDescriptor) -> Dict:
""" """
Performs separation with librosa backend for STFT. Performs separation with librosa backend for STFT.
@@ -280,20 +270,18 @@ class Separator(object):
outputs = sess.run( outputs = sess.run(
outputs, outputs,
feed_dict=self._get_input_provider().get_feed_dict( feed_dict=self._get_input_provider().get_feed_dict(
features, features, stft, audio_descriptor
stft, ),
audio_descriptor)) )
for inst in self._get_builder().instruments: for inst in self._get_builder().instruments:
out[inst] = self._stft( out[inst] = self._stft(
outputs[inst], outputs[inst], inverse=True, length=waveform.shape[0]
inverse=True, )
length=waveform.shape[0])
return out return out
def _separate_tensorflow( def _separate_tensorflow(
self, self, waveform: np.ndarray, audio_descriptor: AudioDescriptor
waveform: np.ndarray, ) -> Dict:
audio_descriptor: AudioDescriptor) -> Dict:
""" """
Performs source separation over the given waveform with tensorflow Performs source separation over the given waveform with tensorflow
backend. backend.
@@ -310,18 +298,17 @@ class Separator(object):
waveform = to_stereo(waveform) waveform = to_stereo(waveform)
prediction_generator = self._get_prediction_generator() prediction_generator = self._get_prediction_generator()
# NOTE: update data in generator before performing separation. # NOTE: update data in generator before performing separation.
self._data_generator.update_data({ self._data_generator.update_data(
'waveform': waveform, {"waveform": waveform, "audio_id": np.array(audio_descriptor)}
'audio_id': np.array(audio_descriptor)}) )
# NOTE: perform separation. # NOTE: perform separation.
prediction = next(prediction_generator) prediction = next(prediction_generator)
prediction.pop('audio_id') prediction.pop("audio_id")
return prediction return prediction
def separate( def separate(
self, self, waveform: np.ndarray, audio_descriptor: Optional[str] = None
waveform: np.ndarray, ) -> None:
audio_descriptor: Optional[str] = None) -> None:
""" """
Performs separation on a waveform. Performs separation on a waveform.
@@ -331,12 +318,12 @@ class Separator(object):
audio_descriptor (str): audio_descriptor (str):
(Optional) string describing the waveform (e.g. filename). (Optional) string describing the waveform (e.g. filename).
""" """
backend: str = self._params['stft_backend'] backend: str = self._params["stft_backend"]
if backend == STFTBackend.TENSORFLOW: if backend == STFTBackend.TENSORFLOW:
return self._separate_tensorflow(waveform, audio_descriptor) return self._separate_tensorflow(waveform, audio_descriptor)
elif backend == STFTBackend.LIBROSA: elif backend == STFTBackend.LIBROSA:
return self._separate_librosa(waveform, audio_descriptor) return self._separate_librosa(waveform, audio_descriptor)
raise ValueError(f'Unsupported STFT backend {backend}') raise ValueError(f"Unsupported STFT backend {backend}")
def separate_to_file( def separate_to_file(
self, self,
@@ -344,11 +331,12 @@ class Separator(object):
destination: str, destination: str,
audio_adapter: Optional[AudioAdapter] = None, audio_adapter: Optional[AudioAdapter] = None,
offset: int = 0, offset: int = 0,
duration: float = 600., duration: float = 600.0,
codec: Codec = Codec.WAV, codec: Codec = Codec.WAV,
bitrate: str = '128k', bitrate: str = "128k",
filename_format: str = '{filename}/{instrument}.{codec}', filename_format: str = "{filename}/{instrument}.{codec}",
synchronous: bool = True) -> None: synchronous: bool = True,
) -> None:
""" """
Performs source separation and export result to file using Performs source separation and export result to file using
given audio adapter. given audio adapter.
@@ -389,7 +377,8 @@ class Separator(object):
audio_descriptor, audio_descriptor,
offset=offset, offset=offset,
duration=duration, duration=duration,
sample_rate=self._sample_rate) sample_rate=self._sample_rate,
)
sources = self.separate(waveform, audio_descriptor) sources = self.separate(waveform, audio_descriptor)
self.save_to_file( self.save_to_file(
sources, sources,
@@ -399,18 +388,20 @@ class Separator(object):
codec, codec,
audio_adapter, audio_adapter,
bitrate, bitrate,
synchronous) synchronous,
)
def save_to_file( def save_to_file(
self, self,
sources: Dict, sources: Dict,
audio_descriptor: AudioDescriptor, audio_descriptor: AudioDescriptor,
destination: str, destination: str,
filename_format: str = '{filename}/{instrument}.{codec}', filename_format: str = "{filename}/{instrument}.{codec}",
codec: Codec = Codec.WAV, codec: Codec = Codec.WAV,
audio_adapter: Optional[AudioAdapter] = None, audio_adapter: Optional[AudioAdapter] = None,
bitrate: str = '128k', bitrate: str = "128k",
synchronous: bool = True) -> None: synchronous: bool = True,
) -> None:
""" """
Export dictionary of sources to files. Export dictionary of sources to files.
@@ -443,34 +434,32 @@ class Separator(object):
filename = splitext(basename(audio_descriptor))[0] filename = splitext(basename(audio_descriptor))[0]
generated = [] generated = []
for instrument, data in sources.items(): for instrument, data in sources.items():
path = join(destination, filename_format.format( path = join(
destination,
filename_format.format(
filename=filename, filename=filename,
instrument=instrument, instrument=instrument,
foldername=foldername, foldername=foldername,
codec=codec, codec=codec,
)) ),
)
directory = os.path.dirname(path) directory = os.path.dirname(path)
if not os.path.exists(directory): if not os.path.exists(directory):
os.makedirs(directory) os.makedirs(directory)
if path in generated: if path in generated:
raise SpleeterError(( raise SpleeterError(
f'Separated source path conflict : {path},' (
'please check your filename format')) f"Separated source path conflict : {path},"
"please check your filename format"
)
)
generated.append(path) generated.append(path)
if self._pool: if self._pool:
task = self._pool.apply_async(audio_adapter.save, ( task = self._pool.apply_async(
path, audio_adapter.save, (path, data, self._sample_rate, codec, bitrate)
data, )
self._sample_rate,
codec,
bitrate))
self._tasks.append(task) self._tasks.append(task)
else: else:
audio_adapter.save( audio_adapter.save(path, data, self._sample_rate, codec, bitrate)
path,
data,
self._sample_rate,
codec,
bitrate)
if synchronous and self._pool: if synchronous and self._pool:
self.join() self.join()

View File

@@ -8,6 +8,7 @@ from typing import Any, Tuple
# pyright: reportMissingImports=false # pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import numpy as np import numpy as np
# pylint: enable=import-error # pylint: enable=import-error
AudioDescriptor: type = Any AudioDescriptor: type = Any

View File

@@ -3,6 +3,6 @@
""" This package provides utility function and classes. """ """ This package provides utility function and classes. """
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"

View File

@@ -3,20 +3,18 @@
""" Module that provides configuration loading function. """ """ Module that provides configuration loading function. """
import json
import importlib.resources as loader import importlib.resources as loader
import json
from os.path import exists from os.path import exists
from typing import Dict from typing import Dict
from .. import resources, SpleeterError from .. import SpleeterError, resources
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
__email__ = 'spleeter@deezer.com' _EMBEDDED_CONFIGURATION_PREFIX: str = "spleeter:"
__author__ = 'Deezer Research'
__license__ = 'MIT License'
_EMBEDDED_CONFIGURATION_PREFIX: str = 'spleeter:'
def load_configuration(descriptor: str) -> Dict: def load_configuration(descriptor: str) -> Dict:
@@ -41,13 +39,13 @@ def load_configuration(descriptor: str) -> Dict:
""" """
# Embedded configuration reading. # Embedded configuration reading.
if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX): if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX):
name = descriptor[len(_EMBEDDED_CONFIGURATION_PREFIX):] name = descriptor[len(_EMBEDDED_CONFIGURATION_PREFIX) :]
if not loader.is_resource(resources, f'{name}.json'): if not loader.is_resource(resources, f"{name}.json"):
raise SpleeterError(f'No embedded configuration {name} found') raise SpleeterError(f"No embedded configuration {name} found")
with loader.open_text(resources, f'{name}.json') as stream: with loader.open_text(resources, f"{name}.json") as stream:
return json.load(stream) return json.load(stream)
# Standard file reading. # Standard file reading.
if not exists(descriptor): if not exists(descriptor):
raise SpleeterError(f'Configuration file {descriptor} not found') raise SpleeterError(f"Configuration file {descriptor} not found")
with open(descriptor, 'r') as stream: with open(descriptor, "r") as stream:
return json.load(stream) return json.load(stream)

View File

@@ -5,19 +5,19 @@
import logging import logging
import warnings import warnings
from os import environ from os import environ
# pyright: reportMissingImports=false # pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
from typer import echo from typer import echo
# pylint: enable=import-error # pylint: enable=import-error
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"
environ['TF_CPP_MIN_LOG_LEVEL'] = '3' environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
class TyperLoggerHandler(logging.Handler): class TyperLoggerHandler(logging.Handler):
@@ -27,10 +27,10 @@ class TyperLoggerHandler(logging.Handler):
echo(self.format(record)) echo(self.format(record))
formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s') formatter = logging.Formatter("%(levelname)s:%(name)s:%(message)s")
handler = TyperLoggerHandler() handler = TyperLoggerHandler()
handler.setFormatter(formatter) handler.setFormatter(formatter)
logger: logging.Logger = logging.getLogger('spleeter') logger: logging.Logger = logging.getLogger("spleeter")
logger.addHandler(handler) logger.addHandler(handler)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@@ -45,11 +45,12 @@ def configure_logger(verbose: bool) -> None:
""" """
from tensorflow import get_logger from tensorflow import get_logger
from tensorflow.compat.v1 import logging as tf_logging from tensorflow.compat.v1 import logging as tf_logging
tf_logger = get_logger() tf_logger = get_logger()
tf_logger.handlers = [handler] tf_logger.handlers = [handler]
if verbose: if verbose:
tf_logging.set_verbosity(tf_logging.INFO) tf_logging.set_verbosity(tf_logging.INFO)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
else: else:
warnings.filterwarnings('ignore') warnings.filterwarnings("ignore")
tf_logging.set_verbosity(tf_logging.ERROR) tf_logging.set_verbosity(tf_logging.ERROR)

View File

@@ -5,21 +5,22 @@
from typing import Any, Callable, Dict from typing import Any, Callable, Dict
import pandas as pd
# pyright: reportMissingImports=false # pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import tensorflow as tf import tensorflow as tf
import pandas as pd
# pylint: enable=import-error # pylint: enable=import-error
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"
def sync_apply( def sync_apply(
tensor_dict: tf.Tensor, tensor_dict: tf.Tensor, func: Callable, concat_axis: int = 1
func: Callable, ) -> Dict[str, tf.Tensor]:
concat_axis: int = 1) -> Dict[str, tf.Tensor]:
""" """
Return a function that applies synchronously the provided func on the Return a function that applies synchronously the provided func on the
provided dictionnary of tensor. This means that func is applied to the provided dictionnary of tensor. This means that func is applied to the
@@ -48,7 +49,8 @@ def sync_apply(
""" """
if concat_axis not in {0, 1}: if concat_axis not in {0, 1}:
raise NotImplementedError( raise NotImplementedError(
'Function only implemented for concat_axis equal to 0 or 1') "Function only implemented for concat_axis equal to 0 or 1"
)
tensor_list = list(tensor_dict.values()) tensor_list = list(tensor_dict.values())
concat_tensor = tf.concat(tensor_list, concat_axis) concat_tensor = tf.concat(tensor_list, concat_axis)
processed_concat_tensor = func(concat_tensor) processed_concat_tensor = func(concat_tensor)
@@ -56,18 +58,21 @@ def sync_apply(
D = tensor_shape[concat_axis] D = tensor_shape[concat_axis]
if concat_axis == 0: if concat_axis == 0:
return { return {
name: processed_concat_tensor[index * D:(index + 1) * D, :, :] name: processed_concat_tensor[index * D : (index + 1) * D, :, :]
for index, name in enumerate(tensor_dict)} for index, name in enumerate(tensor_dict)
}
return { return {
name: processed_concat_tensor[:, index * D:(index + 1) * D, :] name: processed_concat_tensor[:, index * D : (index + 1) * D, :]
for index, name in enumerate(tensor_dict)} for index, name in enumerate(tensor_dict)
}
def from_float32_to_uint8( def from_float32_to_uint8(
tensor: tf.Tensor, tensor: tf.Tensor,
tensor_key: str = 'tensor', tensor_key: str = "tensor",
min_key: str = 'min', min_key: str = "min",
max_key: str = 'max') -> tf.Tensor: max_key: str = "max",
) -> tf.Tensor:
""" """
Parameters: Parameters:
@@ -83,16 +88,17 @@ def from_float32_to_uint8(
tensor_max = tf.reduce_max(tensor) tensor_max = tf.reduce_max(tensor)
return { return {
tensor_key: tf.cast( tensor_key: tf.cast(
(tensor - tensor_min) / (tensor_max - tensor_min + 1e-16) (tensor - tensor_min) / (tensor_max - tensor_min + 1e-16) * 255.9999,
* 255.9999, dtype=tf.uint8), dtype=tf.uint8,
),
min_key: tensor_min, min_key: tensor_min,
max_key: tensor_max} max_key: tensor_max,
}
def from_uint8_to_float32( def from_uint8_to_float32(
tensor: tf.Tensor, tensor: tf.Tensor, tensor_min: tf.Tensor, tensor_max: tf.Tensor
tensor_min: tf.Tensor, ) -> tf.Tensor:
tensor_max: tf.Tensor) -> tf.Tensor:
""" """
Parameters: Parameters:
@@ -104,14 +110,11 @@ def from_uint8_to_float32(
tensorflow.Tensor: tensorflow.Tensor:
""" """
return ( return (
tf.cast(tensor, tf.float32) tf.cast(tensor, tf.float32) * (tensor_max - tensor_min) / 255.9999 + tensor_min
* (tensor_max - tensor_min) )
/ 255.9999 + tensor_min)
def pad_and_partition( def pad_and_partition(tensor: tf.Tensor, segment_len: int) -> tf.Tensor:
tensor: tf.Tensor,
segment_len: int) -> tf.Tensor:
""" """
Pad and partition a tensor into segment of len `segment_len` Pad and partition a tensor into segment of len `segment_len`
along the first dimension. The tensor is padded with 0 in order along the first dimension. The tensor is padded with 0 in order
@@ -137,15 +140,11 @@ def pad_and_partition(
""" """
tensor_size = tf.math.floormod(tf.shape(tensor)[0], segment_len) tensor_size = tf.math.floormod(tf.shape(tensor)[0], segment_len)
pad_size = tf.math.floormod(segment_len - tensor_size, segment_len) pad_size = tf.math.floormod(segment_len - tensor_size, segment_len)
padded = tf.pad( padded = tf.pad(tensor, [[0, pad_size]] + [[0, 0]] * (len(tensor.shape) - 1))
tensor,
[[0, pad_size]] + [[0, 0]] * (len(tensor.shape)-1))
split = (tf.shape(padded)[0] + segment_len - 1) // segment_len split = (tf.shape(padded)[0] + segment_len - 1) // segment_len
return tf.reshape( return tf.reshape(
padded, padded, tf.concat([[split, segment_len], tf.shape(padded)[1:]], axis=0)
tf.concat( )
[[split, segment_len], tf.shape(padded)[1:]],
axis=0))
def pad_and_reshape(instr_spec, frame_length, F) -> Any: def pad_and_reshape(instr_spec, frame_length, F) -> Any:
@@ -164,10 +163,7 @@ def pad_and_reshape(instr_spec, frame_length, F) -> Any:
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1]) extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
extended_spec = tf.concat([instr_spec, extension], axis=2) extended_spec = tf.concat([instr_spec, extension], axis=2)
old_shape = tf.shape(extended_spec) old_shape = tf.shape(extended_spec)
new_shape = tf.concat([ new_shape = tf.concat([[old_shape[0] * old_shape[1]], old_shape[2:]], axis=0)
[old_shape[0] * old_shape[1]],
old_shape[2:]],
axis=0)
processed_instr_spec = tf.reshape(extended_spec, new_shape) processed_instr_spec = tf.reshape(extended_spec, new_shape)
return processed_instr_spec return processed_instr_spec
@@ -186,16 +182,11 @@ def dataset_from_csv(csv_path: str, **kwargs) -> Any:
Loaded dataset. Loaded dataset.
""" """
df = pd.read_csv(csv_path, **kwargs) df = pd.read_csv(csv_path, **kwargs)
dataset = ( dataset = tf.data.Dataset.from_tensor_slices({key: df[key].values for key in df})
tf.data.Dataset.from_tensor_slices(
{key: df[key].values for key in df})
)
return dataset return dataset
def check_tensor_shape( def check_tensor_shape(tensor_tf: tf.Tensor, target_shape: Any) -> bool:
tensor_tf: tf.Tensor,
target_shape: Any) -> bool:
""" """
Return a Tensorflow boolean graph that indicates whether Return a Tensorflow boolean graph that indicates whether
sample[features_key] has the specified target shape. Only check sample[features_key] has the specified target shape. Only check
@@ -215,14 +206,12 @@ def check_tensor_shape(
for i, target_length in enumerate(target_shape): for i, target_length in enumerate(target_shape):
if target_length: if target_length:
result = tf.logical_and( result = tf.logical_and(
result, result, tf.equal(tf.constant(target_length), tf.shape(tensor_tf)[i])
tf.equal(tf.constant(target_length), tf.shape(tensor_tf)[i])) )
return result return result
def set_tensor_shape( def set_tensor_shape(tensor: tf.Tensor, tensor_shape: Any) -> tf.Tensor:
tensor: tf.Tensor,
tensor_shape: Any) -> tf.Tensor:
""" """
Set shape for a tensor (not in place, as opposed to tf.set_shape) Set shape for a tensor (not in place, as opposed to tf.set_shape)