Merge pull request #529 from deezer/typer

Spleeter 2.1.0
This commit is contained in:
Félix Voituret
2021-01-08 18:40:12 +01:00
committed by GitHub
46 changed files with 4482 additions and 2176 deletions

View File

@@ -1,8 +1,6 @@
name: conda
on:
push:
branches:
- master
- workflow_dispatch
env:
ANACONDA_USERNAME: ${{ secrets.ANACONDA_USERNAME }}
ANACONDA_PASSWORD: ${{ secrets.ANACONDA_PASSWORD }}
@@ -11,7 +9,7 @@ jobs:
strategy:
matrix:
python: [3.7, 3.8]
package: [spleeter]
package: [spleeter, spleeter-gpu]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2

View File

@@ -4,40 +4,22 @@ on:
branches:
- master
env:
TWINE_USERNAME: ${{ secrets.TWINE_USERNAME }}
TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD }}
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
jobs:
package-and-deploy:
strategy:
matrix:
platform: [cpu, gpu]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: 3.7
- uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }}
restore-keys: |
${{ runner.os }}-pip-
- uses: actions/cache@v2
with:
path: ${{ env.GITHUB_WORKSPACE }}/dist
key: sdist-${{ matrix.platform }}-${{ hashFiles('**/setup.py') }}
restore-keys: |
sdist-${{ matrix.platform }}-${{ hashFiles('**/setup.py') }}
sdist-${{ matrix.platform }}
sdist-
- name: Install dependencies
run: pip install --upgrade pip setuptools twine
- if: ${{ matrix.platform == 'cpu' }}
name: Package CPU distribution
run: make build
- if: ${{ matrix.platform == 'gpu' }}
name: Package GPU distribution)
run: make build-gpu
- name: Install Poetry
run: |
pip install poetry
poetry config virtualenvs.in-project false
poetry config virtualenvs.path ~/.virtualenvs
poetry config pypi-token.pypi $PYPI_TOKEN
- name: Deploy to pypi
run: make deploy
run: |
poetry build
poetry publish

View File

@@ -1,41 +0,0 @@
name: pytest
on:
pull_request:
branches:
- master
jobs:
tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- uses: actions/cache@v2
id: spleeter-pip-cache
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }}
restore-keys: |
${{ runner.os }}-pip-
- uses: actions/cache@v2
env:
model-release: 1
id: spleeter-model-cache
with:
path: ${{ env.GITHUB_WORKSPACE }}/pretrained_models
key: models-${{ env.model-release }}
restore-keys: |
models-${{ env.model-release }}
- name: Install dependencies
run: |
sudo apt-get update && sudo apt-get install -y ffmpeg
pip install --upgrade pip setuptools
pip install pytest==5.4.3 pytest-xdist==1.32.0 pytest-forked==1.1.3 musdb museval
python setup.py install
- name: Test with pytest
run: make test

51
.github/workflows/test.yml vendored Normal file
View File

@@ -0,0 +1,51 @@
name: test
on:
pull_request:
branches:
- master
jobs:
tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- uses: actions/cache@v2
env:
model-release: 1
id: spleeter-model-cache
with:
path: ${{ env.GITHUB_WORKSPACE }}/pretrained_models
key: models-${{ env.model-release }}
restore-keys: |
models-${{ env.model-release }}
- name: Install ffmpeg
run: |
sudo apt-get update && sudo apt-get install -y ffmpeg
- name: Install Poetry
run: |
pip install poetry
poetry config virtualenvs.in-project false
poetry config virtualenvs.path ~/.virtualenvs
- name: Cache Poetry virtualenv
uses: actions/cache@v1
id: cache
with:
path: ~/.virtualenvs
key: poetry-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
restore-keys: |
poetry-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
- name: Install Dependencies
run: poetry install
if: steps.cache.outputs.cache-hit != 'true'
- name: Code quality checks
run: |
poetry run black spleeter --check
poetry run isort spleeter --check
- name: Test with pytest
run: poetry run pytest tests/

View File

@@ -1,5 +1,27 @@
# Changelog History
## 2.1.0
This version introduce design related changes, especially transition to Typer for CLI managment and Poetry as
library build backend.
* `-i` option is now deprecated and replaced by traditional CLI input argument listing
* Project is now built using Poetry
* Project requires code formatting using Black and iSort
* Dedicated GPU package `spleeter-gpu` is not supported anymore, `spleeter` package will support both CPU and GPU hardware
### API changes:
* function `get_default_audio_adapter` is now available as `default()` class method within `AudioAdapter` class
* function `get_default_model_provider` is now available as `default()` class method within `ModelProvider` class
* `STFTBackend` and `Codec` are now string enum
* `GithubModelProvider` now use `httpx` with HTTP/2 support
* Commands are now located in `__main__` module, wrapped as simple function using Typer options module provide specification for each available option and argument
* `types` module provide custom type specification and must be enhanced in future release to provide more robust typing support with MyPy
* `utils.logging` module has been cleaned, logger instance is now a module singleton, and a single function is used to configure it with verbose parameter
* Added a custom logger handler (see tiangolo/typer#203 discussion)
## 2.0
First release, October 9th 2020

View File

@@ -1,3 +0,0 @@
include spleeter/resources/*.json
include README.md
include LICENSE

View File

@@ -1,34 +0,0 @@
# =======================================================
# Library lifecycle management.
#
# @author Deezer Research <spleeter@deezer.com>
# @licence MIT Licence
# =======================================================
FEEDSTOCK = spleeter-feedstock
FEEDSTOCK_REPOSITORY = https://github.com/deezer/$(FEEDSTOCK)
FEEDSTOCK_RECIPE = $(FEEDSTOCK)/recipe/spleeter/meta.yaml
PYTEST_CMD = pytest -W ignore::FutureWarning -W ignore::DeprecationWarning -vv --forked
all: clean build test deploy
clean:
rm -Rf *.egg-info
rm -Rf dist
build: clean
sed -i "s/project_name = '[^']*'/project_name = 'spleeter'/g" setup.py
sed -i "s/tensorflow_dependency = '[^']*'/tensorflow_dependency = 'tensorflow'/g" setup.py
python3 setup.py sdist
build-gpu: clean
sed -i "s/project_name = '[^']*'/project_name = 'spleeter-gpu'/g" setup.py
sed -i "s/tensorflow_dependency = '[^']*'/tensorflow_dependency = 'tensorflow-gpu'/g" setup.py
python3 setup.py sdist
test:
$(PYTEST_CMD) tests/
deploy:
pip install twine
twine upload --skip-existing dist/*

View File

@@ -2,6 +2,9 @@
[![Github actions](https://github.com/deezer/spleeter/workflows/pytest/badge.svg)](https://github.com/deezer/spleeter/actions) ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/spleeter) [![PyPI version](https://badge.fury.io/py/spleeter.svg)](https://badge.fury.io/py/spleeter) [![Conda](https://img.shields.io/conda/vn/conda-forge/spleeter)](https://anaconda.org/conda-forge/spleeter) [![Docker Pulls](https://img.shields.io/docker/pulls/researchdeezer/spleeter)](https://hub.docker.com/r/researchdeezer/spleeter) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deezer/spleeter/blob/master/spleeter.ipynb) [![Gitter chat](https://badges.gitter.im/gitterHQ/gitter.png)](https://gitter.im/spleeter/community) [![status](https://joss.theoj.org/papers/259e5efe669945a343bad6eccb89018b/status.svg)](https://joss.theoj.org/papers/259e5efe669945a343bad6eccb89018b)
> :warning: [Spleeter 2.1.0](https://pypi.org/project/spleeter/) release introduces some breaking changes, including new CLI option naming for input, and the drop
> of dedicated GPU package. Please read [CHANGELOG](CHANGELOG.md) for more details.
## About
**Spleeter** is [Deezer](https://www.deezer.com/) source separation library with pretrained models
@@ -46,7 +49,7 @@ conda install -c conda-forge spleeter
# download an example audio file (if you don't have wget, use another tool for downloading)
wget https://github.com/deezer/spleeter/raw/master/audio_example.mp3
# separate the example audio into two components
spleeter separate -i audio_example.mp3 -p spleeter:2stems -o output
spleeter separate -p spleeter:2stems -o output audio_example.mp3
```
You should get two separated audio files (`vocals.wav` and `accompaniment.wav`) in the `output/audio_example` folder.
@@ -55,13 +58,18 @@ For a detailed documentation, please check the [repository wiki](https://github.
## Development and Testing
The following set of commands will clone this repository, create a virtual environment provisioned with the dependencies and run the tests (will take a few minutes):
This project is managed using [Poetry](https://python-poetry.org/docs/basic-usage/), to run test suite you
can execute the following set of commands:
```bash
# Clone spleeter repository
git clone https://github.com/Deezer/spleeter && cd spleeter
python -m venv spleeterenv && source spleeterenv/bin/activate
pip install . && pip install pytest pytest-xdist
make test
# Install poetry
pip install poetry
# Install spleeter dependencies
poetry install
# Run unit test suite
poetry run pytest tests/
```
## Reference

View File

@@ -0,0 +1,52 @@
{% set name = "spleeter-gpu" %}
{% set version = "2.0.2" %}
package:
name: {{ name|lower }}
version: {{ version }}
source:
- url: https://pypi.io/packages/source/{{ name[0] }}/{{ name }}/{{ name }}-{{ version }}.tar.gz
sha256: ecd3518a98f9978b9088d1cb2ef98f766401fd9007c2bf72a34e5b5bc5a6fdc3
build:
number: 0
script: {{ PYTHON }} -m pip install . -vv
skip: True # [osx]
entry_points:
- spleeter = spleeter.__main__:entrypoint
requirements:
host:
- python {{ python }}
- pip
run:
- python {{ python }}
- tensorflow-gpu ==2.2.0 # [linux]
- tensorflow-gpu ==23.0 # [win]
- pandas
- ffmpeg-python
- norbert
- librosa
test:
imports:
- spleeter
- spleeter.commands
- spleeter.model
- spleeter.utils
- spleeter.separator
about:
home: https://github.com/deezer/spleeter
license: MIT
license_family: MIT
license_file: LICENSE
summary: The Deezer source separation library with pretrained models based on tensorflow.
doc_url: https://github.com/deezer/spleeter/wiki
dev_url: https://github.com/deezer/spleeter
extra:
recipe-maintainers:
- Faylixe
- romi1502

View File

@@ -1,3 +0,0 @@
python:
- 3.7
- 3.8

1880
poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

83
pyproject.toml Normal file
View File

@@ -0,0 +1,83 @@
[tool.poetry]
name = "spleeter"
version = "2.1.0"
description = "The Deezer source separation library with pretrained models based on tensorflow."
authors = ["Deezer Research <spleeter@deezer.com>"]
license = "MIT License"
readme = "README.md"
repository = "https://github.com/deezer/spleeter"
homepage = "https://github.com/deezer/spleeter"
classifiers = [
"Environment :: Console",
"Environment :: MacOS X",
"Intended Audience :: Developers",
"Intended Audience :: Information Technology",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Operating System :: MacOS",
"Operating System :: Microsoft :: Windows",
"Operating System :: POSIX :: Linux",
"Operating System :: Unix",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: Implementation :: CPython",
"Topic :: Artistic Software",
"Topic :: Multimedia",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Sound/Audio :: Analysis",
"Topic :: Multimedia :: Sound/Audio :: Conversion",
"Topic :: Multimedia :: Sound/Audio :: Sound Synthesis",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Information Analysis",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Utilities"
]
packages = [ { include = "spleeter" } ]
include = ["LICENSE", "spleeter/resources/*.json"]
[tool.poetry.dependencies]
python = "^3.7"
ffmpeg-python = "0.2.0"
norbert = "0.2.1"
httpx = {extras = ["http2"], version = "^0.16.1"}
typer = "^0.3.2"
librosa = "0.8.0"
musdb = {version = "0.3.1", optional = true}
museval = {version = "0.3.0", optional = true}
tensorflow = "2.3.0"
pandas = "1.1.2"
numpy = "<1.19.0,>=1.16.0"
[tool.poetry.dev-dependencies]
pytest = "^6.2.1"
isort = "^5.7.0"
black = "^20.8b1"
mypy = "^0.790"
pytest-forked = "^1.3.0"
musdb = "0.3.1"
museval = "0.3.0"
[tool.poetry.scripts]
spleeter = 'spleeter.__main__:entrypoint'
[tool.poetry.extras]
evaluation = ["musdb", "museval"]
[tool.isort]
profile = "black"
multi_line_output = 3
[tool.pytest.ini_options]
addopts = "-W ignore::FutureWarning -W ignore::DeprecationWarning -vv --forked"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

102
setup.py
View File

@@ -1,102 +0,0 @@
#!/usr/bin/env python
# coding: utf8
""" Distribution script. """
import sys
from os import path
from setuptools import setup
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# Default project values.
project_name = 'spleeter'
project_version = '2.0.2'
tensorflow_dependency = 'tensorflow'
tensorflow_version = '2.3.0'
here = path.abspath(path.dirname(__file__))
readme_path = path.join(here, 'README.md')
with open(readme_path, 'r') as stream:
readme = stream.read()
# Package setup entrypoint.
setup(
name=project_name,
version=project_version,
description='''
The Deezer source separation library with
pretrained models based on tensorflow.
''',
long_description=readme,
long_description_content_type='text/markdown',
author='Deezer Research',
author_email='spleeter@deezer.com',
url='https://github.com/deezer/spleeter',
license='MIT License',
packages=[
'spleeter',
'spleeter.audio',
'spleeter.commands',
'spleeter.model',
'spleeter.model.functions',
'spleeter.model.provider',
'spleeter.resources',
'spleeter.utils',
],
package_data={'spleeter.resources': ['*.json']},
python_requires='>=3.6, <3.9',
include_package_data=True,
install_requires=[
'ffmpeg-python==0.2.0',
'importlib_resources ; python_version<"3.7"',
'norbert==0.2.1',
'numpy<1.19.0,>=1.16.0',
'pandas==1.1.2',
'requests',
'scipy==1.4.1',
'setuptools>=41.0.0',
'librosa==0.8.0',
'{}=={}'.format(tensorflow_dependency, tensorflow_version),
],
extras_require={
'evaluation': ['musdb==0.3.1', 'museval==0.3.0']
},
entry_points={
'console_scripts': ['spleeter=spleeter.__main__:entrypoint']
},
classifiers=[
'Environment :: Console',
'Environment :: MacOS X',
'Intended Audience :: Developers',
'Intended Audience :: Information Technology',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: MIT License',
'Natural Language :: English',
'Operating System :: MacOS',
'Operating System :: Microsoft :: Windows',
'Operating System :: POSIX :: Linux',
'Operating System :: Unix',
'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3 :: Only',
'Programming Language :: Python :: Implementation :: CPython',
'Topic :: Artistic Software',
'Topic :: Multimedia',
'Topic :: Multimedia :: Sound/Audio',
'Topic :: Multimedia :: Sound/Audio :: Analysis',
'Topic :: Multimedia :: Sound/Audio :: Conversion',
'Topic :: Multimedia :: Sound/Audio :: Sound Synthesis',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Scientific/Engineering :: Information Analysis',
'Topic :: Software Development',
'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules',
'Topic :: Utilities']
)

View File

@@ -13,9 +13,9 @@
by providing train, evaluation and source separation action.
"""
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class SpleeterError(Exception):

View File

@@ -5,54 +5,252 @@
Python oneliner script usage.
USAGE: python -m spleeter {train,evaluate,separate} ...
Notes:
All critical import involving TF, numpy or Pandas are deported to
command function scope to avoid heavy import on CLI evaluation,
leading to large bootstraping time.
"""
import sys
import warnings
import json
from functools import partial
from glob import glob
from itertools import product
from os.path import join
from pathlib import Path
from typing import Container, Dict, List, Optional
# pyright: reportMissingImports=false
# pylint: disable=import-error
from typer import Exit, Typer
from . import SpleeterError
from .commands import create_argument_parser
from .utils.configuration import load_configuration
from .utils.logging import (
enable_logging,
enable_tensorflow_logging,
get_logger)
from .options import *
from .utils.logging import configure_logger, logger
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# pylint: enable=import-error
spleeter: Typer = Typer(add_completion=False)
""" CLI application. """
def main(argv):
""" Spleeter runner. Parse provided command line arguments
and run entrypoint for required command (either train,
evaluate or separate).
:param argv: Provided command line arguments.
@spleeter.command()
def train(
adapter: str = AudioAdapterOption,
data: Path = TrainingDataDirectoryOption,
params_filename: str = ModelParametersOption,
verbose: bool = VerboseOption,
) -> None:
"""
Train a source separation model
"""
import tensorflow as tf
from .audio.adapter import AudioAdapter
from .dataset import get_training_dataset, get_validation_dataset
from .model import model_fn
from .model.provider import ModelProvider
from .utils.configuration import load_configuration
configure_logger(verbose)
audio_adapter = AudioAdapter.get(adapter)
audio_path = str(data)
params = load_configuration(params_filename)
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params["model_dir"],
params=params,
config=tf.estimator.RunConfig(
save_checkpoints_steps=params["save_checkpoints_steps"],
tf_random_seed=params["random_seed"],
save_summary_steps=params["save_summary_steps"],
session_config=session_config,
log_step_count_steps=10,
keep_checkpoint_max=2,
),
)
input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
train_spec = tf.estimator.TrainSpec(
input_fn=input_fn, max_steps=params["train_max_steps"]
)
input_fn = partial(get_validation_dataset, params, audio_adapter, audio_path)
evaluation_spec = tf.estimator.EvalSpec(
input_fn=input_fn, steps=None, throttle_secs=params["throttle_secs"]
)
logger.info("Start model training")
tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec)
ModelProvider.writeProbe(params["model_dir"])
logger.info("Model training done")
@spleeter.command()
def separate(
deprecated_files: Optional[str] = AudioInputOption,
files: List[Path] = AudioInputArgument,
adapter: str = AudioAdapterOption,
bitrate: str = AudioBitrateOption,
codec: Codec = AudioCodecOption,
duration: float = AudioDurationOption,
offset: float = AudioOffsetOption,
output_path: Path = AudioOutputOption,
stft_backend: STFTBackend = AudioSTFTBackendOption,
filename_format: str = FilenameFormatOption,
params_filename: str = ModelParametersOption,
mwf: bool = MWFOption,
verbose: bool = VerboseOption,
) -> None:
"""
Separate audio file(s)
"""
from .audio.adapter import AudioAdapter
from .separator import Separator
configure_logger(verbose)
if deprecated_files is not None:
logger.error(
"⚠️ -i option is not supported anymore, audio files must be supplied "
"using input argument instead (see spleeter separate --help)"
)
raise Exit(20)
audio_adapter: AudioAdapter = AudioAdapter.get(adapter)
separator: Separator = Separator(
params_filename, MWF=mwf, stft_backend=stft_backend
)
for filename in files:
separator.separate_to_file(
str(filename),
str(output_path),
audio_adapter=audio_adapter,
offset=offset,
duration=duration,
codec=codec,
bitrate=bitrate,
filename_format=filename_format,
synchronous=False,
)
separator.join()
EVALUATION_SPLIT: str = "test"
EVALUATION_METRICS_DIRECTORY: str = "metrics"
EVALUATION_INSTRUMENTS: Container[str] = ("vocals", "drums", "bass", "other")
EVALUATION_METRICS: Container[str] = ("SDR", "SAR", "SIR", "ISR")
EVALUATION_MIXTURE: str = "mixture.wav"
EVALUATION_AUDIO_DIRECTORY: str = "audio"
def _compile_metrics(metrics_output_directory) -> Dict:
"""
Compiles metrics from given directory and returns results as dict.
Parameters:
metrics_output_directory (str):
Directory to get metrics from.
Returns:
Dict:
Compiled metrics as dict.
"""
import numpy as np
import pandas as pd
songs = glob(join(metrics_output_directory, "test/*.json"))
index = pd.MultiIndex.from_tuples(
product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS),
names=["instrument", "metric"],
)
pd.DataFrame([], index=["config1", "config2"], columns=index)
metrics = {
instrument: {k: [] for k in EVALUATION_METRICS}
for instrument in EVALUATION_INSTRUMENTS
}
for song in songs:
with open(song, "r") as stream:
data = json.load(stream)
for target in data["targets"]:
instrument = target["name"]
for metric in EVALUATION_METRICS:
sdr_med = np.median(
[
frame["metrics"][metric]
for frame in target["frames"]
if not np.isnan(frame["metrics"][metric])
]
)
metrics[instrument][metric].append(sdr_med)
return metrics
@spleeter.command()
def evaluate(
adapter: str = AudioAdapterOption,
output_path: Path = AudioOutputOption,
stft_backend: STFTBackend = AudioSTFTBackendOption,
params_filename: str = ModelParametersOption,
mus_dir: Path = MUSDBDirectoryOption,
mwf: bool = MWFOption,
verbose: bool = VerboseOption,
) -> Dict:
"""
Evaluate a model on the musDB test dataset
"""
import numpy as np
configure_logger(verbose)
try:
parser = create_argument_parser()
arguments = parser.parse_args(argv[1:])
enable_logging()
if arguments.verbose:
enable_tensorflow_logging()
if arguments.command == 'separate':
from .commands.separate import entrypoint
elif arguments.command == 'train':
from .commands.train import entrypoint
elif arguments.command == 'evaluate':
from .commands.evaluate import entrypoint
params = load_configuration(arguments.configuration)
entrypoint(arguments, params)
except SpleeterError as e:
get_logger().error(e)
import musdb
import museval
except ImportError:
logger.error("Extra dependencies musdb and museval not found")
logger.error("Please install musdb and museval first, abort")
raise Exit(10)
# Separate musdb sources.
songs = glob(join(mus_dir, EVALUATION_SPLIT, "*/"))
mixtures = [join(song, EVALUATION_MIXTURE) for song in songs]
audio_output_directory = join(output_path, EVALUATION_AUDIO_DIRECTORY)
separate(
deprecated_files=None,
files=mixtures,
adapter=adapter,
bitrate="128k",
codec=Codec.WAV,
duration=600.0,
offset=0,
output_path=join(audio_output_directory, EVALUATION_SPLIT),
stft_backend=stft_backend,
filename_format="{foldername}/{instrument}.{codec}",
params_filename=params_filename,
mwf=mwf,
verbose=verbose,
)
# Compute metrics with musdb.
metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY)
logger.info("Starting musdb evaluation (this could be long) ...")
dataset = musdb.DB(root=mus_dir, is_wav=True, subsets=[EVALUATION_SPLIT])
museval.eval_mus_dir(
dataset=dataset,
estimates_dir=audio_output_directory,
output_dir=metrics_output_directory,
)
logger.info("musdb evaluation done")
# Compute and pretty print median metrics.
metrics = _compile_metrics(metrics_output_directory)
for instrument, metric in metrics.items():
logger.info(f"{instrument}:")
for metric, value in metric.items():
logger.info(f"{metric}: {np.median(value):.3f}")
return metrics
def entrypoint():
""" Command line entrypoint. """
warnings.filterwarnings('ignore')
main(sys.argv)
""" Application entrypoint. """
try:
spleeter()
except SpleeterError as e:
logger.error(e)
if __name__ == '__main__':
if __name__ == "__main__":
entrypoint()

View File

@@ -10,6 +10,43 @@
- Waveform convertion and transforming functions.
"""
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
from enum import Enum
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class Codec(str, Enum):
""" Enumeration of supported audio codec. """
WAV: str = "wav"
MP3: str = "mp3"
OGG: str = "ogg"
M4A: str = "m4a"
WMA: str = "wma"
FLAC: str = "flac"
class STFTBackend(str, Enum):
""" Enumeration of supported STFT backend. """
AUTO: str = "auto"
TENSORFLOW: str = "tensorflow"
LIBROSA: str = "librosa"
@classmethod
def resolve(cls: type, backend: str) -> str:
# NOTE: import is resolved here to avoid performance issues on command
# evaluation.
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
if backend not in cls.__members__.values():
raise ValueError(f"Unsupported backend {backend}")
if backend == cls.AUTO:
if len(tf.config.list_physical_devices("GPU")):
return cls.TENSORFLOW
return cls.LIBROSA
return backend

View File

@@ -3,70 +3,101 @@
""" AudioAdapter class defintion. """
import subprocess
from abc import ABC, abstractmethod
from importlib import import_module
from os.path import exists
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
from tensorflow.signal import stft, hann_window
# pylint: enable=import-error
from spleeter.audio import Codec
from .. import SpleeterError
from ..utils.logging import get_logger
from ..types import AudioDescriptor, Signal
from ..utils.logging import logger
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# pylint: enable=import-error
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class AudioAdapter(ABC):
""" An abstract class for manipulating audio signal. """
# Default audio adapter singleton instance.
DEFAULT = None
_DEFAULT: "AudioAdapter" = None
""" Default audio adapter singleton instance. """
@abstractmethod
def load(
self, audio_descriptor, offset, duration,
sample_rate, dtype=np.float32):
""" Loads the audio file denoted by the given audio descriptor
and returns it data as a waveform. Aims to be implemented
by client.
self,
audio_descriptor: AudioDescriptor,
offset: Optional[float] = None,
duration: Optional[float] = None,
sample_rate: Optional[float] = None,
dtype: np.dtype = np.float32,
) -> Signal:
"""
Loads the audio file denoted by the given audio descriptor and
returns it data as a waveform. Aims to be implemented by client.
:param audio_descriptor: Describe song to load, in case of file
based audio adapter, such descriptor would
be a file path.
:param offset: Start offset to load from in seconds.
:param duration: Duration to load in seconds.
:param sample_rate: Sample rate to load audio with.
:param dtype: Numpy data type to use, default to float32.
:returns: Loaded data as (wf, sample_rate) tuple.
Parameters:
audio_descriptor (AudioDescriptor):
Describe song to load, in case of file based audio adapter,
such descriptor would be a file path.
offset (Optional[float]):
Start offset to load from in seconds.
duration (Optional[float]):
Duration to load in seconds.
sample_rate (Optional[float]):
Sample rate to load audio with.
dtype (numpy.dtype):
(Optional) Numpy data type to use, default to `float32`.
Returns:
Signal:
Loaded data as (wf, sample_rate) tuple.
"""
pass
def load_tf_waveform(
self, audio_descriptor,
offset=0.0, duration=1800., sample_rate=44100,
dtype=b'float32', waveform_name='waveform'):
""" Load the audio and convert it to a tensorflow waveform.
self,
audio_descriptor,
offset: float = 0.0,
duration: float = 1800.0,
sample_rate: int = 44100,
dtype: bytes = b"float32",
waveform_name: str = "waveform",
) -> Dict[str, Any]:
"""
Load the audio and convert it to a tensorflow waveform.
:param audio_descriptor: Describe song to load, in case of file
based audio adapter, such descriptor would
be a file path.
:param offset: Start offset to load from in seconds.
:param duration: Duration to load in seconds.
:param sample_rate: Sample rate to load audio with.
:param dtype: Numpy data type to use, default to float32.
:param waveform_name: (Optional) Name of the key in output dict.
:returns: TF output dict with waveform as
(T x chan numpy array) and a boolean that
tells whether there were an error while
trying to load the waveform.
Parameters:
audio_descriptor ():
Describe song to load, in case of file based audio adapter,
such descriptor would be a file path.
offset (float):
Start offset to load from in seconds.
duration (float):
Duration to load in seconds.
sample_rate (float):
Sample rate to load audio with.
dtype (bytes):
(Optional)data type to use, default to `b'float32'`.
waveform_name (str):
(Optional) Name of the key in output dict, default to
`'waveform'`.
Returns:
Dict[str, Any]:
TF output dict with waveform as `(T x chan numpy array)`
and a boolean that tells whether there were an error while
trying to load the waveform.
"""
# Cast parameters to TF format.
offset = tf.cast(offset, tf.float64)
@@ -74,76 +105,96 @@ class AudioAdapter(ABC):
# Defined safe loading function.
def safe_load(path, offset, duration, sample_rate, dtype):
logger = get_logger()
logger.info(
f'Loading audio {path} from {offset} to {offset + duration}')
logger.info(f"Loading audio {path} from {offset} to {offset + duration}")
try:
(data, _) = self.load(
path.numpy(),
offset.numpy(),
duration.numpy(),
sample_rate.numpy(),
dtype=dtype.numpy())
logger.info('Audio data loaded successfully')
dtype=dtype.numpy(),
)
logger.info("Audio data loaded successfully")
return (data, False)
except Exception as e:
logger.exception(
'An error occurs while loading audio',
exc_info=e)
logger.exception("An error occurs while loading audio", exc_info=e)
return (np.float32(-1.0), True)
# Execute function and format results.
results = tf.py_function(
safe_load,
[audio_descriptor, offset, duration, sample_rate, dtype],
(tf.float32, tf.bool)),
results = (
tf.py_function(
safe_load,
[audio_descriptor, offset, duration, sample_rate, dtype],
(tf.float32, tf.bool),
),
)
waveform, error = results[0]
return {
waveform_name: waveform,
f'{waveform_name}_error': error
}
return {waveform_name: waveform, f"{waveform_name}_error": error}
@abstractmethod
def save(
self, path, data, sample_rate,
codec=None, bitrate=None):
""" Save the given audio data to the file denoted by
the given path.
self,
path: Union[Path, str],
data: np.ndarray,
sample_rate: float,
codec: Codec = None,
bitrate: str = None,
) -> None:
"""
Save the given audio data to the file denoted by the given path.
:param path: Path of the audio file to save data in.
:param data: Waveform data to write.
:param sample_rate: Sample rate to write file in.
:param codec: (Optional) Writing codec to use.
:param bitrate: (Optional) Bitrate of the written audio file.
Parameters:
path (Union[Path, str]):
Path like of the audio file to save data in.
data (numpy.ndarray):
Waveform data to write.
sample_rate (float):
Sample rate to write file in.
codec ():
(Optional) Writing codec to use, default to `None`.
bitrate (str):
(Optional) Bitrate of the written audio file, default to
`None`.
"""
pass
@classmethod
def default(cls: type) -> "AudioAdapter":
"""
Builds and returns a default audio adapter instance.
def get_default_audio_adapter():
""" Builds and returns a default audio adapter instance.
Returns:
AudioAdapter:
Default adapter instance to use.
"""
if cls._DEFAULT is None:
from .ffmpeg import FFMPEGProcessAudioAdapter
:returns: An audio adapter instance.
"""
if AudioAdapter.DEFAULT is None:
from .ffmpeg import FFMPEGProcessAudioAdapter
AudioAdapter.DEFAULT = FFMPEGProcessAudioAdapter()
return AudioAdapter.DEFAULT
cls._DEFAULT = FFMPEGProcessAudioAdapter()
return cls._DEFAULT
@classmethod
def get(cls: type, descriptor: str) -> "AudioAdapter":
"""
Load dynamically an AudioAdapter from given class descriptor.
def get_audio_adapter(descriptor):
""" Load dynamically an AudioAdapter from given class descriptor.
Parameters:
descriptor (str):
Adapter class descriptor (module.Class)
:param descriptor: Adapter class descriptor (module.Class)
:returns: Created adapter instance.
"""
if descriptor is None:
return get_default_audio_adapter()
module_path = descriptor.split('.')
adapter_class_name = module_path[-1]
module_path = '.'.join(module_path[:-1])
adapter_module = import_module(module_path)
adapter_class = getattr(adapter_module, adapter_class_name)
if not isinstance(adapter_class, AudioAdapter):
raise SpleeterError(
f'{adapter_class_name} is not a valid AudioAdapter class')
return adapter_class()
Returns:
AudioAdapter:
Created adapter instance.
"""
if not descriptor:
return cls.default()
module_path: List[str] = descriptor.split(".")
adapter_class_name: str = module_path[-1]
module_path: str = ".".join(module_path[:-1])
adapter_module = import_module(module_path)
adapter_class = getattr(adapter_module, adapter_class_name)
if not issubclass(adapter_class, AudioAdapter):
raise SpleeterError(
f"{adapter_class_name} is not a valid AudioAdapter class"
)
return adapter_class()

View File

@@ -3,39 +3,54 @@
""" This module provides audio data convertion functions. """
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
# pylint: enable=import-error
from ..utils.tensor import from_float32_to_uint8, from_uint8_to_float32
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# pylint: enable=import-error
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def to_n_channels(waveform, n_channels):
""" Convert a waveform to n_channels by removing or
duplicating channels if needed (in tensorflow).
def to_n_channels(waveform: tf.Tensor, n_channels: int) -> tf.Tensor:
"""
Convert a waveform to n_channels by removing or duplicating channels if
needed (in tensorflow).
:param waveform: Waveform to transform.
:param n_channels: Number of channel to reshape waveform in.
:returns: Reshaped waveform.
Parameters:
waveform (tensorflow.Tensor):
Waveform to transform.
n_channels (int):
Number of channel to reshape waveform in.
Returns:
tensorflow.Tensor:
Reshaped waveform.
"""
return tf.cond(
tf.shape(waveform)[1] >= n_channels,
true_fn=lambda: waveform[:, :n_channels],
false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels]
false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels],
)
def to_stereo(waveform):
""" Convert a waveform to stereo by duplicating if mono,
or truncating if too many channels.
def to_stereo(waveform: np.ndarray) -> np.ndarray:
"""
Convert a waveform to stereo by duplicating if mono, or truncating
if too many channels.
:param waveform: a (N, d) numpy array.
:returns: A stereo waveform as a (N, 1) numpy array.
Parameters:
waveform (numpy.ndarray):
a `(N, d)` numpy array.
Returns:
numpy.ndarray:
A stereo waveform as a `(N, 1)` numpy array.
"""
if waveform.shape[1] == 1:
return np.repeat(waveform, 2, axis=-1)
@@ -44,45 +59,81 @@ def to_stereo(waveform):
return waveform
def gain_to_db(tensor, espilon=10e-10):
""" Convert from gain to decibel in tensorflow.
:param tensor: Tensor to convert.
:param epsilon: Operation constant.
:returns: Converted tensor.
def gain_to_db(tensor: tf.Tensor, espilon: float = 10e-10) -> tf.Tensor:
"""
return 20. / np.log(10) * tf.math.log(tf.maximum(tensor, espilon))
Convert from gain to decibel in tensorflow.
Parameters:
tensor (tensorflow.Tensor):
Tensor to convert
epsilon (float):
Operation constant.
def db_to_gain(tensor):
""" Convert from decibel to gain in tensorflow.
:param tensor_db: Tensor to convert.
:returns: Converted tensor.
Returns:
tensorflow.Tensor:
Converted tensor.
"""
return tf.pow(10., (tensor / 20.))
return 20.0 / np.log(10) * tf.math.log(tf.maximum(tensor, espilon))
def spectrogram_to_db_uint(spectrogram, db_range=100., **kwargs):
""" Encodes given spectrogram into uint8 using decibel scale.
:param spectrogram: Spectrogram to be encoded as TF float tensor.
:param db_range: Range in decibel for encoding.
:returns: Encoded decibel spectrogram as uint8 tensor.
def db_to_gain(tensor: tf.Tensor) -> tf.Tensor:
"""
db_spectrogram = gain_to_db(spectrogram)
max_db_spectrogram = tf.reduce_max(db_spectrogram)
db_spectrogram = tf.maximum(db_spectrogram, max_db_spectrogram - db_range)
Convert from decibel to gain in tensorflow.
Parameters:
tensor (tensorflow.Tensor):
Tensor to convert
Returns:
tensorflow.Tensor:
Converted tensor.
"""
return tf.pow(10.0, (tensor / 20.0))
def spectrogram_to_db_uint(
spectrogram: tf.Tensor, db_range: float = 100.0, **kwargs
) -> tf.Tensor:
"""
Encodes given spectrogram into uint8 using decibel scale.
Parameters:
spectrogram (tensorflow.Tensor):
Spectrogram to be encoded as TF float tensor.
db_range (float):
Range in decibel for encoding.
Returns:
tensorflow.Tensor:
Encoded decibel spectrogram as `uint8` tensor.
"""
db_spectrogram: tf.Tensor = gain_to_db(spectrogram)
max_db_spectrogram: tf.Tensor = tf.reduce_max(db_spectrogram)
db_spectrogram: tf.Tensor = tf.maximum(
db_spectrogram, max_db_spectrogram - db_range
)
return from_float32_to_uint8(db_spectrogram, **kwargs)
def db_uint_spectrogram_to_gain(db_uint_spectrogram, min_db, max_db):
""" Decode spectrogram from uint8 decibel scale.
:param db_uint_spectrogram: Decibel pectrogram to decode.
:param min_db: Lower bound limit for decoding.
:param max_db: Upper bound limit for decoding.
:returns: Decoded spectrogram as float2 tensor.
def db_uint_spectrogram_to_gain(
db_uint_spectrogram: tf.Tensor, min_db: tf.Tensor, max_db: tf.Tensor
) -> tf.Tensor:
"""
db_spectrogram = from_uint8_to_float32(db_uint_spectrogram, min_db, max_db)
Decode spectrogram from uint8 decibel scale.
Paramters:
db_uint_spectrogram (tensorflow.Tensor):
Decibel spectrogram to decode.
min_db (tensorflow.Tensor):
Lower bound limit for decoding.
max_db (tensorflow.Tensor):
Upper bound limit for decoding.
Returns:
tensorflow.Tensor:
Decoded spectrogram as `float32` tensor.
"""
db_spectrogram: tf.Tensor = from_uint8_to_float32(
db_uint_spectrogram, min_db, max_db
)
return db_to_gain(db_spectrogram)

View File

@@ -8,143 +8,178 @@
used within this library.
"""
import datetime as dt
import os
import shutil
from pathlib import Path
from typing import Dict, Optional, Union
# pyright: reportMissingImports=false
# pylint: disable=import-error
import ffmpeg
import numpy as np
from .. import SpleeterError
from ..types import Signal
from ..utils.logging import logger
from . import Codec
from .adapter import AudioAdapter
# pylint: enable=import-error
from .adapter import AudioAdapter
from .. import SpleeterError
from ..utils.logging import get_logger
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def _check_ffmpeg_install():
""" Ensure FFMPEG binaries are available.
:raise SpleeterError: If ffmpeg or ffprobe is not found.
"""
for binary in ('ffmpeg', 'ffprobe'):
if shutil.which(binary) is None:
raise SpleeterError('{} binary not found'.format(binary))
def _to_ffmpeg_time(n):
""" Format number of seconds to time expected by FFMPEG.
:param n: Time in seconds to format.
:returns: Formatted time in FFMPEG format.
"""
m, s = divmod(n, 60)
h, m = divmod(m, 60)
return '%d:%02d:%09.6f' % (h, m, s)
def _to_ffmpeg_codec(codec):
ffmpeg_codecs = {
'm4a': 'aac',
'ogg': 'libvorbis',
'wma': 'wmav2',
}
return ffmpeg_codecs.get(codec) or codec
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class FFMPEGProcessAudioAdapter(AudioAdapter):
""" An AudioAdapter implementation that use FFMPEG binary through
"""
An AudioAdapter implementation that use FFMPEG binary through
subprocess in order to perform I/O operation for audio processing.
When created, FFMPEG binary path will be checked and expended,
raising exception if not found. Such path could be infered using
FFMPEG_PATH environment variable.
`FFMPEG_PATH` environment variable.
"""
SUPPORTED_CODECS: Dict[Codec, str] = {
Codec.M4A: "aac",
Codec.OGG: "libvorbis",
Codec.WMA: "wmav2",
}
""" FFMPEG codec name mapping. """
def __init__(_) -> None:
"""
Default constructor, ensure FFMPEG binaries are available.
Raises:
SpleeterError:
If ffmpeg or ffprobe is not found.
"""
for binary in ("ffmpeg", "ffprobe"):
if shutil.which(binary) is None:
raise SpleeterError("{} binary not found".format(binary))
def load(
self, path, offset=None, duration=None,
sample_rate=None, dtype=np.float32):
""" Loads the audio file denoted by the given path
_,
path: Union[Path, str],
offset: Optional[float] = None,
duration: Optional[float] = None,
sample_rate: Optional[float] = None,
dtype: np.dtype = np.float32,
) -> Signal:
"""
Loads the audio file denoted by the given path
and returns it data as a waveform.
:param path: Path of the audio file to load data from.
:param offset: (Optional) Start offset to load from in seconds.
:param duration: (Optional) Duration to load in seconds.
:param sample_rate: (Optional) Sample rate to load audio with.
:param dtype: (Optional) Numpy data type to use, default to float32.
:returns: Loaded data a (waveform, sample_rate) tuple.
:raise SpleeterError: If any error occurs while loading audio.
Parameters:
path (Union[Path, str]:
Path of the audio file to load data from.
offset (Optional[float]):
Start offset to load from in seconds.
duration (Optional[float]):
Duration to load in seconds.
sample_rate (Optional[float]):
Sample rate to load audio with.
dtype (numpy.dtype):
(Optional) Numpy data type to use, default to `float32`.
Returns:
Signal:
Loaded data a (waveform, sample_rate) tuple.
Raises:
SpleeterError:
If any error occurs while loading audio.
"""
_check_ffmpeg_install()
if isinstance(path, Path):
path = str(path)
if not isinstance(path, str):
path = path.decode()
try:
probe = ffmpeg.probe(path)
except ffmpeg._run.Error as e:
raise SpleeterError(
'An error occurs with ffprobe (see ffprobe output below)\n\n{}'
.format(e.stderr.decode()))
if 'streams' not in probe or len(probe['streams']) == 0:
raise SpleeterError('No stream was found with ffprobe')
"An error occurs with ffprobe (see ffprobe output below)\n\n{}".format(
e.stderr.decode()
)
)
if "streams" not in probe or len(probe["streams"]) == 0:
raise SpleeterError("No stream was found with ffprobe")
metadata = next(
stream
for stream in probe['streams']
if stream['codec_type'] == 'audio')
n_channels = metadata['channels']
stream for stream in probe["streams"] if stream["codec_type"] == "audio"
)
n_channels = metadata["channels"]
if sample_rate is None:
sample_rate = metadata['sample_rate']
output_kwargs = {'format': 'f32le', 'ar': sample_rate}
sample_rate = metadata["sample_rate"]
output_kwargs = {"format": "f32le", "ar": sample_rate}
if duration is not None:
output_kwargs['t'] = _to_ffmpeg_time(duration)
output_kwargs["t"] = str(dt.timedelta(seconds=duration))
if offset is not None:
output_kwargs['ss'] = _to_ffmpeg_time(offset)
output_kwargs["ss"] = str(dt.timedelta(seconds=offset))
process = (
ffmpeg
.input(path)
.output('pipe:', **output_kwargs)
.run_async(pipe_stdout=True, pipe_stderr=True))
ffmpeg.input(path)
.output("pipe:", **output_kwargs)
.run_async(pipe_stdout=True, pipe_stderr=True)
)
buffer, _ = process.communicate()
waveform = np.frombuffer(buffer, dtype='<f4').reshape(-1, n_channels)
waveform = np.frombuffer(buffer, dtype="<f4").reshape(-1, n_channels)
if not waveform.dtype == np.dtype(dtype):
waveform = waveform.astype(dtype)
return (waveform, sample_rate)
def save(
self, path, data, sample_rate,
codec=None, bitrate=None):
""" Write waveform data to the file denoted by the given path
using FFMPEG process.
:param path: Path of the audio file to save data in.
:param data: Waveform data to write.
:param sample_rate: Sample rate to write file in.
:param codec: (Optional) Writing codec to use.
:param bitrate: (Optional) Bitrate of the written audio file.
:raise IOError: If any error occurs while using FFMPEG to write data.
self,
path: Union[Path, str],
data: np.ndarray,
sample_rate: float,
codec: Codec = None,
bitrate: str = None,
) -> None:
"""
_check_ffmpeg_install()
Write waveform data to the file denoted by the given path using
FFMPEG process.
Parameters:
path (Union[Path, str]):
Path like of the audio file to save data in.
data (numpy.ndarray):
Waveform data to write.
sample_rate (float):
Sample rate to write file in.
codec ():
(Optional) Writing codec to use, default to `None`.
bitrate (str):
(Optional) Bitrate of the written audio file, default to
`None`.
Raises:
IOError:
If any error occurs while using FFMPEG to write data.
"""
if isinstance(path, Path):
path = str(path)
directory = os.path.dirname(path)
if not os.path.exists(directory):
raise SpleeterError(f'output directory does not exists: {directory}')
get_logger().debug('Writing file %s', path)
input_kwargs = {'ar': sample_rate, 'ac': data.shape[1]}
output_kwargs = {'ar': sample_rate, 'strict': '-2'}
raise SpleeterError(f"output directory does not exists: {directory}")
logger.debug(f"Writing file {path}")
input_kwargs = {"ar": sample_rate, "ac": data.shape[1]}
output_kwargs = {"ar": sample_rate, "strict": "-2"}
if bitrate:
output_kwargs['audio_bitrate'] = bitrate
if codec is not None and codec != 'wav':
output_kwargs['codec'] = _to_ffmpeg_codec(codec)
output_kwargs["audio_bitrate"] = bitrate
if codec is not None and codec != "wav":
output_kwargs["codec"] = self.SUPPORTED_CODECS.get(codec, codec)
process = (
ffmpeg
.input('pipe:', format='f32le', **input_kwargs)
ffmpeg.input("pipe:", format="f32le", **input_kwargs)
.output(path, **output_kwargs)
.overwrite_output()
.run_async(pipe_stdin=True, pipe_stderr=True, quiet=True))
.run_async(pipe_stdin=True, pipe_stderr=True, quiet=True)
)
try:
process.stdin.write(data.astype('<f4').tobytes())
process.stdin.write(data.astype("<f4").tobytes())
process.stdin.close()
process.wait()
except IOError:
raise SpleeterError(f'FFMPEG error: {process.stderr.read()}')
get_logger().info('File %s written succesfully', path)
raise SpleeterError(f"FFMPEG error: {process.stderr.read()}")
logger.info(f"File {path} written succesfully")

View File

@@ -1,128 +1,176 @@
#!/usr/bin/env python
# coding: utf8
""" Spectrogram specific data augmentation """
""" Spectrogram specific data augmentation. """
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
from tensorflow.signal import hann_window, stft
from tensorflow.signal import stft, hann_window
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def compute_spectrogram_tf(
waveform,
frame_length=2048, frame_step=512,
spec_exponent=1., window_exponent=1.):
""" Compute magnitude / power spectrogram from waveform as
a n_samples x n_channels tensor.
:param waveform: Input waveform as (times x number of channels)
tensor.
:param frame_length: Length of a STFT frame to use.
:param frame_step: HOP between successive frames.
:param spec_exponent: Exponent of the spectrogram (usually 1 for
magnitude spectrogram, or 2 for power spectrogram).
:param window_exponent: Exponent applied to the Hann windowing function
(may be useful for making perfect STFT/iSTFT
reconstruction).
:returns: Computed magnitude / power spectrogram as a
(T x F x n_channels) tensor.
waveform: tf.Tensor,
frame_length: int = 2048,
frame_step: int = 512,
spec_exponent: float = 1.0,
window_exponent: float = 1.0,
) -> tf.Tensor:
"""
stft_tensor = tf.transpose(
Compute magnitude / power spectrogram from waveform as a
`n_samples x n_channels` tensor.
Parameters:
waveform (tensorflow.Tensor):
Input waveform as `(times x number of channels)` tensor.
frame_length (int):
Length of a STFT frame to use.
frame_step (int):
HOP between successive frames.
spec_exponent (float):
Exponent of the spectrogram (usually 1 for magnitude
spectrogram, or 2 for power spectrogram).
window_exponent (float):
Exponent applied to the Hann windowing function (may be
useful for making perfect STFT/iSTFT reconstruction).
Returns:
tensorflow.Tensor:
Computed magnitude / power spectrogram as a
`(T x F x n_channels)` tensor.
"""
stft_tensor: tf.Tensor = tf.transpose(
stft(
tf.transpose(waveform),
frame_length,
frame_step,
window_fn=lambda f, dtype: hann_window(
f,
periodic=True,
dtype=waveform.dtype) ** window_exponent),
perm=[1, 2, 0])
f, periodic=True, dtype=waveform.dtype
)
** window_exponent,
),
perm=[1, 2, 0],
)
return tf.abs(stft_tensor) ** spec_exponent
def time_stretch(
spectrogram,
factor=1.0,
method=tf.image.ResizeMethod.BILINEAR):
""" Time stretch a spectrogram preserving shape in tensorflow. Note that
spectrogram: tf.Tensor,
factor: float = 1.0,
method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR,
) -> tf.Tensor:
"""
Time stretch a spectrogram preserving shape in tensorflow. Note that
this is an approximation in the frequency domain.
:param spectrogram: Input spectrogram to be time stretched as tensor.
:param factor: (Optional) Time stretch factor, must be >0, default to 1.
:param mehtod: (Optional) Interpolation method, default to BILINEAR.
:returns: Time stretched spectrogram as tensor with same shape.
Parameters:
spectrogram (tensorflow.Tensor):
Input spectrogram to be time stretched as tensor.
factor (float):
(Optional) Time stretch factor, must be > 0, default to `1`.
method (tensorflow.image.ResizeMethod):
(Optional) Interpolation method, default to `BILINEAR`.
Returns:
tensorflow.Tensor:
Time stretched spectrogram as tensor with same shape.
"""
T = tf.shape(spectrogram)[0]
T_ts = tf.cast(tf.cast(T, tf.float32) * factor, tf.int32)[0]
F = tf.shape(spectrogram)[1]
ts_spec = tf.image.resize_images(
spectrogram,
[T_ts, F],
method=method,
align_corners=True)
spectrogram, [T_ts, F], method=method, align_corners=True
)
return tf.image.resize_image_with_crop_or_pad(ts_spec, T, F)
def random_time_stretch(spectrogram, factor_min=0.9, factor_max=1.1, **kwargs):
""" Time stretch a spectrogram preserving shape with random ratio in
tensorflow. Applies time_stretch to spectrogram with a random ratio drawn
uniformly in [factor_min, factor_max].
:param spectrogram: Input spectrogram to be time stretched as tensor.
:param factor_min: (Optional) Min time stretch factor, default to 0.9.
:param factor_max: (Optional) Max time stretch factor, default to 1.1.
:returns: Randomly time stretched spectrogram as tensor with same shape.
def random_time_stretch(
spectrogram: tf.Tensor, factor_min: float = 0.9, factor_max: float = 1.1, **kwargs
) -> tf.Tensor:
"""
factor = tf.random_uniform(
shape=(1,),
seed=0) * (factor_max - factor_min) + factor_min
Time stretch a spectrogram preserving shape with random ratio in
tensorflow. Applies time_stretch to spectrogram with a random ratio
drawn uniformly in `[factor_min, factor_max]`.
Parameters:
spectrogram (tensorflow.Tensor):
Input spectrogram to be time stretched as tensor.
factor_min (float):
(Optional) Min time stretch factor, default to `0.9`.
factor_max (float):
(Optional) Max time stretch factor, default to `1.1`.
Returns:
tensorflow.Tensor:
Randomly time stretched spectrogram as tensor with same shape.
"""
factor = (
tf.random_uniform(shape=(1,), seed=0) * (factor_max - factor_min) + factor_min
)
return time_stretch(spectrogram, factor=factor, **kwargs)
def pitch_shift(
spectrogram,
semitone_shift=0.0,
method=tf.image.ResizeMethod.BILINEAR):
""" Pitch shift a spectrogram preserving shape in tensorflow. Note that
spectrogram: tf.Tensor,
semitone_shift: float = 0.0,
method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR,
) -> tf.Tensor:
"""
Pitch shift a spectrogram preserving shape in tensorflow. Note that
this is an approximation in the frequency domain.
:param spectrogram: Input spectrogram to be pitch shifted as tensor.
:param semitone_shift: (Optional) Pitch shift in semitone, default to 0.0.
:param mehtod: (Optional) Interpolation method, default to BILINEAR.
:returns: Pitch shifted spectrogram (same shape as spectrogram).
Parameters:
spectrogram (tensorflow.Tensor):
Input spectrogram to be pitch shifted as tensor.
semitone_shift (float):
(Optional) Pitch shift in semitone, default to `0.0`.
method (tensorflow.image.ResizeMethod):
(Optional) Interpolation method, default to `BILINEAR`.
Returns:
tensorflow.Tensor:
Pitch shifted spectrogram (same shape as spectrogram).
"""
factor = 2 ** (semitone_shift / 12.)
factor = 2 ** (semitone_shift / 12.0)
T = tf.shape(spectrogram)[0]
F = tf.shape(spectrogram)[1]
F_ps = tf.cast(tf.cast(F, tf.float32) * factor, tf.int32)[0]
ps_spec = tf.image.resize_images(
spectrogram,
[T, F_ps],
method=method,
align_corners=True)
spectrogram, [T, F_ps], method=method, align_corners=True
)
paddings = [[0, 0], [0, tf.maximum(0, F - F_ps)], [0, 0]]
return tf.pad(ps_spec[:, :F, :], paddings, 'CONSTANT')
return tf.pad(ps_spec[:, :F, :], paddings, "CONSTANT")
def random_pitch_shift(spectrogram, shift_min=-1., shift_max=1., **kwargs):
""" Pitch shift a spectrogram preserving shape with random ratio in
tensorflow. Applies pitch_shift to spectrogram with a random shift
amount (expressed in semitones) drawn uniformly in [shift_min, shift_max].
:param spectrogram: Input spectrogram to be pitch shifted as tensor.
:param shift_min: (Optional) Min pitch shift in semitone, default to -1.
:param shift_max: (Optional) Max pitch shift in semitone, default to 1.
:returns: Randomly pitch shifted spectrogram (same shape as spectrogram).
def random_pitch_shift(
spectrogram: tf.Tensor, shift_min: float = -1.0, shift_max: float = 1.0, **kwargs
) -> tf.Tensor:
"""
semitone_shift = tf.random_uniform(
shape=(1,),
seed=0) * (shift_max - shift_min) + shift_min
Pitch shift a spectrogram preserving shape with random ratio in
tensorflow. Applies pitch_shift to spectrogram with a random shift
amount (expressed in semitones) drawn uniformly in
`[shift_min, shift_max]`.
Parameters:
spectrogram (tensorflow.Tensor):
Input spectrogram to be pitch shifted as tensor.
shift_min (float):
(Optional) Min pitch shift in semitone, default to -1.
shift_max (float):
(Optional) Max pitch shift in semitone, default to 1.
Returns:
tensorflow.Tensor:
Randomly pitch shifted spectrogram (same shape as spectrogram).
"""
semitone_shift = (
tf.random_uniform(shape=(1,), seed=0) * (shift_max - shift_min) + shift_min
)
return pitch_shift(spectrogram, semitone_shift=semitone_shift, **kwargs)

View File

@@ -1,209 +0,0 @@
#!/usr/bin/env python
# coding: utf8
""" This modules provides spleeter command as well as CLI parsing methods. """
import json
import logging
from argparse import ArgumentParser
from tempfile import gettempdir
from os.path import exists, join
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# -i opt specification (separate).
OPT_INPUT = {
'dest': 'inputs',
'nargs': '+',
'help': 'List of input audio filenames',
'required': True
}
# -o opt specification (evaluate and separate).
OPT_OUTPUT = {
'dest': 'output_path',
'default': join(gettempdir(), 'separated_audio'),
'help': 'Path of the output directory to write audio files in'
}
# -f opt specification (separate).
OPT_FORMAT = {
'dest': 'filename_format',
'default': '{filename}/{instrument}.{codec}',
'help': (
'Template string that will be formatted to generated'
'output filename. Such template should be Python formattable'
'string, and could use {filename}, {instrument}, and {codec}'
'variables.'
)
}
# -p opt specification (train, evaluate and separate).
OPT_PARAMS = {
'dest': 'configuration',
'default': 'spleeter:2stems',
'type': str,
'action': 'store',
'help': 'JSON filename that contains params'
}
# -s opt specification (separate).
OPT_OFFSET = {
'dest': 'offset',
'type': float,
'default': 0.,
'help': 'Set the starting offset to separate audio from.'
}
# -d opt specification (separate).
OPT_DURATION = {
'dest': 'duration',
'type': float,
'default': 600.,
'help': (
'Set a maximum duration for processing audio '
'(only separate offset + duration first seconds of '
'the input file)')
}
# -w opt specification (separate)
OPT_STFT_BACKEND = {
'dest': 'stft_backend',
'type': str,
'choices' : ["tensorflow", "librosa", "auto"],
'default': "auto",
'help': 'Who should be in charge of computing the stfts. Librosa is faster than tensorflow on CPU and uses'
' less memory. "auto" will use tensorflow when GPU acceleration is available and librosa when not.'
}
# -c opt specification (separate).
OPT_CODEC = {
'dest': 'codec',
'choices': ('wav', 'mp3', 'ogg', 'm4a', 'wma', 'flac'),
'default': 'wav',
'help': 'Audio codec to be used for the separated output'
}
# -b opt specification (separate).
OPT_BITRATE = {
'dest': 'bitrate',
'default': '128k',
'help': 'Audio bitrate to be used for the separated output'
}
# -m opt specification (evaluate and separate).
OPT_MWF = {
'dest': 'MWF',
'action': 'store_const',
'const': True,
'default': False,
'help': 'Whether to use multichannel Wiener filtering for separation',
}
# --mus_dir opt specification (evaluate).
OPT_MUSDB = {
'dest': 'mus_dir',
'type': str,
'required': True,
'help': 'Path to folder with musDB'
}
# -d opt specification (train).
OPT_DATA = {
'dest': 'audio_path',
'type': str,
'required': True,
'help': 'Path of the folder containing audio data for training'
}
# -a opt specification (train, evaluate and separate).
OPT_ADAPTER = {
'dest': 'audio_adapter',
'type': str,
'help': 'Name of the audio adapter to use for audio I/O'
}
# -a opt specification (train, evaluate and separate).
OPT_VERBOSE = {
'action': 'store_true',
'help': 'Shows verbose logs'
}
def _add_common_options(parser):
""" Add common option to the given parser.
:param parser: Parser to add common opt to.
"""
parser.add_argument('-a', '--adapter', **OPT_ADAPTER)
parser.add_argument('-p', '--params_filename', **OPT_PARAMS)
parser.add_argument('--verbose', **OPT_VERBOSE)
def _create_train_parser(parser_factory):
""" Creates an argparser for training command
:param parser_factory: Factory to use to create parser instance.
:returns: Created and configured parser.
"""
parser = parser_factory('train', help='Train a source separation model')
_add_common_options(parser)
parser.add_argument('-d', '--data', **OPT_DATA)
return parser
def _create_evaluate_parser(parser_factory):
""" Creates an argparser for evaluation command
:param parser_factory: Factory to use to create parser instance.
:returns: Created and configured parser.
"""
parser = parser_factory(
'evaluate',
help='Evaluate a model on the musDB test dataset')
_add_common_options(parser)
parser.add_argument('-o', '--output_path', **OPT_OUTPUT)
parser.add_argument('--mus_dir', **OPT_MUSDB)
parser.add_argument('-m', '--mwf', **OPT_MWF)
parser.add_argument('-B', '--stft-backend', **OPT_STFT_BACKEND)
return parser
def _create_separate_parser(parser_factory):
""" Creates an argparser for separation command
:param parser_factory: Factory to use to create parser instance.
:returns: Created and configured parser.
"""
parser = parser_factory('separate', help='Separate audio files')
_add_common_options(parser)
parser.add_argument('-i', '--inputs', **OPT_INPUT)
parser.add_argument('-o', '--output_path', **OPT_OUTPUT)
parser.add_argument('-f', '--filename_format', **OPT_FORMAT)
parser.add_argument('-d', '--duration', **OPT_DURATION)
parser.add_argument('-s', '--offset', **OPT_OFFSET)
parser.add_argument('-c', '--codec', **OPT_CODEC)
parser.add_argument('-b', '--birate', **OPT_BITRATE)
parser.add_argument('-m', '--mwf', **OPT_MWF)
parser.add_argument('-B', '--stft-backend', **OPT_STFT_BACKEND)
return parser
def create_argument_parser():
""" Creates overall command line parser for Spleeter.
:returns: Created argument parser.
"""
parser = ArgumentParser(prog='spleeter')
subparsers = parser.add_subparsers()
subparsers.dest = 'command'
subparsers.required = True
_create_separate_parser(subparsers.add_parser)
_create_train_parser(subparsers.add_parser)
_create_evaluate_parser(subparsers.add_parser)
return parser

View File

@@ -1,167 +0,0 @@
#!/usr/bin/env python
# coding: utf8
"""
Entrypoint provider for performing model evaluation.
Evaluation is performed against musDB dataset.
USAGE: python -m spleeter evaluate \
-p /path/to/params \
-o /path/to/output/dir \
[-m] \
--mus_dir /path/to/musdb dataset
"""
import sys
import json
from argparse import Namespace
from itertools import product
from glob import glob
from os.path import join, exists
# pylint: disable=import-error
import numpy as np
import pandas as pd
# pylint: enable=import-error
from .separate import entrypoint as separate_entrypoint
from ..utils.logging import get_logger
try:
import musdb
import museval
except ImportError:
logger = get_logger()
logger.error('Extra dependencies musdb and museval not found')
logger.error('Please install musdb and museval first, abort')
sys.exit(1)
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
_SPLIT = 'test'
_MIXTURE = 'mixture.wav'
_AUDIO_DIRECTORY = 'audio'
_METRICS_DIRECTORY = 'metrics'
_INSTRUMENTS = ('vocals', 'drums', 'bass', 'other')
_METRICS = ('SDR', 'SAR', 'SIR', 'ISR')
def _separate_evaluation_dataset(arguments, musdb_root_directory, params):
""" Performs audio separation on the musdb dataset from
the given directory and params.
:param arguments: Entrypoint arguments.
:param musdb_root_directory: Directory to retrieve dataset from.
:param params: Spleeter configuration to apply to separation.
:returns: Separation output directory path.
"""
songs = glob(join(musdb_root_directory, _SPLIT, '*/'))
mixtures = [join(song, _MIXTURE) for song in songs]
audio_output_directory = join(
arguments.output_path,
_AUDIO_DIRECTORY)
separate_entrypoint(
Namespace(
audio_adapter=arguments.audio_adapter,
configuration=arguments.configuration,
inputs=mixtures,
output_path=join(audio_output_directory, _SPLIT),
filename_format='{foldername}/{instrument}.{codec}',
codec='wav',
duration=600.,
offset=0.,
bitrate='128k',
MWF=arguments.MWF,
verbose=arguments.verbose,
stft_backend=arguments.stft_backend),
params)
return audio_output_directory
def _compute_musdb_metrics(
arguments,
musdb_root_directory,
audio_output_directory):
""" Generates musdb metrics fro previsouly computed audio estimation.
:param arguments: Entrypoint arguments.
:param audio_output_directory: Directory to get audio estimation from.
:returns: Path of generated metrics directory.
"""
metrics_output_directory = join(
arguments.output_path,
_METRICS_DIRECTORY)
get_logger().info('Starting musdb evaluation (this could be long) ...')
dataset = musdb.DB(
root=musdb_root_directory,
is_wav=True,
subsets=[_SPLIT])
museval.eval_mus_dir(
dataset=dataset,
estimates_dir=audio_output_directory,
output_dir=metrics_output_directory)
get_logger().info('musdb evaluation done')
return metrics_output_directory
def _compile_metrics(metrics_output_directory):
""" Compiles metrics from given directory and returns
results as dict.
:param metrics_output_directory: Directory to get metrics from.
:returns: Compiled metrics as dict.
"""
songs = glob(join(metrics_output_directory, 'test/*.json'))
index = pd.MultiIndex.from_tuples(
product(_INSTRUMENTS, _METRICS),
names=['instrument', 'metric'])
pd.DataFrame([], index=['config1', 'config2'], columns=index)
metrics = {
instrument: {k: [] for k in _METRICS}
for instrument in _INSTRUMENTS}
for song in songs:
with open(song, 'r') as stream:
data = json.load(stream)
for target in data['targets']:
instrument = target['name']
for metric in _METRICS:
sdr_med = np.median([
frame['metrics'][metric]
for frame in target['frames']
if not np.isnan(frame['metrics'][metric])])
metrics[instrument][metric].append(sdr_med)
return metrics
def entrypoint(arguments, params):
""" Command entrypoint.
:param arguments: Command line parsed argument as argparse.Namespace.
:param params: Deserialized JSON configuration file provided in CLI args.
"""
# Parse and check musdb directory.
musdb_root_directory = arguments.mus_dir
if not exists(musdb_root_directory):
raise IOError(f'musdb directory {musdb_root_directory} not found')
# Separate musdb sources.
audio_output_directory = _separate_evaluation_dataset(
arguments,
musdb_root_directory,
params)
# Compute metrics with musdb.
metrics_output_directory = _compute_musdb_metrics(
arguments,
musdb_root_directory,
audio_output_directory)
# Compute and pretty print median metrics.
metrics = _compile_metrics(metrics_output_directory)
for instrument, metric in metrics.items():
get_logger().info('%s:', instrument)
for metric, value in metric.items():
get_logger().info('%s: %s', metric, f'{np.median(value):.3f}')
return metrics

View File

@@ -1,47 +0,0 @@
#!/usr/bin/env python
# coding: utf8
"""
Entrypoint provider for performing source separation.
USAGE: python -m spleeter separate \
-p /path/to/params \
-i inputfile1 inputfile2 ... inputfilen
-o /path/to/output/dir \
-i /path/to/audio1.wav /path/to/audio2.mp3
"""
from ..audio.adapter import get_audio_adapter
from ..separator import Separator
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def entrypoint(arguments, params):
""" Command entrypoint.
:param arguments: Command line parsed argument as argparse.Namespace.
:param params: Deserialized JSON configuration file provided in CLI args.
"""
# TODO: check with output naming.
audio_adapter = get_audio_adapter(arguments.audio_adapter)
separator = Separator(
arguments.configuration,
MWF=arguments.MWF,
stft_backend=arguments.stft_backend)
for filename in arguments.inputs:
separator.separate_to_file(
filename,
arguments.output_path,
audio_adapter=audio_adapter,
offset=arguments.offset,
duration=arguments.duration,
codec=arguments.codec,
bitrate=arguments.bitrate,
filename_format=arguments.filename_format,
synchronous=False
)
separator.join()

View File

@@ -1,100 +0,0 @@
#!/usr/bin/env python
# coding: utf8
"""
Entrypoint provider for performing model training.
USAGE: python -m spleeter train -p /path/to/params
"""
from functools import partial
# pylint: disable=import-error
import tensorflow as tf
# pylint: enable=import-error
from ..audio.adapter import get_audio_adapter
from ..dataset import get_training_dataset, get_validation_dataset
from ..model import model_fn
from ..model.provider import ModelProvider
from ..utils.logging import get_logger
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def _create_estimator(params):
""" Creates estimator.
:param params: TF params to build estimator from.
:returns: Built estimator.
"""
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params['model_dir'],
params=params,
config=tf.estimator.RunConfig(
save_checkpoints_steps=params['save_checkpoints_steps'],
tf_random_seed=params['random_seed'],
save_summary_steps=params['save_summary_steps'],
session_config=session_config,
log_step_count_steps=10,
keep_checkpoint_max=2))
return estimator
def _create_train_spec(params, audio_adapter, audio_path):
""" Creates train spec.
:param params: TF params to build spec from.
:returns: Built train spec.
"""
input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
train_spec = tf.estimator.TrainSpec(
input_fn=input_fn,
max_steps=params['train_max_steps'])
return train_spec
def _create_evaluation_spec(params, audio_adapter, audio_path):
""" Setup eval spec evaluating ever n seconds
:param params: TF params to build spec from.
:returns: Built evaluation spec.
"""
input_fn = partial(
get_validation_dataset,
params,
audio_adapter,
audio_path)
evaluation_spec = tf.estimator.EvalSpec(
input_fn=input_fn,
steps=None,
throttle_secs=params['throttle_secs'])
return evaluation_spec
def entrypoint(arguments, params):
""" Command entrypoint.
:param arguments: Command line parsed argument as argparse.Namespace.
:param params: Deserialized JSON configuration file provided in CLI args.
"""
audio_adapter = get_audio_adapter(arguments.audio_adapter)
audio_path = arguments.audio_path
estimator = _create_estimator(params)
train_spec = _create_train_spec(params, audio_adapter, audio_path)
evaluation_spec = _create_evaluation_spec(
params,
audio_adapter,
audio_path)
get_logger().info('Start model training')
tf.estimator.train_and_evaluate(
estimator,
train_spec,
evaluation_spec)
ModelProvider.writeProbe(params['model_dir'])
get_logger().info('Model training done')

View File

@@ -14,87 +14,110 @@
(ground truth)
"""
import time
import os
from os.path import exists, join, sep as SEPARATOR
import time
from os.path import exists
from os.path import sep as SEPARATOR
from typing import Any, Dict, Optional
# pyright: reportMissingImports=false
# pylint: disable=import-error
import pandas as pd
import numpy as np
import tensorflow as tf
# pylint: enable=import-error
from .audio.convertor import (
db_uint_spectrogram_to_gain,
spectrogram_to_db_uint)
from .audio.adapter import AudioAdapter
from .audio.convertor import db_uint_spectrogram_to_gain, spectrogram_to_db_uint
from .audio.spectrogram import (
compute_spectrogram_tf,
random_pitch_shift,
random_time_stretch)
from .utils.logging import get_logger
random_time_stretch,
)
from .utils.logging import logger
from .utils.tensor import (
check_tensor_shape,
dataset_from_csv,
set_tensor_shape,
sync_apply)
sync_apply,
)
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# pylint: enable=import-error
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
# Default audio parameters to use.
DEFAULT_AUDIO_PARAMS = {
'instrument_list': ('vocals', 'accompaniment'),
'mix_name': 'mix',
'sample_rate': 44100,
'frame_length': 4096,
'frame_step': 1024,
'T': 512,
'F': 1024
DEFAULT_AUDIO_PARAMS: Dict = {
"instrument_list": ("vocals", "accompaniment"),
"mix_name": "mix",
"sample_rate": 44100,
"frame_length": 4096,
"frame_step": 1024,
"T": 512,
"F": 1024,
}
def get_training_dataset(audio_params, audio_adapter, audio_path):
""" Builds training dataset.
def get_training_dataset(
audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str
) -> Any:
"""
Builds training dataset.
:param audio_params: Audio parameters.
:param audio_adapter: Adapter to load audio from.
:param audio_path: Path of directory containing audio.
:returns: Built dataset.
Parameters:
audio_params (Dict):
Audio parameters.
audio_adapter (AudioAdapter):
Adapter to load audio from.
audio_path (str):
Path of directory containing audio.
Returns:
Any:
Built dataset.
"""
builder = DatasetBuilder(
audio_params,
audio_adapter,
audio_path,
chunk_duration=audio_params.get('chunk_duration', 20.0),
random_seed=audio_params.get('random_seed', 0))
chunk_duration=audio_params.get("chunk_duration", 20.0),
random_seed=audio_params.get("random_seed", 0),
)
return builder.build(
audio_params.get('train_csv'),
cache_directory=audio_params.get('training_cache'),
batch_size=audio_params.get('batch_size'),
n_chunks_per_song=audio_params.get('n_chunks_per_song', 2),
audio_params.get("train_csv"),
cache_directory=audio_params.get("training_cache"),
batch_size=audio_params.get("batch_size"),
n_chunks_per_song=audio_params.get("n_chunks_per_song", 2),
random_data_augmentation=False,
convert_to_uint=True,
wait_for_cache=False)
wait_for_cache=False,
)
def get_validation_dataset(audio_params, audio_adapter, audio_path):
""" Builds validation dataset.
def get_validation_dataset(
audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str
) -> Any:
"""
Builds validation dataset.
:param audio_params: Audio parameters.
:param audio_adapter: Adapter to load audio from.
:param audio_path: Path of directory containing audio.
:returns: Built dataset.
Parameters:
audio_params (Dict):
Audio parameters.
audio_adapter (AudioAdapter):
Adapter to load audio from.
audio_path (str):
Path of directory containing audio.
Returns:
Any:
Built dataset.
"""
builder = DatasetBuilder(
audio_params,
audio_adapter,
audio_path,
chunk_duration=12.0)
audio_params, audio_adapter, audio_path, chunk_duration=12.0
)
return builder.build(
audio_params.get('validation_csv'),
batch_size=audio_params.get('batch_size'),
cache_directory=audio_params.get('validation_cache'),
audio_params.get("validation_csv"),
batch_size=audio_params.get("batch_size"),
cache_directory=audio_params.get("validation_cache"),
convert_to_uint=True,
infinite_generator=False,
n_chunks_per_song=1,
@@ -108,127 +131,175 @@ def get_validation_dataset(audio_params, audio_adapter, audio_path):
class InstrumentDatasetBuilder(object):
""" Instrument based filter and mapper provider. """
def __init__(self, parent, instrument):
""" Default constructor.
def __init__(self, parent, instrument) -> None:
"""
Default constructor.
:param parent: Parent dataset builder.
:param instrument: Target instrument.
Parameters:
parent:
Parent dataset builder.
instrument:
Target instrument.
"""
self._parent = parent
self._instrument = instrument
self._spectrogram_key = f'{instrument}_spectrogram'
self._min_spectrogram_key = f'min_{instrument}_spectrogram'
self._max_spectrogram_key = f'max_{instrument}_spectrogram'
self._spectrogram_key = f"{instrument}_spectrogram"
self._min_spectrogram_key = f"min_{instrument}_spectrogram"
self._max_spectrogram_key = f"max_{instrument}_spectrogram"
def load_waveform(self, sample):
""" Load waveform for given sample. """
return dict(sample, **self._parent._audio_adapter.load_tf_waveform(
sample[f'{self._instrument}_path'],
offset=sample['start'],
duration=self._parent._chunk_duration,
sample_rate=self._parent._sample_rate,
waveform_name='waveform'))
return dict(
sample,
**self._parent._audio_adapter.load_tf_waveform(
sample[f"{self._instrument}_path"],
offset=sample["start"],
duration=self._parent._chunk_duration,
sample_rate=self._parent._sample_rate,
waveform_name="waveform",
),
)
def compute_spectrogram(self, sample):
""" Compute spectrogram of the given sample. """
return dict(sample, **{
self._spectrogram_key: compute_spectrogram_tf(
sample['waveform'],
frame_length=self._parent._frame_length,
frame_step=self._parent._frame_step,
spec_exponent=1.,
window_exponent=1.)})
return dict(
sample,
**{
self._spectrogram_key: compute_spectrogram_tf(
sample["waveform"],
frame_length=self._parent._frame_length,
frame_step=self._parent._frame_step,
spec_exponent=1.0,
window_exponent=1.0,
)
},
)
def filter_frequencies(self, sample):
""" """
return dict(sample, **{
self._spectrogram_key:
sample[self._spectrogram_key][:, :self._parent._F, :]})
return dict(
sample,
**{
self._spectrogram_key: sample[self._spectrogram_key][
:, : self._parent._F, :
]
},
)
def convert_to_uint(self, sample):
""" Convert given sample from float to unit. """
return dict(sample, **spectrogram_to_db_uint(
sample[self._spectrogram_key],
tensor_key=self._spectrogram_key,
min_key=self._min_spectrogram_key,
max_key=self._max_spectrogram_key))
return dict(
sample,
**spectrogram_to_db_uint(
sample[self._spectrogram_key],
tensor_key=self._spectrogram_key,
min_key=self._min_spectrogram_key,
max_key=self._max_spectrogram_key,
),
)
def filter_infinity(self, sample):
""" Filter infinity sample. """
return tf.logical_not(
tf.math.is_inf(
sample[self._min_spectrogram_key]))
return tf.logical_not(tf.math.is_inf(sample[self._min_spectrogram_key]))
def convert_to_float32(self, sample):
""" Convert given sample from unit to float. """
return dict(sample, **{
self._spectrogram_key: db_uint_spectrogram_to_gain(
sample[self._spectrogram_key],
sample[self._min_spectrogram_key],
sample[self._max_spectrogram_key])})
return dict(
sample,
**{
self._spectrogram_key: db_uint_spectrogram_to_gain(
sample[self._spectrogram_key],
sample[self._min_spectrogram_key],
sample[self._max_spectrogram_key],
)
},
)
def time_crop(self, sample):
""" """
def start(sample):
""" mid_segment_start """
return tf.cast(
tf.maximum(
tf.shape(sample[self._spectrogram_key])[0]
/ 2 - self._parent._T / 2, 0),
tf.int32)
return dict(sample, **{
self._spectrogram_key: sample[self._spectrogram_key][
start(sample):start(sample) + self._parent._T, :, :]})
tf.shape(sample[self._spectrogram_key])[0] / 2
- self._parent._T / 2,
0,
),
tf.int32,
)
return dict(
sample,
**{
self._spectrogram_key: sample[self._spectrogram_key][
start(sample) : start(sample) + self._parent._T, :, :
]
},
)
def filter_shape(self, sample):
""" Filter badly shaped sample. """
return check_tensor_shape(
sample[self._spectrogram_key], (
self._parent._T, self._parent._F, 2))
sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2)
)
def reshape_spectrogram(self, sample):
""" """
return dict(sample, **{
self._spectrogram_key: set_tensor_shape(
sample[self._spectrogram_key],
(self._parent._T, self._parent._F, 2))})
""" Reshape given sample. """
return dict(
sample,
**{
self._spectrogram_key: set_tensor_shape(
sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2)
)
},
)
class DatasetBuilder(object):
"""
TO BE DOCUMENTED.
"""
# Margin at beginning and end of songs in seconds.
MARGIN = 0.5
MARGIN: float = 0.5
""" Margin at beginning and end of songs in seconds. """
# Wait period for cache (in seconds).
WAIT_PERIOD = 60
WAIT_PERIOD: int = 60
""" Wait period for cache (in seconds). """
def __init__(
self,
audio_params, audio_adapter, audio_path,
random_seed=0, chunk_duration=20.0):
""" Default constructor.
self,
audio_params: Dict,
audio_adapter: AudioAdapter,
audio_path: str,
random_seed: int = 0,
chunk_duration: float = 20.0,
) -> None:
"""
Default constructor.
NOTE: Probably need for AudioAdapter.
:param audio_params: Audio parameters to use.
:param audio_adapter: Audio adapter to use.
:param audio_path:
:param random_seed:
:param chunk_duration:
Parameters:
audio_params (Dict):
Audio parameters to use.
audio_adapter (AudioAdapter):
Audio adapter to use.
audio_path (str):
random_seed (int):
chunk_duration (float):
"""
# Length of segment in frames (if fs=22050 and
# frame_step=512, then T=512 corresponds to 11.89s)
self._T = audio_params['T']
self._T = audio_params["T"]
# Number of frequency bins to be used (should
# be less than frame_length/2 + 1)
self._F = audio_params['F']
self._sample_rate = audio_params['sample_rate']
self._frame_length = audio_params['frame_length']
self._frame_step = audio_params['frame_step']
self._mix_name = audio_params['mix_name']
self._instruments = [self._mix_name] + audio_params['instrument_list']
self._F = audio_params["F"]
self._sample_rate = audio_params["sample_rate"]
self._frame_length = audio_params["frame_length"]
self._frame_step = audio_params["frame_step"]
self._mix_name = audio_params["mix_name"]
self._instruments = [self._mix_name] + audio_params["instrument_list"]
self._instrument_builders = None
self._chunk_duration = chunk_duration
self._audio_adapter = audio_adapter
@@ -238,130 +309,202 @@ class DatasetBuilder(object):
def expand_path(self, sample):
""" Expands audio paths for the given sample. """
return dict(sample, **{f'{instrument}_path': tf.strings.join(
(self._audio_path, sample[f'{instrument}_path']), SEPARATOR)
for instrument in self._instruments})
return dict(
sample,
**{
f"{instrument}_path": tf.strings.join(
(self._audio_path, sample[f"{instrument}_path"]), SEPARATOR
)
for instrument in self._instruments
},
)
def filter_error(self, sample):
""" Filter errored sample. """
return tf.logical_not(sample['waveform_error'])
return tf.logical_not(sample["waveform_error"])
def filter_waveform(self, sample):
""" Filter waveform from sample. """
return {k: v for k, v in sample.items() if not k == 'waveform'}
return {k: v for k, v in sample.items() if not k == "waveform"}
def harmonize_spectrogram(self, sample):
""" Ensure same size for vocals and mix spectrograms. """
def _reduce(sample):
return tf.reduce_min([
tf.shape(sample[f'{instrument}_spectrogram'])[0]
for instrument in self._instruments])
return dict(sample, **{
f'{instrument}_spectrogram':
sample[f'{instrument}_spectrogram'][:_reduce(sample), :, :]
for instrument in self._instruments})
return tf.reduce_min(
[
tf.shape(sample[f"{instrument}_spectrogram"])[0]
for instrument in self._instruments
]
)
return dict(
sample,
**{
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"][
: _reduce(sample), :, :
]
for instrument in self._instruments
},
)
def filter_short_segments(self, sample):
""" Filter out too short segment. """
return tf.reduce_any([
tf.shape(sample[f'{instrument}_spectrogram'])[0] >= self._T
for instrument in self._instruments])
return tf.reduce_any(
[
tf.shape(sample[f"{instrument}_spectrogram"])[0] >= self._T
for instrument in self._instruments
]
)
def random_time_crop(self, sample):
""" Random time crop of 11.88s. """
return dict(sample, **sync_apply({
f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram']
for instrument in self._instruments},
lambda x: tf.image.random_crop(
x, (self._T, len(self._instruments) * self._F, 2),
seed=self._random_seed)))
return dict(
sample,
**sync_apply(
{
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
for instrument in self._instruments
},
lambda x: tf.image.random_crop(
x,
(self._T, len(self._instruments) * self._F, 2),
seed=self._random_seed,
),
),
)
def random_time_stretch(self, sample):
""" Randomly time stretch the given sample. """
return dict(sample, **sync_apply({
f'{instrument}_spectrogram':
sample[f'{instrument}_spectrogram']
for instrument in self._instruments},
lambda x: random_time_stretch(
x, factor_min=0.9, factor_max=1.1)))
return dict(
sample,
**sync_apply(
{
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
for instrument in self._instruments
},
lambda x: random_time_stretch(x, factor_min=0.9, factor_max=1.1),
),
)
def random_pitch_shift(self, sample):
""" Randomly pitch shift the given sample. """
return dict(sample, **sync_apply({
f'{instrument}_spectrogram':
sample[f'{instrument}_spectrogram']
for instrument in self._instruments},
lambda x: random_pitch_shift(
x, shift_min=-1.0, shift_max=1.0), concat_axis=0))
return dict(
sample,
**sync_apply(
{
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
for instrument in self._instruments
},
lambda x: random_pitch_shift(x, shift_min=-1.0, shift_max=1.0),
concat_axis=0,
),
)
def map_features(self, sample):
""" Select features and annotation of the given sample. """
input_ = {
f'{self._mix_name}_spectrogram':
sample[f'{self._mix_name}_spectrogram']}
f"{self._mix_name}_spectrogram": sample[f"{self._mix_name}_spectrogram"]
}
output = {
f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram']
for instrument in self._audio_params['instrument_list']}
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
for instrument in self._audio_params["instrument_list"]
}
return (input_, output)
def compute_segments(self, dataset, n_chunks_per_song):
""" Computes segments for each song of the dataset.
def compute_segments(self, dataset: Any, n_chunks_per_song: int) -> Any:
"""
Computes segments for each song of the dataset.
:param dataset: Dataset to compute segments for.
:param n_chunks_per_song: Number of segment per song to compute.
:returns: Segmented dataset.
Parameters:
dataset (Any):
Dataset to compute segments for.
n_chunks_per_song (int):
Number of segment per song to compute.
Returns:
Any:
Segmented dataset.
"""
if n_chunks_per_song <= 0:
raise ValueError('n_chunks_per_song must be positif')
raise ValueError("n_chunks_per_song must be positif")
datasets = []
for k in range(n_chunks_per_song):
if n_chunks_per_song > 1:
datasets.append(
dataset.map(lambda sample: dict(sample, start=tf.maximum(
k * (
sample['duration'] - self._chunk_duration - 2
* self.MARGIN) / (n_chunks_per_song - 1)
+ self.MARGIN, 0))))
dataset.map(
lambda sample: dict(
sample,
start=tf.maximum(
k
* (
sample["duration"]
- self._chunk_duration
- 2 * self.MARGIN
)
/ (n_chunks_per_song - 1)
+ self.MARGIN,
0,
),
)
)
)
elif n_chunks_per_song == 1: # Take central segment.
datasets.append(
dataset.map(lambda sample: dict(sample, start=tf.maximum(
sample['duration'] / 2 - self._chunk_duration / 2,
0))))
dataset.map(
lambda sample: dict(
sample,
start=tf.maximum(
sample["duration"] / 2 - self._chunk_duration / 2, 0
),
)
)
)
dataset = datasets[-1]
for d in datasets[:-1]:
dataset = dataset.concatenate(d)
return dataset
@property
def instruments(self):
""" Instrument dataset builder generator.
def instruments(self) -> Any:
"""
Instrument dataset builder generator.
:yield InstrumentBuilder instance.
Yields:
Any:
InstrumentBuilder instance.
"""
if self._instrument_builders is None:
self._instrument_builders = []
for instrument in self._instruments:
self._instrument_builders.append(
InstrumentDatasetBuilder(self, instrument))
InstrumentDatasetBuilder(self, instrument)
)
for builder in self._instrument_builders:
yield builder
def cache(self, dataset, cache, wait):
""" Cache the given dataset if cache is enabled. Eventually waits for
cache to be available (useful if another process is already computing
cache) if provided wait flag is True.
def cache(self, dataset: Any, cache: str, wait: bool) -> Any:
"""
Cache the given dataset if cache is enabled. Eventually waits for
cache to be available (useful if another process is already
computing cache) if provided wait flag is `True`.
:param dataset: Dataset to be cached if cache is required.
:param cache: Path of cache directory to be used, None if no cache.
:param wait: If caching is enabled, True is cache should be waited.
:returns: Cached dataset if needed, original dataset otherwise.
Parameters:
dataset (Any):
Dataset to be cached if cache is required.
cache (str):
Path of cache directory to be used, None if no cache.
wait (bool):
If caching is enabled, True is cache should be waited.
Returns:
Any:
Cached dataset if needed, original dataset otherwise.
"""
if cache is not None:
if wait:
while not exists(f'{cache}.index'):
get_logger().info(
'Cache not available, wait %s',
self.WAIT_PERIOD)
while not exists(f"{cache}.index"):
logger.info(f"Cache not available, wait {self.WAIT_PERIOD}")
time.sleep(self.WAIT_PERIOD)
cache_path = os.path.split(cache)[0]
os.makedirs(cache_path, exist_ok=True)
@@ -369,11 +512,19 @@ class DatasetBuilder(object):
return dataset
def build(
self, csv_path,
batch_size=8, shuffle=True, convert_to_uint=True,
random_data_augmentation=False, random_time_crop=True,
infinite_generator=True, cache_directory=None,
wait_for_cache=False, num_parallel_calls=4, n_chunks_per_song=2,):
self,
csv_path: str,
batch_size: int = 8,
shuffle: bool = True,
convert_to_uint: bool = True,
random_data_augmentation: bool = False,
random_time_crop: bool = True,
infinite_generator: bool = True,
cache_directory: Optional[str] = None,
wait_for_cache: bool = False,
num_parallel_calls: int = 4,
n_chunks_per_song: float = 2,
) -> Any:
"""
TO BE DOCUMENTED.
"""
@@ -385,7 +536,8 @@ class DatasetBuilder(object):
buffer_size=200000,
seed=self._random_seed,
# useless since it is cached :
reshuffle_each_iteration=True)
reshuffle_each_iteration=True,
)
# Expand audio path.
dataset = dataset.map(self.expand_path)
# Load waveform, compute spectrogram, and filtering error,
@@ -393,11 +545,11 @@ class DatasetBuilder(object):
N = num_parallel_calls
for instrument in self.instruments:
dataset = (
dataset
.map(instrument.load_waveform, num_parallel_calls=N)
dataset.map(instrument.load_waveform, num_parallel_calls=N)
.filter(self.filter_error)
.map(instrument.compute_spectrogram, num_parallel_calls=N)
.map(instrument.filter_frequencies))
.map(instrument.filter_frequencies)
)
dataset = dataset.map(self.filter_waveform)
# Convert to uint before caching in order to save space.
if convert_to_uint:
@@ -428,26 +580,25 @@ class DatasetBuilder(object):
# after croping but before converting back to float.
if shuffle:
dataset = dataset.shuffle(
buffer_size=256, seed=self._random_seed,
reshuffle_each_iteration=True)
buffer_size=256, seed=self._random_seed, reshuffle_each_iteration=True
)
# Convert back to float32
if convert_to_uint:
for instrument in self.instruments:
dataset = dataset.map(
instrument.convert_to_float32, num_parallel_calls=N)
instrument.convert_to_float32, num_parallel_calls=N
)
M = 8 # Parallel call post caching.
# Must be applied with the same factor on mix and vocals.
if random_data_augmentation:
dataset = (
dataset
.map(self.random_time_stretch, num_parallel_calls=M)
.map(self.random_pitch_shift, num_parallel_calls=M))
dataset = dataset.map(self.random_time_stretch, num_parallel_calls=M).map(
self.random_pitch_shift, num_parallel_calls=M
)
# Filter by shape (remove badly shaped tensors).
for instrument in self.instruments:
dataset = (
dataset
.filter(instrument.filter_shape)
.map(instrument.reshape_spectrogram))
dataset = dataset.filter(instrument.filter_shape).map(
instrument.reshape_spectrogram
)
# Select features and annotation.
dataset = dataset.map(self.map_features)
# Make batch (done after selection to avoid

View File

@@ -5,17 +5,19 @@
import importlib
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
from tensorflow.signal import stft, inverse_stft, hann_window
# pylint: enable=import-error
from tensorflow.signal import hann_window, inverse_stft, stft
from ..utils.tensor import pad_and_partition, pad_and_reshape
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# pylint: enable=import-error
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
placeholder = tf.compat.v1.placeholder
@@ -23,29 +25,28 @@ placeholder = tf.compat.v1.placeholder
def get_model_function(model_type):
"""
Get tensorflow function of the model to be applied to the input tensor.
For instance "unet.softmax_unet" will return the softmax_unet function
in the "unet.py" submodule of the current module (spleeter.model).
Get tensorflow function of the model to be applied to the input tensor.
For instance "unet.softmax_unet" will return the softmax_unet function
in the "unet.py" submodule of the current module (spleeter.model).
Params:
- model_type: str
the relative module path to the model function.
Params:
- model_type: str
the relative module path to the model function.
Returns:
A tensorflow function to be applied to the input tensor to get the
multitrack output.
Returns:
A tensorflow function to be applied to the input tensor to get the
multitrack output.
"""
relative_path_to_module = '.'.join(model_type.split('.')[:-1])
model_name = model_type.split('.')[-1]
main_module = '.'.join((__name__, 'functions'))
path_to_module = f'{main_module}.{relative_path_to_module}'
relative_path_to_module = ".".join(model_type.split(".")[:-1])
model_name = model_type.split(".")[-1]
main_module = ".".join((__name__, "functions"))
path_to_module = f"{main_module}.{relative_path_to_module}"
module = importlib.import_module(path_to_module)
model_function = getattr(module, model_name)
return model_function
class InputProvider(object):
def __init__(self, params):
self.params = params
@@ -61,16 +62,16 @@ class InputProvider(object):
class WaveformInputProvider(InputProvider):
@property
def input_names(self):
return ["audio_id", "waveform"]
def get_input_dict_placeholders(self):
shape = (None, self.params['n_channels'])
shape = (None, self.params["n_channels"])
features = {
'waveform': placeholder(tf.float32, shape=shape, name="waveform"),
'audio_id': placeholder(tf.string, name="audio_id")}
"waveform": placeholder(tf.float32, shape=shape, name="waveform"),
"audio_id": placeholder(tf.string, name="audio_id"),
}
return features
def get_feed_dict(self, features, waveform, audio_id):
@@ -78,7 +79,6 @@ class WaveformInputProvider(InputProvider):
class SpectralInputProvider(InputProvider):
def __init__(self, params):
super().__init__(params)
self.stft_input_name = "{}_stft".format(self.params["mix_name"])
@@ -89,11 +89,17 @@ class SpectralInputProvider(InputProvider):
def get_input_dict_placeholders(self):
features = {
self.stft_input_name: placeholder(tf.complex64,
shape=(None, self.params["frame_length"]//2+1,
self.params['n_channels']),
name=self.stft_input_name),
'audio_id': placeholder(tf.string, name="audio_id")}
self.stft_input_name: placeholder(
tf.complex64,
shape=(
None,
self.params["frame_length"] // 2 + 1,
self.params["n_channels"],
),
name=self.stft_input_name,
),
"audio_id": placeholder(tf.string, name="audio_id"),
}
return features
def get_feed_dict(self, features, stft, audio_id):
@@ -101,11 +107,13 @@ class SpectralInputProvider(InputProvider):
class InputProviderFactory(object):
@staticmethod
def get(params):
stft_backend = params["stft_backend"]
assert stft_backend in ("tensorflow", "librosa"), "Unexpected backend {}".format(stft_backend)
assert stft_backend in (
"tensorflow",
"librosa",
), "Unexpected backend {}".format(stft_backend)
if stft_backend == "tensorflow":
return WaveformInputProvider(params)
else:
@@ -113,7 +121,7 @@ class InputProviderFactory(object):
class EstimatorSpecBuilder(object):
""" A builder class that allows to builds a multitrack unet model
"""A builder class that allows to builds a multitrack unet model
estimator. The built model estimator has a different behaviour when
used in a train/eval mode and in predict mode.
@@ -137,22 +145,22 @@ class EstimatorSpecBuilder(object):
"""
# Supported model functions.
DEFAULT_MODEL = 'unet.unet'
DEFAULT_MODEL = "unet.unet"
# Supported loss functions.
L1_MASK = 'L1_mask'
WEIGHTED_L1_MASK = 'weighted_L1_mask'
L1_MASK = "L1_mask"
WEIGHTED_L1_MASK = "weighted_L1_mask"
# Supported optimizers.
ADADELTA = 'Adadelta'
SGD = 'SGD'
ADADELTA = "Adadelta"
SGD = "SGD"
# Math constants.
WINDOW_COMPENSATION_FACTOR = 2./3.
WINDOW_COMPENSATION_FACTOR = 2.0 / 3.0
EPSILON = 1e-10
def __init__(self, features, params):
""" Default constructor. Depending on built model
"""Default constructor. Depending on built model
usage, the provided features should be different:
* In train/eval mode: features is a dictionary with a
@@ -169,20 +177,20 @@ class EstimatorSpecBuilder(object):
self._features = features
self._params = params
# Get instrument name.
self._mix_name = params['mix_name']
self._instruments = params['instrument_list']
self._mix_name = params["mix_name"]
self._instruments = params["instrument_list"]
# Get STFT/signals parameters
self._n_channels = params['n_channels']
self._T = params['T']
self._F = params['F']
self._frame_length = params['frame_length']
self._frame_step = params['frame_step']
self._n_channels = params["n_channels"]
self._T = params["T"]
self._F = params["F"]
self._frame_length = params["frame_length"]
self._frame_step = params["frame_step"]
def include_stft_computations(self):
return self._params["stft_backend"] == "tensorflow"
def _build_model_outputs(self):
""" Created a batch_sizexTxFxn_channels input tensor containing
"""Created a batch_sizexTxFxn_channels input tensor containing
mix magnitude spectrogram, then an output dict from it according
to the selected model in internal parameters.
@@ -191,22 +199,21 @@ class EstimatorSpecBuilder(object):
"""
input_tensor = self.spectrogram_feature
model = self._params.get('model', None)
model = self._params.get("model", None)
if model is not None:
model_type = model.get('type', self.DEFAULT_MODEL)
model_type = model.get("type", self.DEFAULT_MODEL)
else:
model_type = self.DEFAULT_MODEL
try:
apply_model = get_model_function(model_type)
except ModuleNotFoundError:
raise ValueError(f'No model function {model_type} found')
raise ValueError(f"No model function {model_type} found")
self._model_outputs = apply_model(
input_tensor,
self._instruments,
self._params['model']['params'])
input_tensor, self._instruments, self._params["model"]["params"]
)
def _build_loss(self, labels):
""" Construct tensorflow loss and metrics
"""Construct tensorflow loss and metrics
:param output_dict: dictionary of network outputs (key: instrument
name, value: estimated spectrogram of the instrument)
@@ -215,7 +222,7 @@ class EstimatorSpecBuilder(object):
:returns: tensorflow (loss, metrics) tuple.
"""
output_dict = self.model_outputs
loss_type = self._params.get('loss_type', self.L1_MASK)
loss_type = self._params.get("loss_type", self.L1_MASK)
if loss_type == self.L1_MASK:
losses = {
name: tf.reduce_mean(tf.abs(output - labels[name]))
@@ -224,11 +231,9 @@ class EstimatorSpecBuilder(object):
elif loss_type == self.WEIGHTED_L1_MASK:
losses = {
name: tf.reduce_mean(
tf.reduce_mean(
labels[name],
axis=[1, 2, 3],
keep_dims=True) *
tf.abs(output - labels[name]))
tf.reduce_mean(labels[name], axis=[1, 2, 3], keep_dims=True)
* tf.abs(output - labels[name])
)
for name, output in output_dict.items()
}
else:
@@ -236,20 +241,20 @@ class EstimatorSpecBuilder(object):
loss = tf.reduce_sum(list(losses.values()))
# Add metrics for monitoring each instrument.
metrics = {k: tf.compat.v1.metrics.mean(v) for k, v in losses.items()}
metrics['absolute_difference'] = tf.compat.v1.metrics.mean(loss)
metrics["absolute_difference"] = tf.compat.v1.metrics.mean(loss)
return loss, metrics
def _build_optimizer(self):
""" Builds an optimizer instance from internal parameter values.
"""Builds an optimizer instance from internal parameter values.
Default to AdamOptimizer if not specified.
:returns: Optimizer instance from internal configuration.
"""
name = self._params.get('optimizer')
name = self._params.get("optimizer")
if name == self.ADADELTA:
return tf.compat.v1.train.AdadeltaOptimizer()
rate = self._params['learning_rate']
rate = self._params["learning_rate"]
if name == self.SGD:
return tf.compat.v1.train.GradientDescentOptimizer(rate)
return tf.compat.v1.train.AdamOptimizer(rate)
@@ -260,15 +265,15 @@ class EstimatorSpecBuilder(object):
@property
def stft_name(self):
return f'{self._mix_name}_stft'
return f"{self._mix_name}_stft"
@property
def spectrogram_name(self):
return f'{self._mix_name}_spectrogram'
return f"{self._mix_name}_spectrogram"
def _build_stft_feature(self):
""" Compute STFT of waveform and slice the STFT in segment
with the right length to feed the network.
"""Compute STFT of waveform and slice the STFT in segment
with the right length to feed the network.
"""
stft_name = self.stft_name
@@ -276,25 +281,30 @@ class EstimatorSpecBuilder(object):
if stft_name not in self._features:
# pad input with a frame of zeros
waveform = tf.concat([
tf.zeros((self._frame_length, self._n_channels)),
self._features['waveform']
],
0
)
waveform = tf.concat(
[
tf.zeros((self._frame_length, self._n_channels)),
self._features["waveform"],
],
0,
)
stft_feature = tf.transpose(
stft(
tf.transpose(waveform),
self._frame_length,
self._frame_step,
window_fn=lambda frame_length, dtype: (
hann_window(frame_length, periodic=True, dtype=dtype)),
pad_end=True),
perm=[1, 2, 0])
self._features[f'{self._mix_name}_stft'] = stft_feature
hann_window(frame_length, periodic=True, dtype=dtype)
),
pad_end=True,
),
perm=[1, 2, 0],
)
self._features[f"{self._mix_name}_stft"] = stft_feature
if spec_name not in self._features:
self._features[spec_name] = tf.abs(
pad_and_partition(self._features[stft_name], self._T))[:, :, :self._F, :]
pad_and_partition(self._features[stft_name], self._T)
)[:, :, : self._F, :]
@property
def model_outputs(self):
@@ -333,25 +343,29 @@ class EstimatorSpecBuilder(object):
return self._masked_stfts
def _inverse_stft(self, stft_t, time_crop=None):
""" Inverse and reshape the given STFT
"""Inverse and reshape the given STFT
:param stft_t: input STFT
:returns: inverse STFT (waveform)
"""
inversed = inverse_stft(
tf.transpose(stft_t, perm=[2, 0, 1]),
self._frame_length,
self._frame_step,
window_fn=lambda frame_length, dtype: (
hann_window(frame_length, periodic=True, dtype=dtype))
) * self.WINDOW_COMPENSATION_FACTOR
inversed = (
inverse_stft(
tf.transpose(stft_t, perm=[2, 0, 1]),
self._frame_length,
self._frame_step,
window_fn=lambda frame_length, dtype: (
hann_window(frame_length, periodic=True, dtype=dtype)
),
)
* self.WINDOW_COMPENSATION_FACTOR
)
reshaped = tf.transpose(inversed)
if time_crop is None:
time_crop = tf.shape(self._features['waveform'])[0]
return reshaped[self._frame_length:self._frame_length+time_crop, :]
time_crop = tf.shape(self._features["waveform"])[0]
return reshaped[self._frame_length : self._frame_length + time_crop, :]
def _build_mwf_output_waveform(self):
""" Perform separation with multichannel Wiener Filtering using Norbert.
"""Perform separation with multichannel Wiener Filtering using Norbert.
Note: multichannel Wiener Filtering is not coded in Tensorflow and thus
may be quite slow.
@@ -359,36 +373,42 @@ class EstimatorSpecBuilder(object):
value: estimated waveform of the instrument)
"""
import norbert # pylint: disable=import-error
output_dict = self.model_outputs
x = self.stft_feature
v = tf.stack(
[
pad_and_reshape(
output_dict[f'{instrument}_spectrogram'],
output_dict[f"{instrument}_spectrogram"],
self._frame_length,
self._F)[:tf.shape(x)[0], ...]
self._F,
)[: tf.shape(x)[0], ...]
for instrument in self._instruments
],
axis=3)
axis=3,
)
input_args = [v, x]
stft_function = tf.py_function(
lambda v, x: norbert.wiener(v.numpy(), x.numpy()),
input_args,
tf.complex64),
stft_function = (
tf.py_function(
lambda v, x: norbert.wiener(v.numpy(), x.numpy()),
input_args,
tf.complex64,
),
)
return {
instrument: self._inverse_stft(stft_function[0][:, :, :, k])
for k, instrument in enumerate(self._instruments)
}
def _extend_mask(self, mask):
""" Extend mask, from reduced number of frequency bin to the number of
"""Extend mask, from reduced number of frequency bin to the number of
frequency bin in the STFT.
:param mask: restricted mask
:returns: extended mask
:raise ValueError: If invalid mask_extension parameter is set.
"""
extension = self._params['mask_extension']
extension = self._params["mask_extension"]
# Extend with average
# (dispatch according to energy in the processed band)
if extension == "average":
@@ -397,13 +417,9 @@ class EstimatorSpecBuilder(object):
# (avoid extension artifacts but not conservative separation)
elif extension == "zeros":
mask_shape = tf.shape(mask)
extension_row = tf.zeros((
mask_shape[0],
mask_shape[1],
1,
mask_shape[-1]))
extension_row = tf.zeros((mask_shape[0], mask_shape[1], 1, mask_shape[-1]))
else:
raise ValueError(f'Invalid mask_extension parameter {extension}')
raise ValueError(f"Invalid mask_extension parameter {extension}")
n_extra_row = self._frame_length // 2 + 1 - self._F
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
return tf.concat([mask, extension], axis=2)
@@ -415,29 +431,31 @@ class EstimatorSpecBuilder(object):
"""
output_dict = self.model_outputs
stft_feature = self.stft_feature
separation_exponent = self._params['separation_exponent']
output_sum = tf.reduce_sum(
[e ** separation_exponent for e in output_dict.values()],
axis=0
) + self.EPSILON
separation_exponent = self._params["separation_exponent"]
output_sum = (
tf.reduce_sum(
[e ** separation_exponent for e in output_dict.values()], axis=0
)
+ self.EPSILON
)
out = {}
for instrument in self._instruments:
output = output_dict[f'{instrument}_spectrogram']
output = output_dict[f"{instrument}_spectrogram"]
# Compute mask with the model.
instrument_mask = (output ** separation_exponent
+ (self.EPSILON / len(output_dict))) / output_sum
instrument_mask = (
output ** separation_exponent + (self.EPSILON / len(output_dict))
) / output_sum
# Extend mask;
instrument_mask = self._extend_mask(instrument_mask)
# Stack back mask.
old_shape = tf.shape(instrument_mask)
new_shape = tf.concat(
[[old_shape[0] * old_shape[1]], old_shape[2:]],
axis=0)
[[old_shape[0] * old_shape[1]], old_shape[2:]], axis=0
)
instrument_mask = tf.reshape(instrument_mask, new_shape)
# Remove padded part (for mask having the same size as STFT);
instrument_mask = instrument_mask[
:tf.shape(stft_feature)[0], ...]
instrument_mask = instrument_mask[: tf.shape(stft_feature)[0], ...]
out[instrument] = instrument_mask
self._masks = out
@@ -449,7 +467,7 @@ class EstimatorSpecBuilder(object):
self._masked_stfts = out
def _build_manual_output_waveform(self, masked_stft):
""" Perform ratio mask separation
"""Perform ratio mask separation
:param output_dict: dictionary of estimated spectrogram (key: instrument
name, value: estimated spectrogram of the instrument)
@@ -463,14 +481,14 @@ class EstimatorSpecBuilder(object):
return output_waveform
def _build_output_waveform(self, masked_stft):
""" Build output waveform from given output dict in order to be used in
"""Build output waveform from given output dict in order to be used in
prediction context. Regarding of the configuration building method will
be using MWF.
:returns: Built output waveform.
"""
if self._params.get('MWF', False):
if self._params.get("MWF", False):
output_waveform = self._build_mwf_output_waveform()
else:
output_waveform = self._build_manual_output_waveform(masked_stft)
@@ -482,11 +500,11 @@ class EstimatorSpecBuilder(object):
else:
self._outputs = self.masked_stfts
if 'audio_id' in self._features:
self._outputs['audio_id'] = self._features['audio_id']
if "audio_id" in self._features:
self._outputs["audio_id"] = self._features["audio_id"]
def build_predict_model(self):
""" Builder interface for creating model instance that aims to perform
"""Builder interface for creating model instance that aims to perform
prediction / inference over given track. The output of such estimator
will be a dictionary with a "<instrument>" key per separated instrument
, associated to the estimated separated waveform of the instrument.
@@ -495,11 +513,11 @@ class EstimatorSpecBuilder(object):
"""
return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.PREDICT,
predictions=self.outputs)
tf.estimator.ModeKeys.PREDICT, predictions=self.outputs
)
def build_evaluation_model(self, labels):
""" Builder interface for creating model instance that aims to perform
"""Builder interface for creating model instance that aims to perform
model evaluation. The output of such estimator will be a dictionary
with a key "<instrument>_spectrogram" per separated instrument,
associated to the estimated separated instrument magnitude spectrogram.
@@ -509,12 +527,11 @@ class EstimatorSpecBuilder(object):
"""
loss, metrics = self._build_loss(labels)
return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.EVAL,
loss=loss,
eval_metric_ops=metrics)
tf.estimator.ModeKeys.EVAL, loss=loss, eval_metric_ops=metrics
)
def build_train_model(self, labels):
""" Builder interface for creating model instance that aims to perform
"""Builder interface for creating model instance that aims to perform
model training. The output of such estimator will be a dictionary
with a key "<instrument>_spectrogram" per separated instrument,
associated to the estimated separated instrument magnitude spectrogram.
@@ -525,8 +542,8 @@ class EstimatorSpecBuilder(object):
loss, metrics = self._build_loss(labels)
optimizer = self._build_optimizer()
train_operation = optimizer.minimize(
loss=loss,
global_step=tf.compat.v1.train.get_global_step())
loss=loss, global_step=tf.compat.v1.train.get_global_step()
)
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN,
loss=loss,
@@ -539,9 +556,9 @@ def model_fn(features, labels, mode, params, config):
"""
:param features:
:param labels:
:param labels:
:param mode: Estimator mode.
:param params:
:param params:
:param config: TF configuration (not used).
:returns: Built EstimatorSpec.
:raise ValueError: If estimator mode is not supported.
@@ -553,4 +570,4 @@ def model_fn(features, labels, mode, params, config):
return builder.build_evaluation_model(labels)
elif mode == tf.estimator.ModeKeys.TRAIN:
return builder.build_train_model(labels)
raise ValueError(f'Unknown mode {mode}')
raise ValueError(f"Unknown mode {mode}")

View File

@@ -3,25 +3,45 @@
""" This package provide model functions. """
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
from typing import Callable, Dict, Iterable, Optional
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
# pylint: enable=import-error
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def apply(function, input_tensor, instruments, params={}):
""" Apply given function to the input tensor.
:param function: Function to be applied to tensor.
:param input_tensor: Tensor to apply blstm to.
:param instruments: Iterable that provides a collection of instruments.
:param params: (Optional) dict of BLSTM parameters.
:returns: Created output tensor dict.
def apply(
function: Callable,
input_tensor: tf.Tensor,
instruments: Iterable[str],
params: Optional[Dict] = None,
) -> Dict:
"""
output_dict = {}
Apply given function to the input tensor.
Parameters:
function:
Function to be applied to tensor.
input_tensor (tensorflow.Tensor):
Tensor to apply blstm to.
instruments (Iterable[str]):
Iterable that provides a collection of instruments.
params:
(Optional) dict of BLSTM parameters.
Returns:
Created output tensor dict.
"""
output_dict: Dict = {}
for instrument in instruments:
out_name = f'{instrument}_spectrogram'
out_name = f"{instrument}_spectrogram"
output_dict[out_name] = function(
input_tensor,
output_name=out_name,
params=params)
input_tensor, output_name=out_name, params=params or {}
)
return output_dict

View File

@@ -20,7 +20,11 @@
selection (LSTM layer dropout rate, regularization strength).
"""
from typing import Dict, Optional
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
from tensorflow.compat.v1.keras.initializers import he_uniform
from tensorflow.compat.v1.keras.layers import CuDNNLSTM
from tensorflow.keras.layers import (
@@ -28,34 +32,48 @@ from tensorflow.keras.layers import (
Dense,
Flatten,
Reshape,
TimeDistributed)
# pylint: enable=import-error
TimeDistributed,
)
from . import apply
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# pylint: enable=import-error
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def apply_blstm(input_tensor, output_name='output', params={}):
""" Apply BLSTM to the given input_tensor.
:param input_tensor: Input of the model.
:param output_name: (Optional) name of the output, default to 'output'.
:param params: (Optional) dict of BLSTM parameters.
:returns: Output tensor.
def apply_blstm(
input_tensor: tf.Tensor, output_name: str = "output", params: Optional[Dict] = None
) -> tf.Tensor:
"""
units = params.get('lstm_units', 250)
Apply BLSTM to the given input_tensor.
Parameters:
input_tensor (tensorflow.Tensor):
Input of the model.
output_name (str):
(Optional) name of the output, default to 'output'.
params (Optional[Dict]):
(Optional) dict of BLSTM parameters.
Returns:
tensorflow.Tensor:
Output tensor.
"""
if params is None:
params = {}
units: int = params.get("lstm_units", 250)
kernel_initializer = he_uniform(seed=50)
flatten_input = TimeDistributed(Flatten())((input_tensor))
def create_bidirectional():
return Bidirectional(
CuDNNLSTM(
units,
kernel_initializer=kernel_initializer,
return_sequences=True))
units, kernel_initializer=kernel_initializer, return_sequences=True
)
)
l1 = create_bidirectional()((flatten_input))
l2 = create_bidirectional()((l1))
@@ -63,14 +81,18 @@ def apply_blstm(input_tensor, output_name='output', params={}):
dense = TimeDistributed(
Dense(
int(flatten_input.shape[2]),
activation='relu',
kernel_initializer=kernel_initializer))((l3))
output = TimeDistributed(
Reshape(input_tensor.shape[2:]),
name=output_name)(dense)
activation="relu",
kernel_initializer=kernel_initializer,
)
)((l3))
output: tf.Tensor = TimeDistributed(
Reshape(input_tensor.shape[2:]), name=output_name
)(dense)
return output
def blstm(input_tensor, output_name='output', params={}):
def blstm(
input_tensor: tf.Tensor, output_name: str = "output", params: Optional[Dict] = None
) -> tf.Tensor:
""" Model function applier. """
return apply(apply_blstm, input_tensor, output_name, params)

View File

@@ -2,92 +2,109 @@
# coding: utf8
"""
This module contains building functions for U-net source
separation models in a similar way as in A. Jansson et al. "Singing
voice separation with deep u-net convolutional networks", ISMIR 2017.
Each instrument is modeled by a single U-net convolutional
/ deconvolutional network that take a mix spectrogram as input and the
estimated sound spectrogram as output.
This module contains building functions for U-net source
separation models in a similar way as in A. Jansson et al. :
"Singing voice separation with deep u-net convolutional networks",
ISMIR 2017
Each instrument is modeled by a single U-net
convolutional / deconvolutional network that take a mix spectrogram
as input and the estimated sound spectrogram as output.
"""
from functools import partial
from typing import Any, Dict, Iterable, Optional
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
from tensorflow.compat.v1 import logging
from tensorflow.compat.v1.keras.initializers import he_uniform
from tensorflow.keras.layers import (
ELU,
BatchNormalization,
Concatenate,
Conv2D,
Conv2DTranspose,
Dropout,
ELU,
LeakyReLU,
Multiply,
ReLU,
Softmax)
from tensorflow.compat.v1 import logging
from tensorflow.compat.v1.keras.initializers import he_uniform
# pylint: enable=import-error
Softmax,
)
from . import apply
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# pylint: enable=import-error
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def _get_conv_activation_layer(params):
def _get_conv_activation_layer(params: Dict) -> Any:
"""
> To be documented.
:param params:
:returns: Required Activation function.
Parameters:
params (Dict):
Returns:
Any:
Required Activation function.
"""
conv_activation = params.get('conv_activation')
if conv_activation == 'ReLU':
conv_activation: str = params.get("conv_activation")
if conv_activation == "ReLU":
return ReLU()
elif conv_activation == 'ELU':
elif conv_activation == "ELU":
return ELU()
return LeakyReLU(0.2)
def _get_deconv_activation_layer(params):
def _get_deconv_activation_layer(params: Dict) -> Any:
"""
> To be documented.
:param params:
:returns: Required Activation function.
Parameters:
params (Dict):
Returns:
Any:
Required Activation function.
"""
deconv_activation = params.get('deconv_activation')
if deconv_activation == 'LeakyReLU':
deconv_activation: str = params.get("deconv_activation")
if deconv_activation == "LeakyReLU":
return LeakyReLU(0.2)
elif deconv_activation == 'ELU':
elif deconv_activation == "ELU":
return ELU()
return ReLU()
def apply_unet(
input_tensor,
output_name='output',
params={},
output_mask_logit=False):
""" Apply a convolutionnal U-net to model a single instrument (one U-net
input_tensor: tf.Tensor,
output_name: str = "output",
params: Optional[Dict] = None,
output_mask_logit: bool = False,
) -> Any:
"""
Apply a convolutionnal U-net to model a single instrument (one U-net
is used for each instrument).
:param input_tensor:
:param output_name: (Optional) , default to 'output'
:param params: (Optional) , default to empty dict.
:param output_mask_logit: (Optional) , default to False.
Parameters:
input_tensor (tensorflow.Tensor):
output_name (str):
params (Optional[Dict]):
output_mask_logit (bool):
"""
logging.info(f'Apply unet for {output_name}')
conv_n_filters = params.get('conv_n_filters', [16, 32, 64, 128, 256, 512])
logging.info(f"Apply unet for {output_name}")
conv_n_filters = params.get("conv_n_filters", [16, 32, 64, 128, 256, 512])
conv_activation_layer = _get_conv_activation_layer(params)
deconv_activation_layer = _get_deconv_activation_layer(params)
kernel_initializer = he_uniform(seed=50)
conv2d_factory = partial(
Conv2D,
strides=(2, 2),
padding='same',
kernel_initializer=kernel_initializer)
Conv2D, strides=(2, 2), padding="same", kernel_initializer=kernel_initializer
)
# First layer.
conv1 = conv2d_factory(conv_n_filters[0], (5, 5))(input_tensor)
batch1 = BatchNormalization(axis=-1)(conv1)
@@ -117,8 +134,9 @@ def apply_unet(
conv2d_transpose_factory = partial(
Conv2DTranspose,
strides=(2, 2),
padding='same',
kernel_initializer=kernel_initializer)
padding="same",
kernel_initializer=kernel_initializer,
)
#
up1 = conv2d_transpose_factory(conv_n_filters[4], (5, 5))((conv6))
up1 = deconv_activation_layer(up1)
@@ -157,46 +175,60 @@ def apply_unet(
2,
(4, 4),
dilation_rate=(2, 2),
activation='sigmoid',
padding='same',
kernel_initializer=kernel_initializer)((batch12))
activation="sigmoid",
padding="same",
kernel_initializer=kernel_initializer,
)((batch12))
output = Multiply(name=output_name)([up7, input_tensor])
return output
return Conv2D(
2,
(4, 4),
dilation_rate=(2, 2),
padding='same',
kernel_initializer=kernel_initializer)((batch12))
padding="same",
kernel_initializer=kernel_initializer,
)((batch12))
def unet(input_tensor, instruments, params={}):
def unet(
input_tensor: tf.Tensor, instruments: Iterable[str], params: Optional[Dict] = None
) -> Dict:
""" Model function applier. """
return apply(apply_unet, input_tensor, instruments, params)
def softmax_unet(input_tensor, instruments, params={}):
""" Apply softmax to multitrack unet in order to have mask suming to one.
def softmax_unet(
input_tensor: tf.Tensor, instruments: Iterable[str], params: Optional[Dict] = None
) -> Dict:
"""
Apply softmax to multitrack unet in order to have mask suming to one.
:param input_tensor: Tensor to apply blstm to.
:param instruments: Iterable that provides a collection of instruments.
:param params: (Optional) dict of BLSTM parameters.
:returns: Created output tensor dict.
Parameters:
input_tensor (tensorflow.Tensor):
Tensor to apply blstm to.
instruments (Iterable[str]):
Iterable that provides a collection of instruments.
params (Optional[Dict]):
(Optional) dict of BLSTM parameters.
Returns:
Dict:
Created output tensor dict.
"""
logit_mask_list = []
for instrument in instruments:
out_name = f'{instrument}_spectrogram'
out_name = f"{instrument}_spectrogram"
logit_mask_list.append(
apply_unet(
input_tensor,
output_name=out_name,
params=params,
output_mask_logit=True))
output_mask_logit=True,
)
)
masks = Softmax(axis=4)(tf.stack(logit_mask_list, axis=4))
output_dict = {}
for i, instrument in enumerate(instruments):
out_name = f'{instrument}_spectrogram'
output_dict[out_name] = Multiply(name=out_name)([
masks[..., i],
input_tensor])
out_name = f"{instrument}_spectrogram"
output_dict[out_name] = Multiply(name=out_name)([masks[..., i], input_tensor])
return output_dict

View File

@@ -5,77 +5,91 @@
This package provides tools for downloading model from network
using remote storage abstraction.
:Example:
Examples:
```python
>>> provider = MyProviderImplementation()
>>> provider.get('/path/to/local/storage', params)
```
"""
from abc import ABC, abstractmethod
from os import environ, makedirs
from os.path import exists, isabs, join, sep
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class ModelProvider(ABC):
"""
A ModelProvider manages model files on disk and
file download is not available.
A ModelProvider manages model files on disk and
file download is not available.
"""
DEFAULT_MODEL_PATH = environ.get('MODEL_PATH', 'pretrained_models')
MODEL_PROBE_PATH = '.probe'
DEFAULT_MODEL_PATH: str = environ.get("MODEL_PATH", "pretrained_models")
MODEL_PROBE_PATH: str = ".probe"
@abstractmethod
def download(self, name, path):
""" Download model denoted by the given name to disk.
def download(_, name: str, path: str) -> None:
"""
Download model denoted by the given name to disk.
:param name: Name of the model to download.
:param path: Path of the directory to save model into.
Parameters:
name (str):
Name of the model to download.
path (str):
Path of the directory to save model into.
"""
pass
@staticmethod
def writeProbe(directory):
""" Write a model probe file into the given directory.
:param directory: Directory to write probe into.
def writeProbe(directory: str) -> None:
"""
probe = join(directory, ModelProvider.MODEL_PROBE_PATH)
with open(probe, 'w') as stream:
stream.write('OK')
Write a model probe file into the given directory.
def get(self, model_directory):
""" Ensures required model is available at given location.
Parameters:
directory (str):
Directory to write probe into.
"""
probe: str = join(directory, ModelProvider.MODEL_PROBE_PATH)
with open(probe, "w") as stream:
stream.write("OK")
:param model_directory: Expected model_directory to be available.
:raise IOError: If model can not be retrieved.
def get(self, model_directory: str) -> str:
"""
Ensures required model is available at given location.
Parameters:
model_directory (str):
Expected model_directory to be available.
Raises:
IOError:
If model can not be retrieved.
"""
# Expend model directory if needed.
if not isabs(model_directory):
model_directory = join(self.DEFAULT_MODEL_PATH, model_directory)
# Download it if not exists.
model_probe = join(model_directory, self.MODEL_PROBE_PATH)
model_probe: str = join(model_directory, self.MODEL_PROBE_PATH)
if not exists(model_probe):
if not exists(model_directory):
makedirs(model_directory)
self.download(
model_directory.split(sep)[-1],
model_directory)
self.download(model_directory.split(sep)[-1], model_directory)
self.writeProbe(model_directory)
return model_directory
@classmethod
def default(_: type) -> "ModelProvider":
"""
Builds and returns a default model provider.
def get_default_model_provider():
""" Builds and returns a default model provider.
Returns:
ModelProvider:
A default model provider instance to use.
"""
from .github import GithubModelProvider
:returns: A default model provider instance to use.
"""
from .github import GithubModelProvider
host = environ.get('GITHUB_HOST', 'https://github.com')
repository = environ.get('GITHUB_REPOSITORY', 'deezer/spleeter')
release = environ.get('GITHUB_RELEASE', GithubModelProvider.LATEST_RELEASE)
return GithubModelProvider(host, repository, release)
return GithubModelProvider.from_environ()

View File

@@ -4,41 +4,48 @@
"""
A ModelProvider backed by Github Release feature.
:Example:
Examples:
```python
>>> from spleeter.model.provider import github
>>> provider = github.GithubModelProvider(
'github.com',
'Deezer/spleeter',
'latest')
>>> provider.download('2stems', '/path/to/local/storage')
```
"""
import hashlib
import tarfile
import os
import tarfile
from os import environ
from tempfile import NamedTemporaryFile
from typing import Dict
import requests
# pyright: reportMissingImports=false
# pylint: disable=import-error
import httpx
from ...utils.logging import logger
from . import ModelProvider
from ...utils.logging import get_logger
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# pylint: enable=import-error
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def compute_file_checksum(path):
""" Computes given path file sha256.
"""Computes given path file sha256.
:param path: Path of the file to compute checksum for.
:returns: File checksum.
"""
sha256 = hashlib.sha256()
with open(path, 'rb') as stream:
for chunk in iter(lambda: stream.read(4096), b''):
with open(path, "rb") as stream:
for chunk in iter(lambda: stream.read(4096), b""):
sha256.update(chunk)
return sha256.hexdigest()
@@ -46,69 +53,104 @@ def compute_file_checksum(path):
class GithubModelProvider(ModelProvider):
""" A ModelProvider implementation backed on Github for remote storage. """
LATEST_RELEASE = 'v1.4.0'
RELEASE_PATH = 'releases/download'
CHECKSUM_INDEX = 'checksum.json'
DEFAULT_HOST: str = "https://github.com"
DEFAULT_REPOSITORY: str = "deezer/spleeter"
def __init__(self, host, repository, release):
""" Default constructor.
CHECKSUM_INDEX: str = "checksum.json"
LATEST_RELEASE: str = "v1.4.0"
RELEASE_PATH: str = "releases/download"
:param host: Host to the Github instance to reach.
:param repository: Repository path within target Github.
:param release: Release name to get models from.
def __init__(self, host: str, repository: str, release: str) -> None:
"""Default constructor.
Parameters:
host (str):
Host to the Github instance to reach.
repository (str):
Repository path within target Github.
release (str):
Release name to get models from.
"""
self._host = host
self._repository = repository
self._release = release
self._host: str = host
self._repository: str = repository
self._release: str = release
def checksum(self, name):
""" Downloads and returns reference checksum for the given model name.
:param name: Name of the model to get checksum for.
:returns: Checksum of the required model.
:raise ValueError: If the given model name is not indexed.
@classmethod
def from_environ(cls: type) -> "GithubModelProvider":
"""
url = '{}/{}/{}/{}/{}'.format(
self._host,
self._repository,
self.RELEASE_PATH,
self._release,
self.CHECKSUM_INDEX)
response = requests.get(url)
Factory method that creates provider from envvars.
Returns:
GithubModelProvider:
Created instance.
"""
return cls(
environ.get("GITHUB_HOST", cls.DEFAULT_HOST),
environ.get("GITHUB_REPOSITORY", cls.DEFAULT_REPOSITORY),
environ.get("GITHUB_RELEASE", cls.LATEST_RELEASE),
)
def checksum(self, name: str) -> str:
"""
Downloads and returns reference checksum for the given model name.
Parameters:
name (str):
Name of the model to get checksum for.
Returns:
str:
Checksum of the required model.
Raises:
ValueError:
If the given model name is not indexed.
"""
url: str = "/".join(
(
self._host,
self._repository,
self.RELEASE_PATH,
self._release,
self.CHECKSUM_INDEX,
)
)
response: httpx.Response = httpx.get(url)
response.raise_for_status()
index = response.json()
index: Dict = response.json()
if name not in index:
raise ValueError('No checksum for model {}'.format(name))
raise ValueError(f"No checksum for model {name}")
return index[name]
def download(self, name, path):
""" Download model denoted by the given name to disk.
:param name: Name of the model to download.
:param path: Path of the directory to save model into.
def download(self, name: str, path: str) -> None:
"""
url = '{}/{}/{}/{}/{}.tar.gz'.format(
self._host,
self._repository,
self.RELEASE_PATH,
self._release,
name)
get_logger().info('Downloading model archive %s', url)
with requests.get(url, stream=True) as response:
response.raise_for_status()
archive = NamedTemporaryFile(delete=False)
try:
with archive as stream:
# Note: check for chunk size parameters ?
for chunk in response.iter_content(chunk_size=8192):
if chunk:
Download model denoted by the given name to disk.
Parameters:
name (str):
Name of the model to download.
path (str):
Path of the directory to save model into.
"""
url: str = "/".join(
(self._host, self._repository, self.RELEASE_PATH, self._release, name)
)
url = f"{url}.tar.gz"
logger.info(f"Downloading model archive {url}")
with httpx.Client(http2=True) as client:
with client.stream("GET", url) as response:
response.raise_for_status()
archive = NamedTemporaryFile(delete=False)
try:
with archive as stream:
for chunk in response.iter_raw():
stream.write(chunk)
get_logger().info('Validating archive checksum')
if compute_file_checksum(archive.name) != self.checksum(name):
raise IOError('Downloaded file is corrupted, please retry')
get_logger().info('Extracting downloaded %s archive', name)
with tarfile.open(name=archive.name) as tar:
tar.extractall(path=path)
finally:
os.unlink(archive.name)
get_logger().info('%s model file(s) extracted', name)
logger.info("Validating archive checksum")
checksum: str = compute_file_checksum(archive.name)
if checksum != self.checksum(name):
raise IOError("Downloaded file is corrupted, please retry")
logger.info(f"Extracting downloaded {name} archive")
with tarfile.open(name=archive.name) as tar:
tar.extractall(path=path)
finally:
os.unlink(archive.name)
logger.info(f"{name} model file(s) extracted")

128
spleeter/options.py Normal file
View File

@@ -0,0 +1,128 @@
#!/usr/bin/env python
# coding: utf8
""" This modules provides spleeter command as well as CLI parsing methods. """
from os.path import join
from tempfile import gettempdir
from typer import Argument, Option
from typer.models import ArgumentInfo, OptionInfo
from .audio import Codec, STFTBackend
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
AudioInputArgument: ArgumentInfo = Argument(
...,
help="List of input audio file path",
exists=True,
file_okay=True,
dir_okay=False,
readable=True,
resolve_path=True,
)
AudioInputOption: OptionInfo = Option(
None, "--inputs", "-i", help="(DEPRECATED) placeholder for deprecated input option"
)
AudioAdapterOption: OptionInfo = Option(
"spleeter.audio.ffmpeg.FFMPEGProcessAudioAdapter",
"--adapter",
"-a",
help="Name of the audio adapter to use for audio I/O",
)
AudioOutputOption: OptionInfo = Option(
join(gettempdir(), "separated_audio"),
"--output_path",
"-o",
help="Path of the output directory to write audio files in",
)
AudioOffsetOption: OptionInfo = Option(
0.0, "--offset", "-s", help="Set the starting offset to separate audio from"
)
AudioDurationOption: OptionInfo = Option(
600.0,
"--duration",
"-d",
help=(
"Set a maximum duration for processing audio "
"(only separate offset + duration first seconds of "
"the input file)"
),
)
AudioSTFTBackendOption: OptionInfo = Option(
STFTBackend.AUTO,
"--stft-backend",
"-B",
case_sensitive=False,
help=(
"Who should be in charge of computing the stfts. Librosa is faster "
'than tensorflow on CPU and uses less memory. "auto" will use '
"tensorflow when GPU acceleration is available and librosa when not"
),
)
AudioCodecOption: OptionInfo = Option(
Codec.WAV, "--codec", "-c", help="Audio codec to be used for the separated output"
)
AudioBitrateOption: OptionInfo = Option(
"128k", "--bitrate", "-b", help="Audio bitrate to be used for the separated output"
)
FilenameFormatOption: OptionInfo = Option(
"{filename}/{instrument}.{codec}",
"--filename_format",
"-f",
help=(
"Template string that will be formatted to generated"
"output filename. Such template should be Python formattable"
"string, and could use {filename}, {instrument}, and {codec}"
"variables"
),
)
ModelParametersOption: OptionInfo = Option(
"spleeter:2stems",
"--params_filename",
"-p",
help="JSON filename that contains params",
)
MWFOption: OptionInfo = Option(
False, "--mwf", help="Whether to use multichannel Wiener filtering for separation"
)
MUSDBDirectoryOption: OptionInfo = Option(
...,
"--mus_dir",
exists=True,
dir_okay=True,
file_okay=False,
readable=True,
resolve_path=True,
help="Path to musDB dataset directory",
)
TrainingDataDirectoryOption: OptionInfo = Option(
...,
"--data",
"-d",
exists=True,
dir_okay=True,
file_okay=False,
readable=True,
resolve_path=True,
help="Path of the folder containing audio data for training",
)
VerboseOption: OptionInfo = Option(False, "--verbose", help="Enable verbose logs")

0
spleeter/py.typed Normal file
View File

View File

@@ -3,6 +3,6 @@
""" Packages that provides static resources file for the library. """
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"

View File

@@ -4,60 +4,63 @@
"""
Module that provides a class wrapper for source separation.
:Example:
Examples:
```python
>>> from spleeter.separator import Separator
>>> separator = Separator('spleeter:2stems')
>>> separator.separate(waveform, lambda instrument, data: ...)
>>> separator.separate_to_file(...)
```
"""
import atexit
import os
import logging
from multiprocessing import Pool
from os.path import basename, join, splitext, dirname
from time import time
from typing import Container, NoReturn
from os.path import basename, dirname, join, splitext
from typing import Dict, Generator, Optional
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
from librosa.core import stft, istft
from librosa.core import istft, stft
from scipy.signal.windows import hann
from spleeter.model.provider import ModelProvider
from . import SpleeterError
from .audio.adapter import get_default_audio_adapter
from .audio import Codec, STFTBackend
from .audio.adapter import AudioAdapter
from .audio.convertor import to_stereo
from .model import EstimatorSpecBuilder, InputProviderFactory, model_fn
from .model.provider import ModelProvider
from .types import AudioDescriptor
from .utils.configuration import load_configuration
from .utils.estimator import create_estimator, get_default_model_dir
from .model import EstimatorSpecBuilder, InputProviderFactory
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# pylint: enable=import-error
SUPPORTED_BACKEND: Container[str] = ('auto', 'tensorflow', 'librosa')
""" """
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class DataGenerator():
class DataGenerator(object):
"""
Generator object that store a sample and generate it once while called.
Used to feed a tensorflow estimator without knowing the whole data at
build time.
Generator object that store a sample and generate it once while called.
Used to feed a tensorflow estimator without knowing the whole data at
build time.
"""
def __init__(self):
def __init__(self) -> None:
""" Default constructor. """
self._current_data = None
def update_data(self, data):
def update_data(self, data) -> None:
""" Replace internal data. """
self._current_data = data
def __call__(self):
def __call__(self) -> Generator:
""" Generation process. """
buffer = self._current_data
while buffer:
@@ -65,34 +68,52 @@ class DataGenerator():
buffer = self._current_data
def get_backend(backend: str) -> str:
def create_estimator(params, MWF):
"""
Initialize tensorflow estimator that will perform separation
Params:
- params: a dictionary of parameters for building the model
Returns:
a tensorflow estimator
"""
if backend not in SUPPORTED_BACKEND:
raise ValueError(f'Unsupported backend {backend}')
if backend == 'auto':
if len(tf.config.list_physical_devices('GPU')):
return 'tensorflow'
return 'librosa'
return backend
# Load model.
provider: ModelProvider = ModelProvider.default()
params["model_dir"] = provider.get(params["model_dir"])
params["MWF"] = MWF
# Setup config
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
config = tf.estimator.RunConfig(session_config=session_config)
# Setup estimator
estimator = tf.estimator.Estimator(
model_fn=model_fn, model_dir=params["model_dir"], params=params, config=config
)
return estimator
class Separator(object):
""" A wrapper class for performing separation. """
def __init__(
self,
params_descriptor,
MWF: bool = False,
stft_backend: str = 'auto',
multiprocess: bool = True):
""" Default constructor.
self,
params_descriptor: str,
MWF: bool = False,
stft_backend: STFTBackend = STFTBackend.AUTO,
multiprocess: bool = True,
) -> None:
"""
Default constructor.
:param params_descriptor: Descriptor for TF params to be used.
:param MWF: (Optional) True if MWF should be used, False otherwise.
Parameters:
params_descriptor (str):
Descriptor for TF params to be used.
MWF (bool):
(Optional) `True` if MWF should be used, `False` otherwise.
"""
self._params = load_configuration(params_descriptor)
self._sample_rate = self._params['sample_rate']
self._sample_rate = self._params["sample_rate"]
self._MWF = MWF
self._tf_graph = tf.Graph()
self._prediction_generator = None
@@ -106,19 +127,21 @@ class Separator(object):
else:
self._pool = None
self._tasks = []
self._params['stft_backend'] = get_backend(stft_backend)
self._params["stft_backend"] = STFTBackend.resolve(stft_backend)
self._data_generator = DataGenerator()
def __del__(self):
""" """
def __del__(self) -> None:
if self._session:
self._session.close()
def _get_prediction_generator(self):
""" Lazy loading access method for internal prediction generator
def _get_prediction_generator(self) -> Generator:
"""
Lazy loading access method for internal prediction generator
returned by the predict method of a tensorflow estimator.
:returns: generator of prediction.
Returns:
Generator:
Generator of prediction.
"""
if self._prediction_generator is None:
estimator = create_estimator(self._params, self._MWF)
@@ -126,82 +149,74 @@ class Separator(object):
def get_dataset():
return tf.data.Dataset.from_generator(
self._data_generator,
output_types={
'waveform': tf.float32,
'audio_id': tf.string},
output_shapes={
'waveform': (None, 2),
'audio_id': ()})
output_types={"waveform": tf.float32, "audio_id": tf.string},
output_shapes={"waveform": (None, 2), "audio_id": ()},
)
self._prediction_generator = estimator.predict(
get_dataset,
yield_single_examples=False)
get_dataset, yield_single_examples=False
)
return self._prediction_generator
def join(self, timeout: int = 200) -> NoReturn:
""" Wait for all pending tasks to be finished.
def join(self, timeout: int = 200) -> None:
"""
Wait for all pending tasks to be finished.
:param timeout: (Optional) task waiting timeout.
Parameters:
timeout (int):
(Optional) task waiting timeout.
"""
while len(self._tasks) > 0:
task = self._tasks.pop()
task.get()
task.wait(timeout=timeout)
def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
""" Performs source separation over the given waveform with tensorflow
backend.
:param waveform: Waveform to apply separation on.
:returns: Separated waveforms.
def _stft(
self, data: np.ndarray, inverse: bool = False, length: Optional[int] = None
) -> np.ndarray:
"""
if not waveform.shape[-1] == 2:
waveform = to_stereo(waveform)
prediction_generator = self._get_prediction_generator()
# NOTE: update data in generator before performing separation.
self._data_generator.update_data({
'waveform': waveform,
'audio_id': np.array(audio_descriptor)})
# NOTE: perform separation.
prediction = next(prediction_generator)
prediction.pop('audio_id')
return prediction
def _stft(self, data, inverse: bool = False, length=None):
""" Single entrypoint for both stft and istft. This computes stft and
Single entrypoint for both stft and istft. This computes stft and
istft with librosa on stereo data. The two channels are processed
separately and are concatenated together in the result. The expected
input formats are: (n_samples, 2) for stft and (T, F, 2) for istft.
separately and are concatenated together in the result. The
expected input formats are: (n_samples, 2) for stft and (T, F, 2)
for istft.
:param data: np.array with either the waveform or the complex
spectrogram depending on the parameter inverse
:param inverse: should a stft or an istft be computed.
:returns: Stereo data as numpy array for the transform.
The channels are stored in the last dimension.
Parameters:
data (numpy.array):
Array with either the waveform or the complex spectrogram
depending on the parameter inverse
inverse (bool):
(Optional) Should a stft or an istft be computed.
length (Optional[int]):
Returns:
numpy.ndarray:
Stereo data as numpy array for the transform. The channels
are stored in the last dimension.
"""
assert not (inverse and length is None)
data = np.asfortranarray(data)
N = self._params['frame_length']
H = self._params['frame_step']
N = self._params["frame_length"]
H = self._params["frame_step"]
win = hann(N, sym=False)
fstft = istft if inverse else stft
win_len_arg = {
'win_length': None,
'length': None} if inverse else {'n_fft': N}
win_len_arg = {"win_length": None, "length": None} if inverse else {"n_fft": N}
n_channels = data.shape[-1]
out = []
for c in range(n_channels):
d = np.concatenate(
(np.zeros((N, )), data[:, c], np.zeros((N, )))
) if not inverse else data[:, :, c].T
d = (
np.concatenate((np.zeros((N,)), data[:, c], np.zeros((N,))))
if not inverse
else data[:, :, c].T
)
s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg)
if inverse:
s = s[N:N+length]
s = np.expand_dims(s.T, 2-inverse)
s = s[N : N + length]
s = np.expand_dims(s.T, 2 - inverse)
out.append(s)
if len(out) == 1:
return out[0]
return np.concatenate(out, axis=2-inverse)
return np.concatenate(out, axis=2 - inverse)
def _get_input_provider(self):
if self._input_provider is None:
@@ -216,22 +231,29 @@ class Separator(object):
def _get_builder(self):
if self._builder is None:
self._builder = EstimatorSpecBuilder(
self._get_features(),
self._params)
self._builder = EstimatorSpecBuilder(self._get_features(), self._params)
return self._builder
def _get_session(self):
if self._session is None:
saver = tf.compat.v1.train.Saver()
latest_checkpoint = tf.train.latest_checkpoint(
get_default_model_dir(self._params['model_dir']))
provider = ModelProvider.default()
model_directory: str = provider.get(self._params["model_dir"])
latest_checkpoint = tf.train.latest_checkpoint(model_directory)
self._session = tf.compat.v1.Session()
saver.restore(self._session, latest_checkpoint)
return self._session
def _separate_librosa(self, waveform: np.ndarray, audio_id):
""" Performs separation with librosa backend for STFT.
def _separate_librosa(
self, waveform: np.ndarray, audio_descriptor: AudioDescriptor
) -> Dict:
"""
Performs separation with librosa backend for STFT.
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
audio_descriptor (AudioDescriptor):
"""
with self._tf_graph.as_default():
out = {}
@@ -248,65 +270,115 @@ class Separator(object):
outputs = sess.run(
outputs,
feed_dict=self._get_input_provider().get_feed_dict(
features,
stft,
audio_id))
features, stft, audio_descriptor
),
)
for inst in self._get_builder().instruments:
out[inst] = self._stft(
outputs[inst],
inverse=True,
length=waveform.shape[0])
outputs[inst], inverse=True, length=waveform.shape[0]
)
return out
def separate(self, waveform: np.ndarray, audio_descriptor=''):
""" Performs separation on a waveform.
:param waveform: Waveform to be separated (as a numpy array)
:param audio_descriptor: (Optional) string describing the waveform
(e.g. filename).
def _separate_tensorflow(
self, waveform: np.ndarray, audio_descriptor: AudioDescriptor
) -> Dict:
"""
if self._params['stft_backend'] == 'tensorflow':
Performs source separation over the given waveform with tensorflow
backend.
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
audio_descriptor (AudioDescriptor):
Returns:
Separated waveforms.
"""
if not waveform.shape[-1] == 2:
waveform = to_stereo(waveform)
prediction_generator = self._get_prediction_generator()
# NOTE: update data in generator before performing separation.
self._data_generator.update_data(
{"waveform": waveform, "audio_id": np.array(audio_descriptor)}
)
# NOTE: perform separation.
prediction = next(prediction_generator)
prediction.pop("audio_id")
return prediction
def separate(
self, waveform: np.ndarray, audio_descriptor: Optional[str] = None
) -> None:
"""
Performs separation on a waveform.
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
audio_descriptor (str):
(Optional) string describing the waveform (e.g. filename).
"""
backend: str = self._params["stft_backend"]
if backend == STFTBackend.TENSORFLOW:
return self._separate_tensorflow(waveform, audio_descriptor)
else:
elif backend == STFTBackend.LIBROSA:
return self._separate_librosa(waveform, audio_descriptor)
raise ValueError(f"Unsupported STFT backend {backend}")
def separate_to_file(
self,
audio_descriptor,
destination,
audio_adapter=get_default_audio_adapter(),
offset=0,
duration=600.,
codec='wav',
bitrate='128k',
filename_format='{filename}/{instrument}.{codec}',
synchronous=True):
""" Performs source separation and export result to file using
self,
audio_descriptor: AudioDescriptor,
destination: str,
audio_adapter: Optional[AudioAdapter] = None,
offset: int = 0,
duration: float = 600.0,
codec: Codec = Codec.WAV,
bitrate: str = "128k",
filename_format: str = "{filename}/{instrument}.{codec}",
synchronous: bool = True,
) -> None:
"""
Performs source separation and export result to file using
given audio adapter.
Filename format should be a Python formattable string that could use
following parameters : {instrument}, {filename}, {foldername} and
{codec}.
Filename format should be a Python formattable string that could
use following parameters :
:param audio_descriptor: Describe song to separate, used by audio
adapter to retrieve and load audio data,
in case of file based audio adapter, such
descriptor would be a file path.
:param destination: Target directory to write output to.
:param audio_adapter: (Optional) Audio adapter to use for I/O.
:param offset: (Optional) Offset of loaded song.
:param duration: (Optional) Duration of loaded song
(default: 600s).
:param codec: (Optional) Export codec.
:param bitrate: (Optional) Export bitrate.
:param filename_format: (Optional) Filename format.
:param synchronous: (Optional) True is should by synchronous.
- {instrument}
- {filename}
- {foldername}
- {codec}.
Parameters:
audio_descriptor (AudioDescriptor):
Describe song to separate, used by audio adapter to
retrieve and load audio data, in case of file based
audio adapter, such descriptor would be a file path.
destination (str):
Target directory to write output to.
audio_adapter (Optional[AudioAdapter]):
(Optional) Audio adapter to use for I/O.
offset (int):
(Optional) Offset of loaded song.
duration (float):
(Optional) Duration of loaded song (default: 600s).
codec (Codec):
(Optional) Export codec.
bitrate (str):
(Optional) Export bitrate.
filename_format (str):
(Optional) Filename format.
synchronous (bool):
(Optional) True is should by synchronous.
"""
waveform, sample_rate = audio_adapter.load(
if audio_adapter is None:
audio_adapter = AudioAdapter.default()
waveform, _ = audio_adapter.load(
audio_descriptor,
offset=offset,
duration=duration,
sample_rate=self._sample_rate)
sample_rate=self._sample_rate,
)
sources = self.separate(waveform, audio_descriptor)
self.save_to_file(
sources,
@@ -316,69 +388,78 @@ class Separator(object):
codec,
audio_adapter,
bitrate,
synchronous)
synchronous,
)
def save_to_file(
self,
sources,
audio_descriptor,
destination,
filename_format='{filename}/{instrument}.{codec}',
codec='wav',
audio_adapter=get_default_audio_adapter(),
bitrate='128k',
synchronous=True):
""" Export dictionary of sources to files.
:param sources: Dictionary of sources to be exported. The
keys are the name of the instruments, and
the values are Nx2 numpy arrays containing
the corresponding intrument waveform, as
returned by the separate method
:param audio_descriptor: Describe song to separate, used by audio
adapter to retrieve and load audio data,
in case of file based audio adapter, such
descriptor would be a file path.
:param destination: Target directory to write output to.
:param filename_format: (Optional) Filename format.
:param codec: (Optional) Export codec.
:param audio_adapter: (Optional) Audio adapter to use for I/O.
:param bitrate: (Optional) Export bitrate.
:param synchronous: (Optional) True is should by synchronous.
self,
sources: Dict,
audio_descriptor: AudioDescriptor,
destination: str,
filename_format: str = "{filename}/{instrument}.{codec}",
codec: Codec = Codec.WAV,
audio_adapter: Optional[AudioAdapter] = None,
bitrate: str = "128k",
synchronous: bool = True,
) -> None:
"""
Export dictionary of sources to files.
Parameters:
sources (Dict):
Dictionary of sources to be exported. The keys are the name
of the instruments, and the values are `N x 2` numpy arrays
containing the corresponding intrument waveform, as
returned by the separate method
audio_descriptor (AudioDescriptor):
Describe song to separate, used by audio adapter to
retrieve and load audio data, in case of file based audio
adapter, such descriptor would be a file path.
destination (str):
Target directory to write output to.
filename_format (str):
(Optional) Filename format.
codec (Codec):
(Optional) Export codec.
audio_adapter (Optional[AudioAdapter]):
(Optional) Audio adapter to use for I/O.
bitrate (str):
(Optional) Export bitrate.
synchronous (bool):
(Optional) True is should by synchronous.
"""
if audio_adapter is None:
audio_adapter = AudioAdapter.default()
foldername = basename(dirname(audio_descriptor))
filename = splitext(basename(audio_descriptor))[0]
generated = []
for instrument, data in sources.items():
path = join(destination, filename_format.format(
filename=filename,
instrument=instrument,
foldername=foldername,
codec=codec,
))
path = join(
destination,
filename_format.format(
filename=filename,
instrument=instrument,
foldername=foldername,
codec=codec,
),
)
directory = os.path.dirname(path)
if not os.path.exists(directory):
os.makedirs(directory)
if path in generated:
raise SpleeterError((
f'Separated source path conflict : {path},'
'please check your filename format'))
raise SpleeterError(
(
f"Separated source path conflict : {path},"
"please check your filename format"
)
)
generated.append(path)
if self._pool:
task = self._pool.apply_async(audio_adapter.save, (
path,
data,
self._sample_rate,
codec,
bitrate))
task = self._pool.apply_async(
audio_adapter.save, (path, data, self._sample_rate, codec, bitrate)
)
self._tasks.append(task)
else:
audio_adapter.save(
path,
data,
self._sample_rate,
codec,
bitrate)
audio_adapter.save(path, data, self._sample_rate, codec, bitrate)
if synchronous and self._pool:
self.join()

15
spleeter/types.py Normal file
View File

@@ -0,0 +1,15 @@
#!/usr/bin/env python
# coding: utf8
""" Custom types definition. """
from typing import Any, Tuple
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
# pylint: enable=import-error
AudioDescriptor: type = Any
Signal: type = Tuple[np.ndarray, float]

View File

@@ -3,6 +3,6 @@
""" This package provides utility function and classes. """
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"

View File

@@ -3,45 +3,49 @@
""" Module that provides configuration loading function. """
import importlib.resources as loader
import json
try:
import importlib.resources as loader
except ImportError:
# Try backported to PY<37 `importlib_resources`.
import importlib_resources as loader
from os.path import exists
from typing import Dict
from .. import resources, SpleeterError
from .. import SpleeterError, resources
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
_EMBEDDED_CONFIGURATION_PREFIX: str = "spleeter:"
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def load_configuration(descriptor: str) -> Dict:
"""
Load configuration from the given descriptor. Could be either a
`spleeter:` prefixed embedded configuration name or a file system path
to read configuration from.
_EMBEDDED_CONFIGURATION_PREFIX = 'spleeter:'
Parameters:
descriptor (str):
Configuration descriptor to use for lookup.
Returns:
Dict:
Loaded description as dict.
def load_configuration(descriptor):
""" Load configuration from the given descriptor. Could be
either a `spleeter:` prefixed embedded configuration name
or a file system path to read configuration from.
:param descriptor: Configuration descriptor to use for lookup.
:returns: Loaded description as dict.
:raise ValueError: If required embedded configuration does not exists.
:raise SpleeterError: If required configuration file does not exists.
Raises:
ValueError:
If required embedded configuration does not exists.
SpleeterError:
If required configuration file does not exists.
"""
# Embedded configuration reading.
if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX):
name = descriptor[len(_EMBEDDED_CONFIGURATION_PREFIX):]
if not loader.is_resource(resources, f'{name}.json'):
raise SpleeterError(f'No embedded configuration {name} found')
with loader.open_text(resources, f'{name}.json') as stream:
name = descriptor[len(_EMBEDDED_CONFIGURATION_PREFIX) :]
if not loader.is_resource(resources, f"{name}.json"):
raise SpleeterError(f"No embedded configuration {name} found")
with loader.open_text(resources, f"{name}.json") as stream:
return json.load(stream)
# Standard file reading.
if not exists(descriptor):
raise SpleeterError(f'Configuration file {descriptor} not found')
with open(descriptor, 'r') as stream:
raise SpleeterError(f"Configuration file {descriptor} not found")
with open(descriptor, "r") as stream:
return json.load(stream)

View File

@@ -1,46 +0,0 @@
#!/usr/bin/env python
# coding: utf8
""" Utility functions for creating estimator. """
import tensorflow as tf # pylint: disable=import-error
from ..model import model_fn
from ..model.provider import get_default_model_provider
def get_default_model_dir(model_dir):
"""
Transforms a string like 'spleeter:2stems' into an actual path.
:param model_dir:
:return:
"""
model_provider = get_default_model_provider()
return model_provider.get(model_dir)
def create_estimator(params, MWF):
"""
Initialize tensorflow estimator that will perform separation
Params:
- params: a dictionary of parameters for building the model
Returns:
a tensorflow estimator
"""
# Load model.
params['model_dir'] = get_default_model_dir(params['model_dir'])
params['MWF'] = MWF
# Setup config
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
config = tf.estimator.RunConfig(session_config=session_config)
# Setup estimator
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params['model_dir'],
params=params,
config=config
)
return estimator

View File

@@ -4,58 +4,53 @@
""" Centralized logging facilities for Spleeter. """
import logging
import warnings
from os import environ
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# pyright: reportMissingImports=false
# pylint: disable=import-error
from typer import echo
_FORMAT = '%(levelname)s:%(name)s:%(message)s'
# pylint: enable=import-error
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
class _LoggerHolder(object):
""" Logger singleton instance holder. """
class TyperLoggerHandler(logging.Handler):
""" A custom logger handler that use Typer echo. """
INSTANCE = None
def emit(self, record: logging.LogRecord) -> None:
echo(self.format(record))
def get_tensorflow_logger():
formatter = logging.Formatter("%(levelname)s:%(name)s:%(message)s")
handler = TyperLoggerHandler()
handler.setFormatter(formatter)
logger: logging.Logger = logging.getLogger("spleeter")
logger.addHandler(handler)
logger.setLevel(logging.INFO)
def configure_logger(verbose: bool) -> None:
"""
Configure application logger.
Parameters:
verbose (bool):
`True` to use verbose logger, `False` otherwise.
"""
# pylint: disable=import-error
from tensorflow.compat.v1 import logging
# pylint: enable=import-error
return logging
from tensorflow import get_logger
from tensorflow.compat.v1 import logging as tf_logging
def get_logger():
""" Returns library scoped logger.
:returns: Library logger.
"""
if _LoggerHolder.INSTANCE is None:
formatter = logging.Formatter(_FORMAT)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger = logging.getLogger('spleeter')
logger.addHandler(handler)
logger.setLevel(logging.INFO)
_LoggerHolder.INSTANCE = logger
return _LoggerHolder.INSTANCE
def enable_tensorflow_logging():
""" Enable tensorflow logging. """
environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
tf_logger = get_tensorflow_logger()
tf_logger.set_verbosity(tf_logger.INFO)
logger = get_logger()
logger.setLevel(logging.DEBUG)
def enable_logging():
""" Configure default logging. """
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf_logger = get_tensorflow_logger()
tf_logger.set_verbosity(tf_logger.ERROR)
tf_logger = get_logger()
tf_logger.handlers = [handler]
if verbose:
tf_logging.set_verbosity(tf_logging.INFO)
logger.setLevel(logging.DEBUG)
else:
warnings.filterwarnings("ignore")
tf_logging.set_verbosity(tf_logging.ERROR)

View File

@@ -3,43 +3,54 @@
""" Utility function for tensorflow. """
from typing import Any, Callable, Dict
import pandas as pd
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
import pandas as pd
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def sync_apply(tensor_dict, func, concat_axis=1):
""" Return a function that applies synchronously the provided func on the
def sync_apply(
tensor_dict: tf.Tensor, func: Callable, concat_axis: int = 1
) -> Dict[str, tf.Tensor]:
"""
Return a function that applies synchronously the provided func on the
provided dictionnary of tensor. This means that func is applied to the
concatenation of the tensors in tensor_dict. This is useful for performing
random operation that needs the same drawn value on multiple tensor, such
as a random time-crop on both input data and label (the same crop should be
applied to both input data and label, so random crop cannot be applied
separately on each of them).
concatenation of the tensors in tensor_dict. This is useful for
performing random operation that needs the same drawn value on multiple
tensor, such as a random time-crop on both input data and label (the
same crop should be applied to both input data and label, so random
crop cannot be applied separately on each of them).
IMPORTANT NOTE: all tensor are assumed to be the same shape.
Notes:
All tensor are assumed to be the same shape.
Params:
- tensor_dict: dictionary (key: strings, values: tf.tensor)
a dictionary of tensor.
- func: function
function to be applied to the concatenation of the tensors in
tensor_dict
- concat_axis: int
The axis on which to perform the concatenation.
Parameters:
tensor_dict (Dict[str, tensorflow.Tensor]):
A dictionary of tensor.
func (Callable):
Function to be applied to the concatenation of the tensors in
`tensor_dict`.
concat_axis (int):
The axis on which to perform the concatenation.
Returns:
processed tensors dictionary with the same name (keys) as input
tensor_dict.
Returns:
Dict[str, tensorflow.Tensor]:
Processed tensors dictionary with the same name (keys) as input
tensor_dict.
"""
if concat_axis not in {0, 1}:
raise NotImplementedError(
'Function only implemented for concat_axis equal to 0 or 1')
"Function only implemented for concat_axis equal to 0 or 1"
)
tensor_list = list(tensor_dict.values())
concat_tensor = tf.concat(tensor_list, concat_axis)
processed_concat_tensor = func(concat_tensor)
@@ -47,90 +58,104 @@ def sync_apply(tensor_dict, func, concat_axis=1):
D = tensor_shape[concat_axis]
if concat_axis == 0:
return {
name: processed_concat_tensor[index * D:(index + 1) * D, :, :]
name: processed_concat_tensor[index * D : (index + 1) * D, :, :]
for index, name in enumerate(tensor_dict)
}
return {
name: processed_concat_tensor[:, index * D:(index + 1) * D, :]
name: processed_concat_tensor[:, index * D : (index + 1) * D, :]
for index, name in enumerate(tensor_dict)
}
def from_float32_to_uint8(
tensor,
tensor_key='tensor',
min_key='min',
max_key='max'):
tensor: tf.Tensor,
tensor_key: str = "tensor",
min_key: str = "min",
max_key: str = "max",
) -> tf.Tensor:
"""
:param tensor:
:param tensor_key:
:param min_key:
:param max_key:
:returns:
Parameters:
tensor (tensorflow.Tensor):
tensor_key (str):
min_key (str):
max_key (str):
Returns:
tensorflow.Tensor:
"""
tensor_min = tf.reduce_min(tensor)
tensor_max = tf.reduce_max(tensor)
return {
tensor_key: tf.cast(
(tensor - tensor_min) / (tensor_max - tensor_min + 1e-16)
* 255.9999, dtype=tf.uint8),
(tensor - tensor_min) / (tensor_max - tensor_min + 1e-16) * 255.9999,
dtype=tf.uint8,
),
min_key: tensor_min,
max_key: tensor_max
max_key: tensor_max,
}
def from_uint8_to_float32(tensor, tensor_min, tensor_max):
def from_uint8_to_float32(
tensor: tf.Tensor, tensor_min: tf.Tensor, tensor_max: tf.Tensor
) -> tf.Tensor:
"""
:param tensor:
:param tensor_min:
:param tensor_max:
:returns:
Parameters:
tensor (tensorflow.Tensor):
tensor_min (tensorflow.Tensor):
tensor_max (tensorflow.Tensor):
Returns:
tensorflow.Tensor:
"""
return (
tf.cast(tensor, tf.float32)
* (tensor_max - tensor_min)
/ 255.9999 + tensor_min)
tf.cast(tensor, tf.float32) * (tensor_max - tensor_min) / 255.9999 + tensor_min
)
def pad_and_partition(tensor, segment_len):
""" Pad and partition a tensor into segment of len segment_len
def pad_and_partition(tensor: tf.Tensor, segment_len: int) -> tf.Tensor:
"""
Pad and partition a tensor into segment of len `segment_len`
along the first dimension. The tensor is padded with 0 in order
to ensure that the first dimension is a multiple of segment_len.
to ensure that the first dimension is a multiple of `segment_len`.
Tensor must be of known fixed rank
:Example:
Examples:
>>> tensor = [[1, 2, 3], [4, 5, 6]]
>>> segment_len = 2
>>> pad_and_partition(tensor, segment_len)
[[[1, 2], [4, 5]], [[3, 0], [6, 0]]]
```python
>>> tensor = [[1, 2, 3], [4, 5, 6]]
>>> segment_len = 2
>>> pad_and_partition(tensor, segment_len)
[[[1, 2], [4, 5]], [[3, 0], [6, 0]]]
````
:param tensor:
:param segment_len:
:returns:
Parameters:
tensor (tensorflow.Tensor):
segment_len (int):
Returns:
tensorflow.Tensor:
"""
tensor_size = tf.math.floormod(tf.shape(tensor)[0], segment_len)
pad_size = tf.math.floormod(segment_len - tensor_size, segment_len)
padded = tf.pad(
tensor,
[[0, pad_size]] + [[0, 0]] * (len(tensor.shape)-1))
padded = tf.pad(tensor, [[0, pad_size]] + [[0, 0]] * (len(tensor.shape) - 1))
split = (tf.shape(padded)[0] + segment_len - 1) // segment_len
return tf.reshape(
padded,
tf.concat(
[[split, segment_len], tf.shape(padded)[1:]],
axis=0))
padded, tf.concat([[split, segment_len], tf.shape(padded)[1:]], axis=0)
)
def pad_and_reshape(instr_spec, frame_length, F):
def pad_and_reshape(instr_spec, frame_length, F) -> Any:
"""
:param instr_spec:
:param frame_length:
:param F:
:returns:
Parameters:
instr_spec:
frame_length:
F:
Returns:
Any:
"""
spec_shape = tf.shape(instr_spec)
extension_row = tf.zeros((spec_shape[0], spec_shape[1], 1, spec_shape[-1]))
@@ -138,53 +163,67 @@ def pad_and_reshape(instr_spec, frame_length, F):
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
extended_spec = tf.concat([instr_spec, extension], axis=2)
old_shape = tf.shape(extended_spec)
new_shape = tf.concat([
[old_shape[0] * old_shape[1]],
old_shape[2:]],
axis=0)
new_shape = tf.concat([[old_shape[0] * old_shape[1]], old_shape[2:]], axis=0)
processed_instr_spec = tf.reshape(extended_spec, new_shape)
return processed_instr_spec
def dataset_from_csv(csv_path, **kwargs):
""" Load dataset from a CSV file using Pandas. kwargs if any are
def dataset_from_csv(csv_path: str, **kwargs) -> Any:
"""
Load dataset from a CSV file using Pandas. kwargs if any are
forwarded to the `pandas.read_csv` function.
:param csv_path: Path of the CSV file to load dataset from.
:returns: Loaded dataset.
Parameters:
csv_path (str):
Path of the CSV file to load dataset from.
Returns:
Any:
Loaded dataset.
"""
df = pd.read_csv(csv_path, **kwargs)
dataset = (
tf.data.Dataset.from_tensor_slices(
{key: df[key].values for key in df})
)
dataset = tf.data.Dataset.from_tensor_slices({key: df[key].values for key in df})
return dataset
def check_tensor_shape(tensor_tf, target_shape):
""" Return a Tensorflow boolean graph that indicates whether
def check_tensor_shape(tensor_tf: tf.Tensor, target_shape: Any) -> bool:
"""
Return a Tensorflow boolean graph that indicates whether
sample[features_key] has the specified target shape. Only check
not None entries of target_shape.
:param tensor_tf: Tensor to check shape for.
:param target_shape: Target shape to compare tensor to.
:returns: True if shape is valid, False otherwise (as TF boolean).
Parameters:
tensor_tf (tensorflow.Tensor):
Tensor to check shape for.
target_shape (Any):
Target shape to compare tensor to.
Returns:
bool:
`True` if shape is valid, `False` otherwise (as TF boolean).
"""
result = tf.constant(True)
for i, target_length in enumerate(target_shape):
if target_length:
result = tf.logical_and(
result,
tf.equal(tf.constant(target_length), tf.shape(tensor_tf)[i]))
result, tf.equal(tf.constant(target_length), tf.shape(tensor_tf)[i])
)
return result
def set_tensor_shape(tensor, tensor_shape):
""" Set shape for a tensor (not in place, as opposed to tf.set_shape)
def set_tensor_shape(tensor: tf.Tensor, tensor_shape: Any) -> tf.Tensor:
"""
Set shape for a tensor (not in place, as opposed to tf.set_shape)
:param tensor: Tensor to reshape.
:param tensor_shape: Shape to apply to the tensor.
:returns: A reshaped tensor.
Parameters:
tensor (tensorflow.Tensor):
Tensor to reshape.
tensor_shape (Any):
Shape to apply to the tensor.
Returns:
tensorflow.Tensor:
A reshaped tensor.
"""
# NOTE: That SOUND LIKE IN PLACE HERE ?
tensor.set_shape(tensor_shape)

View File

@@ -7,82 +7,82 @@ __email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
import filecmp
import itertools
from os import makedirs
from os.path import splitext, basename, exists, join
from os.path import join
from tempfile import TemporaryDirectory
import pytest
import numpy as np
import tensorflow as tf
from spleeter.__main__ import evaluate
from spleeter.audio.adapter import AudioAdapter
from spleeter.audio.adapter import get_default_audio_adapter
from spleeter.commands import create_argument_parser
from spleeter.commands import evaluate
from spleeter.utils.configuration import load_configuration
BACKENDS = ["tensorflow", "librosa"]
TEST_CONFIGURATIONS = {el:el for el in BACKENDS}
BACKENDS = ['tensorflow', 'librosa']
TEST_CONFIGURATIONS = {el: el for el in BACKENDS}
res_4stems = {
"vocals": {
"SDR": 3.25e-05,
"SAR": -11.153575,
"SIR": -1.3849,
"ISR": 2.75e-05
},
"drums": {
"SDR": -0.079505,
"SAR": -15.7073575,
"SIR": -4.972755,
"ISR": 0.0013575
},
"bass":{
"SDR": 2.5e-06,
"SAR": -10.3520575,
"SIR": -4.272325,
"ISR": 2.5e-06
},
"other":{
"SDR": -1.359175,
"SAR": -14.7076775,
"SIR": -4.761505,
"ISR": -0.01528
}
}
'vocals': {
'SDR': 3.25e-05,
'SAR': -11.153575,
'SIR': -1.3849,
'ISR': 2.75e-05
},
'drums': {
'SDR': -0.079505,
'SAR': -15.7073575,
'SIR': -4.972755,
'ISR': 0.0013575
},
'bass': {
'SDR': 2.5e-06,
'SAR': -10.3520575,
'SIR': -4.272325,
'ISR': 2.5e-06
},
'other': {
'SDR': -1.359175,
'SAR': -14.7076775,
'SIR': -4.761505,
'ISR': -0.01528
}
}
def generate_fake_eval_dataset(path):
"""
generate fake evaluation dataset
"""
aa = get_default_audio_adapter()
aa = AudioAdapter.default()
n_songs = 2
fs = 44100
duration = 3
n_channels = 2
rng = np.random.RandomState(seed=0)
for song in range(n_songs):
song_path = join(path, "test", f"song{song}")
song_path = join(path, 'test', f'song{song}')
makedirs(song_path, exist_ok=True)
for instr in ["mixture", "vocals", "bass", "drums", "other"]:
filename = join(song_path, f"{instr}.wav")
for instr in ['mixture', 'vocals', 'bass', 'drums', 'other']:
filename = join(song_path, f'{instr}.wav')
data = rng.rand(duration*fs, n_channels)-0.5
aa.save(filename, data, fs)
@pytest.mark.parametrize('backend', TEST_CONFIGURATIONS)
def test_evaluate(backend):
with TemporaryDirectory() as directory:
generate_fake_eval_dataset(directory)
p = create_argument_parser()
arguments = p.parse_args(["evaluate", "-p", "spleeter:4stems", "--mus_dir", directory, "-B", backend])
params = load_configuration(arguments.configuration)
metrics = evaluate.entrypoint(arguments, params)
for instrument, metric in metrics.items():
for m, value in metric.items():
assert np.allclose(np.median(value), res_4stems[instrument][m], atol=1e-3)
with TemporaryDirectory() as dataset:
with TemporaryDirectory() as evaluation:
generate_fake_eval_dataset(dataset)
metrics = evaluate(
adapter='spleeter.audio.ffmpeg.FFMPEGProcessAudioAdapter',
output_path=evaluation,
stft_backend=backend,
params_filename='spleeter:4stems',
mus_dir=dataset,
mwf=False,
verbose=False)
for instrument, metric in metrics.items():
for m, value in metric.items():
assert np.allclose(
np.median(value),
res_4stems[instrument][m],
atol=1e-3)

View File

@@ -10,6 +10,11 @@ __license__ = 'MIT License'
from os.path import join
from tempfile import TemporaryDirectory
from spleeter import SpleeterError
from spleeter.audio.adapter import AudioAdapter
from spleeter.audio.ffmpeg import FFMPEGProcessAudioAdapter
# pyright: reportMissingImports=false
# pylint: disable=import-error
from pytest import fixture, raises
@@ -17,12 +22,6 @@ import numpy as np
import ffmpeg
# pylint: enable=import-error
from spleeter import SpleeterError
from spleeter.audio.adapter import AudioAdapter
from spleeter.audio.adapter import get_default_audio_adapter
from spleeter.audio.adapter import get_audio_adapter
from spleeter.audio.ffmpeg import FFMPEGProcessAudioAdapter
TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3'
TEST_OFFSET = 0
TEST_DURATION = 600.
@@ -32,7 +31,7 @@ TEST_SAMPLE_RATE = 44100
@fixture(scope='session')
def adapter():
""" Target test audio adapter fixture. """
return get_default_audio_adapter()
return AudioAdapter.default()
@fixture(scope='session')
@@ -48,7 +47,7 @@ def audio_data(adapter):
def test_default_adapter(adapter):
""" Test adapter as default adapter. """
assert isinstance(adapter, FFMPEGProcessAudioAdapter)
assert adapter is AudioAdapter.DEFAULT
assert adapter is AudioAdapter._DEFAULT
def test_load(audio_data):

View File

@@ -5,12 +5,12 @@
from pytest import raises
from spleeter.model.provider import get_default_model_provider
from spleeter.model.provider import ModelProvider
def test_checksum():
""" Test archive checksum index retrieval. """
provider = get_default_model_provider()
provider = ModelProvider.default()
assert provider.checksum('2stems') == \
'f3a90b39dd2874269e8b05a48a86745df897b848c61f3958efc80a39152bd692'
assert provider.checksum('4stems') == \

View File

@@ -17,7 +17,7 @@ import numpy as np
import tensorflow as tf
from spleeter import SpleeterError
from spleeter.audio.adapter import get_default_audio_adapter
from spleeter.audio.adapter import AudioAdapter
from spleeter.separator import Separator
TEST_AUDIO_DESCRIPTORS = ['audio_example.mp3', 'audio_example_mono.mp3']
@@ -41,7 +41,7 @@ print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__))
@pytest.mark.parametrize('test_file', TEST_AUDIO_DESCRIPTORS)
def test_separator_backends(test_file):
adapter = get_default_audio_adapter()
adapter = AudioAdapter.default()
waveform, _ = adapter.load(test_file)
separator_lib = Separator(
@@ -64,11 +64,13 @@ def test_separator_backends(test_file):
assert np.allclose(out_tf[instrument], out_lib[instrument], atol=1e-5)
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
@pytest.mark.parametrize(
'test_file, configuration, backend',
TEST_CONFIGURATIONS)
def test_separate(test_file, configuration, backend):
""" Test separation from raw data. """
instruments = MODEL_TO_INST[configuration]
adapter = get_default_audio_adapter()
adapter = AudioAdapter.default()
waveform, _ = adapter.load(test_file)
separator = Separator(
configuration, stft_backend=backend, multiprocess=False)
@@ -85,7 +87,9 @@ def test_separate(test_file, configuration, backend):
assert not np.allclose(track, prediction[compared])
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
@pytest.mark.parametrize(
'test_file, configuration, backend',
TEST_CONFIGURATIONS)
def test_separate_to_file(test_file, configuration, backend):
""" Test file based separation. """
instruments = MODEL_TO_INST[configuration]
@@ -102,7 +106,9 @@ def test_separate_to_file(test_file, configuration, backend):
'{}/{}.wav'.format(name, instrument)))
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
@pytest.mark.parametrize(
'test_file, configuration, backend',
TEST_CONFIGURATIONS)
def test_filename_format(test_file, configuration, backend):
""" Test custom filename format. """
instruments = MODEL_TO_INST[configuration]
@@ -120,7 +126,9 @@ def test_filename_format(test_file, configuration, backend):
'export/{}/{}.wav'.format(name, instrument)))
@pytest.mark.parametrize('test_file, configuration', MODELS_AND_TEST_FILES)
@pytest.mark.parametrize(
'test_file, configuration',
MODELS_AND_TEST_FILES)
def test_filename_conflict(test_file, configuration):
""" Test error handling with static pattern. """
separator = Separator(configuration, multiprocess=False)

View File

@@ -7,107 +7,102 @@ __email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
import filecmp
import itertools
import json
import os
from os import makedirs
from os.path import splitext, basename, exists, join
from os.path import join
from tempfile import TemporaryDirectory
import numpy as np
import pandas as pd
import json
import tensorflow as tf
from spleeter.audio.adapter import AudioAdapter
from spleeter.__main__ import spleeter
from typer.testing import CliRunner
from spleeter.audio.adapter import get_default_audio_adapter
from spleeter.commands import create_argument_parser
from spleeter.commands import train
from spleeter.utils.configuration import load_configuration
TRAIN_CONFIG = {
"mix_name": "mix",
"instrument_list": ["vocals", "other"],
"sample_rate":44100,
"frame_length":4096,
"frame_step":1024,
"T":128,
"F":128,
"n_channels":2,
"chunk_duration":4,
"n_chunks_per_song":1,
"separation_exponent":2,
"mask_extension":"zeros",
"learning_rate": 1e-4,
"batch_size":2,
"train_max_steps": 10,
"throttle_secs":20,
"save_checkpoints_steps":100,
"save_summary_steps":5,
"random_seed":0,
"model":{
"type":"unet.unet",
"params":{
"conv_activation":"ELU",
"deconv_activation":"ELU"
'mix_name': 'mix',
'instrument_list': ['vocals', 'other'],
'sample_rate': 44100,
'frame_length': 4096,
'frame_step': 1024,
'T': 128,
'F': 128,
'n_channels': 2,
'chunk_duration': 4,
'n_chunks_per_song': 1,
'separation_exponent': 2,
'mask_extension': 'zeros',
'learning_rate': 1e-4,
'batch_size': 2,
'train_max_steps': 10,
'throttle_secs': 20,
'save_checkpoints_steps': 100,
'save_summary_steps': 5,
'random_seed': 0,
'model': {
'type': 'unet.unet',
'params': {
'conv_activation': 'ELU',
'deconv_activation': 'ELU'
}
}
}
def generate_fake_training_dataset(path, instrument_list=["vocals", "other"]):
def generate_fake_training_dataset(path, instrument_list=['vocals', 'other']):
"""
generates a fake training dataset in path:
- generates audio files
- generates a csv file describing the dataset
"""
aa = get_default_audio_adapter()
aa = AudioAdapter.default()
n_songs = 2
fs = 44100
duration = 6
n_channels = 2
rng = np.random.RandomState(seed=0)
dataset_df = pd.DataFrame(columns=["mix_path"]+[f"{instr}_path" for instr in instrument_list]+["duration"])
dataset_df = pd.DataFrame(
columns=['mix_path'] + [
f'{instr}_path' for instr in instrument_list] + ['duration'])
for song in range(n_songs):
song_path = join(path, "train", f"song{song}")
song_path = join(path, 'train', f'song{song}')
makedirs(song_path, exist_ok=True)
dataset_df.loc[song, f"duration"] = duration
for instr in instrument_list+["mix"]:
filename = join(song_path, f"{instr}.wav")
dataset_df.loc[song, f'duration'] = duration
for instr in instrument_list+['mix']:
filename = join(song_path, f'{instr}.wav')
data = rng.rand(duration*fs, n_channels)-0.5
aa.save(filename, data, fs)
dataset_df.loc[song, f"{instr}_path"] = join("train", f"song{song}", f"{instr}.wav")
dataset_df.to_csv(join(path, "train", "train.csv"), index=False)
dataset_df.loc[song, f'{instr}_path'] = join(
'train',
f'song{song}',
f'{instr}.wav')
dataset_df.to_csv(join(path, 'train', 'train.csv'), index=False)
def test_train():
with TemporaryDirectory() as path:
# generate training dataset
generate_fake_training_dataset(path)
# set training command aruments
p = create_argument_parser()
arguments = p.parse_args(["train", "-p", "useless_config.json", "-d", path])
TRAIN_CONFIG["train_csv"] = join(path, "train", "train.csv")
TRAIN_CONFIG["validation_csv"] = join(path, "train", "train.csv")
TRAIN_CONFIG["model_dir"] = join(path, "model")
TRAIN_CONFIG["training_cache"] = join(path, "cache", "training")
TRAIN_CONFIG["validation_cache"] = join(path, "cache", "validation")
runner = CliRunner()
TRAIN_CONFIG['train_csv'] = join(path, 'train', 'train.csv')
TRAIN_CONFIG['validation_csv'] = join(path, 'train', 'train.csv')
TRAIN_CONFIG['model_dir'] = join(path, 'model')
TRAIN_CONFIG['training_cache'] = join(path, 'cache', 'training')
TRAIN_CONFIG['validation_cache'] = join(path, 'cache', 'validation')
with open('useless_config.json', 'w') as stream:
json.dump(TRAIN_CONFIG, stream)
# execute training
res = train.entrypoint(arguments, TRAIN_CONFIG)
result = runner.invoke(spleeter, [
'train',
'-p', 'useless_config.json',
'-d', path
])
# assert that model checkpoint was created.
assert os.path.exists(join(path,'model','model.ckpt-10.index'))
assert os.path.exists(join(path,'model','checkpoint'))
assert os.path.exists(join(path,'model','model.ckpt-0.meta'))
if __name__=="__main__":
test_train()
assert os.path.exists(join(path, 'model', 'model.ckpt-10.index'))
assert os.path.exists(join(path, 'model', 'checkpoint'))
assert os.path.exists(join(path, 'model', 'model.ckpt-0.meta'))
assert result.exit_code == 0