add pyproject.toml for poetry transition

This commit is contained in:
Félix Voituret
2021-01-08 17:32:39 +01:00
parent ad0171f6dd
commit 2b479eb683
25 changed files with 3741 additions and 1624 deletions

View File

@@ -1,4 +1,4 @@
name: pytest
name: test
on:
pull_request:
branches:
@@ -15,13 +15,6 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- uses: actions/cache@v2
id: spleeter-pip-cache
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }}
restore-keys: |
${{ runner.os }}-pip-
- uses: actions/cache@v2
env:
model-release: 1
@@ -31,11 +24,29 @@ jobs:
key: models-${{ env.model-release }}
restore-keys: |
models-${{ env.model-release }}
- name: Install dependencies
- name: Install ffmpeg
run: |
sudo apt-get update && sudo apt-get install -y ffmpeg
pip install --upgrade pip setuptools
pip install pytest==5.4.3 pytest-xdist==1.32.0 pytest-forked==1.1.3 musdb museval
python setup.py install
- name: Install Poetry
uses: dschep/install-poetry-action@v1.2
- name: Cache Poetry virtualenv
uses: actions/cache@v1
id: cache
with:
path: ~/.virtualenvs
key: poetry-${{ hashFiles('**/poetry.lock') }}
restore-keys: |
poetry-${{ hashFiles('**/poetry.lock') }}
- name: Set Poetry config
run: |
poetry config settings.virtualenvs.in-project false
poetry config settings.virtualenvs.path ~/.virtualenvs
- name: Install Dependencies
run: poetry install
if: steps.cache.outputs.cache-hit != 'true'
- name: Code quality checks
run: |
poetry run black spleeter --check
poetry run isort spleeter --check
- name: Test with pytest
run: make test
run: poetry run pytest tests/

1931
poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

84
pyproject.toml Normal file
View File

@@ -0,0 +1,84 @@
[tool.poetry]
name = "spleeter"
version = "2.1.0"
description = "The Deezer source separation library with pretrained models based on tensorflow."
authors = ["Deezer Research <spleeter@deezer.com>"]
license = "MIT License"
readme = "README.md"
repository = "https://github.com/deezer/spleeter"
homepage = "https://github.com/deezer/spleeter"
classifiers = [
"Environment :: Console",
"Environment :: MacOS X",
"Intended Audience :: Developers",
"Intended Audience :: Information Technology",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Operating System :: MacOS",
"Operating System :: Microsoft :: Windows",
"Operating System :: POSIX :: Linux",
"Operating System :: Unix",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: Implementation :: CPython",
"Topic :: Artistic Software",
"Topic :: Multimedia",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Sound/Audio :: Analysis",
"Topic :: Multimedia :: Sound/Audio :: Conversion",
"Topic :: Multimedia :: Sound/Audio :: Sound Synthesis",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Information Analysis",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Utilities"
]
packages = [ { include = "spleeter" } ]
include = ["spleeter/resources/*.json"]
[tool.poetry.dependencies]
python = "^3.7"
ffmpeg-python = "0.2.0"
norbert = "0.2.1"
httpx = {extras = ["http2"], version = "^0.16.1"}
typer = "^0.3.2"
librosa = "0.8.0"
musdb = {version = "0.3.1", optional = true}
museval = {version = "0.3.0", optional = true}
tensorflow = "2.3.0"
pandas = "1.1.2"
numpy = "<1.19.0,>=1.16.0"
[tool.poetry.dev-dependencies]
pytest = "^6.2.1"
isort = "^5.7.0"
black = "^20.8b1"
mypy = "^0.790"
pytest-xdist = "^2.2.0"
pytest-forked = "^1.3.0"
musdb = "0.3.1"
museval = "0.3.0"
[tool.poetry.scripts]
spleeter = 'spleeter.__main__:entrypoint'
[tool.poetry.extras]
evaluation = ["musdb", "museval"]
[tool.isort]
profile = "black"
multi_line_output = 3
[tool.pytest.ini_options]
addopts = "-W ignore::FutureWarning -W ignore::DeprecationWarning -vv --forked"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

View File

@@ -13,9 +13,9 @@
by providing train, evaluation and source separation action.
"""
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class SpleeterError(Exception):

View File

@@ -13,21 +13,21 @@
"""
import json
from functools import partial
from itertools import product
from glob import glob
from itertools import product
from os.path import join
from pathlib import Path
from typing import Container, Dict, List, Optional
# pyright: reportMissingImports=false
# pylint: disable=import-error
from typer import Exit, Typer
from . import SpleeterError
from .options import *
from .utils.logging import configure_logger, logger
# pyright: reportMissingImports=false
# pylint: disable=import-error
from typer import Exit, Typer
# pylint: enable=import-error
spleeter: Typer = Typer(add_completion=False)
@@ -39,18 +39,19 @@ def train(
adapter: str = AudioAdapterOption,
data: Path = TrainingDataDirectoryOption,
params_filename: str = ModelParametersOption,
verbose: bool = VerboseOption) -> None:
verbose: bool = VerboseOption,
) -> None:
"""
Train a source separation model
"""
import tensorflow as tf
from .audio.adapter import AudioAdapter
from .dataset import get_training_dataset, get_validation_dataset
from .model import model_fn
from .model.provider import ModelProvider
from .utils.configuration import load_configuration
import tensorflow as tf
configure_logger(verbose)
audio_adapter = AudioAdapter.get(adapter)
audio_path = str(data)
@@ -59,32 +60,29 @@ def train(
session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params['model_dir'],
model_dir=params["model_dir"],
params=params,
config=tf.estimator.RunConfig(
save_checkpoints_steps=params['save_checkpoints_steps'],
tf_random_seed=params['random_seed'],
save_summary_steps=params['save_summary_steps'],
save_checkpoints_steps=params["save_checkpoints_steps"],
tf_random_seed=params["random_seed"],
save_summary_steps=params["save_summary_steps"],
session_config=session_config,
log_step_count_steps=10,
keep_checkpoint_max=2))
keep_checkpoint_max=2,
),
)
input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
train_spec = tf.estimator.TrainSpec(
input_fn=input_fn,
max_steps=params['train_max_steps'])
input_fn = partial(
get_validation_dataset,
params,
audio_adapter,
audio_path)
input_fn=input_fn, max_steps=params["train_max_steps"]
)
input_fn = partial(get_validation_dataset, params, audio_adapter, audio_path)
evaluation_spec = tf.estimator.EvalSpec(
input_fn=input_fn,
steps=None,
throttle_secs=params['throttle_secs'])
logger.info('Start model training')
input_fn=input_fn, steps=None, throttle_secs=params["throttle_secs"]
)
logger.info("Start model training")
tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec)
ModelProvider.writeProbe(params['model_dir'])
logger.info('Model training done')
ModelProvider.writeProbe(params["model_dir"])
logger.info("Model training done")
@spleeter.command()
@@ -101,7 +99,8 @@ def separate(
filename_format: str = FilenameFormatOption,
params_filename: str = ModelParametersOption,
mwf: bool = MWFOption,
verbose: bool = VerboseOption) -> None:
verbose: bool = VerboseOption,
) -> None:
"""
Separate audio file(s)
"""
@@ -111,14 +110,14 @@ def separate(
configure_logger(verbose)
if deprecated_files is not None:
logger.error(
'⚠️ -i option is not supported anymore, audio files must be supplied '
'using input argument instead (see spleeter separate --help)')
"⚠️ -i option is not supported anymore, audio files must be supplied "
"using input argument instead (see spleeter separate --help)"
)
raise Exit(20)
audio_adapter: AudioAdapter = AudioAdapter.get(adapter)
separator: Separator = Separator(
params_filename,
MWF=mwf,
stft_backend=stft_backend)
params_filename, MWF=mwf, stft_backend=stft_backend
)
for filename in files:
separator.separate_to_file(
str(filename),
@@ -129,16 +128,17 @@ def separate(
codec=codec,
bitrate=bitrate,
filename_format=filename_format,
synchronous=False)
synchronous=False,
)
separator.join()
EVALUATION_SPLIT: str = 'test'
EVALUATION_METRICS_DIRECTORY: str = 'metrics'
EVALUATION_INSTRUMENTS: Container[str] = ('vocals', 'drums', 'bass', 'other')
EVALUATION_METRICS: Container[str] = ('SDR', 'SAR', 'SIR', 'ISR')
EVALUATION_MIXTURE: str = 'mixture.wav'
EVALUATION_AUDIO_DIRECTORY: str = 'audio'
EVALUATION_SPLIT: str = "test"
EVALUATION_METRICS_DIRECTORY: str = "metrics"
EVALUATION_INSTRUMENTS: Container[str] = ("vocals", "drums", "bass", "other")
EVALUATION_METRICS: Container[str] = ("SDR", "SAR", "SIR", "ISR")
EVALUATION_MIXTURE: str = "mixture.wav"
EVALUATION_AUDIO_DIRECTORY: str = "audio"
def _compile_metrics(metrics_output_directory) -> Dict:
@@ -153,27 +153,32 @@ def _compile_metrics(metrics_output_directory) -> Dict:
Dict:
Compiled metrics as dict.
"""
import pandas as pd
import numpy as np
import pandas as pd
songs = glob(join(metrics_output_directory, 'test/*.json'))
songs = glob(join(metrics_output_directory, "test/*.json"))
index = pd.MultiIndex.from_tuples(
product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS),
names=['instrument', 'metric'])
pd.DataFrame([], index=['config1', 'config2'], columns=index)
names=["instrument", "metric"],
)
pd.DataFrame([], index=["config1", "config2"], columns=index)
metrics = {
instrument: {k: [] for k in EVALUATION_METRICS}
for instrument in EVALUATION_INSTRUMENTS}
for instrument in EVALUATION_INSTRUMENTS
}
for song in songs:
with open(song, 'r') as stream:
with open(song, "r") as stream:
data = json.load(stream)
for target in data['targets']:
instrument = target['name']
for target in data["targets"]:
instrument = target["name"]
for metric in EVALUATION_METRICS:
sdr_med = np.median([
frame['metrics'][metric]
for frame in target['frames']
if not np.isnan(frame['metrics'][metric])])
sdr_med = np.median(
[
frame["metrics"][metric]
for frame in target["frames"]
if not np.isnan(frame["metrics"][metric])
]
)
metrics[instrument][metric].append(sdr_med)
return metrics
@@ -186,7 +191,8 @@ def evaluate(
params_filename: str = ModelParametersOption,
mus_dir: Path = MUSDBDirectoryOption,
mwf: bool = MWFOption,
verbose: bool = VerboseOption) -> Dict:
verbose: bool = VerboseOption,
) -> Dict:
"""
Evaluate a model on the musDB test dataset
"""
@@ -197,42 +203,44 @@ def evaluate(
import musdb
import museval
except ImportError:
logger.error('Extra dependencies musdb and museval not found')
logger.error('Please install musdb and museval first, abort')
logger.error("Extra dependencies musdb and museval not found")
logger.error("Please install musdb and museval first, abort")
raise Exit(10)
# Separate musdb sources.
songs = glob(join(mus_dir, EVALUATION_SPLIT, '*/'))
songs = glob(join(mus_dir, EVALUATION_SPLIT, "*/"))
mixtures = [join(song, EVALUATION_MIXTURE) for song in songs]
audio_output_directory = join(output_path, EVALUATION_AUDIO_DIRECTORY)
separate(
deprecated_files=None,
files=mixtures,
adapter=adapter,
bitrate='128k',
bitrate="128k",
codec=Codec.WAV,
duration=600.,
duration=600.0,
offset=0,
output_path=join(audio_output_directory, EVALUATION_SPLIT),
stft_backend=stft_backend,
filename_format='{foldername}/{instrument}.{codec}',
filename_format="{foldername}/{instrument}.{codec}",
params_filename=params_filename,
mwf=mwf,
verbose=verbose)
verbose=verbose,
)
# Compute metrics with musdb.
metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY)
logger.info('Starting musdb evaluation (this could be long) ...')
logger.info("Starting musdb evaluation (this could be long) ...")
dataset = musdb.DB(root=mus_dir, is_wav=True, subsets=[EVALUATION_SPLIT])
museval.eval_mus_dir(
dataset=dataset,
estimates_dir=audio_output_directory,
output_dir=metrics_output_directory)
logger.info('musdb evaluation done')
output_dir=metrics_output_directory,
)
logger.info("musdb evaluation done")
# Compute and pretty print median metrics.
metrics = _compile_metrics(metrics_output_directory)
for instrument, metric in metrics.items():
logger.info(f'{instrument}:')
logger.info(f"{instrument}:")
for metric, value in metric.items():
logger.info(f'{metric}: {np.median(value):.3f}')
logger.info(f"{metric}: {np.median(value):.3f}")
return metrics
@@ -244,5 +252,5 @@ def entrypoint():
logger.error(e)
if __name__ == '__main__':
if __name__ == "__main__":
entrypoint()

View File

@@ -12,28 +12,28 @@
from enum import Enum
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class Codec(str, Enum):
""" Enumeration of supported audio codec. """
WAV: str = 'wav'
MP3: str = 'mp3'
OGG: str = 'ogg'
M4A: str = 'm4a'
WMA: str = 'wma'
FLAC: str = 'flac'
WAV: str = "wav"
MP3: str = "mp3"
OGG: str = "ogg"
M4A: str = "m4a"
WMA: str = "wma"
FLAC: str = "flac"
class STFTBackend(str, Enum):
""" Enumeration of supported STFT backend. """
AUTO: str = 'auto'
TENSORFLOW: str = 'tensorflow'
LIBROSA: str = 'librosa'
AUTO: str = "auto"
TENSORFLOW: str = "tensorflow"
LIBROSA: str = "librosa"
@classmethod
def resolve(cls: type, backend: str) -> str:
@@ -44,9 +44,9 @@ class STFTBackend(str, Enum):
import tensorflow as tf
if backend not in cls.__members__.values():
raise ValueError(f'Unsupported backend {backend}')
raise ValueError(f"Unsupported backend {backend}")
if backend == cls.AUTO:
if len(tf.config.list_physical_devices('GPU')):
if len(tf.config.list_physical_devices("GPU")):
return cls.TENSORFLOW
return cls.LIBROSA
return backend

View File

@@ -6,29 +6,31 @@
from abc import ABC, abstractmethod
from importlib import import_module
from pathlib import Path
from spleeter.audio import Codec
from typing import Any, Dict, List, Optional, Union
from .. import SpleeterError
from ..types import AudioDescriptor, Signal
from ..utils.logging import logger
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
from spleeter.audio import Codec
from .. import SpleeterError
from ..types import AudioDescriptor, Signal
from ..utils.logging import logger
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class AudioAdapter(ABC):
""" An abstract class for manipulating audio signal. """
_DEFAULT: 'AudioAdapter' = None
_DEFAULT: "AudioAdapter" = None
""" Default audio adapter singleton instance. """
@abstractmethod
@@ -38,7 +40,8 @@ class AudioAdapter(ABC):
offset: Optional[float] = None,
duration: Optional[float] = None,
sample_rate: Optional[float] = None,
dtype: np.dtype = np.float32) -> Signal:
dtype: np.dtype = np.float32,
) -> Signal:
"""
Loads the audio file denoted by the given audio descriptor and
returns it data as a waveform. Aims to be implemented by client.
@@ -66,10 +69,11 @@ class AudioAdapter(ABC):
self,
audio_descriptor,
offset: float = 0.0,
duration: float = 1800.,
duration: float = 1800.0,
sample_rate: int = 44100,
dtype: bytes = b'float32',
waveform_name: str = 'waveform') -> Dict[str, Any]:
dtype: bytes = b"float32",
waveform_name: str = "waveform",
) -> Dict[str, Any]:
"""
Load the audio and convert it to a tensorflow waveform.
@@ -101,32 +105,31 @@ class AudioAdapter(ABC):
# Defined safe loading function.
def safe_load(path, offset, duration, sample_rate, dtype):
logger.info(
f'Loading audio {path} from {offset} to {offset + duration}')
logger.info(f"Loading audio {path} from {offset} to {offset + duration}")
try:
(data, _) = self.load(
path.numpy(),
offset.numpy(),
duration.numpy(),
sample_rate.numpy(),
dtype=dtype.numpy())
logger.info('Audio data loaded successfully')
dtype=dtype.numpy(),
)
logger.info("Audio data loaded successfully")
return (data, False)
except Exception as e:
logger.exception(
'An error occurs while loading audio',
exc_info=e)
logger.exception("An error occurs while loading audio", exc_info=e)
return (np.float32(-1.0), True)
# Execute function and format results.
results = tf.py_function(
results = (
tf.py_function(
safe_load,
[audio_descriptor, offset, duration, sample_rate, dtype],
(tf.float32, tf.bool)),
(tf.float32, tf.bool),
),
)
waveform, error = results[0]
return {
waveform_name: waveform,
f'{waveform_name}_error': error}
return {waveform_name: waveform, f"{waveform_name}_error": error}
@abstractmethod
def save(
@@ -135,7 +138,8 @@ class AudioAdapter(ABC):
data: np.ndarray,
sample_rate: float,
codec: Codec = None,
bitrate: str = None) -> None:
bitrate: str = None,
) -> None:
"""
Save the given audio data to the file denoted by the given path.
@@ -155,7 +159,7 @@ class AudioAdapter(ABC):
pass
@classmethod
def default(cls: type) -> 'AudioAdapter':
def default(cls: type) -> "AudioAdapter":
"""
Builds and returns a default audio adapter instance.
@@ -165,11 +169,12 @@ class AudioAdapter(ABC):
"""
if cls._DEFAULT is None:
from .ffmpeg import FFMPEGProcessAudioAdapter
cls._DEFAULT = FFMPEGProcessAudioAdapter()
return cls._DEFAULT
@classmethod
def get(cls: type, descriptor: str) -> 'AudioAdapter':
def get(cls: type, descriptor: str) -> "AudioAdapter":
"""
Load dynamically an AudioAdapter from given class descriptor.
@@ -183,12 +188,13 @@ class AudioAdapter(ABC):
"""
if not descriptor:
return cls.default()
module_path: List[str] = descriptor.split('.')
module_path: List[str] = descriptor.split(".")
adapter_class_name: str = module_path[-1]
module_path: str = '.'.join(module_path[:-1])
module_path: str = ".".join(module_path[:-1])
adapter_module = import_module(module_path)
adapter_class = getattr(adapter_module, adapter_class_name)
if not issubclass(adapter_class, AudioAdapter):
raise SpleeterError(
f'{adapter_class_name} is not a valid AudioAdapter class')
f"{adapter_class_name} is not a valid AudioAdapter class"
)
return adapter_class()

View File

@@ -3,22 +3,21 @@
""" This module provides audio data convertion functions. """
from ..utils.tensor import from_float32_to_uint8, from_uint8_to_float32
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
from ..utils.tensor import from_float32_to_uint8, from_uint8_to_float32
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def to_n_channels(
waveform: tf.Tensor,
n_channels: int) -> tf.Tensor:
def to_n_channels(waveform: tf.Tensor, n_channels: int) -> tf.Tensor:
"""
Convert a waveform to n_channels by removing or duplicating channels if
needed (in tensorflow).
@@ -36,7 +35,8 @@ def to_n_channels(
return tf.cond(
tf.shape(waveform)[1] >= n_channels,
true_fn=lambda: waveform[:, :n_channels],
false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels])
false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels],
)
def to_stereo(waveform: np.ndarray) -> np.ndarray:
@@ -73,7 +73,7 @@ def gain_to_db(tensor: tf.Tensor, espilon: float = 10e-10) -> tf.Tensor:
tensorflow.Tensor:
Converted tensor.
"""
return 20. / np.log(10) * tf.math.log(tf.maximum(tensor, espilon))
return 20.0 / np.log(10) * tf.math.log(tf.maximum(tensor, espilon))
def db_to_gain(tensor: tf.Tensor) -> tf.Tensor:
@@ -88,13 +88,12 @@ def db_to_gain(tensor: tf.Tensor) -> tf.Tensor:
tensorflow.Tensor:
Converted tensor.
"""
return tf.pow(10., (tensor / 20.))
return tf.pow(10.0, (tensor / 20.0))
def spectrogram_to_db_uint(
spectrogram: tf.Tensor,
db_range: float = 100.,
**kwargs) -> tf.Tensor:
spectrogram: tf.Tensor, db_range: float = 100.0, **kwargs
) -> tf.Tensor:
"""
Encodes given spectrogram into uint8 using decibel scale.
@@ -111,15 +110,14 @@ def spectrogram_to_db_uint(
db_spectrogram: tf.Tensor = gain_to_db(spectrogram)
max_db_spectrogram: tf.Tensor = tf.reduce_max(db_spectrogram)
db_spectrogram: tf.Tensor = tf.maximum(
db_spectrogram,
max_db_spectrogram - db_range)
db_spectrogram, max_db_spectrogram - db_range
)
return from_float32_to_uint8(db_spectrogram, **kwargs)
def db_uint_spectrogram_to_gain(
db_uint_spectrogram: tf.Tensor,
min_db: tf.Tensor,
max_db: tf.Tensor) -> tf.Tensor:
db_uint_spectrogram: tf.Tensor, min_db: tf.Tensor, max_db: tf.Tensor
) -> tf.Tensor:
"""
Decode spectrogram from uint8 decibel scale.
@@ -136,7 +134,6 @@ def db_uint_spectrogram_to_gain(
Decoded spectrogram as `float32` tensor.
"""
db_spectrogram: tf.Tensor = from_uint8_to_float32(
db_uint_spectrogram,
min_db,
max_db)
db_uint_spectrogram, min_db, max_db
)
return db_to_gain(db_spectrogram)

View File

@@ -11,25 +11,25 @@
import datetime as dt
import os
import shutil
from pathlib import Path
from typing import Dict, Optional, Union
from . import Codec
from .adapter import AudioAdapter
from .. import SpleeterError
from ..types import Signal
from ..utils.logging import logger
# pyright: reportMissingImports=false
# pylint: disable=import-error
import ffmpeg
import numpy as np
from .. import SpleeterError
from ..types import Signal
from ..utils.logging import logger
from . import Codec
from .adapter import AudioAdapter
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class FFMPEGProcessAudioAdapter(AudioAdapter):
@@ -43,9 +43,9 @@ class FFMPEGProcessAudioAdapter(AudioAdapter):
"""
SUPPORTED_CODECS: Dict[Codec, str] = {
Codec.M4A: 'aac',
Codec.OGG: 'libvorbis',
Codec.WMA: 'wmav2'
Codec.M4A: "aac",
Codec.OGG: "libvorbis",
Codec.WMA: "wmav2",
}
""" FFMPEG codec name mapping. """
@@ -57,9 +57,9 @@ class FFMPEGProcessAudioAdapter(AudioAdapter):
SpleeterError:
If ffmpeg or ffprobe is not found.
"""
for binary in ('ffmpeg', 'ffprobe'):
for binary in ("ffmpeg", "ffprobe"):
if shutil.which(binary) is None:
raise SpleeterError('{} binary not found'.format(binary))
raise SpleeterError("{} binary not found".format(binary))
def load(
_,
@@ -67,7 +67,8 @@ class FFMPEGProcessAudioAdapter(AudioAdapter):
offset: Optional[float] = None,
duration: Optional[float] = None,
sample_rate: Optional[float] = None,
dtype: np.dtype = np.float32) -> Signal:
dtype: np.dtype = np.float32,
) -> Signal:
"""
Loads the audio file denoted by the given path
and returns it data as a waveform.
@@ -100,29 +101,30 @@ class FFMPEGProcessAudioAdapter(AudioAdapter):
probe = ffmpeg.probe(path)
except ffmpeg._run.Error as e:
raise SpleeterError(
'An error occurs with ffprobe (see ffprobe output below)\n\n{}'
.format(e.stderr.decode()))
if 'streams' not in probe or len(probe['streams']) == 0:
raise SpleeterError('No stream was found with ffprobe')
"An error occurs with ffprobe (see ffprobe output below)\n\n{}".format(
e.stderr.decode()
)
)
if "streams" not in probe or len(probe["streams"]) == 0:
raise SpleeterError("No stream was found with ffprobe")
metadata = next(
stream
for stream in probe['streams']
if stream['codec_type'] == 'audio')
n_channels = metadata['channels']
stream for stream in probe["streams"] if stream["codec_type"] == "audio"
)
n_channels = metadata["channels"]
if sample_rate is None:
sample_rate = metadata['sample_rate']
output_kwargs = {'format': 'f32le', 'ar': sample_rate}
sample_rate = metadata["sample_rate"]
output_kwargs = {"format": "f32le", "ar": sample_rate}
if duration is not None:
output_kwargs['t'] = str(dt.timedelta(seconds=duration))
output_kwargs["t"] = str(dt.timedelta(seconds=duration))
if offset is not None:
output_kwargs['ss'] = str(dt.timedelta(seconds=offset))
output_kwargs["ss"] = str(dt.timedelta(seconds=offset))
process = (
ffmpeg
.input(path)
.output('pipe:', **output_kwargs)
.run_async(pipe_stdout=True, pipe_stderr=True))
ffmpeg.input(path)
.output("pipe:", **output_kwargs)
.run_async(pipe_stdout=True, pipe_stderr=True)
)
buffer, _ = process.communicate()
waveform = np.frombuffer(buffer, dtype='<f4').reshape(-1, n_channels)
waveform = np.frombuffer(buffer, dtype="<f4").reshape(-1, n_channels)
if not waveform.dtype == np.dtype(dtype):
waveform = waveform.astype(dtype)
return (waveform, sample_rate)
@@ -133,7 +135,8 @@ class FFMPEGProcessAudioAdapter(AudioAdapter):
data: np.ndarray,
sample_rate: float,
codec: Codec = None,
bitrate: str = None) -> None:
bitrate: str = None,
) -> None:
"""
Write waveform data to the file denoted by the given path using
FFMPEG process.
@@ -159,25 +162,24 @@ class FFMPEGProcessAudioAdapter(AudioAdapter):
path = str(path)
directory = os.path.dirname(path)
if not os.path.exists(directory):
raise SpleeterError(
f'output directory does not exists: {directory}')
logger.debug(f'Writing file {path}')
input_kwargs = {'ar': sample_rate, 'ac': data.shape[1]}
output_kwargs = {'ar': sample_rate, 'strict': '-2'}
raise SpleeterError(f"output directory does not exists: {directory}")
logger.debug(f"Writing file {path}")
input_kwargs = {"ar": sample_rate, "ac": data.shape[1]}
output_kwargs = {"ar": sample_rate, "strict": "-2"}
if bitrate:
output_kwargs['audio_bitrate'] = bitrate
if codec is not None and codec != 'wav':
output_kwargs['codec'] = self.SUPPORTED_CODECS.get(codec, codec)
output_kwargs["audio_bitrate"] = bitrate
if codec is not None and codec != "wav":
output_kwargs["codec"] = self.SUPPORTED_CODECS.get(codec, codec)
process = (
ffmpeg
.input('pipe:', format='f32le', **input_kwargs)
ffmpeg.input("pipe:", format="f32le", **input_kwargs)
.output(path, **output_kwargs)
.overwrite_output()
.run_async(pipe_stdin=True, pipe_stderr=True, quiet=True))
.run_async(pipe_stdin=True, pipe_stderr=True, quiet=True)
)
try:
process.stdin.write(data.astype('<f4').tobytes())
process.stdin.write(data.astype("<f4").tobytes())
process.stdin.close()
process.wait()
except IOError:
raise SpleeterError(f'FFMPEG error: {process.stderr.read()}')
logger.info(f'File {path} written succesfully')
raise SpleeterError(f"FFMPEG error: {process.stderr.read()}")
logger.info(f"File {path} written succesfully")

View File

@@ -7,21 +7,22 @@
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
from tensorflow.signal import hann_window, stft
from tensorflow.signal import stft, hann_window
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def compute_spectrogram_tf(
waveform: tf.Tensor,
frame_length: int = 2048,
frame_step: int = 512,
spec_exponent: float = 1.,
window_exponent: float = 1.) -> tf.Tensor:
spec_exponent: float = 1.0,
window_exponent: float = 1.0,
) -> tf.Tensor:
"""
Compute magnitude / power spectrogram from waveform as a
`n_samples x n_channels` tensor.
@@ -51,18 +52,20 @@ def compute_spectrogram_tf(
frame_length,
frame_step,
window_fn=lambda f, dtype: hann_window(
f,
periodic=True,
dtype=waveform.dtype) ** window_exponent),
perm=[1, 2, 0])
f, periodic=True, dtype=waveform.dtype
)
** window_exponent,
),
perm=[1, 2, 0],
)
return tf.abs(stft_tensor) ** spec_exponent
def time_stretch(
spectrogram: tf.Tensor,
factor: float = 1.0,
method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR
) -> tf.Tensor:
method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR,
) -> tf.Tensor:
"""
Time stretch a spectrogram preserving shape in tensorflow. Note that
this is an approximation in the frequency domain.
@@ -83,18 +86,14 @@ def time_stretch(
T_ts = tf.cast(tf.cast(T, tf.float32) * factor, tf.int32)[0]
F = tf.shape(spectrogram)[1]
ts_spec = tf.image.resize_images(
spectrogram,
[T_ts, F],
method=method,
align_corners=True)
spectrogram, [T_ts, F], method=method, align_corners=True
)
return tf.image.resize_image_with_crop_or_pad(ts_spec, T, F)
def random_time_stretch(
spectrogram: tf.Tensor,
factor_min: float = 0.9,
factor_max: float = 1.1,
**kwargs) -> tf.Tensor:
spectrogram: tf.Tensor, factor_min: float = 0.9, factor_max: float = 1.1, **kwargs
) -> tf.Tensor:
"""
Time stretch a spectrogram preserving shape with random ratio in
tensorflow. Applies time_stretch to spectrogram with a random ratio
@@ -112,17 +111,17 @@ def random_time_stretch(
tensorflow.Tensor:
Randomly time stretched spectrogram as tensor with same shape.
"""
factor = tf.random_uniform(
shape=(1,),
seed=0) * (factor_max - factor_min) + factor_min
factor = (
tf.random_uniform(shape=(1,), seed=0) * (factor_max - factor_min) + factor_min
)
return time_stretch(spectrogram, factor=factor, **kwargs)
def pitch_shift(
spectrogram: tf.Tensor,
semitone_shift: float = 0.0,
method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR
) -> tf.Tensor:
method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR,
) -> tf.Tensor:
"""
Pitch shift a spectrogram preserving shape in tensorflow. Note that
this is an approximation in the frequency domain.
@@ -139,24 +138,20 @@ def pitch_shift(
tensorflow.Tensor:
Pitch shifted spectrogram (same shape as spectrogram).
"""
factor = 2 ** (semitone_shift / 12.)
factor = 2 ** (semitone_shift / 12.0)
T = tf.shape(spectrogram)[0]
F = tf.shape(spectrogram)[1]
F_ps = tf.cast(tf.cast(F, tf.float32) * factor, tf.int32)[0]
ps_spec = tf.image.resize_images(
spectrogram,
[T, F_ps],
method=method,
align_corners=True)
spectrogram, [T, F_ps], method=method, align_corners=True
)
paddings = [[0, 0], [0, tf.maximum(0, F - F_ps)], [0, 0]]
return tf.pad(ps_spec[:, :F, :], paddings, 'CONSTANT')
return tf.pad(ps_spec[:, :F, :], paddings, "CONSTANT")
def random_pitch_shift(
spectrogram: tf.Tensor,
shift_min: float = -1.,
shift_max: float = 1.,
**kwargs) -> tf.Tensor:
spectrogram: tf.Tensor, shift_min: float = -1.0, shift_max: float = 1.0, **kwargs
) -> tf.Tensor:
"""
Pitch shift a spectrogram preserving shape with random ratio in
tensorflow. Applies pitch_shift to spectrogram with a random shift
@@ -175,7 +170,7 @@ def random_pitch_shift(
tensorflow.Tensor:
Randomly pitch shifted spectrogram (same shape as spectrogram).
"""
semitone_shift = tf.random_uniform(
shape=(1,),
seed=0) * (shift_max - shift_min) + shift_min
semitone_shift = (
tf.random_uniform(shape=(1,), seed=0) * (shift_max - shift_min) + shift_min
)
return pitch_shift(spectrogram, semitone_shift=semitone_shift, **kwargs)

View File

@@ -14,45 +14,52 @@
(ground truth)
"""
import time
import os
from os.path import exists, sep as SEPARATOR
import time
from os.path import exists
from os.path import sep as SEPARATOR
from typing import Any, Dict, Optional
from .audio.adapter import AudioAdapter
from .audio.convertor import db_uint_spectrogram_to_gain
from .audio.convertor import spectrogram_to_db_uint
from .audio.spectrogram import compute_spectrogram_tf
from .audio.spectrogram import random_pitch_shift, random_time_stretch
from .utils.logging import logger
from .utils.tensor import check_tensor_shape, dataset_from_csv
from .utils.tensor import set_tensor_shape, sync_apply
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
from .audio.adapter import AudioAdapter
from .audio.convertor import db_uint_spectrogram_to_gain, spectrogram_to_db_uint
from .audio.spectrogram import (
compute_spectrogram_tf,
random_pitch_shift,
random_time_stretch,
)
from .utils.logging import logger
from .utils.tensor import (
check_tensor_shape,
dataset_from_csv,
set_tensor_shape,
sync_apply,
)
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
# Default audio parameters to use.
DEFAULT_AUDIO_PARAMS: Dict = {
'instrument_list': ('vocals', 'accompaniment'),
'mix_name': 'mix',
'sample_rate': 44100,
'frame_length': 4096,
'frame_step': 1024,
'T': 512,
'F': 1024}
"instrument_list": ("vocals", "accompaniment"),
"mix_name": "mix",
"sample_rate": 44100,
"frame_length": 4096,
"frame_step": 1024,
"T": 512,
"F": 1024,
}
def get_training_dataset(
audio_params: Dict,
audio_adapter: AudioAdapter,
audio_path: str) -> Any:
audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str
) -> Any:
"""
Builds training dataset.
@@ -72,22 +79,23 @@ def get_training_dataset(
audio_params,
audio_adapter,
audio_path,
chunk_duration=audio_params.get('chunk_duration', 20.0),
random_seed=audio_params.get('random_seed', 0))
chunk_duration=audio_params.get("chunk_duration", 20.0),
random_seed=audio_params.get("random_seed", 0),
)
return builder.build(
audio_params.get('train_csv'),
cache_directory=audio_params.get('training_cache'),
batch_size=audio_params.get('batch_size'),
n_chunks_per_song=audio_params.get('n_chunks_per_song', 2),
audio_params.get("train_csv"),
cache_directory=audio_params.get("training_cache"),
batch_size=audio_params.get("batch_size"),
n_chunks_per_song=audio_params.get("n_chunks_per_song", 2),
random_data_augmentation=False,
convert_to_uint=True,
wait_for_cache=False)
wait_for_cache=False,
)
def get_validation_dataset(
audio_params: Dict,
audio_adapter: AudioAdapter,
audio_path: str) -> Any:
audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str
) -> Any:
"""
Builds validation dataset.
@@ -104,14 +112,12 @@ def get_validation_dataset(
Built dataset.
"""
builder = DatasetBuilder(
audio_params,
audio_adapter,
audio_path,
chunk_duration=12.0)
audio_params, audio_adapter, audio_path, chunk_duration=12.0
)
return builder.build(
audio_params.get('validation_csv'),
batch_size=audio_params.get('batch_size'),
cache_directory=audio_params.get('validation_cache'),
audio_params.get("validation_csv"),
batch_size=audio_params.get("batch_size"),
cache_directory=audio_params.get("validation_cache"),
convert_to_uint=True,
infinite_generator=False,
n_chunks_per_song=1,
@@ -137,82 +143,117 @@ class InstrumentDatasetBuilder(object):
"""
self._parent = parent
self._instrument = instrument
self._spectrogram_key = f'{instrument}_spectrogram'
self._min_spectrogram_key = f'min_{instrument}_spectrogram'
self._max_spectrogram_key = f'max_{instrument}_spectrogram'
self._spectrogram_key = f"{instrument}_spectrogram"
self._min_spectrogram_key = f"min_{instrument}_spectrogram"
self._max_spectrogram_key = f"max_{instrument}_spectrogram"
def load_waveform(self, sample):
""" Load waveform for given sample. """
return dict(sample, **self._parent._audio_adapter.load_tf_waveform(
sample[f'{self._instrument}_path'],
offset=sample['start'],
return dict(
sample,
**self._parent._audio_adapter.load_tf_waveform(
sample[f"{self._instrument}_path"],
offset=sample["start"],
duration=self._parent._chunk_duration,
sample_rate=self._parent._sample_rate,
waveform_name='waveform'))
waveform_name="waveform",
),
)
def compute_spectrogram(self, sample):
""" Compute spectrogram of the given sample. """
return dict(sample, **{
return dict(
sample,
**{
self._spectrogram_key: compute_spectrogram_tf(
sample['waveform'],
sample["waveform"],
frame_length=self._parent._frame_length,
frame_step=self._parent._frame_step,
spec_exponent=1.,
window_exponent=1.)})
spec_exponent=1.0,
window_exponent=1.0,
)
},
)
def filter_frequencies(self, sample):
""" """
return dict(sample, **{
self._spectrogram_key:
sample[self._spectrogram_key][:, :self._parent._F, :]})
return dict(
sample,
**{
self._spectrogram_key: sample[self._spectrogram_key][
:, : self._parent._F, :
]
},
)
def convert_to_uint(self, sample):
""" Convert given sample from float to unit. """
return dict(sample, **spectrogram_to_db_uint(
return dict(
sample,
**spectrogram_to_db_uint(
sample[self._spectrogram_key],
tensor_key=self._spectrogram_key,
min_key=self._min_spectrogram_key,
max_key=self._max_spectrogram_key))
max_key=self._max_spectrogram_key,
),
)
def filter_infinity(self, sample):
""" Filter infinity sample. """
return tf.logical_not(
tf.math.is_inf(
sample[self._min_spectrogram_key]))
return tf.logical_not(tf.math.is_inf(sample[self._min_spectrogram_key]))
def convert_to_float32(self, sample):
""" Convert given sample from unit to float. """
return dict(sample, **{
return dict(
sample,
**{
self._spectrogram_key: db_uint_spectrogram_to_gain(
sample[self._spectrogram_key],
sample[self._min_spectrogram_key],
sample[self._max_spectrogram_key])})
sample[self._max_spectrogram_key],
)
},
)
def time_crop(self, sample):
""" """
def start(sample):
""" mid_segment_start """
return tf.cast(
tf.maximum(
tf.shape(sample[self._spectrogram_key])[0]
/ 2 - self._parent._T / 2, 0),
tf.int32)
return dict(sample, **{
tf.shape(sample[self._spectrogram_key])[0] / 2
- self._parent._T / 2,
0,
),
tf.int32,
)
return dict(
sample,
**{
self._spectrogram_key: sample[self._spectrogram_key][
start(sample):start(sample) + self._parent._T, :, :]})
start(sample) : start(sample) + self._parent._T, :, :
]
},
)
def filter_shape(self, sample):
""" Filter badly shaped sample. """
return check_tensor_shape(
sample[self._spectrogram_key], (
self._parent._T, self._parent._F, 2))
sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2)
)
def reshape_spectrogram(self, sample):
""" Reshape given sample. """
return dict(sample, **{
return dict(
sample,
**{
self._spectrogram_key: set_tensor_shape(
sample[self._spectrogram_key],
(self._parent._T, self._parent._F, 2))})
sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2)
)
},
)
class DatasetBuilder(object):
@@ -232,7 +273,8 @@ class DatasetBuilder(object):
audio_adapter: AudioAdapter,
audio_path: str,
random_seed: int = 0,
chunk_duration: float = 20.0) -> None:
chunk_duration: float = 20.0,
) -> None:
"""
Default constructor.
@@ -249,15 +291,15 @@ class DatasetBuilder(object):
"""
# Length of segment in frames (if fs=22050 and
# frame_step=512, then T=512 corresponds to 11.89s)
self._T = audio_params['T']
self._T = audio_params["T"]
# Number of frequency bins to be used (should
# be less than frame_length/2 + 1)
self._F = audio_params['F']
self._sample_rate = audio_params['sample_rate']
self._frame_length = audio_params['frame_length']
self._frame_step = audio_params['frame_step']
self._mix_name = audio_params['mix_name']
self._instruments = [self._mix_name] + audio_params['instrument_list']
self._F = audio_params["F"]
self._sample_rate = audio_params["sample_rate"]
self._frame_length = audio_params["frame_length"]
self._frame_step = audio_params["frame_step"]
self._mix_name = audio_params["mix_name"]
self._instruments = [self._mix_name] + audio_params["instrument_list"]
self._instrument_builders = None
self._chunk_duration = chunk_duration
self._audio_adapter = audio_adapter
@@ -267,76 +309,110 @@ class DatasetBuilder(object):
def expand_path(self, sample):
""" Expands audio paths for the given sample. """
return dict(sample, **{f'{instrument}_path': tf.strings.join(
(self._audio_path, sample[f'{instrument}_path']), SEPARATOR)
for instrument in self._instruments})
return dict(
sample,
**{
f"{instrument}_path": tf.strings.join(
(self._audio_path, sample[f"{instrument}_path"]), SEPARATOR
)
for instrument in self._instruments
},
)
def filter_error(self, sample):
""" Filter errored sample. """
return tf.logical_not(sample['waveform_error'])
return tf.logical_not(sample["waveform_error"])
def filter_waveform(self, sample):
""" Filter waveform from sample. """
return {k: v for k, v in sample.items() if not k == 'waveform'}
return {k: v for k, v in sample.items() if not k == "waveform"}
def harmonize_spectrogram(self, sample):
""" Ensure same size for vocals and mix spectrograms. """
def _reduce(sample):
return tf.reduce_min([
tf.shape(sample[f'{instrument}_spectrogram'])[0]
for instrument in self._instruments])
return dict(sample, **{
f'{instrument}_spectrogram':
sample[f'{instrument}_spectrogram'][:_reduce(sample), :, :]
for instrument in self._instruments})
return tf.reduce_min(
[
tf.shape(sample[f"{instrument}_spectrogram"])[0]
for instrument in self._instruments
]
)
return dict(
sample,
**{
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"][
: _reduce(sample), :, :
]
for instrument in self._instruments
},
)
def filter_short_segments(self, sample):
""" Filter out too short segment. """
return tf.reduce_any([
tf.shape(sample[f'{instrument}_spectrogram'])[0] >= self._T
for instrument in self._instruments])
return tf.reduce_any(
[
tf.shape(sample[f"{instrument}_spectrogram"])[0] >= self._T
for instrument in self._instruments
]
)
def random_time_crop(self, sample):
""" Random time crop of 11.88s. """
return dict(sample, **sync_apply({
f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram']
for instrument in self._instruments},
return dict(
sample,
**sync_apply(
{
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
for instrument in self._instruments
},
lambda x: tf.image.random_crop(
x, (self._T, len(self._instruments) * self._F, 2),
seed=self._random_seed)))
x,
(self._T, len(self._instruments) * self._F, 2),
seed=self._random_seed,
),
),
)
def random_time_stretch(self, sample):
""" Randomly time stretch the given sample. """
return dict(sample, **sync_apply({
f'{instrument}_spectrogram':
sample[f'{instrument}_spectrogram']
for instrument in self._instruments},
lambda x: random_time_stretch(
x, factor_min=0.9, factor_max=1.1)))
return dict(
sample,
**sync_apply(
{
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
for instrument in self._instruments
},
lambda x: random_time_stretch(x, factor_min=0.9, factor_max=1.1),
),
)
def random_pitch_shift(self, sample):
""" Randomly pitch shift the given sample. """
return dict(sample, **sync_apply({
f'{instrument}_spectrogram':
sample[f'{instrument}_spectrogram']
for instrument in self._instruments},
lambda x: random_pitch_shift(
x, shift_min=-1.0, shift_max=1.0), concat_axis=0))
return dict(
sample,
**sync_apply(
{
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
for instrument in self._instruments
},
lambda x: random_pitch_shift(x, shift_min=-1.0, shift_max=1.0),
concat_axis=0,
),
)
def map_features(self, sample):
""" Select features and annotation of the given sample. """
input_ = {
f'{self._mix_name}_spectrogram':
sample[f'{self._mix_name}_spectrogram']}
f"{self._mix_name}_spectrogram": sample[f"{self._mix_name}_spectrogram"]
}
output = {
f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram']
for instrument in self._audio_params['instrument_list']}
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
for instrument in self._audio_params["instrument_list"]
}
return (input_, output)
def compute_segments(
self,
dataset: Any,
n_chunks_per_song: int) -> Any:
def compute_segments(self, dataset: Any, n_chunks_per_song: int) -> Any:
"""
Computes segments for each song of the dataset.
@@ -351,21 +427,39 @@ class DatasetBuilder(object):
Segmented dataset.
"""
if n_chunks_per_song <= 0:
raise ValueError('n_chunks_per_song must be positif')
raise ValueError("n_chunks_per_song must be positif")
datasets = []
for k in range(n_chunks_per_song):
if n_chunks_per_song > 1:
datasets.append(
dataset.map(lambda sample: dict(sample, start=tf.maximum(
k * (
sample['duration'] - self._chunk_duration - 2
* self.MARGIN) / (n_chunks_per_song - 1)
+ self.MARGIN, 0))))
dataset.map(
lambda sample: dict(
sample,
start=tf.maximum(
k
* (
sample["duration"]
- self._chunk_duration
- 2 * self.MARGIN
)
/ (n_chunks_per_song - 1)
+ self.MARGIN,
0,
),
)
)
)
elif n_chunks_per_song == 1: # Take central segment.
datasets.append(
dataset.map(lambda sample: dict(sample, start=tf.maximum(
sample['duration'] / 2 - self._chunk_duration / 2,
0))))
dataset.map(
lambda sample: dict(
sample,
start=tf.maximum(
sample["duration"] / 2 - self._chunk_duration / 2, 0
),
)
)
)
dataset = datasets[-1]
for d in datasets[:-1]:
dataset = dataset.concatenate(d)
@@ -384,15 +478,12 @@ class DatasetBuilder(object):
self._instrument_builders = []
for instrument in self._instruments:
self._instrument_builders.append(
InstrumentDatasetBuilder(self, instrument))
InstrumentDatasetBuilder(self, instrument)
)
for builder in self._instrument_builders:
yield builder
def cache(
self,
dataset: Any,
cache: str,
wait: bool) -> Any:
def cache(self, dataset: Any, cache: str, wait: bool) -> Any:
"""
Cache the given dataset if cache is enabled. Eventually waits for
cache to be available (useful if another process is already
@@ -412,9 +503,8 @@ class DatasetBuilder(object):
"""
if cache is not None:
if wait:
while not exists(f'{cache}.index'):
logger.info(
f'Cache not available, wait {self.WAIT_PERIOD}')
while not exists(f"{cache}.index"):
logger.info(f"Cache not available, wait {self.WAIT_PERIOD}")
time.sleep(self.WAIT_PERIOD)
cache_path = os.path.split(cache)[0]
os.makedirs(cache_path, exist_ok=True)
@@ -433,7 +523,8 @@ class DatasetBuilder(object):
cache_directory: Optional[str] = None,
wait_for_cache: bool = False,
num_parallel_calls: int = 4,
n_chunks_per_song: float = 2,) -> Any:
n_chunks_per_song: float = 2,
) -> Any:
"""
TO BE DOCUMENTED.
"""
@@ -445,7 +536,8 @@ class DatasetBuilder(object):
buffer_size=200000,
seed=self._random_seed,
# useless since it is cached :
reshuffle_each_iteration=True)
reshuffle_each_iteration=True,
)
# Expand audio path.
dataset = dataset.map(self.expand_path)
# Load waveform, compute spectrogram, and filtering error,
@@ -453,11 +545,11 @@ class DatasetBuilder(object):
N = num_parallel_calls
for instrument in self.instruments:
dataset = (
dataset
.map(instrument.load_waveform, num_parallel_calls=N)
dataset.map(instrument.load_waveform, num_parallel_calls=N)
.filter(self.filter_error)
.map(instrument.compute_spectrogram, num_parallel_calls=N)
.map(instrument.filter_frequencies))
.map(instrument.filter_frequencies)
)
dataset = dataset.map(self.filter_waveform)
# Convert to uint before caching in order to save space.
if convert_to_uint:
@@ -488,26 +580,25 @@ class DatasetBuilder(object):
# after croping but before converting back to float.
if shuffle:
dataset = dataset.shuffle(
buffer_size=256, seed=self._random_seed,
reshuffle_each_iteration=True)
buffer_size=256, seed=self._random_seed, reshuffle_each_iteration=True
)
# Convert back to float32
if convert_to_uint:
for instrument in self.instruments:
dataset = dataset.map(
instrument.convert_to_float32, num_parallel_calls=N)
instrument.convert_to_float32, num_parallel_calls=N
)
M = 8 # Parallel call post caching.
# Must be applied with the same factor on mix and vocals.
if random_data_augmentation:
dataset = (
dataset
.map(self.random_time_stretch, num_parallel_calls=M)
.map(self.random_pitch_shift, num_parallel_calls=M))
dataset = dataset.map(self.random_time_stretch, num_parallel_calls=M).map(
self.random_pitch_shift, num_parallel_calls=M
)
# Filter by shape (remove badly shaped tensors).
for instrument in self.instruments:
dataset = (
dataset
.filter(instrument.filter_shape)
.map(instrument.reshape_spectrogram))
dataset = dataset.filter(instrument.filter_shape).map(
instrument.reshape_spectrogram
)
# Select features and annotation.
dataset = dataset.map(self.map_features)
# Make batch (done after selection to avoid

View File

@@ -8,15 +8,16 @@ import importlib
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
from tensorflow.signal import stft, inverse_stft, hann_window
# pylint: enable=import-error
from tensorflow.signal import hann_window, inverse_stft, stft
from ..utils.tensor import pad_and_partition, pad_and_reshape
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# pylint: enable=import-error
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
placeholder = tf.compat.v1.placeholder
@@ -36,17 +37,16 @@ def get_model_function(model_type):
A tensorflow function to be applied to the input tensor to get the
multitrack output.
"""
relative_path_to_module = '.'.join(model_type.split('.')[:-1])
model_name = model_type.split('.')[-1]
main_module = '.'.join((__name__, 'functions'))
path_to_module = f'{main_module}.{relative_path_to_module}'
relative_path_to_module = ".".join(model_type.split(".")[:-1])
model_name = model_type.split(".")[-1]
main_module = ".".join((__name__, "functions"))
path_to_module = f"{main_module}.{relative_path_to_module}"
module = importlib.import_module(path_to_module)
model_function = getattr(module, model_name)
return model_function
class InputProvider(object):
def __init__(self, params):
self.params = params
@@ -62,16 +62,16 @@ class InputProvider(object):
class WaveformInputProvider(InputProvider):
@property
def input_names(self):
return ["audio_id", "waveform"]
def get_input_dict_placeholders(self):
shape = (None, self.params['n_channels'])
shape = (None, self.params["n_channels"])
features = {
'waveform': placeholder(tf.float32, shape=shape, name="waveform"),
'audio_id': placeholder(tf.string, name="audio_id")}
"waveform": placeholder(tf.float32, shape=shape, name="waveform"),
"audio_id": placeholder(tf.string, name="audio_id"),
}
return features
def get_feed_dict(self, features, waveform, audio_id):
@@ -79,7 +79,6 @@ class WaveformInputProvider(InputProvider):
class SpectralInputProvider(InputProvider):
def __init__(self, params):
super().__init__(params)
self.stft_input_name = "{}_stft".format(self.params["mix_name"])
@@ -90,11 +89,17 @@ class SpectralInputProvider(InputProvider):
def get_input_dict_placeholders(self):
features = {
self.stft_input_name: placeholder(tf.complex64,
shape=(None, self.params["frame_length"]//2+1,
self.params['n_channels']),
name=self.stft_input_name),
'audio_id': placeholder(tf.string, name="audio_id")}
self.stft_input_name: placeholder(
tf.complex64,
shape=(
None,
self.params["frame_length"] // 2 + 1,
self.params["n_channels"],
),
name=self.stft_input_name,
),
"audio_id": placeholder(tf.string, name="audio_id"),
}
return features
def get_feed_dict(self, features, stft, audio_id):
@@ -102,11 +107,13 @@ class SpectralInputProvider(InputProvider):
class InputProviderFactory(object):
@staticmethod
def get(params):
stft_backend = params["stft_backend"]
assert stft_backend in ("tensorflow", "librosa"), "Unexpected backend {}".format(stft_backend)
assert stft_backend in (
"tensorflow",
"librosa",
), "Unexpected backend {}".format(stft_backend)
if stft_backend == "tensorflow":
return WaveformInputProvider(params)
else:
@@ -114,7 +121,7 @@ class InputProviderFactory(object):
class EstimatorSpecBuilder(object):
""" A builder class that allows to builds a multitrack unet model
"""A builder class that allows to builds a multitrack unet model
estimator. The built model estimator has a different behaviour when
used in a train/eval mode and in predict mode.
@@ -138,22 +145,22 @@ class EstimatorSpecBuilder(object):
"""
# Supported model functions.
DEFAULT_MODEL = 'unet.unet'
DEFAULT_MODEL = "unet.unet"
# Supported loss functions.
L1_MASK = 'L1_mask'
WEIGHTED_L1_MASK = 'weighted_L1_mask'
L1_MASK = "L1_mask"
WEIGHTED_L1_MASK = "weighted_L1_mask"
# Supported optimizers.
ADADELTA = 'Adadelta'
SGD = 'SGD'
ADADELTA = "Adadelta"
SGD = "SGD"
# Math constants.
WINDOW_COMPENSATION_FACTOR = 2./3.
WINDOW_COMPENSATION_FACTOR = 2.0 / 3.0
EPSILON = 1e-10
def __init__(self, features, params):
""" Default constructor. Depending on built model
"""Default constructor. Depending on built model
usage, the provided features should be different:
* In train/eval mode: features is a dictionary with a
@@ -170,20 +177,20 @@ class EstimatorSpecBuilder(object):
self._features = features
self._params = params
# Get instrument name.
self._mix_name = params['mix_name']
self._instruments = params['instrument_list']
self._mix_name = params["mix_name"]
self._instruments = params["instrument_list"]
# Get STFT/signals parameters
self._n_channels = params['n_channels']
self._T = params['T']
self._F = params['F']
self._frame_length = params['frame_length']
self._frame_step = params['frame_step']
self._n_channels = params["n_channels"]
self._T = params["T"]
self._F = params["F"]
self._frame_length = params["frame_length"]
self._frame_step = params["frame_step"]
def include_stft_computations(self):
return self._params["stft_backend"] == "tensorflow"
def _build_model_outputs(self):
""" Created a batch_sizexTxFxn_channels input tensor containing
"""Created a batch_sizexTxFxn_channels input tensor containing
mix magnitude spectrogram, then an output dict from it according
to the selected model in internal parameters.
@@ -192,22 +199,21 @@ class EstimatorSpecBuilder(object):
"""
input_tensor = self.spectrogram_feature
model = self._params.get('model', None)
model = self._params.get("model", None)
if model is not None:
model_type = model.get('type', self.DEFAULT_MODEL)
model_type = model.get("type", self.DEFAULT_MODEL)
else:
model_type = self.DEFAULT_MODEL
try:
apply_model = get_model_function(model_type)
except ModuleNotFoundError:
raise ValueError(f'No model function {model_type} found')
raise ValueError(f"No model function {model_type} found")
self._model_outputs = apply_model(
input_tensor,
self._instruments,
self._params['model']['params'])
input_tensor, self._instruments, self._params["model"]["params"]
)
def _build_loss(self, labels):
""" Construct tensorflow loss and metrics
"""Construct tensorflow loss and metrics
:param output_dict: dictionary of network outputs (key: instrument
name, value: estimated spectrogram of the instrument)
@@ -216,7 +222,7 @@ class EstimatorSpecBuilder(object):
:returns: tensorflow (loss, metrics) tuple.
"""
output_dict = self.model_outputs
loss_type = self._params.get('loss_type', self.L1_MASK)
loss_type = self._params.get("loss_type", self.L1_MASK)
if loss_type == self.L1_MASK:
losses = {
name: tf.reduce_mean(tf.abs(output - labels[name]))
@@ -225,11 +231,9 @@ class EstimatorSpecBuilder(object):
elif loss_type == self.WEIGHTED_L1_MASK:
losses = {
name: tf.reduce_mean(
tf.reduce_mean(
labels[name],
axis=[1, 2, 3],
keep_dims=True) *
tf.abs(output - labels[name]))
tf.reduce_mean(labels[name], axis=[1, 2, 3], keep_dims=True)
* tf.abs(output - labels[name])
)
for name, output in output_dict.items()
}
else:
@@ -237,20 +241,20 @@ class EstimatorSpecBuilder(object):
loss = tf.reduce_sum(list(losses.values()))
# Add metrics for monitoring each instrument.
metrics = {k: tf.compat.v1.metrics.mean(v) for k, v in losses.items()}
metrics['absolute_difference'] = tf.compat.v1.metrics.mean(loss)
metrics["absolute_difference"] = tf.compat.v1.metrics.mean(loss)
return loss, metrics
def _build_optimizer(self):
""" Builds an optimizer instance from internal parameter values.
"""Builds an optimizer instance from internal parameter values.
Default to AdamOptimizer if not specified.
:returns: Optimizer instance from internal configuration.
"""
name = self._params.get('optimizer')
name = self._params.get("optimizer")
if name == self.ADADELTA:
return tf.compat.v1.train.AdadeltaOptimizer()
rate = self._params['learning_rate']
rate = self._params["learning_rate"]
if name == self.SGD:
return tf.compat.v1.train.GradientDescentOptimizer(rate)
return tf.compat.v1.train.AdamOptimizer(rate)
@@ -261,14 +265,14 @@ class EstimatorSpecBuilder(object):
@property
def stft_name(self):
return f'{self._mix_name}_stft'
return f"{self._mix_name}_stft"
@property
def spectrogram_name(self):
return f'{self._mix_name}_spectrogram'
return f"{self._mix_name}_spectrogram"
def _build_stft_feature(self):
""" Compute STFT of waveform and slice the STFT in segment
"""Compute STFT of waveform and slice the STFT in segment
with the right length to feed the network.
"""
@@ -277,11 +281,12 @@ class EstimatorSpecBuilder(object):
if stft_name not in self._features:
# pad input with a frame of zeros
waveform = tf.concat([
waveform = tf.concat(
[
tf.zeros((self._frame_length, self._n_channels)),
self._features['waveform']
self._features["waveform"],
],
0
0,
)
stft_feature = tf.transpose(
stft(
@@ -289,13 +294,17 @@ class EstimatorSpecBuilder(object):
self._frame_length,
self._frame_step,
window_fn=lambda frame_length, dtype: (
hann_window(frame_length, periodic=True, dtype=dtype)),
pad_end=True),
perm=[1, 2, 0])
self._features[f'{self._mix_name}_stft'] = stft_feature
hann_window(frame_length, periodic=True, dtype=dtype)
),
pad_end=True,
),
perm=[1, 2, 0],
)
self._features[f"{self._mix_name}_stft"] = stft_feature
if spec_name not in self._features:
self._features[spec_name] = tf.abs(
pad_and_partition(self._features[stft_name], self._T))[:, :, :self._F, :]
pad_and_partition(self._features[stft_name], self._T)
)[:, :, : self._F, :]
@property
def model_outputs(self):
@@ -334,25 +343,29 @@ class EstimatorSpecBuilder(object):
return self._masked_stfts
def _inverse_stft(self, stft_t, time_crop=None):
""" Inverse and reshape the given STFT
"""Inverse and reshape the given STFT
:param stft_t: input STFT
:returns: inverse STFT (waveform)
"""
inversed = inverse_stft(
inversed = (
inverse_stft(
tf.transpose(stft_t, perm=[2, 0, 1]),
self._frame_length,
self._frame_step,
window_fn=lambda frame_length, dtype: (
hann_window(frame_length, periodic=True, dtype=dtype))
) * self.WINDOW_COMPENSATION_FACTOR
hann_window(frame_length, periodic=True, dtype=dtype)
),
)
* self.WINDOW_COMPENSATION_FACTOR
)
reshaped = tf.transpose(inversed)
if time_crop is None:
time_crop = tf.shape(self._features['waveform'])[0]
return reshaped[self._frame_length:self._frame_length+time_crop, :]
time_crop = tf.shape(self._features["waveform"])[0]
return reshaped[self._frame_length : self._frame_length + time_crop, :]
def _build_mwf_output_waveform(self):
""" Perform separation with multichannel Wiener Filtering using Norbert.
"""Perform separation with multichannel Wiener Filtering using Norbert.
Note: multichannel Wiener Filtering is not coded in Tensorflow and thus
may be quite slow.
@@ -360,36 +373,42 @@ class EstimatorSpecBuilder(object):
value: estimated waveform of the instrument)
"""
import norbert # pylint: disable=import-error
output_dict = self.model_outputs
x = self.stft_feature
v = tf.stack(
[
pad_and_reshape(
output_dict[f'{instrument}_spectrogram'],
output_dict[f"{instrument}_spectrogram"],
self._frame_length,
self._F)[:tf.shape(x)[0], ...]
self._F,
)[: tf.shape(x)[0], ...]
for instrument in self._instruments
],
axis=3)
axis=3,
)
input_args = [v, x]
stft_function = tf.py_function(
stft_function = (
tf.py_function(
lambda v, x: norbert.wiener(v.numpy(), x.numpy()),
input_args,
tf.complex64),
tf.complex64,
),
)
return {
instrument: self._inverse_stft(stft_function[0][:, :, :, k])
for k, instrument in enumerate(self._instruments)
}
def _extend_mask(self, mask):
""" Extend mask, from reduced number of frequency bin to the number of
"""Extend mask, from reduced number of frequency bin to the number of
frequency bin in the STFT.
:param mask: restricted mask
:returns: extended mask
:raise ValueError: If invalid mask_extension parameter is set.
"""
extension = self._params['mask_extension']
extension = self._params["mask_extension"]
# Extend with average
# (dispatch according to energy in the processed band)
if extension == "average":
@@ -398,13 +417,9 @@ class EstimatorSpecBuilder(object):
# (avoid extension artifacts but not conservative separation)
elif extension == "zeros":
mask_shape = tf.shape(mask)
extension_row = tf.zeros((
mask_shape[0],
mask_shape[1],
1,
mask_shape[-1]))
extension_row = tf.zeros((mask_shape[0], mask_shape[1], 1, mask_shape[-1]))
else:
raise ValueError(f'Invalid mask_extension parameter {extension}')
raise ValueError(f"Invalid mask_extension parameter {extension}")
n_extra_row = self._frame_length // 2 + 1 - self._F
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
return tf.concat([mask, extension], axis=2)
@@ -416,29 +431,31 @@ class EstimatorSpecBuilder(object):
"""
output_dict = self.model_outputs
stft_feature = self.stft_feature
separation_exponent = self._params['separation_exponent']
output_sum = tf.reduce_sum(
[e ** separation_exponent for e in output_dict.values()],
axis=0
) + self.EPSILON
separation_exponent = self._params["separation_exponent"]
output_sum = (
tf.reduce_sum(
[e ** separation_exponent for e in output_dict.values()], axis=0
)
+ self.EPSILON
)
out = {}
for instrument in self._instruments:
output = output_dict[f'{instrument}_spectrogram']
output = output_dict[f"{instrument}_spectrogram"]
# Compute mask with the model.
instrument_mask = (output ** separation_exponent
+ (self.EPSILON / len(output_dict))) / output_sum
instrument_mask = (
output ** separation_exponent + (self.EPSILON / len(output_dict))
) / output_sum
# Extend mask;
instrument_mask = self._extend_mask(instrument_mask)
# Stack back mask.
old_shape = tf.shape(instrument_mask)
new_shape = tf.concat(
[[old_shape[0] * old_shape[1]], old_shape[2:]],
axis=0)
[[old_shape[0] * old_shape[1]], old_shape[2:]], axis=0
)
instrument_mask = tf.reshape(instrument_mask, new_shape)
# Remove padded part (for mask having the same size as STFT);
instrument_mask = instrument_mask[
:tf.shape(stft_feature)[0], ...]
instrument_mask = instrument_mask[: tf.shape(stft_feature)[0], ...]
out[instrument] = instrument_mask
self._masks = out
@@ -450,7 +467,7 @@ class EstimatorSpecBuilder(object):
self._masked_stfts = out
def _build_manual_output_waveform(self, masked_stft):
""" Perform ratio mask separation
"""Perform ratio mask separation
:param output_dict: dictionary of estimated spectrogram (key: instrument
name, value: estimated spectrogram of the instrument)
@@ -464,14 +481,14 @@ class EstimatorSpecBuilder(object):
return output_waveform
def _build_output_waveform(self, masked_stft):
""" Build output waveform from given output dict in order to be used in
"""Build output waveform from given output dict in order to be used in
prediction context. Regarding of the configuration building method will
be using MWF.
:returns: Built output waveform.
"""
if self._params.get('MWF', False):
if self._params.get("MWF", False):
output_waveform = self._build_mwf_output_waveform()
else:
output_waveform = self._build_manual_output_waveform(masked_stft)
@@ -483,11 +500,11 @@ class EstimatorSpecBuilder(object):
else:
self._outputs = self.masked_stfts
if 'audio_id' in self._features:
self._outputs['audio_id'] = self._features['audio_id']
if "audio_id" in self._features:
self._outputs["audio_id"] = self._features["audio_id"]
def build_predict_model(self):
""" Builder interface for creating model instance that aims to perform
"""Builder interface for creating model instance that aims to perform
prediction / inference over given track. The output of such estimator
will be a dictionary with a "<instrument>" key per separated instrument
, associated to the estimated separated waveform of the instrument.
@@ -496,11 +513,11 @@ class EstimatorSpecBuilder(object):
"""
return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.PREDICT,
predictions=self.outputs)
tf.estimator.ModeKeys.PREDICT, predictions=self.outputs
)
def build_evaluation_model(self, labels):
""" Builder interface for creating model instance that aims to perform
"""Builder interface for creating model instance that aims to perform
model evaluation. The output of such estimator will be a dictionary
with a key "<instrument>_spectrogram" per separated instrument,
associated to the estimated separated instrument magnitude spectrogram.
@@ -510,12 +527,11 @@ class EstimatorSpecBuilder(object):
"""
loss, metrics = self._build_loss(labels)
return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.EVAL,
loss=loss,
eval_metric_ops=metrics)
tf.estimator.ModeKeys.EVAL, loss=loss, eval_metric_ops=metrics
)
def build_train_model(self, labels):
""" Builder interface for creating model instance that aims to perform
"""Builder interface for creating model instance that aims to perform
model training. The output of such estimator will be a dictionary
with a key "<instrument>_spectrogram" per separated instrument,
associated to the estimated separated instrument magnitude spectrogram.
@@ -526,8 +542,8 @@ class EstimatorSpecBuilder(object):
loss, metrics = self._build_loss(labels)
optimizer = self._build_optimizer()
train_operation = optimizer.minimize(
loss=loss,
global_step=tf.compat.v1.train.get_global_step())
loss=loss, global_step=tf.compat.v1.train.get_global_step()
)
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN,
loss=loss,
@@ -554,4 +570,4 @@ def model_fn(features, labels, mode, params, config):
return builder.build_evaluation_model(labels)
elif mode == tf.estimator.ModeKeys.TRAIN:
return builder.build_train_model(labels)
raise ValueError(f'Unknown mode {mode}')
raise ValueError(f"Unknown mode {mode}")

View File

@@ -8,18 +8,20 @@ from typing import Callable, Dict, Iterable, Optional
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def apply(
function: Callable,
input_tensor: tf.Tensor,
instruments: Iterable[str],
params: Optional[Dict] = None) -> Dict:
params: Optional[Dict] = None,
) -> Dict:
"""
Apply given function to the input tensor.
@@ -38,9 +40,8 @@ def apply(
"""
output_dict: Dict = {}
for instrument in instruments:
out_name = f'{instrument}_spectrogram'
out_name = f"{instrument}_spectrogram"
output_dict[out_name] = function(
input_tensor,
output_name=out_name,
params=params or {})
input_tensor, output_name=out_name, params=params or {}
)
return output_dict

View File

@@ -22,12 +22,9 @@
from typing import Dict, Optional
from . import apply
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
from tensorflow.compat.v1.keras.initializers import he_uniform
from tensorflow.compat.v1.keras.layers import CuDNNLSTM
from tensorflow.keras.layers import (
@@ -35,18 +32,21 @@ from tensorflow.keras.layers import (
Dense,
Flatten,
Reshape,
TimeDistributed)
TimeDistributed,
)
from . import apply
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def apply_blstm(
input_tensor: tf.Tensor,
output_name: str = 'output',
params: Optional[Dict] = None) -> tf.Tensor:
input_tensor: tf.Tensor, output_name: str = "output", params: Optional[Dict] = None
) -> tf.Tensor:
"""
Apply BLSTM to the given input_tensor.
@@ -64,16 +64,16 @@ def apply_blstm(
"""
if params is None:
params = {}
units: int = params.get('lstm_units', 250)
units: int = params.get("lstm_units", 250)
kernel_initializer = he_uniform(seed=50)
flatten_input = TimeDistributed(Flatten())((input_tensor))
def create_bidirectional():
return Bidirectional(
CuDNNLSTM(
units,
kernel_initializer=kernel_initializer,
return_sequences=True))
units, kernel_initializer=kernel_initializer, return_sequences=True
)
)
l1 = create_bidirectional()((flatten_input))
l2 = create_bidirectional()((l1))
@@ -81,17 +81,18 @@ def apply_blstm(
dense = TimeDistributed(
Dense(
int(flatten_input.shape[2]),
activation='relu',
kernel_initializer=kernel_initializer))((l3))
activation="relu",
kernel_initializer=kernel_initializer,
)
)((l3))
output: tf.Tensor = TimeDistributed(
Reshape(input_tensor.shape[2:]),
name=output_name)(dense)
Reshape(input_tensor.shape[2:]), name=output_name
)(dense)
return output
def blstm(
input_tensor: tf.Tensor,
output_name: str = 'output',
params: Optional[Dict] = None) -> tf.Tensor:
input_tensor: tf.Tensor, output_name: str = "output", params: Optional[Dict] = None
) -> tf.Tensor:
""" Model function applier. """
return apply(apply_blstm, input_tensor, output_name, params)

View File

@@ -16,30 +16,31 @@
from functools import partial
from typing import Any, Dict, Iterable, Optional
from . import apply
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
from tensorflow.compat.v1 import logging
from tensorflow.compat.v1.keras.initializers import he_uniform
from tensorflow.keras.layers import (
ELU,
BatchNormalization,
Concatenate,
Conv2D,
Conv2DTranspose,
Dropout,
ELU,
LeakyReLU,
Multiply,
ReLU,
Softmax)
from tensorflow.compat.v1 import logging
from tensorflow.compat.v1.keras.initializers import he_uniform
Softmax,
)
from . import apply
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def _get_conv_activation_layer(params: Dict) -> Any:
@@ -53,10 +54,10 @@ def _get_conv_activation_layer(params: Dict) -> Any:
Any:
Required Activation function.
"""
conv_activation: str = params.get('conv_activation')
if conv_activation == 'ReLU':
conv_activation: str = params.get("conv_activation")
if conv_activation == "ReLU":
return ReLU()
elif conv_activation == 'ELU':
elif conv_activation == "ELU":
return ELU()
return LeakyReLU(0.2)
@@ -72,19 +73,20 @@ def _get_deconv_activation_layer(params: Dict) -> Any:
Any:
Required Activation function.
"""
deconv_activation: str = params.get('deconv_activation')
if deconv_activation == 'LeakyReLU':
deconv_activation: str = params.get("deconv_activation")
if deconv_activation == "LeakyReLU":
return LeakyReLU(0.2)
elif deconv_activation == 'ELU':
elif deconv_activation == "ELU":
return ELU()
return ReLU()
def apply_unet(
input_tensor: tf.Tensor,
output_name: str = 'output',
output_name: str = "output",
params: Optional[Dict] = None,
output_mask_logit: bool = False) -> Any:
output_mask_logit: bool = False,
) -> Any:
"""
Apply a convolutionnal U-net to model a single instrument (one U-net
is used for each instrument).
@@ -95,16 +97,14 @@ def apply_unet(
params (Optional[Dict]):
output_mask_logit (bool):
"""
logging.info(f'Apply unet for {output_name}')
conv_n_filters = params.get('conv_n_filters', [16, 32, 64, 128, 256, 512])
logging.info(f"Apply unet for {output_name}")
conv_n_filters = params.get("conv_n_filters", [16, 32, 64, 128, 256, 512])
conv_activation_layer = _get_conv_activation_layer(params)
deconv_activation_layer = _get_deconv_activation_layer(params)
kernel_initializer = he_uniform(seed=50)
conv2d_factory = partial(
Conv2D,
strides=(2, 2),
padding='same',
kernel_initializer=kernel_initializer)
Conv2D, strides=(2, 2), padding="same", kernel_initializer=kernel_initializer
)
# First layer.
conv1 = conv2d_factory(conv_n_filters[0], (5, 5))(input_tensor)
batch1 = BatchNormalization(axis=-1)(conv1)
@@ -134,8 +134,9 @@ def apply_unet(
conv2d_transpose_factory = partial(
Conv2DTranspose,
strides=(2, 2),
padding='same',
kernel_initializer=kernel_initializer)
padding="same",
kernel_initializer=kernel_initializer,
)
#
up1 = conv2d_transpose_factory(conv_n_filters[4], (5, 5))((conv6))
up1 = deconv_activation_layer(up1)
@@ -174,31 +175,31 @@ def apply_unet(
2,
(4, 4),
dilation_rate=(2, 2),
activation='sigmoid',
padding='same',
kernel_initializer=kernel_initializer)((batch12))
activation="sigmoid",
padding="same",
kernel_initializer=kernel_initializer,
)((batch12))
output = Multiply(name=output_name)([up7, input_tensor])
return output
return Conv2D(
2,
(4, 4),
dilation_rate=(2, 2),
padding='same',
kernel_initializer=kernel_initializer)((batch12))
padding="same",
kernel_initializer=kernel_initializer,
)((batch12))
def unet(
input_tensor: tf.Tensor,
instruments: Iterable[str],
params: Optional[Dict] = None) -> Dict:
input_tensor: tf.Tensor, instruments: Iterable[str], params: Optional[Dict] = None
) -> Dict:
""" Model function applier. """
return apply(apply_unet, input_tensor, instruments, params)
def softmax_unet(
input_tensor: tf.Tensor,
instruments: Iterable[str],
params: Optional[Dict] = None) -> Dict:
input_tensor: tf.Tensor, instruments: Iterable[str], params: Optional[Dict] = None
) -> Dict:
"""
Apply softmax to multitrack unet in order to have mask suming to one.
@@ -216,18 +217,18 @@ def softmax_unet(
"""
logit_mask_list = []
for instrument in instruments:
out_name = f'{instrument}_spectrogram'
out_name = f"{instrument}_spectrogram"
logit_mask_list.append(
apply_unet(
input_tensor,
output_name=out_name,
params=params,
output_mask_logit=True))
output_mask_logit=True,
)
)
masks = Softmax(axis=4)(tf.stack(logit_mask_list, axis=4))
output_dict = {}
for i, instrument in enumerate(instruments):
out_name = f'{instrument}_spectrogram'
output_dict[out_name] = Multiply(name=out_name)([
masks[..., i],
input_tensor])
out_name = f"{instrument}_spectrogram"
output_dict[out_name] = Multiply(name=out_name)([masks[..., i], input_tensor])
return output_dict

View File

@@ -17,9 +17,9 @@ from abc import ABC, abstractmethod
from os import environ, makedirs
from os.path import exists, isabs, join, sep
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class ModelProvider(ABC):
@@ -28,8 +28,8 @@ class ModelProvider(ABC):
file download is not available.
"""
DEFAULT_MODEL_PATH: str = environ.get('MODEL_PATH', 'pretrained_models')
MODEL_PROBE_PATH: str = '.probe'
DEFAULT_MODEL_PATH: str = environ.get("MODEL_PATH", "pretrained_models")
MODEL_PROBE_PATH: str = ".probe"
@abstractmethod
def download(_, name: str, path: str) -> None:
@@ -54,8 +54,8 @@ class ModelProvider(ABC):
Directory to write probe into.
"""
probe: str = join(directory, ModelProvider.MODEL_PROBE_PATH)
with open(probe, 'w') as stream:
stream.write('OK')
with open(probe, "w") as stream:
stream.write("OK")
def get(self, model_directory: str) -> str:
"""
@@ -77,14 +77,12 @@ class ModelProvider(ABC):
if not exists(model_probe):
if not exists(model_directory):
makedirs(model_directory)
self.download(
model_directory.split(sep)[-1],
model_directory)
self.download(model_directory.split(sep)[-1], model_directory)
self.writeProbe(model_directory)
return model_directory
@classmethod
def default(_: type) -> 'ModelProvider':
def default(_: type) -> "ModelProvider":
"""
Builds and returns a default model provider.
@@ -93,4 +91,5 @@ class ModelProvider(ABC):
A default model provider instance to use.
"""
from .github import GithubModelProvider
return GithubModelProvider.from_environ()

View File

@@ -17,35 +17,35 @@
"""
import hashlib
import tarfile
import os
import tarfile
from os import environ
from tempfile import NamedTemporaryFile
from typing import Dict
from . import ModelProvider
from ...utils.logging import logger
# pyright: reportMissingImports=false
# pylint: disable=import-error
import httpx
from ...utils.logging import logger
from . import ModelProvider
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def compute_file_checksum(path):
""" Computes given path file sha256.
"""Computes given path file sha256.
:param path: Path of the file to compute checksum for.
:returns: File checksum.
"""
sha256 = hashlib.sha256()
with open(path, 'rb') as stream:
for chunk in iter(lambda: stream.read(4096), b''):
with open(path, "rb") as stream:
for chunk in iter(lambda: stream.read(4096), b""):
sha256.update(chunk)
return sha256.hexdigest()
@@ -53,19 +53,15 @@ def compute_file_checksum(path):
class GithubModelProvider(ModelProvider):
""" A ModelProvider implementation backed on Github for remote storage. """
DEFAULT_HOST: str = 'https://github.com'
DEFAULT_REPOSITORY: str = 'deezer/spleeter'
DEFAULT_HOST: str = "https://github.com"
DEFAULT_REPOSITORY: str = "deezer/spleeter"
CHECKSUM_INDEX: str = 'checksum.json'
LATEST_RELEASE: str = 'v1.4.0'
RELEASE_PATH: str = 'releases/download'
CHECKSUM_INDEX: str = "checksum.json"
LATEST_RELEASE: str = "v1.4.0"
RELEASE_PATH: str = "releases/download"
def __init__(
self,
host: str,
repository: str,
release: str) -> None:
""" Default constructor.
def __init__(self, host: str, repository: str, release: str) -> None:
"""Default constructor.
Parameters:
host (str):
@@ -80,7 +76,7 @@ class GithubModelProvider(ModelProvider):
self._release: str = release
@classmethod
def from_environ(cls: type) -> 'GithubModelProvider':
def from_environ(cls: type) -> "GithubModelProvider":
"""
Factory method that creates provider from envvars.
@@ -89,9 +85,10 @@ class GithubModelProvider(ModelProvider):
Created instance.
"""
return cls(
environ.get('GITHUB_HOST', cls.DEFAULT_HOST),
environ.get('GITHUB_REPOSITORY', cls.DEFAULT_REPOSITORY),
environ.get('GITHUB_RELEASE', cls.LATEST_RELEASE))
environ.get("GITHUB_HOST", cls.DEFAULT_HOST),
environ.get("GITHUB_REPOSITORY", cls.DEFAULT_REPOSITORY),
environ.get("GITHUB_RELEASE", cls.LATEST_RELEASE),
)
def checksum(self, name: str) -> str:
"""
@@ -108,17 +105,20 @@ class GithubModelProvider(ModelProvider):
ValueError:
If the given model name is not indexed.
"""
url: str = '/'.join((
url: str = "/".join(
(
self._host,
self._repository,
self.RELEASE_PATH,
self._release,
self.CHECKSUM_INDEX))
self.CHECKSUM_INDEX,
)
)
response: httpx.Response = httpx.get(url)
response.raise_for_status()
index: Dict = response.json()
if name not in index:
raise ValueError(f'No checksum for model {name}')
raise ValueError(f"No checksum for model {name}")
return index[name]
def download(self, name: str, path: str) -> None:
@@ -131,30 +131,26 @@ class GithubModelProvider(ModelProvider):
path (str):
Path of the directory to save model into.
"""
url: str = '/'.join((
self._host,
self._repository,
self.RELEASE_PATH,
self._release,
name))
url = f'{url}.tar.gz'
logger.info(f'Downloading model archive {url}')
url: str = "/".join(
(self._host, self._repository, self.RELEASE_PATH, self._release, name)
)
url = f"{url}.tar.gz"
logger.info(f"Downloading model archive {url}")
with httpx.Client(http2=True) as client:
with client.stream('GET', url) as response:
with client.stream("GET", url) as response:
response.raise_for_status()
archive = NamedTemporaryFile(delete=False)
try:
with archive as stream:
for chunk in response.iter_raw():
stream.write(chunk)
logger.info('Validating archive checksum')
logger.info("Validating archive checksum")
checksum: str = compute_file_checksum(archive.name)
if checksum != self.checksum(name):
raise IOError(
'Downloaded file is corrupted, please retry')
logger.info(f'Extracting downloaded {name} archive')
raise IOError("Downloaded file is corrupted, please retry")
logger.info(f"Extracting downloaded {name} archive")
with tarfile.open(name=archive.name) as tar:
tar.extractall(path=path)
finally:
os.unlink(archive.name)
logger.info(f'{name} model file(s) extracted')
logger.info(f"{name} model file(s) extracted")

View File

@@ -3,126 +3,126 @@
""" This modules provides spleeter command as well as CLI parsing methods. """
from tempfile import gettempdir
from os.path import join
from .audio import Codec, STFTBackend
from tempfile import gettempdir
from typer import Argument, Option
from typer.models import ArgumentInfo, OptionInfo
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
from .audio import Codec, STFTBackend
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
AudioInputArgument: ArgumentInfo = Argument(
...,
help='List of input audio file path',
help="List of input audio file path",
exists=True,
file_okay=True,
dir_okay=False,
readable=True,
resolve_path=True)
resolve_path=True,
)
AudioInputOption: OptionInfo = Option(
None,
'--inputs',
'-i',
help='(DEPRECATED) placeholder for deprecated input option')
None, "--inputs", "-i", help="(DEPRECATED) placeholder for deprecated input option"
)
AudioAdapterOption: OptionInfo = Option(
'spleeter.audio.ffmpeg.FFMPEGProcessAudioAdapter',
'--adapter',
'-a',
help='Name of the audio adapter to use for audio I/O')
"spleeter.audio.ffmpeg.FFMPEGProcessAudioAdapter",
"--adapter",
"-a",
help="Name of the audio adapter to use for audio I/O",
)
AudioOutputOption: OptionInfo = Option(
join(gettempdir(), 'separated_audio'),
'--output_path',
'-o',
help='Path of the output directory to write audio files in')
join(gettempdir(), "separated_audio"),
"--output_path",
"-o",
help="Path of the output directory to write audio files in",
)
AudioOffsetOption: OptionInfo = Option(
0.,
'--offset',
'-s',
help='Set the starting offset to separate audio from')
0.0, "--offset", "-s", help="Set the starting offset to separate audio from"
)
AudioDurationOption: OptionInfo = Option(
600.,
'--duration',
'-d',
600.0,
"--duration",
"-d",
help=(
'Set a maximum duration for processing audio '
'(only separate offset + duration first seconds of '
'the input file)'))
"Set a maximum duration for processing audio "
"(only separate offset + duration first seconds of "
"the input file)"
),
)
AudioSTFTBackendOption: OptionInfo = Option(
STFTBackend.AUTO,
'--stft-backend',
'-B',
"--stft-backend",
"-B",
case_sensitive=False,
help=(
'Who should be in charge of computing the stfts. Librosa is faster '
"Who should be in charge of computing the stfts. Librosa is faster "
'than tensorflow on CPU and uses less memory. "auto" will use '
'tensorflow when GPU acceleration is available and librosa when not'))
"tensorflow when GPU acceleration is available and librosa when not"
),
)
AudioCodecOption: OptionInfo = Option(
Codec.WAV,
'--codec',
'-c',
help='Audio codec to be used for the separated output')
Codec.WAV, "--codec", "-c", help="Audio codec to be used for the separated output"
)
AudioBitrateOption: OptionInfo = Option(
'128k',
'--bitrate',
'-b',
help='Audio bitrate to be used for the separated output')
"128k", "--bitrate", "-b", help="Audio bitrate to be used for the separated output"
)
FilenameFormatOption: OptionInfo = Option(
'{filename}/{instrument}.{codec}',
'--filename_format',
'-f',
"{filename}/{instrument}.{codec}",
"--filename_format",
"-f",
help=(
'Template string that will be formatted to generated'
'output filename. Such template should be Python formattable'
'string, and could use {filename}, {instrument}, and {codec}'
'variables'))
"Template string that will be formatted to generated"
"output filename. Such template should be Python formattable"
"string, and could use {filename}, {instrument}, and {codec}"
"variables"
),
)
ModelParametersOption: OptionInfo = Option(
'spleeter:2stems',
'--params_filename',
'-p',
help='JSON filename that contains params')
"spleeter:2stems",
"--params_filename",
"-p",
help="JSON filename that contains params",
)
MWFOption: OptionInfo = Option(
False,
'--mwf',
help='Whether to use multichannel Wiener filtering for separation')
False, "--mwf", help="Whether to use multichannel Wiener filtering for separation"
)
MUSDBDirectoryOption: OptionInfo = Option(
...,
'--mus_dir',
"--mus_dir",
exists=True,
dir_okay=True,
file_okay=False,
readable=True,
resolve_path=True,
help='Path to musDB dataset directory')
help="Path to musDB dataset directory",
)
TrainingDataDirectoryOption: OptionInfo = Option(
...,
'--data',
'-d',
"--data",
"-d",
exists=True,
dir_okay=True,
file_okay=False,
readable=True,
resolve_path=True,
help='Path of the folder containing audio data for training')
help="Path of the folder containing audio data for training",
)
VerboseOption: OptionInfo = Option(
False,
'--verbose',
help='Enable verbose logs')
VerboseOption: OptionInfo = Option(False, "--verbose", help="Enable verbose logs")

View File

@@ -3,6 +3,6 @@
""" Packages that provides static resources file for the library. """
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"

View File

@@ -16,34 +16,33 @@
import atexit
import os
from multiprocessing import Pool
from os.path import basename, join, splitext, dirname
from spleeter.model.provider import ModelProvider
from os.path import basename, dirname, join, splitext
from typing import Dict, Generator, Optional
from . import SpleeterError
from .audio import Codec, STFTBackend
from .audio.adapter import AudioAdapter
from .audio.convertor import to_stereo
from .model import model_fn
from .model import EstimatorSpecBuilder, InputProviderFactory
from .model.provider import ModelProvider
from .types import AudioDescriptor
from .utils.configuration import load_configuration
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
from librosa.core import stft, istft
from librosa.core import istft, stft
from scipy.signal.windows import hann
from spleeter.model.provider import ModelProvider
from . import SpleeterError
from .audio import Codec, STFTBackend
from .audio.adapter import AudioAdapter
from .audio.convertor import to_stereo
from .model import EstimatorSpecBuilder, InputProviderFactory, model_fn
from .model.provider import ModelProvider
from .types import AudioDescriptor
from .utils.configuration import load_configuration
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class DataGenerator(object):
@@ -81,18 +80,16 @@ def create_estimator(params, MWF):
"""
# Load model.
provider: ModelProvider = ModelProvider.default()
params['model_dir'] = provider.get(params['model_dir'])
params['MWF'] = MWF
params["model_dir"] = provider.get(params["model_dir"])
params["MWF"] = MWF
# Setup config
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
config = tf.estimator.RunConfig(session_config=session_config)
# Setup estimator
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params['model_dir'],
params=params,
config=config)
model_fn=model_fn, model_dir=params["model_dir"], params=params, config=config
)
return estimator
@@ -104,7 +101,8 @@ class Separator(object):
params_descriptor: str,
MWF: bool = False,
stft_backend: STFTBackend = STFTBackend.AUTO,
multiprocess: bool = True) -> None:
multiprocess: bool = True,
) -> None:
"""
Default constructor.
@@ -115,7 +113,7 @@ class Separator(object):
(Optional) `True` if MWF should be used, `False` otherwise.
"""
self._params = load_configuration(params_descriptor)
self._sample_rate = self._params['sample_rate']
self._sample_rate = self._params["sample_rate"]
self._MWF = MWF
self._tf_graph = tf.Graph()
self._prediction_generator = None
@@ -129,7 +127,7 @@ class Separator(object):
else:
self._pool = None
self._tasks = []
self._params['stft_backend'] = STFTBackend.resolve(stft_backend)
self._params["stft_backend"] = STFTBackend.resolve(stft_backend)
self._data_generator = DataGenerator()
def __del__(self) -> None:
@@ -151,16 +149,13 @@ class Separator(object):
def get_dataset():
return tf.data.Dataset.from_generator(
self._data_generator,
output_types={
'waveform': tf.float32,
'audio_id': tf.string},
output_shapes={
'waveform': (None, 2),
'audio_id': ()})
output_types={"waveform": tf.float32, "audio_id": tf.string},
output_shapes={"waveform": (None, 2), "audio_id": ()},
)
self._prediction_generator = estimator.predict(
get_dataset,
yield_single_examples=False)
get_dataset, yield_single_examples=False
)
return self._prediction_generator
def join(self, timeout: int = 200) -> None:
@@ -177,10 +172,8 @@ class Separator(object):
task.wait(timeout=timeout)
def _stft(
self,
data: np.ndarray,
inverse: bool = False,
length: Optional[int] = None) -> np.ndarray:
self, data: np.ndarray, inverse: bool = False, length: Optional[int] = None
) -> np.ndarray:
"""
Single entrypoint for both stft and istft. This computes stft and
istft with librosa on stereo data. The two channels are processed
@@ -203,27 +196,27 @@ class Separator(object):
"""
assert not (inverse and length is None)
data = np.asfortranarray(data)
N = self._params['frame_length']
H = self._params['frame_step']
N = self._params["frame_length"]
H = self._params["frame_step"]
win = hann(N, sym=False)
fstft = istft if inverse else stft
win_len_arg = {
'win_length': None,
'length': None} if inverse else {'n_fft': N}
win_len_arg = {"win_length": None, "length": None} if inverse else {"n_fft": N}
n_channels = data.shape[-1]
out = []
for c in range(n_channels):
d = np.concatenate(
(np.zeros((N, )), data[:, c], np.zeros((N, )))
) if not inverse else data[:, :, c].T
d = (
np.concatenate((np.zeros((N,)), data[:, c], np.zeros((N,))))
if not inverse
else data[:, :, c].T
)
s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg)
if inverse:
s = s[N:N+length]
s = np.expand_dims(s.T, 2-inverse)
s = s[N : N + length]
s = np.expand_dims(s.T, 2 - inverse)
out.append(s)
if len(out) == 1:
return out[0]
return np.concatenate(out, axis=2-inverse)
return np.concatenate(out, axis=2 - inverse)
def _get_input_provider(self):
if self._input_provider is None:
@@ -238,25 +231,22 @@ class Separator(object):
def _get_builder(self):
if self._builder is None:
self._builder = EstimatorSpecBuilder(
self._get_features(),
self._params)
self._builder = EstimatorSpecBuilder(self._get_features(), self._params)
return self._builder
def _get_session(self):
if self._session is None:
saver = tf.compat.v1.train.Saver()
provider = ModelProvider.default()
model_directory: str = provider.get(self._params['model_dir'])
model_directory: str = provider.get(self._params["model_dir"])
latest_checkpoint = tf.train.latest_checkpoint(model_directory)
self._session = tf.compat.v1.Session()
saver.restore(self._session, latest_checkpoint)
return self._session
def _separate_librosa(
self,
waveform: np.ndarray,
audio_descriptor: AudioDescriptor) -> Dict:
self, waveform: np.ndarray, audio_descriptor: AudioDescriptor
) -> Dict:
"""
Performs separation with librosa backend for STFT.
@@ -280,20 +270,18 @@ class Separator(object):
outputs = sess.run(
outputs,
feed_dict=self._get_input_provider().get_feed_dict(
features,
stft,
audio_descriptor))
features, stft, audio_descriptor
),
)
for inst in self._get_builder().instruments:
out[inst] = self._stft(
outputs[inst],
inverse=True,
length=waveform.shape[0])
outputs[inst], inverse=True, length=waveform.shape[0]
)
return out
def _separate_tensorflow(
self,
waveform: np.ndarray,
audio_descriptor: AudioDescriptor) -> Dict:
self, waveform: np.ndarray, audio_descriptor: AudioDescriptor
) -> Dict:
"""
Performs source separation over the given waveform with tensorflow
backend.
@@ -310,18 +298,17 @@ class Separator(object):
waveform = to_stereo(waveform)
prediction_generator = self._get_prediction_generator()
# NOTE: update data in generator before performing separation.
self._data_generator.update_data({
'waveform': waveform,
'audio_id': np.array(audio_descriptor)})
self._data_generator.update_data(
{"waveform": waveform, "audio_id": np.array(audio_descriptor)}
)
# NOTE: perform separation.
prediction = next(prediction_generator)
prediction.pop('audio_id')
prediction.pop("audio_id")
return prediction
def separate(
self,
waveform: np.ndarray,
audio_descriptor: Optional[str] = None) -> None:
self, waveform: np.ndarray, audio_descriptor: Optional[str] = None
) -> None:
"""
Performs separation on a waveform.
@@ -331,12 +318,12 @@ class Separator(object):
audio_descriptor (str):
(Optional) string describing the waveform (e.g. filename).
"""
backend: str = self._params['stft_backend']
backend: str = self._params["stft_backend"]
if backend == STFTBackend.TENSORFLOW:
return self._separate_tensorflow(waveform, audio_descriptor)
elif backend == STFTBackend.LIBROSA:
return self._separate_librosa(waveform, audio_descriptor)
raise ValueError(f'Unsupported STFT backend {backend}')
raise ValueError(f"Unsupported STFT backend {backend}")
def separate_to_file(
self,
@@ -344,11 +331,12 @@ class Separator(object):
destination: str,
audio_adapter: Optional[AudioAdapter] = None,
offset: int = 0,
duration: float = 600.,
duration: float = 600.0,
codec: Codec = Codec.WAV,
bitrate: str = '128k',
filename_format: str = '{filename}/{instrument}.{codec}',
synchronous: bool = True) -> None:
bitrate: str = "128k",
filename_format: str = "{filename}/{instrument}.{codec}",
synchronous: bool = True,
) -> None:
"""
Performs source separation and export result to file using
given audio adapter.
@@ -389,7 +377,8 @@ class Separator(object):
audio_descriptor,
offset=offset,
duration=duration,
sample_rate=self._sample_rate)
sample_rate=self._sample_rate,
)
sources = self.separate(waveform, audio_descriptor)
self.save_to_file(
sources,
@@ -399,18 +388,20 @@ class Separator(object):
codec,
audio_adapter,
bitrate,
synchronous)
synchronous,
)
def save_to_file(
self,
sources: Dict,
audio_descriptor: AudioDescriptor,
destination: str,
filename_format: str = '{filename}/{instrument}.{codec}',
filename_format: str = "{filename}/{instrument}.{codec}",
codec: Codec = Codec.WAV,
audio_adapter: Optional[AudioAdapter] = None,
bitrate: str = '128k',
synchronous: bool = True) -> None:
bitrate: str = "128k",
synchronous: bool = True,
) -> None:
"""
Export dictionary of sources to files.
@@ -443,34 +434,32 @@ class Separator(object):
filename = splitext(basename(audio_descriptor))[0]
generated = []
for instrument, data in sources.items():
path = join(destination, filename_format.format(
path = join(
destination,
filename_format.format(
filename=filename,
instrument=instrument,
foldername=foldername,
codec=codec,
))
),
)
directory = os.path.dirname(path)
if not os.path.exists(directory):
os.makedirs(directory)
if path in generated:
raise SpleeterError((
f'Separated source path conflict : {path},'
'please check your filename format'))
raise SpleeterError(
(
f"Separated source path conflict : {path},"
"please check your filename format"
)
)
generated.append(path)
if self._pool:
task = self._pool.apply_async(audio_adapter.save, (
path,
data,
self._sample_rate,
codec,
bitrate))
task = self._pool.apply_async(
audio_adapter.save, (path, data, self._sample_rate, codec, bitrate)
)
self._tasks.append(task)
else:
audio_adapter.save(
path,
data,
self._sample_rate,
codec,
bitrate)
audio_adapter.save(path, data, self._sample_rate, codec, bitrate)
if synchronous and self._pool:
self.join()

View File

@@ -8,6 +8,7 @@ from typing import Any, Tuple
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
# pylint: enable=import-error
AudioDescriptor: type = Any

View File

@@ -3,6 +3,6 @@
""" This package provides utility function and classes. """
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"

View File

@@ -3,20 +3,18 @@
""" Module that provides configuration loading function. """
import json
import importlib.resources as loader
import json
from os.path import exists
from typing import Dict
from .. import resources, SpleeterError
from .. import SpleeterError, resources
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
_EMBEDDED_CONFIGURATION_PREFIX: str = 'spleeter:'
_EMBEDDED_CONFIGURATION_PREFIX: str = "spleeter:"
def load_configuration(descriptor: str) -> Dict:
@@ -41,13 +39,13 @@ def load_configuration(descriptor: str) -> Dict:
"""
# Embedded configuration reading.
if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX):
name = descriptor[len(_EMBEDDED_CONFIGURATION_PREFIX):]
if not loader.is_resource(resources, f'{name}.json'):
raise SpleeterError(f'No embedded configuration {name} found')
with loader.open_text(resources, f'{name}.json') as stream:
name = descriptor[len(_EMBEDDED_CONFIGURATION_PREFIX) :]
if not loader.is_resource(resources, f"{name}.json"):
raise SpleeterError(f"No embedded configuration {name} found")
with loader.open_text(resources, f"{name}.json") as stream:
return json.load(stream)
# Standard file reading.
if not exists(descriptor):
raise SpleeterError(f'Configuration file {descriptor} not found')
with open(descriptor, 'r') as stream:
raise SpleeterError(f"Configuration file {descriptor} not found")
with open(descriptor, "r") as stream:
return json.load(stream)

View File

@@ -5,19 +5,19 @@
import logging
import warnings
from os import environ
# pyright: reportMissingImports=false
# pylint: disable=import-error
from typer import echo
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
class TyperLoggerHandler(logging.Handler):
@@ -27,10 +27,10 @@ class TyperLoggerHandler(logging.Handler):
echo(self.format(record))
formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')
formatter = logging.Formatter("%(levelname)s:%(name)s:%(message)s")
handler = TyperLoggerHandler()
handler.setFormatter(formatter)
logger: logging.Logger = logging.getLogger('spleeter')
logger: logging.Logger = logging.getLogger("spleeter")
logger.addHandler(handler)
logger.setLevel(logging.INFO)
@@ -45,11 +45,12 @@ def configure_logger(verbose: bool) -> None:
"""
from tensorflow import get_logger
from tensorflow.compat.v1 import logging as tf_logging
tf_logger = get_logger()
tf_logger.handlers = [handler]
if verbose:
tf_logging.set_verbosity(tf_logging.INFO)
logger.setLevel(logging.DEBUG)
else:
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore")
tf_logging.set_verbosity(tf_logging.ERROR)

View File

@@ -5,21 +5,22 @@
from typing import Any, Callable, Dict
import pandas as pd
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
import pandas as pd
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def sync_apply(
tensor_dict: tf.Tensor,
func: Callable,
concat_axis: int = 1) -> Dict[str, tf.Tensor]:
tensor_dict: tf.Tensor, func: Callable, concat_axis: int = 1
) -> Dict[str, tf.Tensor]:
"""
Return a function that applies synchronously the provided func on the
provided dictionnary of tensor. This means that func is applied to the
@@ -48,7 +49,8 @@ def sync_apply(
"""
if concat_axis not in {0, 1}:
raise NotImplementedError(
'Function only implemented for concat_axis equal to 0 or 1')
"Function only implemented for concat_axis equal to 0 or 1"
)
tensor_list = list(tensor_dict.values())
concat_tensor = tf.concat(tensor_list, concat_axis)
processed_concat_tensor = func(concat_tensor)
@@ -56,18 +58,21 @@ def sync_apply(
D = tensor_shape[concat_axis]
if concat_axis == 0:
return {
name: processed_concat_tensor[index * D:(index + 1) * D, :, :]
for index, name in enumerate(tensor_dict)}
name: processed_concat_tensor[index * D : (index + 1) * D, :, :]
for index, name in enumerate(tensor_dict)
}
return {
name: processed_concat_tensor[:, index * D:(index + 1) * D, :]
for index, name in enumerate(tensor_dict)}
name: processed_concat_tensor[:, index * D : (index + 1) * D, :]
for index, name in enumerate(tensor_dict)
}
def from_float32_to_uint8(
tensor: tf.Tensor,
tensor_key: str = 'tensor',
min_key: str = 'min',
max_key: str = 'max') -> tf.Tensor:
tensor_key: str = "tensor",
min_key: str = "min",
max_key: str = "max",
) -> tf.Tensor:
"""
Parameters:
@@ -83,16 +88,17 @@ def from_float32_to_uint8(
tensor_max = tf.reduce_max(tensor)
return {
tensor_key: tf.cast(
(tensor - tensor_min) / (tensor_max - tensor_min + 1e-16)
* 255.9999, dtype=tf.uint8),
(tensor - tensor_min) / (tensor_max - tensor_min + 1e-16) * 255.9999,
dtype=tf.uint8,
),
min_key: tensor_min,
max_key: tensor_max}
max_key: tensor_max,
}
def from_uint8_to_float32(
tensor: tf.Tensor,
tensor_min: tf.Tensor,
tensor_max: tf.Tensor) -> tf.Tensor:
tensor: tf.Tensor, tensor_min: tf.Tensor, tensor_max: tf.Tensor
) -> tf.Tensor:
"""
Parameters:
@@ -104,14 +110,11 @@ def from_uint8_to_float32(
tensorflow.Tensor:
"""
return (
tf.cast(tensor, tf.float32)
* (tensor_max - tensor_min)
/ 255.9999 + tensor_min)
tf.cast(tensor, tf.float32) * (tensor_max - tensor_min) / 255.9999 + tensor_min
)
def pad_and_partition(
tensor: tf.Tensor,
segment_len: int) -> tf.Tensor:
def pad_and_partition(tensor: tf.Tensor, segment_len: int) -> tf.Tensor:
"""
Pad and partition a tensor into segment of len `segment_len`
along the first dimension. The tensor is padded with 0 in order
@@ -137,15 +140,11 @@ def pad_and_partition(
"""
tensor_size = tf.math.floormod(tf.shape(tensor)[0], segment_len)
pad_size = tf.math.floormod(segment_len - tensor_size, segment_len)
padded = tf.pad(
tensor,
[[0, pad_size]] + [[0, 0]] * (len(tensor.shape)-1))
padded = tf.pad(tensor, [[0, pad_size]] + [[0, 0]] * (len(tensor.shape) - 1))
split = (tf.shape(padded)[0] + segment_len - 1) // segment_len
return tf.reshape(
padded,
tf.concat(
[[split, segment_len], tf.shape(padded)[1:]],
axis=0))
padded, tf.concat([[split, segment_len], tf.shape(padded)[1:]], axis=0)
)
def pad_and_reshape(instr_spec, frame_length, F) -> Any:
@@ -164,10 +163,7 @@ def pad_and_reshape(instr_spec, frame_length, F) -> Any:
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
extended_spec = tf.concat([instr_spec, extension], axis=2)
old_shape = tf.shape(extended_spec)
new_shape = tf.concat([
[old_shape[0] * old_shape[1]],
old_shape[2:]],
axis=0)
new_shape = tf.concat([[old_shape[0] * old_shape[1]], old_shape[2:]], axis=0)
processed_instr_spec = tf.reshape(extended_spec, new_shape)
return processed_instr_spec
@@ -186,16 +182,11 @@ def dataset_from_csv(csv_path: str, **kwargs) -> Any:
Loaded dataset.
"""
df = pd.read_csv(csv_path, **kwargs)
dataset = (
tf.data.Dataset.from_tensor_slices(
{key: df[key].values for key in df})
)
dataset = tf.data.Dataset.from_tensor_slices({key: df[key].values for key in df})
return dataset
def check_tensor_shape(
tensor_tf: tf.Tensor,
target_shape: Any) -> bool:
def check_tensor_shape(tensor_tf: tf.Tensor, target_shape: Any) -> bool:
"""
Return a Tensorflow boolean graph that indicates whether
sample[features_key] has the specified target shape. Only check
@@ -215,14 +206,12 @@ def check_tensor_shape(
for i, target_length in enumerate(target_shape):
if target_length:
result = tf.logical_and(
result,
tf.equal(tf.constant(target_length), tf.shape(tensor_tf)[i]))
result, tf.equal(tf.constant(target_length), tf.shape(tensor_tf)[i])
)
return result
def set_tensor_shape(
tensor: tf.Tensor,
tensor_shape: Any) -> tf.Tensor:
def set_tensor_shape(tensor: tf.Tensor, tensor_shape: Any) -> tf.Tensor:
"""
Set shape for a tensor (not in place, as opposed to tf.set_shape)