Merge pull request #529 from deezer/typer

Spleeter 2.1.0
This commit is contained in:
Félix Voituret
2021-01-08 18:40:12 +01:00
committed by GitHub
46 changed files with 4482 additions and 2176 deletions

View File

@@ -1,8 +1,6 @@
name: conda name: conda
on: on:
push: - workflow_dispatch
branches:
- master
env: env:
ANACONDA_USERNAME: ${{ secrets.ANACONDA_USERNAME }} ANACONDA_USERNAME: ${{ secrets.ANACONDA_USERNAME }}
ANACONDA_PASSWORD: ${{ secrets.ANACONDA_PASSWORD }} ANACONDA_PASSWORD: ${{ secrets.ANACONDA_PASSWORD }}
@@ -11,7 +9,7 @@ jobs:
strategy: strategy:
matrix: matrix:
python: [3.7, 3.8] python: [3.7, 3.8]
package: [spleeter] package: [spleeter, spleeter-gpu]
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2

View File

@@ -4,40 +4,22 @@ on:
branches: branches:
- master - master
env: env:
TWINE_USERNAME: ${{ secrets.TWINE_USERNAME }} PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD }}
jobs: jobs:
package-and-deploy: package-and-deploy:
strategy:
matrix:
platform: [cpu, gpu]
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- uses: actions/setup-python@v2 - uses: actions/setup-python@v2
with: with:
python-version: 3.7 python-version: 3.7
- uses: actions/cache@v2 - name: Install Poetry
with: run: |
path: ~/.cache/pip pip install poetry
key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }} poetry config virtualenvs.in-project false
restore-keys: | poetry config virtualenvs.path ~/.virtualenvs
${{ runner.os }}-pip- poetry config pypi-token.pypi $PYPI_TOKEN
- uses: actions/cache@v2
with:
path: ${{ env.GITHUB_WORKSPACE }}/dist
key: sdist-${{ matrix.platform }}-${{ hashFiles('**/setup.py') }}
restore-keys: |
sdist-${{ matrix.platform }}-${{ hashFiles('**/setup.py') }}
sdist-${{ matrix.platform }}
sdist-
- name: Install dependencies
run: pip install --upgrade pip setuptools twine
- if: ${{ matrix.platform == 'cpu' }}
name: Package CPU distribution
run: make build
- if: ${{ matrix.platform == 'gpu' }}
name: Package GPU distribution)
run: make build-gpu
- name: Deploy to pypi - name: Deploy to pypi
run: make deploy run: |
poetry build
poetry publish

View File

@@ -1,41 +0,0 @@
name: pytest
on:
pull_request:
branches:
- master
jobs:
tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- uses: actions/cache@v2
id: spleeter-pip-cache
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }}
restore-keys: |
${{ runner.os }}-pip-
- uses: actions/cache@v2
env:
model-release: 1
id: spleeter-model-cache
with:
path: ${{ env.GITHUB_WORKSPACE }}/pretrained_models
key: models-${{ env.model-release }}
restore-keys: |
models-${{ env.model-release }}
- name: Install dependencies
run: |
sudo apt-get update && sudo apt-get install -y ffmpeg
pip install --upgrade pip setuptools
pip install pytest==5.4.3 pytest-xdist==1.32.0 pytest-forked==1.1.3 musdb museval
python setup.py install
- name: Test with pytest
run: make test

51
.github/workflows/test.yml vendored Normal file
View File

@@ -0,0 +1,51 @@
name: test
on:
pull_request:
branches:
- master
jobs:
tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- uses: actions/cache@v2
env:
model-release: 1
id: spleeter-model-cache
with:
path: ${{ env.GITHUB_WORKSPACE }}/pretrained_models
key: models-${{ env.model-release }}
restore-keys: |
models-${{ env.model-release }}
- name: Install ffmpeg
run: |
sudo apt-get update && sudo apt-get install -y ffmpeg
- name: Install Poetry
run: |
pip install poetry
poetry config virtualenvs.in-project false
poetry config virtualenvs.path ~/.virtualenvs
- name: Cache Poetry virtualenv
uses: actions/cache@v1
id: cache
with:
path: ~/.virtualenvs
key: poetry-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
restore-keys: |
poetry-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
- name: Install Dependencies
run: poetry install
if: steps.cache.outputs.cache-hit != 'true'
- name: Code quality checks
run: |
poetry run black spleeter --check
poetry run isort spleeter --check
- name: Test with pytest
run: poetry run pytest tests/

View File

@@ -1,5 +1,27 @@
# Changelog History # Changelog History
## 2.1.0
This version introduce design related changes, especially transition to Typer for CLI managment and Poetry as
library build backend.
* `-i` option is now deprecated and replaced by traditional CLI input argument listing
* Project is now built using Poetry
* Project requires code formatting using Black and iSort
* Dedicated GPU package `spleeter-gpu` is not supported anymore, `spleeter` package will support both CPU and GPU hardware
### API changes:
* function `get_default_audio_adapter` is now available as `default()` class method within `AudioAdapter` class
* function `get_default_model_provider` is now available as `default()` class method within `ModelProvider` class
* `STFTBackend` and `Codec` are now string enum
* `GithubModelProvider` now use `httpx` with HTTP/2 support
* Commands are now located in `__main__` module, wrapped as simple function using Typer options module provide specification for each available option and argument
* `types` module provide custom type specification and must be enhanced in future release to provide more robust typing support with MyPy
* `utils.logging` module has been cleaned, logger instance is now a module singleton, and a single function is used to configure it with verbose parameter
* Added a custom logger handler (see tiangolo/typer#203 discussion)
## 2.0 ## 2.0
First release, October 9th 2020 First release, October 9th 2020

View File

@@ -1,3 +0,0 @@
include spleeter/resources/*.json
include README.md
include LICENSE

View File

@@ -1,34 +0,0 @@
# =======================================================
# Library lifecycle management.
#
# @author Deezer Research <spleeter@deezer.com>
# @licence MIT Licence
# =======================================================
FEEDSTOCK = spleeter-feedstock
FEEDSTOCK_REPOSITORY = https://github.com/deezer/$(FEEDSTOCK)
FEEDSTOCK_RECIPE = $(FEEDSTOCK)/recipe/spleeter/meta.yaml
PYTEST_CMD = pytest -W ignore::FutureWarning -W ignore::DeprecationWarning -vv --forked
all: clean build test deploy
clean:
rm -Rf *.egg-info
rm -Rf dist
build: clean
sed -i "s/project_name = '[^']*'/project_name = 'spleeter'/g" setup.py
sed -i "s/tensorflow_dependency = '[^']*'/tensorflow_dependency = 'tensorflow'/g" setup.py
python3 setup.py sdist
build-gpu: clean
sed -i "s/project_name = '[^']*'/project_name = 'spleeter-gpu'/g" setup.py
sed -i "s/tensorflow_dependency = '[^']*'/tensorflow_dependency = 'tensorflow-gpu'/g" setup.py
python3 setup.py sdist
test:
$(PYTEST_CMD) tests/
deploy:
pip install twine
twine upload --skip-existing dist/*

View File

@@ -2,6 +2,9 @@
[![Github actions](https://github.com/deezer/spleeter/workflows/pytest/badge.svg)](https://github.com/deezer/spleeter/actions) ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/spleeter) [![PyPI version](https://badge.fury.io/py/spleeter.svg)](https://badge.fury.io/py/spleeter) [![Conda](https://img.shields.io/conda/vn/conda-forge/spleeter)](https://anaconda.org/conda-forge/spleeter) [![Docker Pulls](https://img.shields.io/docker/pulls/researchdeezer/spleeter)](https://hub.docker.com/r/researchdeezer/spleeter) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deezer/spleeter/blob/master/spleeter.ipynb) [![Gitter chat](https://badges.gitter.im/gitterHQ/gitter.png)](https://gitter.im/spleeter/community) [![status](https://joss.theoj.org/papers/259e5efe669945a343bad6eccb89018b/status.svg)](https://joss.theoj.org/papers/259e5efe669945a343bad6eccb89018b) [![Github actions](https://github.com/deezer/spleeter/workflows/pytest/badge.svg)](https://github.com/deezer/spleeter/actions) ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/spleeter) [![PyPI version](https://badge.fury.io/py/spleeter.svg)](https://badge.fury.io/py/spleeter) [![Conda](https://img.shields.io/conda/vn/conda-forge/spleeter)](https://anaconda.org/conda-forge/spleeter) [![Docker Pulls](https://img.shields.io/docker/pulls/researchdeezer/spleeter)](https://hub.docker.com/r/researchdeezer/spleeter) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deezer/spleeter/blob/master/spleeter.ipynb) [![Gitter chat](https://badges.gitter.im/gitterHQ/gitter.png)](https://gitter.im/spleeter/community) [![status](https://joss.theoj.org/papers/259e5efe669945a343bad6eccb89018b/status.svg)](https://joss.theoj.org/papers/259e5efe669945a343bad6eccb89018b)
> :warning: [Spleeter 2.1.0](https://pypi.org/project/spleeter/) release introduces some breaking changes, including new CLI option naming for input, and the drop
> of dedicated GPU package. Please read [CHANGELOG](CHANGELOG.md) for more details.
## About ## About
**Spleeter** is [Deezer](https://www.deezer.com/) source separation library with pretrained models **Spleeter** is [Deezer](https://www.deezer.com/) source separation library with pretrained models
@@ -46,7 +49,7 @@ conda install -c conda-forge spleeter
# download an example audio file (if you don't have wget, use another tool for downloading) # download an example audio file (if you don't have wget, use another tool for downloading)
wget https://github.com/deezer/spleeter/raw/master/audio_example.mp3 wget https://github.com/deezer/spleeter/raw/master/audio_example.mp3
# separate the example audio into two components # separate the example audio into two components
spleeter separate -i audio_example.mp3 -p spleeter:2stems -o output spleeter separate -p spleeter:2stems -o output audio_example.mp3
``` ```
You should get two separated audio files (`vocals.wav` and `accompaniment.wav`) in the `output/audio_example` folder. You should get two separated audio files (`vocals.wav` and `accompaniment.wav`) in the `output/audio_example` folder.
@@ -55,13 +58,18 @@ For a detailed documentation, please check the [repository wiki](https://github.
## Development and Testing ## Development and Testing
The following set of commands will clone this repository, create a virtual environment provisioned with the dependencies and run the tests (will take a few minutes): This project is managed using [Poetry](https://python-poetry.org/docs/basic-usage/), to run test suite you
can execute the following set of commands:
```bash ```bash
# Clone spleeter repository
git clone https://github.com/Deezer/spleeter && cd spleeter git clone https://github.com/Deezer/spleeter && cd spleeter
python -m venv spleeterenv && source spleeterenv/bin/activate # Install poetry
pip install . && pip install pytest pytest-xdist pip install poetry
make test # Install spleeter dependencies
poetry install
# Run unit test suite
poetry run pytest tests/
``` ```
## Reference ## Reference

View File

@@ -0,0 +1,52 @@
{% set name = "spleeter-gpu" %}
{% set version = "2.0.2" %}
package:
name: {{ name|lower }}
version: {{ version }}
source:
- url: https://pypi.io/packages/source/{{ name[0] }}/{{ name }}/{{ name }}-{{ version }}.tar.gz
sha256: ecd3518a98f9978b9088d1cb2ef98f766401fd9007c2bf72a34e5b5bc5a6fdc3
build:
number: 0
script: {{ PYTHON }} -m pip install . -vv
skip: True # [osx]
entry_points:
- spleeter = spleeter.__main__:entrypoint
requirements:
host:
- python {{ python }}
- pip
run:
- python {{ python }}
- tensorflow-gpu ==2.2.0 # [linux]
- tensorflow-gpu ==23.0 # [win]
- pandas
- ffmpeg-python
- norbert
- librosa
test:
imports:
- spleeter
- spleeter.commands
- spleeter.model
- spleeter.utils
- spleeter.separator
about:
home: https://github.com/deezer/spleeter
license: MIT
license_family: MIT
license_file: LICENSE
summary: The Deezer source separation library with pretrained models based on tensorflow.
doc_url: https://github.com/deezer/spleeter/wiki
dev_url: https://github.com/deezer/spleeter
extra:
recipe-maintainers:
- Faylixe
- romi1502

View File

@@ -1,3 +0,0 @@
python:
- 3.7
- 3.8

1880
poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

83
pyproject.toml Normal file
View File

@@ -0,0 +1,83 @@
[tool.poetry]
name = "spleeter"
version = "2.1.0"
description = "The Deezer source separation library with pretrained models based on tensorflow."
authors = ["Deezer Research <spleeter@deezer.com>"]
license = "MIT License"
readme = "README.md"
repository = "https://github.com/deezer/spleeter"
homepage = "https://github.com/deezer/spleeter"
classifiers = [
"Environment :: Console",
"Environment :: MacOS X",
"Intended Audience :: Developers",
"Intended Audience :: Information Technology",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Operating System :: MacOS",
"Operating System :: Microsoft :: Windows",
"Operating System :: POSIX :: Linux",
"Operating System :: Unix",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: Implementation :: CPython",
"Topic :: Artistic Software",
"Topic :: Multimedia",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Multimedia :: Sound/Audio :: Analysis",
"Topic :: Multimedia :: Sound/Audio :: Conversion",
"Topic :: Multimedia :: Sound/Audio :: Sound Synthesis",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Information Analysis",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Utilities"
]
packages = [ { include = "spleeter" } ]
include = ["LICENSE", "spleeter/resources/*.json"]
[tool.poetry.dependencies]
python = "^3.7"
ffmpeg-python = "0.2.0"
norbert = "0.2.1"
httpx = {extras = ["http2"], version = "^0.16.1"}
typer = "^0.3.2"
librosa = "0.8.0"
musdb = {version = "0.3.1", optional = true}
museval = {version = "0.3.0", optional = true}
tensorflow = "2.3.0"
pandas = "1.1.2"
numpy = "<1.19.0,>=1.16.0"
[tool.poetry.dev-dependencies]
pytest = "^6.2.1"
isort = "^5.7.0"
black = "^20.8b1"
mypy = "^0.790"
pytest-forked = "^1.3.0"
musdb = "0.3.1"
museval = "0.3.0"
[tool.poetry.scripts]
spleeter = 'spleeter.__main__:entrypoint'
[tool.poetry.extras]
evaluation = ["musdb", "museval"]
[tool.isort]
profile = "black"
multi_line_output = 3
[tool.pytest.ini_options]
addopts = "-W ignore::FutureWarning -W ignore::DeprecationWarning -vv --forked"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

102
setup.py
View File

@@ -1,102 +0,0 @@
#!/usr/bin/env python
# coding: utf8
""" Distribution script. """
import sys
from os import path
from setuptools import setup
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# Default project values.
project_name = 'spleeter'
project_version = '2.0.2'
tensorflow_dependency = 'tensorflow'
tensorflow_version = '2.3.0'
here = path.abspath(path.dirname(__file__))
readme_path = path.join(here, 'README.md')
with open(readme_path, 'r') as stream:
readme = stream.read()
# Package setup entrypoint.
setup(
name=project_name,
version=project_version,
description='''
The Deezer source separation library with
pretrained models based on tensorflow.
''',
long_description=readme,
long_description_content_type='text/markdown',
author='Deezer Research',
author_email='spleeter@deezer.com',
url='https://github.com/deezer/spleeter',
license='MIT License',
packages=[
'spleeter',
'spleeter.audio',
'spleeter.commands',
'spleeter.model',
'spleeter.model.functions',
'spleeter.model.provider',
'spleeter.resources',
'spleeter.utils',
],
package_data={'spleeter.resources': ['*.json']},
python_requires='>=3.6, <3.9',
include_package_data=True,
install_requires=[
'ffmpeg-python==0.2.0',
'importlib_resources ; python_version<"3.7"',
'norbert==0.2.1',
'numpy<1.19.0,>=1.16.0',
'pandas==1.1.2',
'requests',
'scipy==1.4.1',
'setuptools>=41.0.0',
'librosa==0.8.0',
'{}=={}'.format(tensorflow_dependency, tensorflow_version),
],
extras_require={
'evaluation': ['musdb==0.3.1', 'museval==0.3.0']
},
entry_points={
'console_scripts': ['spleeter=spleeter.__main__:entrypoint']
},
classifiers=[
'Environment :: Console',
'Environment :: MacOS X',
'Intended Audience :: Developers',
'Intended Audience :: Information Technology',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: MIT License',
'Natural Language :: English',
'Operating System :: MacOS',
'Operating System :: Microsoft :: Windows',
'Operating System :: POSIX :: Linux',
'Operating System :: Unix',
'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3 :: Only',
'Programming Language :: Python :: Implementation :: CPython',
'Topic :: Artistic Software',
'Topic :: Multimedia',
'Topic :: Multimedia :: Sound/Audio',
'Topic :: Multimedia :: Sound/Audio :: Analysis',
'Topic :: Multimedia :: Sound/Audio :: Conversion',
'Topic :: Multimedia :: Sound/Audio :: Sound Synthesis',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Scientific/Engineering :: Information Analysis',
'Topic :: Software Development',
'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules',
'Topic :: Utilities']
)

View File

@@ -13,9 +13,9 @@
by providing train, evaluation and source separation action. by providing train, evaluation and source separation action.
""" """
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"
class SpleeterError(Exception): class SpleeterError(Exception):

View File

@@ -5,54 +5,252 @@
Python oneliner script usage. Python oneliner script usage.
USAGE: python -m spleeter {train,evaluate,separate} ... USAGE: python -m spleeter {train,evaluate,separate} ...
Notes:
All critical import involving TF, numpy or Pandas are deported to
command function scope to avoid heavy import on CLI evaluation,
leading to large bootstraping time.
""" """
import sys import json
import warnings from functools import partial
from glob import glob
from itertools import product
from os.path import join
from pathlib import Path
from typing import Container, Dict, List, Optional
# pyright: reportMissingImports=false
# pylint: disable=import-error
from typer import Exit, Typer
from . import SpleeterError from . import SpleeterError
from .commands import create_argument_parser from .options import *
from .utils.configuration import load_configuration from .utils.logging import configure_logger, logger
from .utils.logging import (
enable_logging,
enable_tensorflow_logging,
get_logger)
__email__ = 'spleeter@deezer.com' # pylint: enable=import-error
__author__ = 'Deezer Research'
__license__ = 'MIT License' spleeter: Typer = Typer(add_completion=False)
""" CLI application. """
def main(argv): @spleeter.command()
""" Spleeter runner. Parse provided command line arguments def train(
and run entrypoint for required command (either train, adapter: str = AudioAdapterOption,
evaluate or separate). data: Path = TrainingDataDirectoryOption,
params_filename: str = ModelParametersOption,
:param argv: Provided command line arguments. verbose: bool = VerboseOption,
) -> None:
""" """
Train a source separation model
"""
import tensorflow as tf
from .audio.adapter import AudioAdapter
from .dataset import get_training_dataset, get_validation_dataset
from .model import model_fn
from .model.provider import ModelProvider
from .utils.configuration import load_configuration
configure_logger(verbose)
audio_adapter = AudioAdapter.get(adapter)
audio_path = str(data)
params = load_configuration(params_filename)
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params["model_dir"],
params=params,
config=tf.estimator.RunConfig(
save_checkpoints_steps=params["save_checkpoints_steps"],
tf_random_seed=params["random_seed"],
save_summary_steps=params["save_summary_steps"],
session_config=session_config,
log_step_count_steps=10,
keep_checkpoint_max=2,
),
)
input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
train_spec = tf.estimator.TrainSpec(
input_fn=input_fn, max_steps=params["train_max_steps"]
)
input_fn = partial(get_validation_dataset, params, audio_adapter, audio_path)
evaluation_spec = tf.estimator.EvalSpec(
input_fn=input_fn, steps=None, throttle_secs=params["throttle_secs"]
)
logger.info("Start model training")
tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec)
ModelProvider.writeProbe(params["model_dir"])
logger.info("Model training done")
@spleeter.command()
def separate(
deprecated_files: Optional[str] = AudioInputOption,
files: List[Path] = AudioInputArgument,
adapter: str = AudioAdapterOption,
bitrate: str = AudioBitrateOption,
codec: Codec = AudioCodecOption,
duration: float = AudioDurationOption,
offset: float = AudioOffsetOption,
output_path: Path = AudioOutputOption,
stft_backend: STFTBackend = AudioSTFTBackendOption,
filename_format: str = FilenameFormatOption,
params_filename: str = ModelParametersOption,
mwf: bool = MWFOption,
verbose: bool = VerboseOption,
) -> None:
"""
Separate audio file(s)
"""
from .audio.adapter import AudioAdapter
from .separator import Separator
configure_logger(verbose)
if deprecated_files is not None:
logger.error(
"⚠️ -i option is not supported anymore, audio files must be supplied "
"using input argument instead (see spleeter separate --help)"
)
raise Exit(20)
audio_adapter: AudioAdapter = AudioAdapter.get(adapter)
separator: Separator = Separator(
params_filename, MWF=mwf, stft_backend=stft_backend
)
for filename in files:
separator.separate_to_file(
str(filename),
str(output_path),
audio_adapter=audio_adapter,
offset=offset,
duration=duration,
codec=codec,
bitrate=bitrate,
filename_format=filename_format,
synchronous=False,
)
separator.join()
EVALUATION_SPLIT: str = "test"
EVALUATION_METRICS_DIRECTORY: str = "metrics"
EVALUATION_INSTRUMENTS: Container[str] = ("vocals", "drums", "bass", "other")
EVALUATION_METRICS: Container[str] = ("SDR", "SAR", "SIR", "ISR")
EVALUATION_MIXTURE: str = "mixture.wav"
EVALUATION_AUDIO_DIRECTORY: str = "audio"
def _compile_metrics(metrics_output_directory) -> Dict:
"""
Compiles metrics from given directory and returns results as dict.
Parameters:
metrics_output_directory (str):
Directory to get metrics from.
Returns:
Dict:
Compiled metrics as dict.
"""
import numpy as np
import pandas as pd
songs = glob(join(metrics_output_directory, "test/*.json"))
index = pd.MultiIndex.from_tuples(
product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS),
names=["instrument", "metric"],
)
pd.DataFrame([], index=["config1", "config2"], columns=index)
metrics = {
instrument: {k: [] for k in EVALUATION_METRICS}
for instrument in EVALUATION_INSTRUMENTS
}
for song in songs:
with open(song, "r") as stream:
data = json.load(stream)
for target in data["targets"]:
instrument = target["name"]
for metric in EVALUATION_METRICS:
sdr_med = np.median(
[
frame["metrics"][metric]
for frame in target["frames"]
if not np.isnan(frame["metrics"][metric])
]
)
metrics[instrument][metric].append(sdr_med)
return metrics
@spleeter.command()
def evaluate(
adapter: str = AudioAdapterOption,
output_path: Path = AudioOutputOption,
stft_backend: STFTBackend = AudioSTFTBackendOption,
params_filename: str = ModelParametersOption,
mus_dir: Path = MUSDBDirectoryOption,
mwf: bool = MWFOption,
verbose: bool = VerboseOption,
) -> Dict:
"""
Evaluate a model on the musDB test dataset
"""
import numpy as np
configure_logger(verbose)
try: try:
parser = create_argument_parser() import musdb
arguments = parser.parse_args(argv[1:]) import museval
enable_logging() except ImportError:
if arguments.verbose: logger.error("Extra dependencies musdb and museval not found")
enable_tensorflow_logging() logger.error("Please install musdb and museval first, abort")
if arguments.command == 'separate': raise Exit(10)
from .commands.separate import entrypoint # Separate musdb sources.
elif arguments.command == 'train': songs = glob(join(mus_dir, EVALUATION_SPLIT, "*/"))
from .commands.train import entrypoint mixtures = [join(song, EVALUATION_MIXTURE) for song in songs]
elif arguments.command == 'evaluate': audio_output_directory = join(output_path, EVALUATION_AUDIO_DIRECTORY)
from .commands.evaluate import entrypoint separate(
params = load_configuration(arguments.configuration) deprecated_files=None,
entrypoint(arguments, params) files=mixtures,
except SpleeterError as e: adapter=adapter,
get_logger().error(e) bitrate="128k",
codec=Codec.WAV,
duration=600.0,
offset=0,
output_path=join(audio_output_directory, EVALUATION_SPLIT),
stft_backend=stft_backend,
filename_format="{foldername}/{instrument}.{codec}",
params_filename=params_filename,
mwf=mwf,
verbose=verbose,
)
# Compute metrics with musdb.
metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY)
logger.info("Starting musdb evaluation (this could be long) ...")
dataset = musdb.DB(root=mus_dir, is_wav=True, subsets=[EVALUATION_SPLIT])
museval.eval_mus_dir(
dataset=dataset,
estimates_dir=audio_output_directory,
output_dir=metrics_output_directory,
)
logger.info("musdb evaluation done")
# Compute and pretty print median metrics.
metrics = _compile_metrics(metrics_output_directory)
for instrument, metric in metrics.items():
logger.info(f"{instrument}:")
for metric, value in metric.items():
logger.info(f"{metric}: {np.median(value):.3f}")
return metrics
def entrypoint(): def entrypoint():
""" Command line entrypoint. """ """ Application entrypoint. """
warnings.filterwarnings('ignore') try:
main(sys.argv) spleeter()
except SpleeterError as e:
logger.error(e)
if __name__ == '__main__': if __name__ == "__main__":
entrypoint() entrypoint()

View File

@@ -10,6 +10,43 @@
- Waveform convertion and transforming functions. - Waveform convertion and transforming functions.
""" """
__email__ = 'spleeter@deezer.com' from enum import Enum
__author__ = 'Deezer Research'
__license__ = 'MIT License' __email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class Codec(str, Enum):
""" Enumeration of supported audio codec. """
WAV: str = "wav"
MP3: str = "mp3"
OGG: str = "ogg"
M4A: str = "m4a"
WMA: str = "wma"
FLAC: str = "flac"
class STFTBackend(str, Enum):
""" Enumeration of supported STFT backend. """
AUTO: str = "auto"
TENSORFLOW: str = "tensorflow"
LIBROSA: str = "librosa"
@classmethod
def resolve(cls: type, backend: str) -> str:
# NOTE: import is resolved here to avoid performance issues on command
# evaluation.
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
if backend not in cls.__members__.values():
raise ValueError(f"Unsupported backend {backend}")
if backend == cls.AUTO:
if len(tf.config.list_physical_devices("GPU")):
return cls.TENSORFLOW
return cls.LIBROSA
return backend

View File

@@ -3,70 +3,101 @@
""" AudioAdapter class defintion. """ """ AudioAdapter class defintion. """
import subprocess
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from importlib import import_module from importlib import import_module
from os.path import exists from pathlib import Path
from typing import Any, Dict, List, Optional, Union
# pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.signal import stft, hann_window from spleeter.audio import Codec
# pylint: enable=import-error
from .. import SpleeterError from .. import SpleeterError
from ..utils.logging import get_logger from ..types import AudioDescriptor, Signal
from ..utils.logging import logger
__email__ = 'spleeter@deezer.com' # pylint: enable=import-error
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
class AudioAdapter(ABC): class AudioAdapter(ABC):
""" An abstract class for manipulating audio signal. """ """ An abstract class for manipulating audio signal. """
# Default audio adapter singleton instance. _DEFAULT: "AudioAdapter" = None
DEFAULT = None """ Default audio adapter singleton instance. """
@abstractmethod @abstractmethod
def load( def load(
self, audio_descriptor, offset, duration, self,
sample_rate, dtype=np.float32): audio_descriptor: AudioDescriptor,
""" Loads the audio file denoted by the given audio descriptor offset: Optional[float] = None,
and returns it data as a waveform. Aims to be implemented duration: Optional[float] = None,
by client. sample_rate: Optional[float] = None,
dtype: np.dtype = np.float32,
) -> Signal:
"""
Loads the audio file denoted by the given audio descriptor and
returns it data as a waveform. Aims to be implemented by client.
:param audio_descriptor: Describe song to load, in case of file Parameters:
based audio adapter, such descriptor would audio_descriptor (AudioDescriptor):
be a file path. Describe song to load, in case of file based audio adapter,
:param offset: Start offset to load from in seconds. such descriptor would be a file path.
:param duration: Duration to load in seconds. offset (Optional[float]):
:param sample_rate: Sample rate to load audio with. Start offset to load from in seconds.
:param dtype: Numpy data type to use, default to float32. duration (Optional[float]):
:returns: Loaded data as (wf, sample_rate) tuple. Duration to load in seconds.
sample_rate (Optional[float]):
Sample rate to load audio with.
dtype (numpy.dtype):
(Optional) Numpy data type to use, default to `float32`.
Returns:
Signal:
Loaded data as (wf, sample_rate) tuple.
""" """
pass pass
def load_tf_waveform( def load_tf_waveform(
self, audio_descriptor, self,
offset=0.0, duration=1800., sample_rate=44100, audio_descriptor,
dtype=b'float32', waveform_name='waveform'): offset: float = 0.0,
""" Load the audio and convert it to a tensorflow waveform. duration: float = 1800.0,
sample_rate: int = 44100,
dtype: bytes = b"float32",
waveform_name: str = "waveform",
) -> Dict[str, Any]:
"""
Load the audio and convert it to a tensorflow waveform.
:param audio_descriptor: Describe song to load, in case of file Parameters:
based audio adapter, such descriptor would audio_descriptor ():
be a file path. Describe song to load, in case of file based audio adapter,
:param offset: Start offset to load from in seconds. such descriptor would be a file path.
:param duration: Duration to load in seconds. offset (float):
:param sample_rate: Sample rate to load audio with. Start offset to load from in seconds.
:param dtype: Numpy data type to use, default to float32. duration (float):
:param waveform_name: (Optional) Name of the key in output dict. Duration to load in seconds.
:returns: TF output dict with waveform as sample_rate (float):
(T x chan numpy array) and a boolean that Sample rate to load audio with.
tells whether there were an error while dtype (bytes):
trying to load the waveform. (Optional)data type to use, default to `b'float32'`.
waveform_name (str):
(Optional) Name of the key in output dict, default to
`'waveform'`.
Returns:
Dict[str, Any]:
TF output dict with waveform as `(T x chan numpy array)`
and a boolean that tells whether there were an error while
trying to load the waveform.
""" """
# Cast parameters to TF format. # Cast parameters to TF format.
offset = tf.cast(offset, tf.float64) offset = tf.cast(offset, tf.float64)
@@ -74,76 +105,96 @@ class AudioAdapter(ABC):
# Defined safe loading function. # Defined safe loading function.
def safe_load(path, offset, duration, sample_rate, dtype): def safe_load(path, offset, duration, sample_rate, dtype):
logger = get_logger() logger.info(f"Loading audio {path} from {offset} to {offset + duration}")
logger.info(
f'Loading audio {path} from {offset} to {offset + duration}')
try: try:
(data, _) = self.load( (data, _) = self.load(
path.numpy(), path.numpy(),
offset.numpy(), offset.numpy(),
duration.numpy(), duration.numpy(),
sample_rate.numpy(), sample_rate.numpy(),
dtype=dtype.numpy()) dtype=dtype.numpy(),
logger.info('Audio data loaded successfully') )
logger.info("Audio data loaded successfully")
return (data, False) return (data, False)
except Exception as e: except Exception as e:
logger.exception( logger.exception("An error occurs while loading audio", exc_info=e)
'An error occurs while loading audio',
exc_info=e)
return (np.float32(-1.0), True) return (np.float32(-1.0), True)
# Execute function and format results. # Execute function and format results.
results = tf.py_function( results = (
safe_load, tf.py_function(
[audio_descriptor, offset, duration, sample_rate, dtype], safe_load,
(tf.float32, tf.bool)), [audio_descriptor, offset, duration, sample_rate, dtype],
(tf.float32, tf.bool),
),
)
waveform, error = results[0] waveform, error = results[0]
return { return {waveform_name: waveform, f"{waveform_name}_error": error}
waveform_name: waveform,
f'{waveform_name}_error': error
}
@abstractmethod @abstractmethod
def save( def save(
self, path, data, sample_rate, self,
codec=None, bitrate=None): path: Union[Path, str],
""" Save the given audio data to the file denoted by data: np.ndarray,
the given path. sample_rate: float,
codec: Codec = None,
bitrate: str = None,
) -> None:
"""
Save the given audio data to the file denoted by the given path.
:param path: Path of the audio file to save data in. Parameters:
:param data: Waveform data to write. path (Union[Path, str]):
:param sample_rate: Sample rate to write file in. Path like of the audio file to save data in.
:param codec: (Optional) Writing codec to use. data (numpy.ndarray):
:param bitrate: (Optional) Bitrate of the written audio file. Waveform data to write.
sample_rate (float):
Sample rate to write file in.
codec ():
(Optional) Writing codec to use, default to `None`.
bitrate (str):
(Optional) Bitrate of the written audio file, default to
`None`.
""" """
pass pass
@classmethod
def default(cls: type) -> "AudioAdapter":
"""
Builds and returns a default audio adapter instance.
def get_default_audio_adapter(): Returns:
""" Builds and returns a default audio adapter instance. AudioAdapter:
Default adapter instance to use.
"""
if cls._DEFAULT is None:
from .ffmpeg import FFMPEGProcessAudioAdapter
:returns: An audio adapter instance. cls._DEFAULT = FFMPEGProcessAudioAdapter()
""" return cls._DEFAULT
if AudioAdapter.DEFAULT is None:
from .ffmpeg import FFMPEGProcessAudioAdapter
AudioAdapter.DEFAULT = FFMPEGProcessAudioAdapter()
return AudioAdapter.DEFAULT
@classmethod
def get(cls: type, descriptor: str) -> "AudioAdapter":
"""
Load dynamically an AudioAdapter from given class descriptor.
def get_audio_adapter(descriptor): Parameters:
""" Load dynamically an AudioAdapter from given class descriptor. descriptor (str):
Adapter class descriptor (module.Class)
:param descriptor: Adapter class descriptor (module.Class) Returns:
:returns: Created adapter instance. AudioAdapter:
""" Created adapter instance.
if descriptor is None: """
return get_default_audio_adapter() if not descriptor:
module_path = descriptor.split('.') return cls.default()
adapter_class_name = module_path[-1] module_path: List[str] = descriptor.split(".")
module_path = '.'.join(module_path[:-1]) adapter_class_name: str = module_path[-1]
adapter_module = import_module(module_path) module_path: str = ".".join(module_path[:-1])
adapter_class = getattr(adapter_module, adapter_class_name) adapter_module = import_module(module_path)
if not isinstance(adapter_class, AudioAdapter): adapter_class = getattr(adapter_module, adapter_class_name)
raise SpleeterError( if not issubclass(adapter_class, AudioAdapter):
f'{adapter_class_name} is not a valid AudioAdapter class') raise SpleeterError(
return adapter_class() f"{adapter_class_name} is not a valid AudioAdapter class"
)
return adapter_class()

View File

@@ -3,39 +3,54 @@
""" This module provides audio data convertion functions. """ """ This module provides audio data convertion functions. """
# pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
# pylint: enable=import-error
from ..utils.tensor import from_float32_to_uint8, from_uint8_to_float32 from ..utils.tensor import from_float32_to_uint8, from_uint8_to_float32
__email__ = 'spleeter@deezer.com' # pylint: enable=import-error
__author__ = 'Deezer Research'
__license__ = 'MIT License' __email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def to_n_channels(waveform, n_channels): def to_n_channels(waveform: tf.Tensor, n_channels: int) -> tf.Tensor:
""" Convert a waveform to n_channels by removing or """
duplicating channels if needed (in tensorflow). Convert a waveform to n_channels by removing or duplicating channels if
needed (in tensorflow).
:param waveform: Waveform to transform. Parameters:
:param n_channels: Number of channel to reshape waveform in. waveform (tensorflow.Tensor):
:returns: Reshaped waveform. Waveform to transform.
n_channels (int):
Number of channel to reshape waveform in.
Returns:
tensorflow.Tensor:
Reshaped waveform.
""" """
return tf.cond( return tf.cond(
tf.shape(waveform)[1] >= n_channels, tf.shape(waveform)[1] >= n_channels,
true_fn=lambda: waveform[:, :n_channels], true_fn=lambda: waveform[:, :n_channels],
false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels] false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels],
) )
def to_stereo(waveform): def to_stereo(waveform: np.ndarray) -> np.ndarray:
""" Convert a waveform to stereo by duplicating if mono, """
or truncating if too many channels. Convert a waveform to stereo by duplicating if mono, or truncating
if too many channels.
:param waveform: a (N, d) numpy array. Parameters:
:returns: A stereo waveform as a (N, 1) numpy array. waveform (numpy.ndarray):
a `(N, d)` numpy array.
Returns:
numpy.ndarray:
A stereo waveform as a `(N, 1)` numpy array.
""" """
if waveform.shape[1] == 1: if waveform.shape[1] == 1:
return np.repeat(waveform, 2, axis=-1) return np.repeat(waveform, 2, axis=-1)
@@ -44,45 +59,81 @@ def to_stereo(waveform):
return waveform return waveform
def gain_to_db(tensor, espilon=10e-10): def gain_to_db(tensor: tf.Tensor, espilon: float = 10e-10) -> tf.Tensor:
""" Convert from gain to decibel in tensorflow.
:param tensor: Tensor to convert.
:param epsilon: Operation constant.
:returns: Converted tensor.
""" """
return 20. / np.log(10) * tf.math.log(tf.maximum(tensor, espilon)) Convert from gain to decibel in tensorflow.
Parameters:
tensor (tensorflow.Tensor):
Tensor to convert
epsilon (float):
Operation constant.
def db_to_gain(tensor): Returns:
""" Convert from decibel to gain in tensorflow. tensorflow.Tensor:
Converted tensor.
:param tensor_db: Tensor to convert.
:returns: Converted tensor.
""" """
return tf.pow(10., (tensor / 20.)) return 20.0 / np.log(10) * tf.math.log(tf.maximum(tensor, espilon))
def spectrogram_to_db_uint(spectrogram, db_range=100., **kwargs): def db_to_gain(tensor: tf.Tensor) -> tf.Tensor:
""" Encodes given spectrogram into uint8 using decibel scale.
:param spectrogram: Spectrogram to be encoded as TF float tensor.
:param db_range: Range in decibel for encoding.
:returns: Encoded decibel spectrogram as uint8 tensor.
""" """
db_spectrogram = gain_to_db(spectrogram) Convert from decibel to gain in tensorflow.
max_db_spectrogram = tf.reduce_max(db_spectrogram)
db_spectrogram = tf.maximum(db_spectrogram, max_db_spectrogram - db_range) Parameters:
tensor (tensorflow.Tensor):
Tensor to convert
Returns:
tensorflow.Tensor:
Converted tensor.
"""
return tf.pow(10.0, (tensor / 20.0))
def spectrogram_to_db_uint(
spectrogram: tf.Tensor, db_range: float = 100.0, **kwargs
) -> tf.Tensor:
"""
Encodes given spectrogram into uint8 using decibel scale.
Parameters:
spectrogram (tensorflow.Tensor):
Spectrogram to be encoded as TF float tensor.
db_range (float):
Range in decibel for encoding.
Returns:
tensorflow.Tensor:
Encoded decibel spectrogram as `uint8` tensor.
"""
db_spectrogram: tf.Tensor = gain_to_db(spectrogram)
max_db_spectrogram: tf.Tensor = tf.reduce_max(db_spectrogram)
db_spectrogram: tf.Tensor = tf.maximum(
db_spectrogram, max_db_spectrogram - db_range
)
return from_float32_to_uint8(db_spectrogram, **kwargs) return from_float32_to_uint8(db_spectrogram, **kwargs)
def db_uint_spectrogram_to_gain(db_uint_spectrogram, min_db, max_db): def db_uint_spectrogram_to_gain(
""" Decode spectrogram from uint8 decibel scale. db_uint_spectrogram: tf.Tensor, min_db: tf.Tensor, max_db: tf.Tensor
) -> tf.Tensor:
:param db_uint_spectrogram: Decibel pectrogram to decode.
:param min_db: Lower bound limit for decoding.
:param max_db: Upper bound limit for decoding.
:returns: Decoded spectrogram as float2 tensor.
""" """
db_spectrogram = from_uint8_to_float32(db_uint_spectrogram, min_db, max_db) Decode spectrogram from uint8 decibel scale.
Paramters:
db_uint_spectrogram (tensorflow.Tensor):
Decibel spectrogram to decode.
min_db (tensorflow.Tensor):
Lower bound limit for decoding.
max_db (tensorflow.Tensor):
Upper bound limit for decoding.
Returns:
tensorflow.Tensor:
Decoded spectrogram as `float32` tensor.
"""
db_spectrogram: tf.Tensor = from_uint8_to_float32(
db_uint_spectrogram, min_db, max_db
)
return db_to_gain(db_spectrogram) return db_to_gain(db_spectrogram)

View File

@@ -8,143 +8,178 @@
used within this library. used within this library.
""" """
import datetime as dt
import os import os
import shutil import shutil
from pathlib import Path
from typing import Dict, Optional, Union
# pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import ffmpeg import ffmpeg
import numpy as np import numpy as np
from .. import SpleeterError
from ..types import Signal
from ..utils.logging import logger
from . import Codec
from .adapter import AudioAdapter
# pylint: enable=import-error # pylint: enable=import-error
from .adapter import AudioAdapter __email__ = "spleeter@deezer.com"
from .. import SpleeterError __author__ = "Deezer Research"
from ..utils.logging import get_logger __license__ = "MIT License"
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def _check_ffmpeg_install():
""" Ensure FFMPEG binaries are available.
:raise SpleeterError: If ffmpeg or ffprobe is not found.
"""
for binary in ('ffmpeg', 'ffprobe'):
if shutil.which(binary) is None:
raise SpleeterError('{} binary not found'.format(binary))
def _to_ffmpeg_time(n):
""" Format number of seconds to time expected by FFMPEG.
:param n: Time in seconds to format.
:returns: Formatted time in FFMPEG format.
"""
m, s = divmod(n, 60)
h, m = divmod(m, 60)
return '%d:%02d:%09.6f' % (h, m, s)
def _to_ffmpeg_codec(codec):
ffmpeg_codecs = {
'm4a': 'aac',
'ogg': 'libvorbis',
'wma': 'wmav2',
}
return ffmpeg_codecs.get(codec) or codec
class FFMPEGProcessAudioAdapter(AudioAdapter): class FFMPEGProcessAudioAdapter(AudioAdapter):
""" An AudioAdapter implementation that use FFMPEG binary through """
An AudioAdapter implementation that use FFMPEG binary through
subprocess in order to perform I/O operation for audio processing. subprocess in order to perform I/O operation for audio processing.
When created, FFMPEG binary path will be checked and expended, When created, FFMPEG binary path will be checked and expended,
raising exception if not found. Such path could be infered using raising exception if not found. Such path could be infered using
FFMPEG_PATH environment variable. `FFMPEG_PATH` environment variable.
""" """
SUPPORTED_CODECS: Dict[Codec, str] = {
Codec.M4A: "aac",
Codec.OGG: "libvorbis",
Codec.WMA: "wmav2",
}
""" FFMPEG codec name mapping. """
def __init__(_) -> None:
"""
Default constructor, ensure FFMPEG binaries are available.
Raises:
SpleeterError:
If ffmpeg or ffprobe is not found.
"""
for binary in ("ffmpeg", "ffprobe"):
if shutil.which(binary) is None:
raise SpleeterError("{} binary not found".format(binary))
def load( def load(
self, path, offset=None, duration=None, _,
sample_rate=None, dtype=np.float32): path: Union[Path, str],
""" Loads the audio file denoted by the given path offset: Optional[float] = None,
duration: Optional[float] = None,
sample_rate: Optional[float] = None,
dtype: np.dtype = np.float32,
) -> Signal:
"""
Loads the audio file denoted by the given path
and returns it data as a waveform. and returns it data as a waveform.
:param path: Path of the audio file to load data from. Parameters:
:param offset: (Optional) Start offset to load from in seconds. path (Union[Path, str]:
:param duration: (Optional) Duration to load in seconds. Path of the audio file to load data from.
:param sample_rate: (Optional) Sample rate to load audio with. offset (Optional[float]):
:param dtype: (Optional) Numpy data type to use, default to float32. Start offset to load from in seconds.
:returns: Loaded data a (waveform, sample_rate) tuple. duration (Optional[float]):
:raise SpleeterError: If any error occurs while loading audio. Duration to load in seconds.
sample_rate (Optional[float]):
Sample rate to load audio with.
dtype (numpy.dtype):
(Optional) Numpy data type to use, default to `float32`.
Returns:
Signal:
Loaded data a (waveform, sample_rate) tuple.
Raises:
SpleeterError:
If any error occurs while loading audio.
""" """
_check_ffmpeg_install() if isinstance(path, Path):
path = str(path)
if not isinstance(path, str): if not isinstance(path, str):
path = path.decode() path = path.decode()
try: try:
probe = ffmpeg.probe(path) probe = ffmpeg.probe(path)
except ffmpeg._run.Error as e: except ffmpeg._run.Error as e:
raise SpleeterError( raise SpleeterError(
'An error occurs with ffprobe (see ffprobe output below)\n\n{}' "An error occurs with ffprobe (see ffprobe output below)\n\n{}".format(
.format(e.stderr.decode())) e.stderr.decode()
if 'streams' not in probe or len(probe['streams']) == 0: )
raise SpleeterError('No stream was found with ffprobe') )
if "streams" not in probe or len(probe["streams"]) == 0:
raise SpleeterError("No stream was found with ffprobe")
metadata = next( metadata = next(
stream stream for stream in probe["streams"] if stream["codec_type"] == "audio"
for stream in probe['streams'] )
if stream['codec_type'] == 'audio') n_channels = metadata["channels"]
n_channels = metadata['channels']
if sample_rate is None: if sample_rate is None:
sample_rate = metadata['sample_rate'] sample_rate = metadata["sample_rate"]
output_kwargs = {'format': 'f32le', 'ar': sample_rate} output_kwargs = {"format": "f32le", "ar": sample_rate}
if duration is not None: if duration is not None:
output_kwargs['t'] = _to_ffmpeg_time(duration) output_kwargs["t"] = str(dt.timedelta(seconds=duration))
if offset is not None: if offset is not None:
output_kwargs['ss'] = _to_ffmpeg_time(offset) output_kwargs["ss"] = str(dt.timedelta(seconds=offset))
process = ( process = (
ffmpeg ffmpeg.input(path)
.input(path) .output("pipe:", **output_kwargs)
.output('pipe:', **output_kwargs) .run_async(pipe_stdout=True, pipe_stderr=True)
.run_async(pipe_stdout=True, pipe_stderr=True)) )
buffer, _ = process.communicate() buffer, _ = process.communicate()
waveform = np.frombuffer(buffer, dtype='<f4').reshape(-1, n_channels) waveform = np.frombuffer(buffer, dtype="<f4").reshape(-1, n_channels)
if not waveform.dtype == np.dtype(dtype): if not waveform.dtype == np.dtype(dtype):
waveform = waveform.astype(dtype) waveform = waveform.astype(dtype)
return (waveform, sample_rate) return (waveform, sample_rate)
def save( def save(
self, path, data, sample_rate, self,
codec=None, bitrate=None): path: Union[Path, str],
""" Write waveform data to the file denoted by the given path data: np.ndarray,
using FFMPEG process. sample_rate: float,
codec: Codec = None,
:param path: Path of the audio file to save data in. bitrate: str = None,
:param data: Waveform data to write. ) -> None:
:param sample_rate: Sample rate to write file in.
:param codec: (Optional) Writing codec to use.
:param bitrate: (Optional) Bitrate of the written audio file.
:raise IOError: If any error occurs while using FFMPEG to write data.
""" """
_check_ffmpeg_install() Write waveform data to the file denoted by the given path using
FFMPEG process.
Parameters:
path (Union[Path, str]):
Path like of the audio file to save data in.
data (numpy.ndarray):
Waveform data to write.
sample_rate (float):
Sample rate to write file in.
codec ():
(Optional) Writing codec to use, default to `None`.
bitrate (str):
(Optional) Bitrate of the written audio file, default to
`None`.
Raises:
IOError:
If any error occurs while using FFMPEG to write data.
"""
if isinstance(path, Path):
path = str(path)
directory = os.path.dirname(path) directory = os.path.dirname(path)
if not os.path.exists(directory): if not os.path.exists(directory):
raise SpleeterError(f'output directory does not exists: {directory}') raise SpleeterError(f"output directory does not exists: {directory}")
get_logger().debug('Writing file %s', path) logger.debug(f"Writing file {path}")
input_kwargs = {'ar': sample_rate, 'ac': data.shape[1]} input_kwargs = {"ar": sample_rate, "ac": data.shape[1]}
output_kwargs = {'ar': sample_rate, 'strict': '-2'} output_kwargs = {"ar": sample_rate, "strict": "-2"}
if bitrate: if bitrate:
output_kwargs['audio_bitrate'] = bitrate output_kwargs["audio_bitrate"] = bitrate
if codec is not None and codec != 'wav': if codec is not None and codec != "wav":
output_kwargs['codec'] = _to_ffmpeg_codec(codec) output_kwargs["codec"] = self.SUPPORTED_CODECS.get(codec, codec)
process = ( process = (
ffmpeg ffmpeg.input("pipe:", format="f32le", **input_kwargs)
.input('pipe:', format='f32le', **input_kwargs)
.output(path, **output_kwargs) .output(path, **output_kwargs)
.overwrite_output() .overwrite_output()
.run_async(pipe_stdin=True, pipe_stderr=True, quiet=True)) .run_async(pipe_stdin=True, pipe_stderr=True, quiet=True)
)
try: try:
process.stdin.write(data.astype('<f4').tobytes()) process.stdin.write(data.astype("<f4").tobytes())
process.stdin.close() process.stdin.close()
process.wait() process.wait()
except IOError: except IOError:
raise SpleeterError(f'FFMPEG error: {process.stderr.read()}') raise SpleeterError(f"FFMPEG error: {process.stderr.read()}")
get_logger().info('File %s written succesfully', path) logger.info(f"File {path} written succesfully")

View File

@@ -1,128 +1,176 @@
#!/usr/bin/env python #!/usr/bin/env python
# coding: utf8 # coding: utf8
""" Spectrogram specific data augmentation """ """ Spectrogram specific data augmentation. """
# pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.signal import hann_window, stft
from tensorflow.signal import stft, hann_window
# pylint: enable=import-error # pylint: enable=import-error
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"
def compute_spectrogram_tf( def compute_spectrogram_tf(
waveform, waveform: tf.Tensor,
frame_length=2048, frame_step=512, frame_length: int = 2048,
spec_exponent=1., window_exponent=1.): frame_step: int = 512,
""" Compute magnitude / power spectrogram from waveform as spec_exponent: float = 1.0,
a n_samples x n_channels tensor. window_exponent: float = 1.0,
) -> tf.Tensor:
:param waveform: Input waveform as (times x number of channels)
tensor.
:param frame_length: Length of a STFT frame to use.
:param frame_step: HOP between successive frames.
:param spec_exponent: Exponent of the spectrogram (usually 1 for
magnitude spectrogram, or 2 for power spectrogram).
:param window_exponent: Exponent applied to the Hann windowing function
(may be useful for making perfect STFT/iSTFT
reconstruction).
:returns: Computed magnitude / power spectrogram as a
(T x F x n_channels) tensor.
""" """
stft_tensor = tf.transpose( Compute magnitude / power spectrogram from waveform as a
`n_samples x n_channels` tensor.
Parameters:
waveform (tensorflow.Tensor):
Input waveform as `(times x number of channels)` tensor.
frame_length (int):
Length of a STFT frame to use.
frame_step (int):
HOP between successive frames.
spec_exponent (float):
Exponent of the spectrogram (usually 1 for magnitude
spectrogram, or 2 for power spectrogram).
window_exponent (float):
Exponent applied to the Hann windowing function (may be
useful for making perfect STFT/iSTFT reconstruction).
Returns:
tensorflow.Tensor:
Computed magnitude / power spectrogram as a
`(T x F x n_channels)` tensor.
"""
stft_tensor: tf.Tensor = tf.transpose(
stft( stft(
tf.transpose(waveform), tf.transpose(waveform),
frame_length, frame_length,
frame_step, frame_step,
window_fn=lambda f, dtype: hann_window( window_fn=lambda f, dtype: hann_window(
f, f, periodic=True, dtype=waveform.dtype
periodic=True, )
dtype=waveform.dtype) ** window_exponent), ** window_exponent,
perm=[1, 2, 0]) ),
perm=[1, 2, 0],
)
return tf.abs(stft_tensor) ** spec_exponent return tf.abs(stft_tensor) ** spec_exponent
def time_stretch( def time_stretch(
spectrogram, spectrogram: tf.Tensor,
factor=1.0, factor: float = 1.0,
method=tf.image.ResizeMethod.BILINEAR): method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR,
""" Time stretch a spectrogram preserving shape in tensorflow. Note that ) -> tf.Tensor:
"""
Time stretch a spectrogram preserving shape in tensorflow. Note that
this is an approximation in the frequency domain. this is an approximation in the frequency domain.
:param spectrogram: Input spectrogram to be time stretched as tensor. Parameters:
:param factor: (Optional) Time stretch factor, must be >0, default to 1. spectrogram (tensorflow.Tensor):
:param mehtod: (Optional) Interpolation method, default to BILINEAR. Input spectrogram to be time stretched as tensor.
:returns: Time stretched spectrogram as tensor with same shape. factor (float):
(Optional) Time stretch factor, must be > 0, default to `1`.
method (tensorflow.image.ResizeMethod):
(Optional) Interpolation method, default to `BILINEAR`.
Returns:
tensorflow.Tensor:
Time stretched spectrogram as tensor with same shape.
""" """
T = tf.shape(spectrogram)[0] T = tf.shape(spectrogram)[0]
T_ts = tf.cast(tf.cast(T, tf.float32) * factor, tf.int32)[0] T_ts = tf.cast(tf.cast(T, tf.float32) * factor, tf.int32)[0]
F = tf.shape(spectrogram)[1] F = tf.shape(spectrogram)[1]
ts_spec = tf.image.resize_images( ts_spec = tf.image.resize_images(
spectrogram, spectrogram, [T_ts, F], method=method, align_corners=True
[T_ts, F], )
method=method,
align_corners=True)
return tf.image.resize_image_with_crop_or_pad(ts_spec, T, F) return tf.image.resize_image_with_crop_or_pad(ts_spec, T, F)
def random_time_stretch(spectrogram, factor_min=0.9, factor_max=1.1, **kwargs): def random_time_stretch(
""" Time stretch a spectrogram preserving shape with random ratio in spectrogram: tf.Tensor, factor_min: float = 0.9, factor_max: float = 1.1, **kwargs
tensorflow. Applies time_stretch to spectrogram with a random ratio drawn ) -> tf.Tensor:
uniformly in [factor_min, factor_max].
:param spectrogram: Input spectrogram to be time stretched as tensor.
:param factor_min: (Optional) Min time stretch factor, default to 0.9.
:param factor_max: (Optional) Max time stretch factor, default to 1.1.
:returns: Randomly time stretched spectrogram as tensor with same shape.
""" """
factor = tf.random_uniform( Time stretch a spectrogram preserving shape with random ratio in
shape=(1,), tensorflow. Applies time_stretch to spectrogram with a random ratio
seed=0) * (factor_max - factor_min) + factor_min drawn uniformly in `[factor_min, factor_max]`.
Parameters:
spectrogram (tensorflow.Tensor):
Input spectrogram to be time stretched as tensor.
factor_min (float):
(Optional) Min time stretch factor, default to `0.9`.
factor_max (float):
(Optional) Max time stretch factor, default to `1.1`.
Returns:
tensorflow.Tensor:
Randomly time stretched spectrogram as tensor with same shape.
"""
factor = (
tf.random_uniform(shape=(1,), seed=0) * (factor_max - factor_min) + factor_min
)
return time_stretch(spectrogram, factor=factor, **kwargs) return time_stretch(spectrogram, factor=factor, **kwargs)
def pitch_shift( def pitch_shift(
spectrogram, spectrogram: tf.Tensor,
semitone_shift=0.0, semitone_shift: float = 0.0,
method=tf.image.ResizeMethod.BILINEAR): method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR,
""" Pitch shift a spectrogram preserving shape in tensorflow. Note that ) -> tf.Tensor:
"""
Pitch shift a spectrogram preserving shape in tensorflow. Note that
this is an approximation in the frequency domain. this is an approximation in the frequency domain.
:param spectrogram: Input spectrogram to be pitch shifted as tensor. Parameters:
:param semitone_shift: (Optional) Pitch shift in semitone, default to 0.0. spectrogram (tensorflow.Tensor):
:param mehtod: (Optional) Interpolation method, default to BILINEAR. Input spectrogram to be pitch shifted as tensor.
:returns: Pitch shifted spectrogram (same shape as spectrogram). semitone_shift (float):
(Optional) Pitch shift in semitone, default to `0.0`.
method (tensorflow.image.ResizeMethod):
(Optional) Interpolation method, default to `BILINEAR`.
Returns:
tensorflow.Tensor:
Pitch shifted spectrogram (same shape as spectrogram).
""" """
factor = 2 ** (semitone_shift / 12.) factor = 2 ** (semitone_shift / 12.0)
T = tf.shape(spectrogram)[0] T = tf.shape(spectrogram)[0]
F = tf.shape(spectrogram)[1] F = tf.shape(spectrogram)[1]
F_ps = tf.cast(tf.cast(F, tf.float32) * factor, tf.int32)[0] F_ps = tf.cast(tf.cast(F, tf.float32) * factor, tf.int32)[0]
ps_spec = tf.image.resize_images( ps_spec = tf.image.resize_images(
spectrogram, spectrogram, [T, F_ps], method=method, align_corners=True
[T, F_ps], )
method=method,
align_corners=True)
paddings = [[0, 0], [0, tf.maximum(0, F - F_ps)], [0, 0]] paddings = [[0, 0], [0, tf.maximum(0, F - F_ps)], [0, 0]]
return tf.pad(ps_spec[:, :F, :], paddings, 'CONSTANT') return tf.pad(ps_spec[:, :F, :], paddings, "CONSTANT")
def random_pitch_shift(spectrogram, shift_min=-1., shift_max=1., **kwargs): def random_pitch_shift(
""" Pitch shift a spectrogram preserving shape with random ratio in spectrogram: tf.Tensor, shift_min: float = -1.0, shift_max: float = 1.0, **kwargs
tensorflow. Applies pitch_shift to spectrogram with a random shift ) -> tf.Tensor:
amount (expressed in semitones) drawn uniformly in [shift_min, shift_max].
:param spectrogram: Input spectrogram to be pitch shifted as tensor.
:param shift_min: (Optional) Min pitch shift in semitone, default to -1.
:param shift_max: (Optional) Max pitch shift in semitone, default to 1.
:returns: Randomly pitch shifted spectrogram (same shape as spectrogram).
""" """
semitone_shift = tf.random_uniform( Pitch shift a spectrogram preserving shape with random ratio in
shape=(1,), tensorflow. Applies pitch_shift to spectrogram with a random shift
seed=0) * (shift_max - shift_min) + shift_min amount (expressed in semitones) drawn uniformly in
`[shift_min, shift_max]`.
Parameters:
spectrogram (tensorflow.Tensor):
Input spectrogram to be pitch shifted as tensor.
shift_min (float):
(Optional) Min pitch shift in semitone, default to -1.
shift_max (float):
(Optional) Max pitch shift in semitone, default to 1.
Returns:
tensorflow.Tensor:
Randomly pitch shifted spectrogram (same shape as spectrogram).
"""
semitone_shift = (
tf.random_uniform(shape=(1,), seed=0) * (shift_max - shift_min) + shift_min
)
return pitch_shift(spectrogram, semitone_shift=semitone_shift, **kwargs) return pitch_shift(spectrogram, semitone_shift=semitone_shift, **kwargs)

View File

@@ -1,209 +0,0 @@
#!/usr/bin/env python
# coding: utf8
""" This modules provides spleeter command as well as CLI parsing methods. """
import json
import logging
from argparse import ArgumentParser
from tempfile import gettempdir
from os.path import exists, join
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# -i opt specification (separate).
OPT_INPUT = {
'dest': 'inputs',
'nargs': '+',
'help': 'List of input audio filenames',
'required': True
}
# -o opt specification (evaluate and separate).
OPT_OUTPUT = {
'dest': 'output_path',
'default': join(gettempdir(), 'separated_audio'),
'help': 'Path of the output directory to write audio files in'
}
# -f opt specification (separate).
OPT_FORMAT = {
'dest': 'filename_format',
'default': '{filename}/{instrument}.{codec}',
'help': (
'Template string that will be formatted to generated'
'output filename. Such template should be Python formattable'
'string, and could use {filename}, {instrument}, and {codec}'
'variables.'
)
}
# -p opt specification (train, evaluate and separate).
OPT_PARAMS = {
'dest': 'configuration',
'default': 'spleeter:2stems',
'type': str,
'action': 'store',
'help': 'JSON filename that contains params'
}
# -s opt specification (separate).
OPT_OFFSET = {
'dest': 'offset',
'type': float,
'default': 0.,
'help': 'Set the starting offset to separate audio from.'
}
# -d opt specification (separate).
OPT_DURATION = {
'dest': 'duration',
'type': float,
'default': 600.,
'help': (
'Set a maximum duration for processing audio '
'(only separate offset + duration first seconds of '
'the input file)')
}
# -w opt specification (separate)
OPT_STFT_BACKEND = {
'dest': 'stft_backend',
'type': str,
'choices' : ["tensorflow", "librosa", "auto"],
'default': "auto",
'help': 'Who should be in charge of computing the stfts. Librosa is faster than tensorflow on CPU and uses'
' less memory. "auto" will use tensorflow when GPU acceleration is available and librosa when not.'
}
# -c opt specification (separate).
OPT_CODEC = {
'dest': 'codec',
'choices': ('wav', 'mp3', 'ogg', 'm4a', 'wma', 'flac'),
'default': 'wav',
'help': 'Audio codec to be used for the separated output'
}
# -b opt specification (separate).
OPT_BITRATE = {
'dest': 'bitrate',
'default': '128k',
'help': 'Audio bitrate to be used for the separated output'
}
# -m opt specification (evaluate and separate).
OPT_MWF = {
'dest': 'MWF',
'action': 'store_const',
'const': True,
'default': False,
'help': 'Whether to use multichannel Wiener filtering for separation',
}
# --mus_dir opt specification (evaluate).
OPT_MUSDB = {
'dest': 'mus_dir',
'type': str,
'required': True,
'help': 'Path to folder with musDB'
}
# -d opt specification (train).
OPT_DATA = {
'dest': 'audio_path',
'type': str,
'required': True,
'help': 'Path of the folder containing audio data for training'
}
# -a opt specification (train, evaluate and separate).
OPT_ADAPTER = {
'dest': 'audio_adapter',
'type': str,
'help': 'Name of the audio adapter to use for audio I/O'
}
# -a opt specification (train, evaluate and separate).
OPT_VERBOSE = {
'action': 'store_true',
'help': 'Shows verbose logs'
}
def _add_common_options(parser):
""" Add common option to the given parser.
:param parser: Parser to add common opt to.
"""
parser.add_argument('-a', '--adapter', **OPT_ADAPTER)
parser.add_argument('-p', '--params_filename', **OPT_PARAMS)
parser.add_argument('--verbose', **OPT_VERBOSE)
def _create_train_parser(parser_factory):
""" Creates an argparser for training command
:param parser_factory: Factory to use to create parser instance.
:returns: Created and configured parser.
"""
parser = parser_factory('train', help='Train a source separation model')
_add_common_options(parser)
parser.add_argument('-d', '--data', **OPT_DATA)
return parser
def _create_evaluate_parser(parser_factory):
""" Creates an argparser for evaluation command
:param parser_factory: Factory to use to create parser instance.
:returns: Created and configured parser.
"""
parser = parser_factory(
'evaluate',
help='Evaluate a model on the musDB test dataset')
_add_common_options(parser)
parser.add_argument('-o', '--output_path', **OPT_OUTPUT)
parser.add_argument('--mus_dir', **OPT_MUSDB)
parser.add_argument('-m', '--mwf', **OPT_MWF)
parser.add_argument('-B', '--stft-backend', **OPT_STFT_BACKEND)
return parser
def _create_separate_parser(parser_factory):
""" Creates an argparser for separation command
:param parser_factory: Factory to use to create parser instance.
:returns: Created and configured parser.
"""
parser = parser_factory('separate', help='Separate audio files')
_add_common_options(parser)
parser.add_argument('-i', '--inputs', **OPT_INPUT)
parser.add_argument('-o', '--output_path', **OPT_OUTPUT)
parser.add_argument('-f', '--filename_format', **OPT_FORMAT)
parser.add_argument('-d', '--duration', **OPT_DURATION)
parser.add_argument('-s', '--offset', **OPT_OFFSET)
parser.add_argument('-c', '--codec', **OPT_CODEC)
parser.add_argument('-b', '--birate', **OPT_BITRATE)
parser.add_argument('-m', '--mwf', **OPT_MWF)
parser.add_argument('-B', '--stft-backend', **OPT_STFT_BACKEND)
return parser
def create_argument_parser():
""" Creates overall command line parser for Spleeter.
:returns: Created argument parser.
"""
parser = ArgumentParser(prog='spleeter')
subparsers = parser.add_subparsers()
subparsers.dest = 'command'
subparsers.required = True
_create_separate_parser(subparsers.add_parser)
_create_train_parser(subparsers.add_parser)
_create_evaluate_parser(subparsers.add_parser)
return parser

View File

@@ -1,167 +0,0 @@
#!/usr/bin/env python
# coding: utf8
"""
Entrypoint provider for performing model evaluation.
Evaluation is performed against musDB dataset.
USAGE: python -m spleeter evaluate \
-p /path/to/params \
-o /path/to/output/dir \
[-m] \
--mus_dir /path/to/musdb dataset
"""
import sys
import json
from argparse import Namespace
from itertools import product
from glob import glob
from os.path import join, exists
# pylint: disable=import-error
import numpy as np
import pandas as pd
# pylint: enable=import-error
from .separate import entrypoint as separate_entrypoint
from ..utils.logging import get_logger
try:
import musdb
import museval
except ImportError:
logger = get_logger()
logger.error('Extra dependencies musdb and museval not found')
logger.error('Please install musdb and museval first, abort')
sys.exit(1)
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
_SPLIT = 'test'
_MIXTURE = 'mixture.wav'
_AUDIO_DIRECTORY = 'audio'
_METRICS_DIRECTORY = 'metrics'
_INSTRUMENTS = ('vocals', 'drums', 'bass', 'other')
_METRICS = ('SDR', 'SAR', 'SIR', 'ISR')
def _separate_evaluation_dataset(arguments, musdb_root_directory, params):
""" Performs audio separation on the musdb dataset from
the given directory and params.
:param arguments: Entrypoint arguments.
:param musdb_root_directory: Directory to retrieve dataset from.
:param params: Spleeter configuration to apply to separation.
:returns: Separation output directory path.
"""
songs = glob(join(musdb_root_directory, _SPLIT, '*/'))
mixtures = [join(song, _MIXTURE) for song in songs]
audio_output_directory = join(
arguments.output_path,
_AUDIO_DIRECTORY)
separate_entrypoint(
Namespace(
audio_adapter=arguments.audio_adapter,
configuration=arguments.configuration,
inputs=mixtures,
output_path=join(audio_output_directory, _SPLIT),
filename_format='{foldername}/{instrument}.{codec}',
codec='wav',
duration=600.,
offset=0.,
bitrate='128k',
MWF=arguments.MWF,
verbose=arguments.verbose,
stft_backend=arguments.stft_backend),
params)
return audio_output_directory
def _compute_musdb_metrics(
arguments,
musdb_root_directory,
audio_output_directory):
""" Generates musdb metrics fro previsouly computed audio estimation.
:param arguments: Entrypoint arguments.
:param audio_output_directory: Directory to get audio estimation from.
:returns: Path of generated metrics directory.
"""
metrics_output_directory = join(
arguments.output_path,
_METRICS_DIRECTORY)
get_logger().info('Starting musdb evaluation (this could be long) ...')
dataset = musdb.DB(
root=musdb_root_directory,
is_wav=True,
subsets=[_SPLIT])
museval.eval_mus_dir(
dataset=dataset,
estimates_dir=audio_output_directory,
output_dir=metrics_output_directory)
get_logger().info('musdb evaluation done')
return metrics_output_directory
def _compile_metrics(metrics_output_directory):
""" Compiles metrics from given directory and returns
results as dict.
:param metrics_output_directory: Directory to get metrics from.
:returns: Compiled metrics as dict.
"""
songs = glob(join(metrics_output_directory, 'test/*.json'))
index = pd.MultiIndex.from_tuples(
product(_INSTRUMENTS, _METRICS),
names=['instrument', 'metric'])
pd.DataFrame([], index=['config1', 'config2'], columns=index)
metrics = {
instrument: {k: [] for k in _METRICS}
for instrument in _INSTRUMENTS}
for song in songs:
with open(song, 'r') as stream:
data = json.load(stream)
for target in data['targets']:
instrument = target['name']
for metric in _METRICS:
sdr_med = np.median([
frame['metrics'][metric]
for frame in target['frames']
if not np.isnan(frame['metrics'][metric])])
metrics[instrument][metric].append(sdr_med)
return metrics
def entrypoint(arguments, params):
""" Command entrypoint.
:param arguments: Command line parsed argument as argparse.Namespace.
:param params: Deserialized JSON configuration file provided in CLI args.
"""
# Parse and check musdb directory.
musdb_root_directory = arguments.mus_dir
if not exists(musdb_root_directory):
raise IOError(f'musdb directory {musdb_root_directory} not found')
# Separate musdb sources.
audio_output_directory = _separate_evaluation_dataset(
arguments,
musdb_root_directory,
params)
# Compute metrics with musdb.
metrics_output_directory = _compute_musdb_metrics(
arguments,
musdb_root_directory,
audio_output_directory)
# Compute and pretty print median metrics.
metrics = _compile_metrics(metrics_output_directory)
for instrument, metric in metrics.items():
get_logger().info('%s:', instrument)
for metric, value in metric.items():
get_logger().info('%s: %s', metric, f'{np.median(value):.3f}')
return metrics

View File

@@ -1,47 +0,0 @@
#!/usr/bin/env python
# coding: utf8
"""
Entrypoint provider for performing source separation.
USAGE: python -m spleeter separate \
-p /path/to/params \
-i inputfile1 inputfile2 ... inputfilen
-o /path/to/output/dir \
-i /path/to/audio1.wav /path/to/audio2.mp3
"""
from ..audio.adapter import get_audio_adapter
from ..separator import Separator
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def entrypoint(arguments, params):
""" Command entrypoint.
:param arguments: Command line parsed argument as argparse.Namespace.
:param params: Deserialized JSON configuration file provided in CLI args.
"""
# TODO: check with output naming.
audio_adapter = get_audio_adapter(arguments.audio_adapter)
separator = Separator(
arguments.configuration,
MWF=arguments.MWF,
stft_backend=arguments.stft_backend)
for filename in arguments.inputs:
separator.separate_to_file(
filename,
arguments.output_path,
audio_adapter=audio_adapter,
offset=arguments.offset,
duration=arguments.duration,
codec=arguments.codec,
bitrate=arguments.bitrate,
filename_format=arguments.filename_format,
synchronous=False
)
separator.join()

View File

@@ -1,100 +0,0 @@
#!/usr/bin/env python
# coding: utf8
"""
Entrypoint provider for performing model training.
USAGE: python -m spleeter train -p /path/to/params
"""
from functools import partial
# pylint: disable=import-error
import tensorflow as tf
# pylint: enable=import-error
from ..audio.adapter import get_audio_adapter
from ..dataset import get_training_dataset, get_validation_dataset
from ..model import model_fn
from ..model.provider import ModelProvider
from ..utils.logging import get_logger
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def _create_estimator(params):
""" Creates estimator.
:param params: TF params to build estimator from.
:returns: Built estimator.
"""
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params['model_dir'],
params=params,
config=tf.estimator.RunConfig(
save_checkpoints_steps=params['save_checkpoints_steps'],
tf_random_seed=params['random_seed'],
save_summary_steps=params['save_summary_steps'],
session_config=session_config,
log_step_count_steps=10,
keep_checkpoint_max=2))
return estimator
def _create_train_spec(params, audio_adapter, audio_path):
""" Creates train spec.
:param params: TF params to build spec from.
:returns: Built train spec.
"""
input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
train_spec = tf.estimator.TrainSpec(
input_fn=input_fn,
max_steps=params['train_max_steps'])
return train_spec
def _create_evaluation_spec(params, audio_adapter, audio_path):
""" Setup eval spec evaluating ever n seconds
:param params: TF params to build spec from.
:returns: Built evaluation spec.
"""
input_fn = partial(
get_validation_dataset,
params,
audio_adapter,
audio_path)
evaluation_spec = tf.estimator.EvalSpec(
input_fn=input_fn,
steps=None,
throttle_secs=params['throttle_secs'])
return evaluation_spec
def entrypoint(arguments, params):
""" Command entrypoint.
:param arguments: Command line parsed argument as argparse.Namespace.
:param params: Deserialized JSON configuration file provided in CLI args.
"""
audio_adapter = get_audio_adapter(arguments.audio_adapter)
audio_path = arguments.audio_path
estimator = _create_estimator(params)
train_spec = _create_train_spec(params, audio_adapter, audio_path)
evaluation_spec = _create_evaluation_spec(
params,
audio_adapter,
audio_path)
get_logger().info('Start model training')
tf.estimator.train_and_evaluate(
estimator,
train_spec,
evaluation_spec)
ModelProvider.writeProbe(params['model_dir'])
get_logger().info('Model training done')

View File

@@ -14,87 +14,110 @@
(ground truth) (ground truth)
""" """
import time
import os import os
from os.path import exists, join, sep as SEPARATOR import time
from os.path import exists
from os.path import sep as SEPARATOR
from typing import Any, Dict, Optional
# pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import pandas as pd
import numpy as np
import tensorflow as tf import tensorflow as tf
# pylint: enable=import-error
from .audio.convertor import ( from .audio.adapter import AudioAdapter
db_uint_spectrogram_to_gain, from .audio.convertor import db_uint_spectrogram_to_gain, spectrogram_to_db_uint
spectrogram_to_db_uint)
from .audio.spectrogram import ( from .audio.spectrogram import (
compute_spectrogram_tf, compute_spectrogram_tf,
random_pitch_shift, random_pitch_shift,
random_time_stretch) random_time_stretch,
from .utils.logging import get_logger )
from .utils.logging import logger
from .utils.tensor import ( from .utils.tensor import (
check_tensor_shape, check_tensor_shape,
dataset_from_csv, dataset_from_csv,
set_tensor_shape, set_tensor_shape,
sync_apply) sync_apply,
)
__email__ = 'spleeter@deezer.com' # pylint: enable=import-error
__author__ = 'Deezer Research'
__license__ = 'MIT License' __email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
# Default audio parameters to use. # Default audio parameters to use.
DEFAULT_AUDIO_PARAMS = { DEFAULT_AUDIO_PARAMS: Dict = {
'instrument_list': ('vocals', 'accompaniment'), "instrument_list": ("vocals", "accompaniment"),
'mix_name': 'mix', "mix_name": "mix",
'sample_rate': 44100, "sample_rate": 44100,
'frame_length': 4096, "frame_length": 4096,
'frame_step': 1024, "frame_step": 1024,
'T': 512, "T": 512,
'F': 1024 "F": 1024,
} }
def get_training_dataset(audio_params, audio_adapter, audio_path): def get_training_dataset(
""" Builds training dataset. audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str
) -> Any:
"""
Builds training dataset.
:param audio_params: Audio parameters. Parameters:
:param audio_adapter: Adapter to load audio from. audio_params (Dict):
:param audio_path: Path of directory containing audio. Audio parameters.
:returns: Built dataset. audio_adapter (AudioAdapter):
Adapter to load audio from.
audio_path (str):
Path of directory containing audio.
Returns:
Any:
Built dataset.
""" """
builder = DatasetBuilder( builder = DatasetBuilder(
audio_params, audio_params,
audio_adapter, audio_adapter,
audio_path, audio_path,
chunk_duration=audio_params.get('chunk_duration', 20.0), chunk_duration=audio_params.get("chunk_duration", 20.0),
random_seed=audio_params.get('random_seed', 0)) random_seed=audio_params.get("random_seed", 0),
)
return builder.build( return builder.build(
audio_params.get('train_csv'), audio_params.get("train_csv"),
cache_directory=audio_params.get('training_cache'), cache_directory=audio_params.get("training_cache"),
batch_size=audio_params.get('batch_size'), batch_size=audio_params.get("batch_size"),
n_chunks_per_song=audio_params.get('n_chunks_per_song', 2), n_chunks_per_song=audio_params.get("n_chunks_per_song", 2),
random_data_augmentation=False, random_data_augmentation=False,
convert_to_uint=True, convert_to_uint=True,
wait_for_cache=False) wait_for_cache=False,
)
def get_validation_dataset(audio_params, audio_adapter, audio_path): def get_validation_dataset(
""" Builds validation dataset. audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str
) -> Any:
"""
Builds validation dataset.
:param audio_params: Audio parameters. Parameters:
:param audio_adapter: Adapter to load audio from. audio_params (Dict):
:param audio_path: Path of directory containing audio. Audio parameters.
:returns: Built dataset. audio_adapter (AudioAdapter):
Adapter to load audio from.
audio_path (str):
Path of directory containing audio.
Returns:
Any:
Built dataset.
""" """
builder = DatasetBuilder( builder = DatasetBuilder(
audio_params, audio_params, audio_adapter, audio_path, chunk_duration=12.0
audio_adapter, )
audio_path,
chunk_duration=12.0)
return builder.build( return builder.build(
audio_params.get('validation_csv'), audio_params.get("validation_csv"),
batch_size=audio_params.get('batch_size'), batch_size=audio_params.get("batch_size"),
cache_directory=audio_params.get('validation_cache'), cache_directory=audio_params.get("validation_cache"),
convert_to_uint=True, convert_to_uint=True,
infinite_generator=False, infinite_generator=False,
n_chunks_per_song=1, n_chunks_per_song=1,
@@ -108,127 +131,175 @@ def get_validation_dataset(audio_params, audio_adapter, audio_path):
class InstrumentDatasetBuilder(object): class InstrumentDatasetBuilder(object):
""" Instrument based filter and mapper provider. """ """ Instrument based filter and mapper provider. """
def __init__(self, parent, instrument): def __init__(self, parent, instrument) -> None:
""" Default constructor. """
Default constructor.
:param parent: Parent dataset builder. Parameters:
:param instrument: Target instrument. parent:
Parent dataset builder.
instrument:
Target instrument.
""" """
self._parent = parent self._parent = parent
self._instrument = instrument self._instrument = instrument
self._spectrogram_key = f'{instrument}_spectrogram' self._spectrogram_key = f"{instrument}_spectrogram"
self._min_spectrogram_key = f'min_{instrument}_spectrogram' self._min_spectrogram_key = f"min_{instrument}_spectrogram"
self._max_spectrogram_key = f'max_{instrument}_spectrogram' self._max_spectrogram_key = f"max_{instrument}_spectrogram"
def load_waveform(self, sample): def load_waveform(self, sample):
""" Load waveform for given sample. """ """ Load waveform for given sample. """
return dict(sample, **self._parent._audio_adapter.load_tf_waveform( return dict(
sample[f'{self._instrument}_path'], sample,
offset=sample['start'], **self._parent._audio_adapter.load_tf_waveform(
duration=self._parent._chunk_duration, sample[f"{self._instrument}_path"],
sample_rate=self._parent._sample_rate, offset=sample["start"],
waveform_name='waveform')) duration=self._parent._chunk_duration,
sample_rate=self._parent._sample_rate,
waveform_name="waveform",
),
)
def compute_spectrogram(self, sample): def compute_spectrogram(self, sample):
""" Compute spectrogram of the given sample. """ """ Compute spectrogram of the given sample. """
return dict(sample, **{ return dict(
self._spectrogram_key: compute_spectrogram_tf( sample,
sample['waveform'], **{
frame_length=self._parent._frame_length, self._spectrogram_key: compute_spectrogram_tf(
frame_step=self._parent._frame_step, sample["waveform"],
spec_exponent=1., frame_length=self._parent._frame_length,
window_exponent=1.)}) frame_step=self._parent._frame_step,
spec_exponent=1.0,
window_exponent=1.0,
)
},
)
def filter_frequencies(self, sample): def filter_frequencies(self, sample):
""" """ """ """
return dict(sample, **{ return dict(
self._spectrogram_key: sample,
sample[self._spectrogram_key][:, :self._parent._F, :]}) **{
self._spectrogram_key: sample[self._spectrogram_key][
:, : self._parent._F, :
]
},
)
def convert_to_uint(self, sample): def convert_to_uint(self, sample):
""" Convert given sample from float to unit. """ """ Convert given sample from float to unit. """
return dict(sample, **spectrogram_to_db_uint( return dict(
sample[self._spectrogram_key], sample,
tensor_key=self._spectrogram_key, **spectrogram_to_db_uint(
min_key=self._min_spectrogram_key, sample[self._spectrogram_key],
max_key=self._max_spectrogram_key)) tensor_key=self._spectrogram_key,
min_key=self._min_spectrogram_key,
max_key=self._max_spectrogram_key,
),
)
def filter_infinity(self, sample): def filter_infinity(self, sample):
""" Filter infinity sample. """ """ Filter infinity sample. """
return tf.logical_not( return tf.logical_not(tf.math.is_inf(sample[self._min_spectrogram_key]))
tf.math.is_inf(
sample[self._min_spectrogram_key]))
def convert_to_float32(self, sample): def convert_to_float32(self, sample):
""" Convert given sample from unit to float. """ """ Convert given sample from unit to float. """
return dict(sample, **{ return dict(
self._spectrogram_key: db_uint_spectrogram_to_gain( sample,
sample[self._spectrogram_key], **{
sample[self._min_spectrogram_key], self._spectrogram_key: db_uint_spectrogram_to_gain(
sample[self._max_spectrogram_key])}) sample[self._spectrogram_key],
sample[self._min_spectrogram_key],
sample[self._max_spectrogram_key],
)
},
)
def time_crop(self, sample): def time_crop(self, sample):
""" """ """ """
def start(sample): def start(sample):
""" mid_segment_start """ """ mid_segment_start """
return tf.cast( return tf.cast(
tf.maximum( tf.maximum(
tf.shape(sample[self._spectrogram_key])[0] tf.shape(sample[self._spectrogram_key])[0] / 2
/ 2 - self._parent._T / 2, 0), - self._parent._T / 2,
tf.int32) 0,
return dict(sample, **{ ),
self._spectrogram_key: sample[self._spectrogram_key][ tf.int32,
start(sample):start(sample) + self._parent._T, :, :]}) )
return dict(
sample,
**{
self._spectrogram_key: sample[self._spectrogram_key][
start(sample) : start(sample) + self._parent._T, :, :
]
},
)
def filter_shape(self, sample): def filter_shape(self, sample):
""" Filter badly shaped sample. """ """ Filter badly shaped sample. """
return check_tensor_shape( return check_tensor_shape(
sample[self._spectrogram_key], ( sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2)
self._parent._T, self._parent._F, 2)) )
def reshape_spectrogram(self, sample): def reshape_spectrogram(self, sample):
""" """ """ Reshape given sample. """
return dict(sample, **{ return dict(
self._spectrogram_key: set_tensor_shape( sample,
sample[self._spectrogram_key], **{
(self._parent._T, self._parent._F, 2))}) self._spectrogram_key: set_tensor_shape(
sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2)
)
},
)
class DatasetBuilder(object): class DatasetBuilder(object):
""" """
TO BE DOCUMENTED.
""" """
# Margin at beginning and end of songs in seconds. MARGIN: float = 0.5
MARGIN = 0.5 """ Margin at beginning and end of songs in seconds. """
# Wait period for cache (in seconds). WAIT_PERIOD: int = 60
WAIT_PERIOD = 60 """ Wait period for cache (in seconds). """
def __init__( def __init__(
self, self,
audio_params, audio_adapter, audio_path, audio_params: Dict,
random_seed=0, chunk_duration=20.0): audio_adapter: AudioAdapter,
""" Default constructor. audio_path: str,
random_seed: int = 0,
chunk_duration: float = 20.0,
) -> None:
"""
Default constructor.
NOTE: Probably need for AudioAdapter. NOTE: Probably need for AudioAdapter.
:param audio_params: Audio parameters to use. Parameters:
:param audio_adapter: Audio adapter to use. audio_params (Dict):
:param audio_path: Audio parameters to use.
:param random_seed: audio_adapter (AudioAdapter):
:param chunk_duration: Audio adapter to use.
audio_path (str):
random_seed (int):
chunk_duration (float):
""" """
# Length of segment in frames (if fs=22050 and # Length of segment in frames (if fs=22050 and
# frame_step=512, then T=512 corresponds to 11.89s) # frame_step=512, then T=512 corresponds to 11.89s)
self._T = audio_params['T'] self._T = audio_params["T"]
# Number of frequency bins to be used (should # Number of frequency bins to be used (should
# be less than frame_length/2 + 1) # be less than frame_length/2 + 1)
self._F = audio_params['F'] self._F = audio_params["F"]
self._sample_rate = audio_params['sample_rate'] self._sample_rate = audio_params["sample_rate"]
self._frame_length = audio_params['frame_length'] self._frame_length = audio_params["frame_length"]
self._frame_step = audio_params['frame_step'] self._frame_step = audio_params["frame_step"]
self._mix_name = audio_params['mix_name'] self._mix_name = audio_params["mix_name"]
self._instruments = [self._mix_name] + audio_params['instrument_list'] self._instruments = [self._mix_name] + audio_params["instrument_list"]
self._instrument_builders = None self._instrument_builders = None
self._chunk_duration = chunk_duration self._chunk_duration = chunk_duration
self._audio_adapter = audio_adapter self._audio_adapter = audio_adapter
@@ -238,130 +309,202 @@ class DatasetBuilder(object):
def expand_path(self, sample): def expand_path(self, sample):
""" Expands audio paths for the given sample. """ """ Expands audio paths for the given sample. """
return dict(sample, **{f'{instrument}_path': tf.strings.join( return dict(
(self._audio_path, sample[f'{instrument}_path']), SEPARATOR) sample,
for instrument in self._instruments}) **{
f"{instrument}_path": tf.strings.join(
(self._audio_path, sample[f"{instrument}_path"]), SEPARATOR
)
for instrument in self._instruments
},
)
def filter_error(self, sample): def filter_error(self, sample):
""" Filter errored sample. """ """ Filter errored sample. """
return tf.logical_not(sample['waveform_error']) return tf.logical_not(sample["waveform_error"])
def filter_waveform(self, sample): def filter_waveform(self, sample):
""" Filter waveform from sample. """ """ Filter waveform from sample. """
return {k: v for k, v in sample.items() if not k == 'waveform'} return {k: v for k, v in sample.items() if not k == "waveform"}
def harmonize_spectrogram(self, sample): def harmonize_spectrogram(self, sample):
""" Ensure same size for vocals and mix spectrograms. """ """ Ensure same size for vocals and mix spectrograms. """
def _reduce(sample): def _reduce(sample):
return tf.reduce_min([ return tf.reduce_min(
tf.shape(sample[f'{instrument}_spectrogram'])[0] [
for instrument in self._instruments]) tf.shape(sample[f"{instrument}_spectrogram"])[0]
return dict(sample, **{ for instrument in self._instruments
f'{instrument}_spectrogram': ]
sample[f'{instrument}_spectrogram'][:_reduce(sample), :, :] )
for instrument in self._instruments})
return dict(
sample,
**{
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"][
: _reduce(sample), :, :
]
for instrument in self._instruments
},
)
def filter_short_segments(self, sample): def filter_short_segments(self, sample):
""" Filter out too short segment. """ """ Filter out too short segment. """
return tf.reduce_any([ return tf.reduce_any(
tf.shape(sample[f'{instrument}_spectrogram'])[0] >= self._T [
for instrument in self._instruments]) tf.shape(sample[f"{instrument}_spectrogram"])[0] >= self._T
for instrument in self._instruments
]
)
def random_time_crop(self, sample): def random_time_crop(self, sample):
""" Random time crop of 11.88s. """ """ Random time crop of 11.88s. """
return dict(sample, **sync_apply({ return dict(
f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram'] sample,
for instrument in self._instruments}, **sync_apply(
lambda x: tf.image.random_crop( {
x, (self._T, len(self._instruments) * self._F, 2), f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
seed=self._random_seed))) for instrument in self._instruments
},
lambda x: tf.image.random_crop(
x,
(self._T, len(self._instruments) * self._F, 2),
seed=self._random_seed,
),
),
)
def random_time_stretch(self, sample): def random_time_stretch(self, sample):
""" Randomly time stretch the given sample. """ """ Randomly time stretch the given sample. """
return dict(sample, **sync_apply({ return dict(
f'{instrument}_spectrogram': sample,
sample[f'{instrument}_spectrogram'] **sync_apply(
for instrument in self._instruments}, {
lambda x: random_time_stretch( f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
x, factor_min=0.9, factor_max=1.1))) for instrument in self._instruments
},
lambda x: random_time_stretch(x, factor_min=0.9, factor_max=1.1),
),
)
def random_pitch_shift(self, sample): def random_pitch_shift(self, sample):
""" Randomly pitch shift the given sample. """ """ Randomly pitch shift the given sample. """
return dict(sample, **sync_apply({ return dict(
f'{instrument}_spectrogram': sample,
sample[f'{instrument}_spectrogram'] **sync_apply(
for instrument in self._instruments}, {
lambda x: random_pitch_shift( f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
x, shift_min=-1.0, shift_max=1.0), concat_axis=0)) for instrument in self._instruments
},
lambda x: random_pitch_shift(x, shift_min=-1.0, shift_max=1.0),
concat_axis=0,
),
)
def map_features(self, sample): def map_features(self, sample):
""" Select features and annotation of the given sample. """ """ Select features and annotation of the given sample. """
input_ = { input_ = {
f'{self._mix_name}_spectrogram': f"{self._mix_name}_spectrogram": sample[f"{self._mix_name}_spectrogram"]
sample[f'{self._mix_name}_spectrogram']} }
output = { output = {
f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram'] f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
for instrument in self._audio_params['instrument_list']} for instrument in self._audio_params["instrument_list"]
}
return (input_, output) return (input_, output)
def compute_segments(self, dataset, n_chunks_per_song): def compute_segments(self, dataset: Any, n_chunks_per_song: int) -> Any:
""" Computes segments for each song of the dataset. """
Computes segments for each song of the dataset.
:param dataset: Dataset to compute segments for. Parameters:
:param n_chunks_per_song: Number of segment per song to compute. dataset (Any):
:returns: Segmented dataset. Dataset to compute segments for.
n_chunks_per_song (int):
Number of segment per song to compute.
Returns:
Any:
Segmented dataset.
""" """
if n_chunks_per_song <= 0: if n_chunks_per_song <= 0:
raise ValueError('n_chunks_per_song must be positif') raise ValueError("n_chunks_per_song must be positif")
datasets = [] datasets = []
for k in range(n_chunks_per_song): for k in range(n_chunks_per_song):
if n_chunks_per_song > 1: if n_chunks_per_song > 1:
datasets.append( datasets.append(
dataset.map(lambda sample: dict(sample, start=tf.maximum( dataset.map(
k * ( lambda sample: dict(
sample['duration'] - self._chunk_duration - 2 sample,
* self.MARGIN) / (n_chunks_per_song - 1) start=tf.maximum(
+ self.MARGIN, 0)))) k
* (
sample["duration"]
- self._chunk_duration
- 2 * self.MARGIN
)
/ (n_chunks_per_song - 1)
+ self.MARGIN,
0,
),
)
)
)
elif n_chunks_per_song == 1: # Take central segment. elif n_chunks_per_song == 1: # Take central segment.
datasets.append( datasets.append(
dataset.map(lambda sample: dict(sample, start=tf.maximum( dataset.map(
sample['duration'] / 2 - self._chunk_duration / 2, lambda sample: dict(
0)))) sample,
start=tf.maximum(
sample["duration"] / 2 - self._chunk_duration / 2, 0
),
)
)
)
dataset = datasets[-1] dataset = datasets[-1]
for d in datasets[:-1]: for d in datasets[:-1]:
dataset = dataset.concatenate(d) dataset = dataset.concatenate(d)
return dataset return dataset
@property @property
def instruments(self): def instruments(self) -> Any:
""" Instrument dataset builder generator. """
Instrument dataset builder generator.
:yield InstrumentBuilder instance. Yields:
Any:
InstrumentBuilder instance.
""" """
if self._instrument_builders is None: if self._instrument_builders is None:
self._instrument_builders = [] self._instrument_builders = []
for instrument in self._instruments: for instrument in self._instruments:
self._instrument_builders.append( self._instrument_builders.append(
InstrumentDatasetBuilder(self, instrument)) InstrumentDatasetBuilder(self, instrument)
)
for builder in self._instrument_builders: for builder in self._instrument_builders:
yield builder yield builder
def cache(self, dataset, cache, wait): def cache(self, dataset: Any, cache: str, wait: bool) -> Any:
""" Cache the given dataset if cache is enabled. Eventually waits for """
cache to be available (useful if another process is already computing Cache the given dataset if cache is enabled. Eventually waits for
cache) if provided wait flag is True. cache to be available (useful if another process is already
computing cache) if provided wait flag is `True`.
:param dataset: Dataset to be cached if cache is required. Parameters:
:param cache: Path of cache directory to be used, None if no cache. dataset (Any):
:param wait: If caching is enabled, True is cache should be waited. Dataset to be cached if cache is required.
:returns: Cached dataset if needed, original dataset otherwise. cache (str):
Path of cache directory to be used, None if no cache.
wait (bool):
If caching is enabled, True is cache should be waited.
Returns:
Any:
Cached dataset if needed, original dataset otherwise.
""" """
if cache is not None: if cache is not None:
if wait: if wait:
while not exists(f'{cache}.index'): while not exists(f"{cache}.index"):
get_logger().info( logger.info(f"Cache not available, wait {self.WAIT_PERIOD}")
'Cache not available, wait %s',
self.WAIT_PERIOD)
time.sleep(self.WAIT_PERIOD) time.sleep(self.WAIT_PERIOD)
cache_path = os.path.split(cache)[0] cache_path = os.path.split(cache)[0]
os.makedirs(cache_path, exist_ok=True) os.makedirs(cache_path, exist_ok=True)
@@ -369,11 +512,19 @@ class DatasetBuilder(object):
return dataset return dataset
def build( def build(
self, csv_path, self,
batch_size=8, shuffle=True, convert_to_uint=True, csv_path: str,
random_data_augmentation=False, random_time_crop=True, batch_size: int = 8,
infinite_generator=True, cache_directory=None, shuffle: bool = True,
wait_for_cache=False, num_parallel_calls=4, n_chunks_per_song=2,): convert_to_uint: bool = True,
random_data_augmentation: bool = False,
random_time_crop: bool = True,
infinite_generator: bool = True,
cache_directory: Optional[str] = None,
wait_for_cache: bool = False,
num_parallel_calls: int = 4,
n_chunks_per_song: float = 2,
) -> Any:
""" """
TO BE DOCUMENTED. TO BE DOCUMENTED.
""" """
@@ -385,7 +536,8 @@ class DatasetBuilder(object):
buffer_size=200000, buffer_size=200000,
seed=self._random_seed, seed=self._random_seed,
# useless since it is cached : # useless since it is cached :
reshuffle_each_iteration=True) reshuffle_each_iteration=True,
)
# Expand audio path. # Expand audio path.
dataset = dataset.map(self.expand_path) dataset = dataset.map(self.expand_path)
# Load waveform, compute spectrogram, and filtering error, # Load waveform, compute spectrogram, and filtering error,
@@ -393,11 +545,11 @@ class DatasetBuilder(object):
N = num_parallel_calls N = num_parallel_calls
for instrument in self.instruments: for instrument in self.instruments:
dataset = ( dataset = (
dataset dataset.map(instrument.load_waveform, num_parallel_calls=N)
.map(instrument.load_waveform, num_parallel_calls=N)
.filter(self.filter_error) .filter(self.filter_error)
.map(instrument.compute_spectrogram, num_parallel_calls=N) .map(instrument.compute_spectrogram, num_parallel_calls=N)
.map(instrument.filter_frequencies)) .map(instrument.filter_frequencies)
)
dataset = dataset.map(self.filter_waveform) dataset = dataset.map(self.filter_waveform)
# Convert to uint before caching in order to save space. # Convert to uint before caching in order to save space.
if convert_to_uint: if convert_to_uint:
@@ -428,26 +580,25 @@ class DatasetBuilder(object):
# after croping but before converting back to float. # after croping but before converting back to float.
if shuffle: if shuffle:
dataset = dataset.shuffle( dataset = dataset.shuffle(
buffer_size=256, seed=self._random_seed, buffer_size=256, seed=self._random_seed, reshuffle_each_iteration=True
reshuffle_each_iteration=True) )
# Convert back to float32 # Convert back to float32
if convert_to_uint: if convert_to_uint:
for instrument in self.instruments: for instrument in self.instruments:
dataset = dataset.map( dataset = dataset.map(
instrument.convert_to_float32, num_parallel_calls=N) instrument.convert_to_float32, num_parallel_calls=N
)
M = 8 # Parallel call post caching. M = 8 # Parallel call post caching.
# Must be applied with the same factor on mix and vocals. # Must be applied with the same factor on mix and vocals.
if random_data_augmentation: if random_data_augmentation:
dataset = ( dataset = dataset.map(self.random_time_stretch, num_parallel_calls=M).map(
dataset self.random_pitch_shift, num_parallel_calls=M
.map(self.random_time_stretch, num_parallel_calls=M) )
.map(self.random_pitch_shift, num_parallel_calls=M))
# Filter by shape (remove badly shaped tensors). # Filter by shape (remove badly shaped tensors).
for instrument in self.instruments: for instrument in self.instruments:
dataset = ( dataset = dataset.filter(instrument.filter_shape).map(
dataset instrument.reshape_spectrogram
.filter(instrument.filter_shape) )
.map(instrument.reshape_spectrogram))
# Select features and annotation. # Select features and annotation.
dataset = dataset.map(self.map_features) dataset = dataset.map(self.map_features)
# Make batch (done after selection to avoid # Make batch (done after selection to avoid

View File

@@ -5,17 +5,19 @@
import importlib import importlib
# pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import tensorflow as tf import tensorflow as tf
from tensorflow.signal import hann_window, inverse_stft, stft
from tensorflow.signal import stft, inverse_stft, hann_window
# pylint: enable=import-error
from ..utils.tensor import pad_and_partition, pad_and_reshape from ..utils.tensor import pad_and_partition, pad_and_reshape
__email__ = 'spleeter@deezer.com' # pylint: enable=import-error
__author__ = 'Deezer Research'
__license__ = 'MIT License'
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
placeholder = tf.compat.v1.placeholder placeholder = tf.compat.v1.placeholder
@@ -23,29 +25,28 @@ placeholder = tf.compat.v1.placeholder
def get_model_function(model_type): def get_model_function(model_type):
""" """
Get tensorflow function of the model to be applied to the input tensor. Get tensorflow function of the model to be applied to the input tensor.
For instance "unet.softmax_unet" will return the softmax_unet function For instance "unet.softmax_unet" will return the softmax_unet function
in the "unet.py" submodule of the current module (spleeter.model). in the "unet.py" submodule of the current module (spleeter.model).
Params: Params:
- model_type: str - model_type: str
the relative module path to the model function. the relative module path to the model function.
Returns: Returns:
A tensorflow function to be applied to the input tensor to get the A tensorflow function to be applied to the input tensor to get the
multitrack output. multitrack output.
""" """
relative_path_to_module = '.'.join(model_type.split('.')[:-1]) relative_path_to_module = ".".join(model_type.split(".")[:-1])
model_name = model_type.split('.')[-1] model_name = model_type.split(".")[-1]
main_module = '.'.join((__name__, 'functions')) main_module = ".".join((__name__, "functions"))
path_to_module = f'{main_module}.{relative_path_to_module}' path_to_module = f"{main_module}.{relative_path_to_module}"
module = importlib.import_module(path_to_module) module = importlib.import_module(path_to_module)
model_function = getattr(module, model_name) model_function = getattr(module, model_name)
return model_function return model_function
class InputProvider(object): class InputProvider(object):
def __init__(self, params): def __init__(self, params):
self.params = params self.params = params
@@ -61,16 +62,16 @@ class InputProvider(object):
class WaveformInputProvider(InputProvider): class WaveformInputProvider(InputProvider):
@property @property
def input_names(self): def input_names(self):
return ["audio_id", "waveform"] return ["audio_id", "waveform"]
def get_input_dict_placeholders(self): def get_input_dict_placeholders(self):
shape = (None, self.params['n_channels']) shape = (None, self.params["n_channels"])
features = { features = {
'waveform': placeholder(tf.float32, shape=shape, name="waveform"), "waveform": placeholder(tf.float32, shape=shape, name="waveform"),
'audio_id': placeholder(tf.string, name="audio_id")} "audio_id": placeholder(tf.string, name="audio_id"),
}
return features return features
def get_feed_dict(self, features, waveform, audio_id): def get_feed_dict(self, features, waveform, audio_id):
@@ -78,7 +79,6 @@ class WaveformInputProvider(InputProvider):
class SpectralInputProvider(InputProvider): class SpectralInputProvider(InputProvider):
def __init__(self, params): def __init__(self, params):
super().__init__(params) super().__init__(params)
self.stft_input_name = "{}_stft".format(self.params["mix_name"]) self.stft_input_name = "{}_stft".format(self.params["mix_name"])
@@ -89,11 +89,17 @@ class SpectralInputProvider(InputProvider):
def get_input_dict_placeholders(self): def get_input_dict_placeholders(self):
features = { features = {
self.stft_input_name: placeholder(tf.complex64, self.stft_input_name: placeholder(
shape=(None, self.params["frame_length"]//2+1, tf.complex64,
self.params['n_channels']), shape=(
name=self.stft_input_name), None,
'audio_id': placeholder(tf.string, name="audio_id")} self.params["frame_length"] // 2 + 1,
self.params["n_channels"],
),
name=self.stft_input_name,
),
"audio_id": placeholder(tf.string, name="audio_id"),
}
return features return features
def get_feed_dict(self, features, stft, audio_id): def get_feed_dict(self, features, stft, audio_id):
@@ -101,11 +107,13 @@ class SpectralInputProvider(InputProvider):
class InputProviderFactory(object): class InputProviderFactory(object):
@staticmethod @staticmethod
def get(params): def get(params):
stft_backend = params["stft_backend"] stft_backend = params["stft_backend"]
assert stft_backend in ("tensorflow", "librosa"), "Unexpected backend {}".format(stft_backend) assert stft_backend in (
"tensorflow",
"librosa",
), "Unexpected backend {}".format(stft_backend)
if stft_backend == "tensorflow": if stft_backend == "tensorflow":
return WaveformInputProvider(params) return WaveformInputProvider(params)
else: else:
@@ -113,7 +121,7 @@ class InputProviderFactory(object):
class EstimatorSpecBuilder(object): class EstimatorSpecBuilder(object):
""" A builder class that allows to builds a multitrack unet model """A builder class that allows to builds a multitrack unet model
estimator. The built model estimator has a different behaviour when estimator. The built model estimator has a different behaviour when
used in a train/eval mode and in predict mode. used in a train/eval mode and in predict mode.
@@ -137,22 +145,22 @@ class EstimatorSpecBuilder(object):
""" """
# Supported model functions. # Supported model functions.
DEFAULT_MODEL = 'unet.unet' DEFAULT_MODEL = "unet.unet"
# Supported loss functions. # Supported loss functions.
L1_MASK = 'L1_mask' L1_MASK = "L1_mask"
WEIGHTED_L1_MASK = 'weighted_L1_mask' WEIGHTED_L1_MASK = "weighted_L1_mask"
# Supported optimizers. # Supported optimizers.
ADADELTA = 'Adadelta' ADADELTA = "Adadelta"
SGD = 'SGD' SGD = "SGD"
# Math constants. # Math constants.
WINDOW_COMPENSATION_FACTOR = 2./3. WINDOW_COMPENSATION_FACTOR = 2.0 / 3.0
EPSILON = 1e-10 EPSILON = 1e-10
def __init__(self, features, params): def __init__(self, features, params):
""" Default constructor. Depending on built model """Default constructor. Depending on built model
usage, the provided features should be different: usage, the provided features should be different:
* In train/eval mode: features is a dictionary with a * In train/eval mode: features is a dictionary with a
@@ -169,20 +177,20 @@ class EstimatorSpecBuilder(object):
self._features = features self._features = features
self._params = params self._params = params
# Get instrument name. # Get instrument name.
self._mix_name = params['mix_name'] self._mix_name = params["mix_name"]
self._instruments = params['instrument_list'] self._instruments = params["instrument_list"]
# Get STFT/signals parameters # Get STFT/signals parameters
self._n_channels = params['n_channels'] self._n_channels = params["n_channels"]
self._T = params['T'] self._T = params["T"]
self._F = params['F'] self._F = params["F"]
self._frame_length = params['frame_length'] self._frame_length = params["frame_length"]
self._frame_step = params['frame_step'] self._frame_step = params["frame_step"]
def include_stft_computations(self): def include_stft_computations(self):
return self._params["stft_backend"] == "tensorflow" return self._params["stft_backend"] == "tensorflow"
def _build_model_outputs(self): def _build_model_outputs(self):
""" Created a batch_sizexTxFxn_channels input tensor containing """Created a batch_sizexTxFxn_channels input tensor containing
mix magnitude spectrogram, then an output dict from it according mix magnitude spectrogram, then an output dict from it according
to the selected model in internal parameters. to the selected model in internal parameters.
@@ -191,22 +199,21 @@ class EstimatorSpecBuilder(object):
""" """
input_tensor = self.spectrogram_feature input_tensor = self.spectrogram_feature
model = self._params.get('model', None) model = self._params.get("model", None)
if model is not None: if model is not None:
model_type = model.get('type', self.DEFAULT_MODEL) model_type = model.get("type", self.DEFAULT_MODEL)
else: else:
model_type = self.DEFAULT_MODEL model_type = self.DEFAULT_MODEL
try: try:
apply_model = get_model_function(model_type) apply_model = get_model_function(model_type)
except ModuleNotFoundError: except ModuleNotFoundError:
raise ValueError(f'No model function {model_type} found') raise ValueError(f"No model function {model_type} found")
self._model_outputs = apply_model( self._model_outputs = apply_model(
input_tensor, input_tensor, self._instruments, self._params["model"]["params"]
self._instruments, )
self._params['model']['params'])
def _build_loss(self, labels): def _build_loss(self, labels):
""" Construct tensorflow loss and metrics """Construct tensorflow loss and metrics
:param output_dict: dictionary of network outputs (key: instrument :param output_dict: dictionary of network outputs (key: instrument
name, value: estimated spectrogram of the instrument) name, value: estimated spectrogram of the instrument)
@@ -215,7 +222,7 @@ class EstimatorSpecBuilder(object):
:returns: tensorflow (loss, metrics) tuple. :returns: tensorflow (loss, metrics) tuple.
""" """
output_dict = self.model_outputs output_dict = self.model_outputs
loss_type = self._params.get('loss_type', self.L1_MASK) loss_type = self._params.get("loss_type", self.L1_MASK)
if loss_type == self.L1_MASK: if loss_type == self.L1_MASK:
losses = { losses = {
name: tf.reduce_mean(tf.abs(output - labels[name])) name: tf.reduce_mean(tf.abs(output - labels[name]))
@@ -224,11 +231,9 @@ class EstimatorSpecBuilder(object):
elif loss_type == self.WEIGHTED_L1_MASK: elif loss_type == self.WEIGHTED_L1_MASK:
losses = { losses = {
name: tf.reduce_mean( name: tf.reduce_mean(
tf.reduce_mean( tf.reduce_mean(labels[name], axis=[1, 2, 3], keep_dims=True)
labels[name], * tf.abs(output - labels[name])
axis=[1, 2, 3], )
keep_dims=True) *
tf.abs(output - labels[name]))
for name, output in output_dict.items() for name, output in output_dict.items()
} }
else: else:
@@ -236,20 +241,20 @@ class EstimatorSpecBuilder(object):
loss = tf.reduce_sum(list(losses.values())) loss = tf.reduce_sum(list(losses.values()))
# Add metrics for monitoring each instrument. # Add metrics for monitoring each instrument.
metrics = {k: tf.compat.v1.metrics.mean(v) for k, v in losses.items()} metrics = {k: tf.compat.v1.metrics.mean(v) for k, v in losses.items()}
metrics['absolute_difference'] = tf.compat.v1.metrics.mean(loss) metrics["absolute_difference"] = tf.compat.v1.metrics.mean(loss)
return loss, metrics return loss, metrics
def _build_optimizer(self): def _build_optimizer(self):
""" Builds an optimizer instance from internal parameter values. """Builds an optimizer instance from internal parameter values.
Default to AdamOptimizer if not specified. Default to AdamOptimizer if not specified.
:returns: Optimizer instance from internal configuration. :returns: Optimizer instance from internal configuration.
""" """
name = self._params.get('optimizer') name = self._params.get("optimizer")
if name == self.ADADELTA: if name == self.ADADELTA:
return tf.compat.v1.train.AdadeltaOptimizer() return tf.compat.v1.train.AdadeltaOptimizer()
rate = self._params['learning_rate'] rate = self._params["learning_rate"]
if name == self.SGD: if name == self.SGD:
return tf.compat.v1.train.GradientDescentOptimizer(rate) return tf.compat.v1.train.GradientDescentOptimizer(rate)
return tf.compat.v1.train.AdamOptimizer(rate) return tf.compat.v1.train.AdamOptimizer(rate)
@@ -260,15 +265,15 @@ class EstimatorSpecBuilder(object):
@property @property
def stft_name(self): def stft_name(self):
return f'{self._mix_name}_stft' return f"{self._mix_name}_stft"
@property @property
def spectrogram_name(self): def spectrogram_name(self):
return f'{self._mix_name}_spectrogram' return f"{self._mix_name}_spectrogram"
def _build_stft_feature(self): def _build_stft_feature(self):
""" Compute STFT of waveform and slice the STFT in segment """Compute STFT of waveform and slice the STFT in segment
with the right length to feed the network. with the right length to feed the network.
""" """
stft_name = self.stft_name stft_name = self.stft_name
@@ -276,25 +281,30 @@ class EstimatorSpecBuilder(object):
if stft_name not in self._features: if stft_name not in self._features:
# pad input with a frame of zeros # pad input with a frame of zeros
waveform = tf.concat([ waveform = tf.concat(
tf.zeros((self._frame_length, self._n_channels)), [
self._features['waveform'] tf.zeros((self._frame_length, self._n_channels)),
], self._features["waveform"],
0 ],
) 0,
)
stft_feature = tf.transpose( stft_feature = tf.transpose(
stft( stft(
tf.transpose(waveform), tf.transpose(waveform),
self._frame_length, self._frame_length,
self._frame_step, self._frame_step,
window_fn=lambda frame_length, dtype: ( window_fn=lambda frame_length, dtype: (
hann_window(frame_length, periodic=True, dtype=dtype)), hann_window(frame_length, periodic=True, dtype=dtype)
pad_end=True), ),
perm=[1, 2, 0]) pad_end=True,
self._features[f'{self._mix_name}_stft'] = stft_feature ),
perm=[1, 2, 0],
)
self._features[f"{self._mix_name}_stft"] = stft_feature
if spec_name not in self._features: if spec_name not in self._features:
self._features[spec_name] = tf.abs( self._features[spec_name] = tf.abs(
pad_and_partition(self._features[stft_name], self._T))[:, :, :self._F, :] pad_and_partition(self._features[stft_name], self._T)
)[:, :, : self._F, :]
@property @property
def model_outputs(self): def model_outputs(self):
@@ -333,25 +343,29 @@ class EstimatorSpecBuilder(object):
return self._masked_stfts return self._masked_stfts
def _inverse_stft(self, stft_t, time_crop=None): def _inverse_stft(self, stft_t, time_crop=None):
""" Inverse and reshape the given STFT """Inverse and reshape the given STFT
:param stft_t: input STFT :param stft_t: input STFT
:returns: inverse STFT (waveform) :returns: inverse STFT (waveform)
""" """
inversed = inverse_stft( inversed = (
tf.transpose(stft_t, perm=[2, 0, 1]), inverse_stft(
self._frame_length, tf.transpose(stft_t, perm=[2, 0, 1]),
self._frame_step, self._frame_length,
window_fn=lambda frame_length, dtype: ( self._frame_step,
hann_window(frame_length, periodic=True, dtype=dtype)) window_fn=lambda frame_length, dtype: (
) * self.WINDOW_COMPENSATION_FACTOR hann_window(frame_length, periodic=True, dtype=dtype)
),
)
* self.WINDOW_COMPENSATION_FACTOR
)
reshaped = tf.transpose(inversed) reshaped = tf.transpose(inversed)
if time_crop is None: if time_crop is None:
time_crop = tf.shape(self._features['waveform'])[0] time_crop = tf.shape(self._features["waveform"])[0]
return reshaped[self._frame_length:self._frame_length+time_crop, :] return reshaped[self._frame_length : self._frame_length + time_crop, :]
def _build_mwf_output_waveform(self): def _build_mwf_output_waveform(self):
""" Perform separation with multichannel Wiener Filtering using Norbert. """Perform separation with multichannel Wiener Filtering using Norbert.
Note: multichannel Wiener Filtering is not coded in Tensorflow and thus Note: multichannel Wiener Filtering is not coded in Tensorflow and thus
may be quite slow. may be quite slow.
@@ -359,36 +373,42 @@ class EstimatorSpecBuilder(object):
value: estimated waveform of the instrument) value: estimated waveform of the instrument)
""" """
import norbert # pylint: disable=import-error import norbert # pylint: disable=import-error
output_dict = self.model_outputs output_dict = self.model_outputs
x = self.stft_feature x = self.stft_feature
v = tf.stack( v = tf.stack(
[ [
pad_and_reshape( pad_and_reshape(
output_dict[f'{instrument}_spectrogram'], output_dict[f"{instrument}_spectrogram"],
self._frame_length, self._frame_length,
self._F)[:tf.shape(x)[0], ...] self._F,
)[: tf.shape(x)[0], ...]
for instrument in self._instruments for instrument in self._instruments
], ],
axis=3) axis=3,
)
input_args = [v, x] input_args = [v, x]
stft_function = tf.py_function( stft_function = (
lambda v, x: norbert.wiener(v.numpy(), x.numpy()), tf.py_function(
input_args, lambda v, x: norbert.wiener(v.numpy(), x.numpy()),
tf.complex64), input_args,
tf.complex64,
),
)
return { return {
instrument: self._inverse_stft(stft_function[0][:, :, :, k]) instrument: self._inverse_stft(stft_function[0][:, :, :, k])
for k, instrument in enumerate(self._instruments) for k, instrument in enumerate(self._instruments)
} }
def _extend_mask(self, mask): def _extend_mask(self, mask):
""" Extend mask, from reduced number of frequency bin to the number of """Extend mask, from reduced number of frequency bin to the number of
frequency bin in the STFT. frequency bin in the STFT.
:param mask: restricted mask :param mask: restricted mask
:returns: extended mask :returns: extended mask
:raise ValueError: If invalid mask_extension parameter is set. :raise ValueError: If invalid mask_extension parameter is set.
""" """
extension = self._params['mask_extension'] extension = self._params["mask_extension"]
# Extend with average # Extend with average
# (dispatch according to energy in the processed band) # (dispatch according to energy in the processed band)
if extension == "average": if extension == "average":
@@ -397,13 +417,9 @@ class EstimatorSpecBuilder(object):
# (avoid extension artifacts but not conservative separation) # (avoid extension artifacts but not conservative separation)
elif extension == "zeros": elif extension == "zeros":
mask_shape = tf.shape(mask) mask_shape = tf.shape(mask)
extension_row = tf.zeros(( extension_row = tf.zeros((mask_shape[0], mask_shape[1], 1, mask_shape[-1]))
mask_shape[0],
mask_shape[1],
1,
mask_shape[-1]))
else: else:
raise ValueError(f'Invalid mask_extension parameter {extension}') raise ValueError(f"Invalid mask_extension parameter {extension}")
n_extra_row = self._frame_length // 2 + 1 - self._F n_extra_row = self._frame_length // 2 + 1 - self._F
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1]) extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
return tf.concat([mask, extension], axis=2) return tf.concat([mask, extension], axis=2)
@@ -415,29 +431,31 @@ class EstimatorSpecBuilder(object):
""" """
output_dict = self.model_outputs output_dict = self.model_outputs
stft_feature = self.stft_feature stft_feature = self.stft_feature
separation_exponent = self._params['separation_exponent'] separation_exponent = self._params["separation_exponent"]
output_sum = tf.reduce_sum( output_sum = (
[e ** separation_exponent for e in output_dict.values()], tf.reduce_sum(
axis=0 [e ** separation_exponent for e in output_dict.values()], axis=0
) + self.EPSILON )
+ self.EPSILON
)
out = {} out = {}
for instrument in self._instruments: for instrument in self._instruments:
output = output_dict[f'{instrument}_spectrogram'] output = output_dict[f"{instrument}_spectrogram"]
# Compute mask with the model. # Compute mask with the model.
instrument_mask = (output ** separation_exponent instrument_mask = (
+ (self.EPSILON / len(output_dict))) / output_sum output ** separation_exponent + (self.EPSILON / len(output_dict))
) / output_sum
# Extend mask; # Extend mask;
instrument_mask = self._extend_mask(instrument_mask) instrument_mask = self._extend_mask(instrument_mask)
# Stack back mask. # Stack back mask.
old_shape = tf.shape(instrument_mask) old_shape = tf.shape(instrument_mask)
new_shape = tf.concat( new_shape = tf.concat(
[[old_shape[0] * old_shape[1]], old_shape[2:]], [[old_shape[0] * old_shape[1]], old_shape[2:]], axis=0
axis=0) )
instrument_mask = tf.reshape(instrument_mask, new_shape) instrument_mask = tf.reshape(instrument_mask, new_shape)
# Remove padded part (for mask having the same size as STFT); # Remove padded part (for mask having the same size as STFT);
instrument_mask = instrument_mask[ instrument_mask = instrument_mask[: tf.shape(stft_feature)[0], ...]
:tf.shape(stft_feature)[0], ...]
out[instrument] = instrument_mask out[instrument] = instrument_mask
self._masks = out self._masks = out
@@ -449,7 +467,7 @@ class EstimatorSpecBuilder(object):
self._masked_stfts = out self._masked_stfts = out
def _build_manual_output_waveform(self, masked_stft): def _build_manual_output_waveform(self, masked_stft):
""" Perform ratio mask separation """Perform ratio mask separation
:param output_dict: dictionary of estimated spectrogram (key: instrument :param output_dict: dictionary of estimated spectrogram (key: instrument
name, value: estimated spectrogram of the instrument) name, value: estimated spectrogram of the instrument)
@@ -463,14 +481,14 @@ class EstimatorSpecBuilder(object):
return output_waveform return output_waveform
def _build_output_waveform(self, masked_stft): def _build_output_waveform(self, masked_stft):
""" Build output waveform from given output dict in order to be used in """Build output waveform from given output dict in order to be used in
prediction context. Regarding of the configuration building method will prediction context. Regarding of the configuration building method will
be using MWF. be using MWF.
:returns: Built output waveform. :returns: Built output waveform.
""" """
if self._params.get('MWF', False): if self._params.get("MWF", False):
output_waveform = self._build_mwf_output_waveform() output_waveform = self._build_mwf_output_waveform()
else: else:
output_waveform = self._build_manual_output_waveform(masked_stft) output_waveform = self._build_manual_output_waveform(masked_stft)
@@ -482,11 +500,11 @@ class EstimatorSpecBuilder(object):
else: else:
self._outputs = self.masked_stfts self._outputs = self.masked_stfts
if 'audio_id' in self._features: if "audio_id" in self._features:
self._outputs['audio_id'] = self._features['audio_id'] self._outputs["audio_id"] = self._features["audio_id"]
def build_predict_model(self): def build_predict_model(self):
""" Builder interface for creating model instance that aims to perform """Builder interface for creating model instance that aims to perform
prediction / inference over given track. The output of such estimator prediction / inference over given track. The output of such estimator
will be a dictionary with a "<instrument>" key per separated instrument will be a dictionary with a "<instrument>" key per separated instrument
, associated to the estimated separated waveform of the instrument. , associated to the estimated separated waveform of the instrument.
@@ -495,11 +513,11 @@ class EstimatorSpecBuilder(object):
""" """
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.PREDICT, tf.estimator.ModeKeys.PREDICT, predictions=self.outputs
predictions=self.outputs) )
def build_evaluation_model(self, labels): def build_evaluation_model(self, labels):
""" Builder interface for creating model instance that aims to perform """Builder interface for creating model instance that aims to perform
model evaluation. The output of such estimator will be a dictionary model evaluation. The output of such estimator will be a dictionary
with a key "<instrument>_spectrogram" per separated instrument, with a key "<instrument>_spectrogram" per separated instrument,
associated to the estimated separated instrument magnitude spectrogram. associated to the estimated separated instrument magnitude spectrogram.
@@ -509,12 +527,11 @@ class EstimatorSpecBuilder(object):
""" """
loss, metrics = self._build_loss(labels) loss, metrics = self._build_loss(labels)
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.EVAL, loss=loss, eval_metric_ops=metrics
loss=loss, )
eval_metric_ops=metrics)
def build_train_model(self, labels): def build_train_model(self, labels):
""" Builder interface for creating model instance that aims to perform """Builder interface for creating model instance that aims to perform
model training. The output of such estimator will be a dictionary model training. The output of such estimator will be a dictionary
with a key "<instrument>_spectrogram" per separated instrument, with a key "<instrument>_spectrogram" per separated instrument,
associated to the estimated separated instrument magnitude spectrogram. associated to the estimated separated instrument magnitude spectrogram.
@@ -525,8 +542,8 @@ class EstimatorSpecBuilder(object):
loss, metrics = self._build_loss(labels) loss, metrics = self._build_loss(labels)
optimizer = self._build_optimizer() optimizer = self._build_optimizer()
train_operation = optimizer.minimize( train_operation = optimizer.minimize(
loss=loss, loss=loss, global_step=tf.compat.v1.train.get_global_step()
global_step=tf.compat.v1.train.get_global_step()) )
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN, mode=tf.estimator.ModeKeys.TRAIN,
loss=loss, loss=loss,
@@ -539,9 +556,9 @@ def model_fn(features, labels, mode, params, config):
""" """
:param features: :param features:
:param labels: :param labels:
:param mode: Estimator mode. :param mode: Estimator mode.
:param params: :param params:
:param config: TF configuration (not used). :param config: TF configuration (not used).
:returns: Built EstimatorSpec. :returns: Built EstimatorSpec.
:raise ValueError: If estimator mode is not supported. :raise ValueError: If estimator mode is not supported.
@@ -553,4 +570,4 @@ def model_fn(features, labels, mode, params, config):
return builder.build_evaluation_model(labels) return builder.build_evaluation_model(labels)
elif mode == tf.estimator.ModeKeys.TRAIN: elif mode == tf.estimator.ModeKeys.TRAIN:
return builder.build_train_model(labels) return builder.build_train_model(labels)
raise ValueError(f'Unknown mode {mode}') raise ValueError(f"Unknown mode {mode}")

View File

@@ -3,25 +3,45 @@
""" This package provide model functions. """ """ This package provide model functions. """
__email__ = 'spleeter@deezer.com' from typing import Callable, Dict, Iterable, Optional
__author__ = 'Deezer Research'
__license__ = 'MIT License' # pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
# pylint: enable=import-error
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def apply(function, input_tensor, instruments, params={}): def apply(
""" Apply given function to the input tensor. function: Callable,
input_tensor: tf.Tensor,
:param function: Function to be applied to tensor. instruments: Iterable[str],
:param input_tensor: Tensor to apply blstm to. params: Optional[Dict] = None,
:param instruments: Iterable that provides a collection of instruments. ) -> Dict:
:param params: (Optional) dict of BLSTM parameters.
:returns: Created output tensor dict.
""" """
output_dict = {} Apply given function to the input tensor.
Parameters:
function:
Function to be applied to tensor.
input_tensor (tensorflow.Tensor):
Tensor to apply blstm to.
instruments (Iterable[str]):
Iterable that provides a collection of instruments.
params:
(Optional) dict of BLSTM parameters.
Returns:
Created output tensor dict.
"""
output_dict: Dict = {}
for instrument in instruments: for instrument in instruments:
out_name = f'{instrument}_spectrogram' out_name = f"{instrument}_spectrogram"
output_dict[out_name] = function( output_dict[out_name] = function(
input_tensor, input_tensor, output_name=out_name, params=params or {}
output_name=out_name, )
params=params)
return output_dict return output_dict

View File

@@ -20,7 +20,11 @@
selection (LSTM layer dropout rate, regularization strength). selection (LSTM layer dropout rate, regularization strength).
""" """
from typing import Dict, Optional
# pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import tensorflow as tf
from tensorflow.compat.v1.keras.initializers import he_uniform from tensorflow.compat.v1.keras.initializers import he_uniform
from tensorflow.compat.v1.keras.layers import CuDNNLSTM from tensorflow.compat.v1.keras.layers import CuDNNLSTM
from tensorflow.keras.layers import ( from tensorflow.keras.layers import (
@@ -28,34 +32,48 @@ from tensorflow.keras.layers import (
Dense, Dense,
Flatten, Flatten,
Reshape, Reshape,
TimeDistributed) TimeDistributed,
# pylint: enable=import-error )
from . import apply from . import apply
__email__ = 'spleeter@deezer.com' # pylint: enable=import-error
__author__ = 'Deezer Research'
__license__ = 'MIT License' __email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def apply_blstm(input_tensor, output_name='output', params={}): def apply_blstm(
""" Apply BLSTM to the given input_tensor. input_tensor: tf.Tensor, output_name: str = "output", params: Optional[Dict] = None
) -> tf.Tensor:
:param input_tensor: Input of the model.
:param output_name: (Optional) name of the output, default to 'output'.
:param params: (Optional) dict of BLSTM parameters.
:returns: Output tensor.
""" """
units = params.get('lstm_units', 250) Apply BLSTM to the given input_tensor.
Parameters:
input_tensor (tensorflow.Tensor):
Input of the model.
output_name (str):
(Optional) name of the output, default to 'output'.
params (Optional[Dict]):
(Optional) dict of BLSTM parameters.
Returns:
tensorflow.Tensor:
Output tensor.
"""
if params is None:
params = {}
units: int = params.get("lstm_units", 250)
kernel_initializer = he_uniform(seed=50) kernel_initializer = he_uniform(seed=50)
flatten_input = TimeDistributed(Flatten())((input_tensor)) flatten_input = TimeDistributed(Flatten())((input_tensor))
def create_bidirectional(): def create_bidirectional():
return Bidirectional( return Bidirectional(
CuDNNLSTM( CuDNNLSTM(
units, units, kernel_initializer=kernel_initializer, return_sequences=True
kernel_initializer=kernel_initializer, )
return_sequences=True)) )
l1 = create_bidirectional()((flatten_input)) l1 = create_bidirectional()((flatten_input))
l2 = create_bidirectional()((l1)) l2 = create_bidirectional()((l1))
@@ -63,14 +81,18 @@ def apply_blstm(input_tensor, output_name='output', params={}):
dense = TimeDistributed( dense = TimeDistributed(
Dense( Dense(
int(flatten_input.shape[2]), int(flatten_input.shape[2]),
activation='relu', activation="relu",
kernel_initializer=kernel_initializer))((l3)) kernel_initializer=kernel_initializer,
output = TimeDistributed( )
Reshape(input_tensor.shape[2:]), )((l3))
name=output_name)(dense) output: tf.Tensor = TimeDistributed(
Reshape(input_tensor.shape[2:]), name=output_name
)(dense)
return output return output
def blstm(input_tensor, output_name='output', params={}): def blstm(
input_tensor: tf.Tensor, output_name: str = "output", params: Optional[Dict] = None
) -> tf.Tensor:
""" Model function applier. """ """ Model function applier. """
return apply(apply_blstm, input_tensor, output_name, params) return apply(apply_blstm, input_tensor, output_name, params)

View File

@@ -2,92 +2,109 @@
# coding: utf8 # coding: utf8
""" """
This module contains building functions for U-net source This module contains building functions for U-net source
separation models in a similar way as in A. Jansson et al. "Singing separation models in a similar way as in A. Jansson et al. :
voice separation with deep u-net convolutional networks", ISMIR 2017.
Each instrument is modeled by a single U-net convolutional "Singing voice separation with deep u-net convolutional networks",
/ deconvolutional network that take a mix spectrogram as input and the ISMIR 2017
estimated sound spectrogram as output.
Each instrument is modeled by a single U-net
convolutional / deconvolutional network that take a mix spectrogram
as input and the estimated sound spectrogram as output.
""" """
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, Optional
# pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import tensorflow as tf import tensorflow as tf
from tensorflow.compat.v1 import logging
from tensorflow.compat.v1.keras.initializers import he_uniform
from tensorflow.keras.layers import ( from tensorflow.keras.layers import (
ELU,
BatchNormalization, BatchNormalization,
Concatenate, Concatenate,
Conv2D, Conv2D,
Conv2DTranspose, Conv2DTranspose,
Dropout, Dropout,
ELU,
LeakyReLU, LeakyReLU,
Multiply, Multiply,
ReLU, ReLU,
Softmax) Softmax,
from tensorflow.compat.v1 import logging )
from tensorflow.compat.v1.keras.initializers import he_uniform
# pylint: enable=import-error
from . import apply from . import apply
__email__ = 'spleeter@deezer.com' # pylint: enable=import-error
__author__ = 'Deezer Research'
__license__ = 'MIT License' __email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def _get_conv_activation_layer(params): def _get_conv_activation_layer(params: Dict) -> Any:
""" """
> To be documented.
:param params: Parameters:
:returns: Required Activation function. params (Dict):
Returns:
Any:
Required Activation function.
""" """
conv_activation = params.get('conv_activation') conv_activation: str = params.get("conv_activation")
if conv_activation == 'ReLU': if conv_activation == "ReLU":
return ReLU() return ReLU()
elif conv_activation == 'ELU': elif conv_activation == "ELU":
return ELU() return ELU()
return LeakyReLU(0.2) return LeakyReLU(0.2)
def _get_deconv_activation_layer(params): def _get_deconv_activation_layer(params: Dict) -> Any:
""" """
> To be documented.
:param params: Parameters:
:returns: Required Activation function. params (Dict):
Returns:
Any:
Required Activation function.
""" """
deconv_activation = params.get('deconv_activation') deconv_activation: str = params.get("deconv_activation")
if deconv_activation == 'LeakyReLU': if deconv_activation == "LeakyReLU":
return LeakyReLU(0.2) return LeakyReLU(0.2)
elif deconv_activation == 'ELU': elif deconv_activation == "ELU":
return ELU() return ELU()
return ReLU() return ReLU()
def apply_unet( def apply_unet(
input_tensor, input_tensor: tf.Tensor,
output_name='output', output_name: str = "output",
params={}, params: Optional[Dict] = None,
output_mask_logit=False): output_mask_logit: bool = False,
""" Apply a convolutionnal U-net to model a single instrument (one U-net ) -> Any:
"""
Apply a convolutionnal U-net to model a single instrument (one U-net
is used for each instrument). is used for each instrument).
:param input_tensor: Parameters:
:param output_name: (Optional) , default to 'output' input_tensor (tensorflow.Tensor):
:param params: (Optional) , default to empty dict. output_name (str):
:param output_mask_logit: (Optional) , default to False. params (Optional[Dict]):
output_mask_logit (bool):
""" """
logging.info(f'Apply unet for {output_name}') logging.info(f"Apply unet for {output_name}")
conv_n_filters = params.get('conv_n_filters', [16, 32, 64, 128, 256, 512]) conv_n_filters = params.get("conv_n_filters", [16, 32, 64, 128, 256, 512])
conv_activation_layer = _get_conv_activation_layer(params) conv_activation_layer = _get_conv_activation_layer(params)
deconv_activation_layer = _get_deconv_activation_layer(params) deconv_activation_layer = _get_deconv_activation_layer(params)
kernel_initializer = he_uniform(seed=50) kernel_initializer = he_uniform(seed=50)
conv2d_factory = partial( conv2d_factory = partial(
Conv2D, Conv2D, strides=(2, 2), padding="same", kernel_initializer=kernel_initializer
strides=(2, 2), )
padding='same',
kernel_initializer=kernel_initializer)
# First layer. # First layer.
conv1 = conv2d_factory(conv_n_filters[0], (5, 5))(input_tensor) conv1 = conv2d_factory(conv_n_filters[0], (5, 5))(input_tensor)
batch1 = BatchNormalization(axis=-1)(conv1) batch1 = BatchNormalization(axis=-1)(conv1)
@@ -117,8 +134,9 @@ def apply_unet(
conv2d_transpose_factory = partial( conv2d_transpose_factory = partial(
Conv2DTranspose, Conv2DTranspose,
strides=(2, 2), strides=(2, 2),
padding='same', padding="same",
kernel_initializer=kernel_initializer) kernel_initializer=kernel_initializer,
)
# #
up1 = conv2d_transpose_factory(conv_n_filters[4], (5, 5))((conv6)) up1 = conv2d_transpose_factory(conv_n_filters[4], (5, 5))((conv6))
up1 = deconv_activation_layer(up1) up1 = deconv_activation_layer(up1)
@@ -157,46 +175,60 @@ def apply_unet(
2, 2,
(4, 4), (4, 4),
dilation_rate=(2, 2), dilation_rate=(2, 2),
activation='sigmoid', activation="sigmoid",
padding='same', padding="same",
kernel_initializer=kernel_initializer)((batch12)) kernel_initializer=kernel_initializer,
)((batch12))
output = Multiply(name=output_name)([up7, input_tensor]) output = Multiply(name=output_name)([up7, input_tensor])
return output return output
return Conv2D( return Conv2D(
2, 2,
(4, 4), (4, 4),
dilation_rate=(2, 2), dilation_rate=(2, 2),
padding='same', padding="same",
kernel_initializer=kernel_initializer)((batch12)) kernel_initializer=kernel_initializer,
)((batch12))
def unet(input_tensor, instruments, params={}): def unet(
input_tensor: tf.Tensor, instruments: Iterable[str], params: Optional[Dict] = None
) -> Dict:
""" Model function applier. """ """ Model function applier. """
return apply(apply_unet, input_tensor, instruments, params) return apply(apply_unet, input_tensor, instruments, params)
def softmax_unet(input_tensor, instruments, params={}): def softmax_unet(
""" Apply softmax to multitrack unet in order to have mask suming to one. input_tensor: tf.Tensor, instruments: Iterable[str], params: Optional[Dict] = None
) -> Dict:
"""
Apply softmax to multitrack unet in order to have mask suming to one.
:param input_tensor: Tensor to apply blstm to. Parameters:
:param instruments: Iterable that provides a collection of instruments. input_tensor (tensorflow.Tensor):
:param params: (Optional) dict of BLSTM parameters. Tensor to apply blstm to.
:returns: Created output tensor dict. instruments (Iterable[str]):
Iterable that provides a collection of instruments.
params (Optional[Dict]):
(Optional) dict of BLSTM parameters.
Returns:
Dict:
Created output tensor dict.
""" """
logit_mask_list = [] logit_mask_list = []
for instrument in instruments: for instrument in instruments:
out_name = f'{instrument}_spectrogram' out_name = f"{instrument}_spectrogram"
logit_mask_list.append( logit_mask_list.append(
apply_unet( apply_unet(
input_tensor, input_tensor,
output_name=out_name, output_name=out_name,
params=params, params=params,
output_mask_logit=True)) output_mask_logit=True,
)
)
masks = Softmax(axis=4)(tf.stack(logit_mask_list, axis=4)) masks = Softmax(axis=4)(tf.stack(logit_mask_list, axis=4))
output_dict = {} output_dict = {}
for i, instrument in enumerate(instruments): for i, instrument in enumerate(instruments):
out_name = f'{instrument}_spectrogram' out_name = f"{instrument}_spectrogram"
output_dict[out_name] = Multiply(name=out_name)([ output_dict[out_name] = Multiply(name=out_name)([masks[..., i], input_tensor])
masks[..., i],
input_tensor])
return output_dict return output_dict

View File

@@ -5,77 +5,91 @@
This package provides tools for downloading model from network This package provides tools for downloading model from network
using remote storage abstraction. using remote storage abstraction.
:Example: Examples:
```python
>>> provider = MyProviderImplementation() >>> provider = MyProviderImplementation()
>>> provider.get('/path/to/local/storage', params) >>> provider.get('/path/to/local/storage', params)
```
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from os import environ, makedirs from os import environ, makedirs
from os.path import exists, isabs, join, sep from os.path import exists, isabs, join, sep
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"
class ModelProvider(ABC): class ModelProvider(ABC):
""" """
A ModelProvider manages model files on disk and A ModelProvider manages model files on disk and
file download is not available. file download is not available.
""" """
DEFAULT_MODEL_PATH = environ.get('MODEL_PATH', 'pretrained_models') DEFAULT_MODEL_PATH: str = environ.get("MODEL_PATH", "pretrained_models")
MODEL_PROBE_PATH = '.probe' MODEL_PROBE_PATH: str = ".probe"
@abstractmethod @abstractmethod
def download(self, name, path): def download(_, name: str, path: str) -> None:
""" Download model denoted by the given name to disk. """
Download model denoted by the given name to disk.
:param name: Name of the model to download. Parameters:
:param path: Path of the directory to save model into. name (str):
Name of the model to download.
path (str):
Path of the directory to save model into.
""" """
pass pass
@staticmethod @staticmethod
def writeProbe(directory): def writeProbe(directory: str) -> None:
""" Write a model probe file into the given directory.
:param directory: Directory to write probe into.
""" """
probe = join(directory, ModelProvider.MODEL_PROBE_PATH) Write a model probe file into the given directory.
with open(probe, 'w') as stream:
stream.write('OK')
def get(self, model_directory): Parameters:
""" Ensures required model is available at given location. directory (str):
Directory to write probe into.
"""
probe: str = join(directory, ModelProvider.MODEL_PROBE_PATH)
with open(probe, "w") as stream:
stream.write("OK")
:param model_directory: Expected model_directory to be available. def get(self, model_directory: str) -> str:
:raise IOError: If model can not be retrieved. """
Ensures required model is available at given location.
Parameters:
model_directory (str):
Expected model_directory to be available.
Raises:
IOError:
If model can not be retrieved.
""" """
# Expend model directory if needed. # Expend model directory if needed.
if not isabs(model_directory): if not isabs(model_directory):
model_directory = join(self.DEFAULT_MODEL_PATH, model_directory) model_directory = join(self.DEFAULT_MODEL_PATH, model_directory)
# Download it if not exists. # Download it if not exists.
model_probe = join(model_directory, self.MODEL_PROBE_PATH) model_probe: str = join(model_directory, self.MODEL_PROBE_PATH)
if not exists(model_probe): if not exists(model_probe):
if not exists(model_directory): if not exists(model_directory):
makedirs(model_directory) makedirs(model_directory)
self.download( self.download(model_directory.split(sep)[-1], model_directory)
model_directory.split(sep)[-1],
model_directory)
self.writeProbe(model_directory) self.writeProbe(model_directory)
return model_directory return model_directory
@classmethod
def default(_: type) -> "ModelProvider":
"""
Builds and returns a default model provider.
def get_default_model_provider(): Returns:
""" Builds and returns a default model provider. ModelProvider:
A default model provider instance to use.
"""
from .github import GithubModelProvider
:returns: A default model provider instance to use. return GithubModelProvider.from_environ()
"""
from .github import GithubModelProvider
host = environ.get('GITHUB_HOST', 'https://github.com')
repository = environ.get('GITHUB_REPOSITORY', 'deezer/spleeter')
release = environ.get('GITHUB_RELEASE', GithubModelProvider.LATEST_RELEASE)
return GithubModelProvider(host, repository, release)

View File

@@ -4,41 +4,48 @@
""" """
A ModelProvider backed by Github Release feature. A ModelProvider backed by Github Release feature.
:Example: Examples:
```python
>>> from spleeter.model.provider import github >>> from spleeter.model.provider import github
>>> provider = github.GithubModelProvider( >>> provider = github.GithubModelProvider(
'github.com', 'github.com',
'Deezer/spleeter', 'Deezer/spleeter',
'latest') 'latest')
>>> provider.download('2stems', '/path/to/local/storage') >>> provider.download('2stems', '/path/to/local/storage')
```
""" """
import hashlib import hashlib
import tarfile
import os import os
import tarfile
from os import environ
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Dict
import requests # pyright: reportMissingImports=false
# pylint: disable=import-error
import httpx
from ...utils.logging import logger
from . import ModelProvider from . import ModelProvider
from ...utils.logging import get_logger
__email__ = 'spleeter@deezer.com' # pylint: enable=import-error
__author__ = 'Deezer Research'
__license__ = 'MIT License' __email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def compute_file_checksum(path): def compute_file_checksum(path):
""" Computes given path file sha256. """Computes given path file sha256.
:param path: Path of the file to compute checksum for. :param path: Path of the file to compute checksum for.
:returns: File checksum. :returns: File checksum.
""" """
sha256 = hashlib.sha256() sha256 = hashlib.sha256()
with open(path, 'rb') as stream: with open(path, "rb") as stream:
for chunk in iter(lambda: stream.read(4096), b''): for chunk in iter(lambda: stream.read(4096), b""):
sha256.update(chunk) sha256.update(chunk)
return sha256.hexdigest() return sha256.hexdigest()
@@ -46,69 +53,104 @@ def compute_file_checksum(path):
class GithubModelProvider(ModelProvider): class GithubModelProvider(ModelProvider):
""" A ModelProvider implementation backed on Github for remote storage. """ """ A ModelProvider implementation backed on Github for remote storage. """
LATEST_RELEASE = 'v1.4.0' DEFAULT_HOST: str = "https://github.com"
RELEASE_PATH = 'releases/download' DEFAULT_REPOSITORY: str = "deezer/spleeter"
CHECKSUM_INDEX = 'checksum.json'
def __init__(self, host, repository, release): CHECKSUM_INDEX: str = "checksum.json"
""" Default constructor. LATEST_RELEASE: str = "v1.4.0"
RELEASE_PATH: str = "releases/download"
:param host: Host to the Github instance to reach. def __init__(self, host: str, repository: str, release: str) -> None:
:param repository: Repository path within target Github. """Default constructor.
:param release: Release name to get models from.
Parameters:
host (str):
Host to the Github instance to reach.
repository (str):
Repository path within target Github.
release (str):
Release name to get models from.
""" """
self._host = host self._host: str = host
self._repository = repository self._repository: str = repository
self._release = release self._release: str = release
def checksum(self, name): @classmethod
""" Downloads and returns reference checksum for the given model name. def from_environ(cls: type) -> "GithubModelProvider":
:param name: Name of the model to get checksum for.
:returns: Checksum of the required model.
:raise ValueError: If the given model name is not indexed.
""" """
url = '{}/{}/{}/{}/{}'.format( Factory method that creates provider from envvars.
self._host,
self._repository, Returns:
self.RELEASE_PATH, GithubModelProvider:
self._release, Created instance.
self.CHECKSUM_INDEX) """
response = requests.get(url) return cls(
environ.get("GITHUB_HOST", cls.DEFAULT_HOST),
environ.get("GITHUB_REPOSITORY", cls.DEFAULT_REPOSITORY),
environ.get("GITHUB_RELEASE", cls.LATEST_RELEASE),
)
def checksum(self, name: str) -> str:
"""
Downloads and returns reference checksum for the given model name.
Parameters:
name (str):
Name of the model to get checksum for.
Returns:
str:
Checksum of the required model.
Raises:
ValueError:
If the given model name is not indexed.
"""
url: str = "/".join(
(
self._host,
self._repository,
self.RELEASE_PATH,
self._release,
self.CHECKSUM_INDEX,
)
)
response: httpx.Response = httpx.get(url)
response.raise_for_status() response.raise_for_status()
index = response.json() index: Dict = response.json()
if name not in index: if name not in index:
raise ValueError('No checksum for model {}'.format(name)) raise ValueError(f"No checksum for model {name}")
return index[name] return index[name]
def download(self, name, path): def download(self, name: str, path: str) -> None:
""" Download model denoted by the given name to disk.
:param name: Name of the model to download.
:param path: Path of the directory to save model into.
""" """
url = '{}/{}/{}/{}/{}.tar.gz'.format( Download model denoted by the given name to disk.
self._host,
self._repository, Parameters:
self.RELEASE_PATH, name (str):
self._release, Name of the model to download.
name) path (str):
get_logger().info('Downloading model archive %s', url) Path of the directory to save model into.
with requests.get(url, stream=True) as response: """
response.raise_for_status() url: str = "/".join(
archive = NamedTemporaryFile(delete=False) (self._host, self._repository, self.RELEASE_PATH, self._release, name)
try: )
with archive as stream: url = f"{url}.tar.gz"
# Note: check for chunk size parameters ? logger.info(f"Downloading model archive {url}")
for chunk in response.iter_content(chunk_size=8192): with httpx.Client(http2=True) as client:
if chunk: with client.stream("GET", url) as response:
response.raise_for_status()
archive = NamedTemporaryFile(delete=False)
try:
with archive as stream:
for chunk in response.iter_raw():
stream.write(chunk) stream.write(chunk)
get_logger().info('Validating archive checksum') logger.info("Validating archive checksum")
if compute_file_checksum(archive.name) != self.checksum(name): checksum: str = compute_file_checksum(archive.name)
raise IOError('Downloaded file is corrupted, please retry') if checksum != self.checksum(name):
get_logger().info('Extracting downloaded %s archive', name) raise IOError("Downloaded file is corrupted, please retry")
with tarfile.open(name=archive.name) as tar: logger.info(f"Extracting downloaded {name} archive")
tar.extractall(path=path) with tarfile.open(name=archive.name) as tar:
finally: tar.extractall(path=path)
os.unlink(archive.name) finally:
get_logger().info('%s model file(s) extracted', name) os.unlink(archive.name)
logger.info(f"{name} model file(s) extracted")

128
spleeter/options.py Normal file
View File

@@ -0,0 +1,128 @@
#!/usr/bin/env python
# coding: utf8
""" This modules provides spleeter command as well as CLI parsing methods. """
from os.path import join
from tempfile import gettempdir
from typer import Argument, Option
from typer.models import ArgumentInfo, OptionInfo
from .audio import Codec, STFTBackend
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
AudioInputArgument: ArgumentInfo = Argument(
...,
help="List of input audio file path",
exists=True,
file_okay=True,
dir_okay=False,
readable=True,
resolve_path=True,
)
AudioInputOption: OptionInfo = Option(
None, "--inputs", "-i", help="(DEPRECATED) placeholder for deprecated input option"
)
AudioAdapterOption: OptionInfo = Option(
"spleeter.audio.ffmpeg.FFMPEGProcessAudioAdapter",
"--adapter",
"-a",
help="Name of the audio adapter to use for audio I/O",
)
AudioOutputOption: OptionInfo = Option(
join(gettempdir(), "separated_audio"),
"--output_path",
"-o",
help="Path of the output directory to write audio files in",
)
AudioOffsetOption: OptionInfo = Option(
0.0, "--offset", "-s", help="Set the starting offset to separate audio from"
)
AudioDurationOption: OptionInfo = Option(
600.0,
"--duration",
"-d",
help=(
"Set a maximum duration for processing audio "
"(only separate offset + duration first seconds of "
"the input file)"
),
)
AudioSTFTBackendOption: OptionInfo = Option(
STFTBackend.AUTO,
"--stft-backend",
"-B",
case_sensitive=False,
help=(
"Who should be in charge of computing the stfts. Librosa is faster "
'than tensorflow on CPU and uses less memory. "auto" will use '
"tensorflow when GPU acceleration is available and librosa when not"
),
)
AudioCodecOption: OptionInfo = Option(
Codec.WAV, "--codec", "-c", help="Audio codec to be used for the separated output"
)
AudioBitrateOption: OptionInfo = Option(
"128k", "--bitrate", "-b", help="Audio bitrate to be used for the separated output"
)
FilenameFormatOption: OptionInfo = Option(
"{filename}/{instrument}.{codec}",
"--filename_format",
"-f",
help=(
"Template string that will be formatted to generated"
"output filename. Such template should be Python formattable"
"string, and could use {filename}, {instrument}, and {codec}"
"variables"
),
)
ModelParametersOption: OptionInfo = Option(
"spleeter:2stems",
"--params_filename",
"-p",
help="JSON filename that contains params",
)
MWFOption: OptionInfo = Option(
False, "--mwf", help="Whether to use multichannel Wiener filtering for separation"
)
MUSDBDirectoryOption: OptionInfo = Option(
...,
"--mus_dir",
exists=True,
dir_okay=True,
file_okay=False,
readable=True,
resolve_path=True,
help="Path to musDB dataset directory",
)
TrainingDataDirectoryOption: OptionInfo = Option(
...,
"--data",
"-d",
exists=True,
dir_okay=True,
file_okay=False,
readable=True,
resolve_path=True,
help="Path of the folder containing audio data for training",
)
VerboseOption: OptionInfo = Option(False, "--verbose", help="Enable verbose logs")

0
spleeter/py.typed Normal file
View File

View File

@@ -3,6 +3,6 @@
""" Packages that provides static resources file for the library. """ """ Packages that provides static resources file for the library. """
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"

View File

@@ -4,60 +4,63 @@
""" """
Module that provides a class wrapper for source separation. Module that provides a class wrapper for source separation.
:Example: Examples:
```python
>>> from spleeter.separator import Separator >>> from spleeter.separator import Separator
>>> separator = Separator('spleeter:2stems') >>> separator = Separator('spleeter:2stems')
>>> separator.separate(waveform, lambda instrument, data: ...) >>> separator.separate(waveform, lambda instrument, data: ...)
>>> separator.separate_to_file(...) >>> separator.separate_to_file(...)
```
""" """
import atexit import atexit
import os import os
import logging
from multiprocessing import Pool from multiprocessing import Pool
from os.path import basename, join, splitext, dirname from os.path import basename, dirname, join, splitext
from time import time from typing import Dict, Generator, Optional
from typing import Container, NoReturn
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from librosa.core import istft, stft
from librosa.core import stft, istft
from scipy.signal.windows import hann from scipy.signal.windows import hann
from spleeter.model.provider import ModelProvider
from . import SpleeterError from . import SpleeterError
from .audio.adapter import get_default_audio_adapter from .audio import Codec, STFTBackend
from .audio.adapter import AudioAdapter
from .audio.convertor import to_stereo from .audio.convertor import to_stereo
from .model import EstimatorSpecBuilder, InputProviderFactory, model_fn
from .model.provider import ModelProvider
from .types import AudioDescriptor
from .utils.configuration import load_configuration from .utils.configuration import load_configuration
from .utils.estimator import create_estimator, get_default_model_dir
from .model import EstimatorSpecBuilder, InputProviderFactory
__email__ = 'spleeter@deezer.com' # pylint: enable=import-error
__author__ = 'Deezer Research'
__license__ = 'MIT License'
SUPPORTED_BACKEND: Container[str] = ('auto', 'tensorflow', 'librosa') __email__ = "spleeter@deezer.com"
""" """ __author__ = "Deezer Research"
__license__ = "MIT License"
class DataGenerator(): class DataGenerator(object):
""" """
Generator object that store a sample and generate it once while called. Generator object that store a sample and generate it once while called.
Used to feed a tensorflow estimator without knowing the whole data at Used to feed a tensorflow estimator without knowing the whole data at
build time. build time.
""" """
def __init__(self): def __init__(self) -> None:
""" Default constructor. """ """ Default constructor. """
self._current_data = None self._current_data = None
def update_data(self, data): def update_data(self, data) -> None:
""" Replace internal data. """ """ Replace internal data. """
self._current_data = data self._current_data = data
def __call__(self): def __call__(self) -> Generator:
""" Generation process. """ """ Generation process. """
buffer = self._current_data buffer = self._current_data
while buffer: while buffer:
@@ -65,34 +68,52 @@ class DataGenerator():
buffer = self._current_data buffer = self._current_data
def get_backend(backend: str) -> str: def create_estimator(params, MWF):
""" """
Initialize tensorflow estimator that will perform separation
Params:
- params: a dictionary of parameters for building the model
Returns:
a tensorflow estimator
""" """
if backend not in SUPPORTED_BACKEND: # Load model.
raise ValueError(f'Unsupported backend {backend}') provider: ModelProvider = ModelProvider.default()
if backend == 'auto': params["model_dir"] = provider.get(params["model_dir"])
if len(tf.config.list_physical_devices('GPU')): params["MWF"] = MWF
return 'tensorflow' # Setup config
return 'librosa' session_config = tf.compat.v1.ConfigProto()
return backend session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
config = tf.estimator.RunConfig(session_config=session_config)
# Setup estimator
estimator = tf.estimator.Estimator(
model_fn=model_fn, model_dir=params["model_dir"], params=params, config=config
)
return estimator
class Separator(object): class Separator(object):
""" A wrapper class for performing separation. """ """ A wrapper class for performing separation. """
def __init__( def __init__(
self, self,
params_descriptor, params_descriptor: str,
MWF: bool = False, MWF: bool = False,
stft_backend: str = 'auto', stft_backend: STFTBackend = STFTBackend.AUTO,
multiprocess: bool = True): multiprocess: bool = True,
""" Default constructor. ) -> None:
"""
Default constructor.
:param params_descriptor: Descriptor for TF params to be used. Parameters:
:param MWF: (Optional) True if MWF should be used, False otherwise. params_descriptor (str):
Descriptor for TF params to be used.
MWF (bool):
(Optional) `True` if MWF should be used, `False` otherwise.
""" """
self._params = load_configuration(params_descriptor) self._params = load_configuration(params_descriptor)
self._sample_rate = self._params['sample_rate'] self._sample_rate = self._params["sample_rate"]
self._MWF = MWF self._MWF = MWF
self._tf_graph = tf.Graph() self._tf_graph = tf.Graph()
self._prediction_generator = None self._prediction_generator = None
@@ -106,19 +127,21 @@ class Separator(object):
else: else:
self._pool = None self._pool = None
self._tasks = [] self._tasks = []
self._params['stft_backend'] = get_backend(stft_backend) self._params["stft_backend"] = STFTBackend.resolve(stft_backend)
self._data_generator = DataGenerator() self._data_generator = DataGenerator()
def __del__(self): def __del__(self) -> None:
""" """
if self._session: if self._session:
self._session.close() self._session.close()
def _get_prediction_generator(self): def _get_prediction_generator(self) -> Generator:
""" Lazy loading access method for internal prediction generator """
Lazy loading access method for internal prediction generator
returned by the predict method of a tensorflow estimator. returned by the predict method of a tensorflow estimator.
:returns: generator of prediction. Returns:
Generator:
Generator of prediction.
""" """
if self._prediction_generator is None: if self._prediction_generator is None:
estimator = create_estimator(self._params, self._MWF) estimator = create_estimator(self._params, self._MWF)
@@ -126,82 +149,74 @@ class Separator(object):
def get_dataset(): def get_dataset():
return tf.data.Dataset.from_generator( return tf.data.Dataset.from_generator(
self._data_generator, self._data_generator,
output_types={ output_types={"waveform": tf.float32, "audio_id": tf.string},
'waveform': tf.float32, output_shapes={"waveform": (None, 2), "audio_id": ()},
'audio_id': tf.string}, )
output_shapes={
'waveform': (None, 2),
'audio_id': ()})
self._prediction_generator = estimator.predict( self._prediction_generator = estimator.predict(
get_dataset, get_dataset, yield_single_examples=False
yield_single_examples=False) )
return self._prediction_generator return self._prediction_generator
def join(self, timeout: int = 200) -> NoReturn: def join(self, timeout: int = 200) -> None:
""" Wait for all pending tasks to be finished. """
Wait for all pending tasks to be finished.
:param timeout: (Optional) task waiting timeout. Parameters:
timeout (int):
(Optional) task waiting timeout.
""" """
while len(self._tasks) > 0: while len(self._tasks) > 0:
task = self._tasks.pop() task = self._tasks.pop()
task.get() task.get()
task.wait(timeout=timeout) task.wait(timeout=timeout)
def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor): def _stft(
""" Performs source separation over the given waveform with tensorflow self, data: np.ndarray, inverse: bool = False, length: Optional[int] = None
backend. ) -> np.ndarray:
:param waveform: Waveform to apply separation on.
:returns: Separated waveforms.
""" """
if not waveform.shape[-1] == 2: Single entrypoint for both stft and istft. This computes stft and
waveform = to_stereo(waveform)
prediction_generator = self._get_prediction_generator()
# NOTE: update data in generator before performing separation.
self._data_generator.update_data({
'waveform': waveform,
'audio_id': np.array(audio_descriptor)})
# NOTE: perform separation.
prediction = next(prediction_generator)
prediction.pop('audio_id')
return prediction
def _stft(self, data, inverse: bool = False, length=None):
""" Single entrypoint for both stft and istft. This computes stft and
istft with librosa on stereo data. The two channels are processed istft with librosa on stereo data. The two channels are processed
separately and are concatenated together in the result. The expected separately and are concatenated together in the result. The
input formats are: (n_samples, 2) for stft and (T, F, 2) for istft. expected input formats are: (n_samples, 2) for stft and (T, F, 2)
for istft.
:param data: np.array with either the waveform or the complex Parameters:
spectrogram depending on the parameter inverse data (numpy.array):
:param inverse: should a stft or an istft be computed. Array with either the waveform or the complex spectrogram
:returns: Stereo data as numpy array for the transform. depending on the parameter inverse
The channels are stored in the last dimension. inverse (bool):
(Optional) Should a stft or an istft be computed.
length (Optional[int]):
Returns:
numpy.ndarray:
Stereo data as numpy array for the transform. The channels
are stored in the last dimension.
""" """
assert not (inverse and length is None) assert not (inverse and length is None)
data = np.asfortranarray(data) data = np.asfortranarray(data)
N = self._params['frame_length'] N = self._params["frame_length"]
H = self._params['frame_step'] H = self._params["frame_step"]
win = hann(N, sym=False) win = hann(N, sym=False)
fstft = istft if inverse else stft fstft = istft if inverse else stft
win_len_arg = { win_len_arg = {"win_length": None, "length": None} if inverse else {"n_fft": N}
'win_length': None,
'length': None} if inverse else {'n_fft': N}
n_channels = data.shape[-1] n_channels = data.shape[-1]
out = [] out = []
for c in range(n_channels): for c in range(n_channels):
d = np.concatenate( d = (
(np.zeros((N, )), data[:, c], np.zeros((N, ))) np.concatenate((np.zeros((N,)), data[:, c], np.zeros((N,))))
) if not inverse else data[:, :, c].T if not inverse
else data[:, :, c].T
)
s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg) s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg)
if inverse: if inverse:
s = s[N:N+length] s = s[N : N + length]
s = np.expand_dims(s.T, 2-inverse) s = np.expand_dims(s.T, 2 - inverse)
out.append(s) out.append(s)
if len(out) == 1: if len(out) == 1:
return out[0] return out[0]
return np.concatenate(out, axis=2-inverse) return np.concatenate(out, axis=2 - inverse)
def _get_input_provider(self): def _get_input_provider(self):
if self._input_provider is None: if self._input_provider is None:
@@ -216,22 +231,29 @@ class Separator(object):
def _get_builder(self): def _get_builder(self):
if self._builder is None: if self._builder is None:
self._builder = EstimatorSpecBuilder( self._builder = EstimatorSpecBuilder(self._get_features(), self._params)
self._get_features(),
self._params)
return self._builder return self._builder
def _get_session(self): def _get_session(self):
if self._session is None: if self._session is None:
saver = tf.compat.v1.train.Saver() saver = tf.compat.v1.train.Saver()
latest_checkpoint = tf.train.latest_checkpoint( provider = ModelProvider.default()
get_default_model_dir(self._params['model_dir'])) model_directory: str = provider.get(self._params["model_dir"])
latest_checkpoint = tf.train.latest_checkpoint(model_directory)
self._session = tf.compat.v1.Session() self._session = tf.compat.v1.Session()
saver.restore(self._session, latest_checkpoint) saver.restore(self._session, latest_checkpoint)
return self._session return self._session
def _separate_librosa(self, waveform: np.ndarray, audio_id): def _separate_librosa(
""" Performs separation with librosa backend for STFT. self, waveform: np.ndarray, audio_descriptor: AudioDescriptor
) -> Dict:
"""
Performs separation with librosa backend for STFT.
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
audio_descriptor (AudioDescriptor):
""" """
with self._tf_graph.as_default(): with self._tf_graph.as_default():
out = {} out = {}
@@ -248,65 +270,115 @@ class Separator(object):
outputs = sess.run( outputs = sess.run(
outputs, outputs,
feed_dict=self._get_input_provider().get_feed_dict( feed_dict=self._get_input_provider().get_feed_dict(
features, features, stft, audio_descriptor
stft, ),
audio_id)) )
for inst in self._get_builder().instruments: for inst in self._get_builder().instruments:
out[inst] = self._stft( out[inst] = self._stft(
outputs[inst], outputs[inst], inverse=True, length=waveform.shape[0]
inverse=True, )
length=waveform.shape[0])
return out return out
def separate(self, waveform: np.ndarray, audio_descriptor=''): def _separate_tensorflow(
""" Performs separation on a waveform. self, waveform: np.ndarray, audio_descriptor: AudioDescriptor
) -> Dict:
:param waveform: Waveform to be separated (as a numpy array)
:param audio_descriptor: (Optional) string describing the waveform
(e.g. filename).
""" """
if self._params['stft_backend'] == 'tensorflow': Performs source separation over the given waveform with tensorflow
backend.
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
audio_descriptor (AudioDescriptor):
Returns:
Separated waveforms.
"""
if not waveform.shape[-1] == 2:
waveform = to_stereo(waveform)
prediction_generator = self._get_prediction_generator()
# NOTE: update data in generator before performing separation.
self._data_generator.update_data(
{"waveform": waveform, "audio_id": np.array(audio_descriptor)}
)
# NOTE: perform separation.
prediction = next(prediction_generator)
prediction.pop("audio_id")
return prediction
def separate(
self, waveform: np.ndarray, audio_descriptor: Optional[str] = None
) -> None:
"""
Performs separation on a waveform.
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
audio_descriptor (str):
(Optional) string describing the waveform (e.g. filename).
"""
backend: str = self._params["stft_backend"]
if backend == STFTBackend.TENSORFLOW:
return self._separate_tensorflow(waveform, audio_descriptor) return self._separate_tensorflow(waveform, audio_descriptor)
else: elif backend == STFTBackend.LIBROSA:
return self._separate_librosa(waveform, audio_descriptor) return self._separate_librosa(waveform, audio_descriptor)
raise ValueError(f"Unsupported STFT backend {backend}")
def separate_to_file( def separate_to_file(
self, self,
audio_descriptor, audio_descriptor: AudioDescriptor,
destination, destination: str,
audio_adapter=get_default_audio_adapter(), audio_adapter: Optional[AudioAdapter] = None,
offset=0, offset: int = 0,
duration=600., duration: float = 600.0,
codec='wav', codec: Codec = Codec.WAV,
bitrate='128k', bitrate: str = "128k",
filename_format='{filename}/{instrument}.{codec}', filename_format: str = "{filename}/{instrument}.{codec}",
synchronous=True): synchronous: bool = True,
""" Performs source separation and export result to file using ) -> None:
"""
Performs source separation and export result to file using
given audio adapter. given audio adapter.
Filename format should be a Python formattable string that could use Filename format should be a Python formattable string that could
following parameters : {instrument}, {filename}, {foldername} and use following parameters :
{codec}.
:param audio_descriptor: Describe song to separate, used by audio - {instrument}
adapter to retrieve and load audio data, - {filename}
in case of file based audio adapter, such - {foldername}
descriptor would be a file path. - {codec}.
:param destination: Target directory to write output to.
:param audio_adapter: (Optional) Audio adapter to use for I/O. Parameters:
:param offset: (Optional) Offset of loaded song. audio_descriptor (AudioDescriptor):
:param duration: (Optional) Duration of loaded song Describe song to separate, used by audio adapter to
(default: 600s). retrieve and load audio data, in case of file based
:param codec: (Optional) Export codec. audio adapter, such descriptor would be a file path.
:param bitrate: (Optional) Export bitrate. destination (str):
:param filename_format: (Optional) Filename format. Target directory to write output to.
:param synchronous: (Optional) True is should by synchronous. audio_adapter (Optional[AudioAdapter]):
(Optional) Audio adapter to use for I/O.
offset (int):
(Optional) Offset of loaded song.
duration (float):
(Optional) Duration of loaded song (default: 600s).
codec (Codec):
(Optional) Export codec.
bitrate (str):
(Optional) Export bitrate.
filename_format (str):
(Optional) Filename format.
synchronous (bool):
(Optional) True is should by synchronous.
""" """
waveform, sample_rate = audio_adapter.load( if audio_adapter is None:
audio_adapter = AudioAdapter.default()
waveform, _ = audio_adapter.load(
audio_descriptor, audio_descriptor,
offset=offset, offset=offset,
duration=duration, duration=duration,
sample_rate=self._sample_rate) sample_rate=self._sample_rate,
)
sources = self.separate(waveform, audio_descriptor) sources = self.separate(waveform, audio_descriptor)
self.save_to_file( self.save_to_file(
sources, sources,
@@ -316,69 +388,78 @@ class Separator(object):
codec, codec,
audio_adapter, audio_adapter,
bitrate, bitrate,
synchronous) synchronous,
)
def save_to_file( def save_to_file(
self, self,
sources, sources: Dict,
audio_descriptor, audio_descriptor: AudioDescriptor,
destination, destination: str,
filename_format='{filename}/{instrument}.{codec}', filename_format: str = "{filename}/{instrument}.{codec}",
codec='wav', codec: Codec = Codec.WAV,
audio_adapter=get_default_audio_adapter(), audio_adapter: Optional[AudioAdapter] = None,
bitrate='128k', bitrate: str = "128k",
synchronous=True): synchronous: bool = True,
""" Export dictionary of sources to files. ) -> None:
:param sources: Dictionary of sources to be exported. The
keys are the name of the instruments, and
the values are Nx2 numpy arrays containing
the corresponding intrument waveform, as
returned by the separate method
:param audio_descriptor: Describe song to separate, used by audio
adapter to retrieve and load audio data,
in case of file based audio adapter, such
descriptor would be a file path.
:param destination: Target directory to write output to.
:param filename_format: (Optional) Filename format.
:param codec: (Optional) Export codec.
:param audio_adapter: (Optional) Audio adapter to use for I/O.
:param bitrate: (Optional) Export bitrate.
:param synchronous: (Optional) True is should by synchronous.
""" """
Export dictionary of sources to files.
Parameters:
sources (Dict):
Dictionary of sources to be exported. The keys are the name
of the instruments, and the values are `N x 2` numpy arrays
containing the corresponding intrument waveform, as
returned by the separate method
audio_descriptor (AudioDescriptor):
Describe song to separate, used by audio adapter to
retrieve and load audio data, in case of file based audio
adapter, such descriptor would be a file path.
destination (str):
Target directory to write output to.
filename_format (str):
(Optional) Filename format.
codec (Codec):
(Optional) Export codec.
audio_adapter (Optional[AudioAdapter]):
(Optional) Audio adapter to use for I/O.
bitrate (str):
(Optional) Export bitrate.
synchronous (bool):
(Optional) True is should by synchronous.
"""
if audio_adapter is None:
audio_adapter = AudioAdapter.default()
foldername = basename(dirname(audio_descriptor)) foldername = basename(dirname(audio_descriptor))
filename = splitext(basename(audio_descriptor))[0] filename = splitext(basename(audio_descriptor))[0]
generated = [] generated = []
for instrument, data in sources.items(): for instrument, data in sources.items():
path = join(destination, filename_format.format( path = join(
filename=filename, destination,
instrument=instrument, filename_format.format(
foldername=foldername, filename=filename,
codec=codec, instrument=instrument,
)) foldername=foldername,
codec=codec,
),
)
directory = os.path.dirname(path) directory = os.path.dirname(path)
if not os.path.exists(directory): if not os.path.exists(directory):
os.makedirs(directory) os.makedirs(directory)
if path in generated: if path in generated:
raise SpleeterError(( raise SpleeterError(
f'Separated source path conflict : {path},' (
'please check your filename format')) f"Separated source path conflict : {path},"
"please check your filename format"
)
)
generated.append(path) generated.append(path)
if self._pool: if self._pool:
task = self._pool.apply_async(audio_adapter.save, ( task = self._pool.apply_async(
path, audio_adapter.save, (path, data, self._sample_rate, codec, bitrate)
data, )
self._sample_rate,
codec,
bitrate))
self._tasks.append(task) self._tasks.append(task)
else: else:
audio_adapter.save( audio_adapter.save(path, data, self._sample_rate, codec, bitrate)
path,
data,
self._sample_rate,
codec,
bitrate)
if synchronous and self._pool: if synchronous and self._pool:
self.join() self.join()

15
spleeter/types.py Normal file
View File

@@ -0,0 +1,15 @@
#!/usr/bin/env python
# coding: utf8
""" Custom types definition. """
from typing import Any, Tuple
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
# pylint: enable=import-error
AudioDescriptor: type = Any
Signal: type = Tuple[np.ndarray, float]

View File

@@ -3,6 +3,6 @@
""" This package provides utility function and classes. """ """ This package provides utility function and classes. """
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"

View File

@@ -3,45 +3,49 @@
""" Module that provides configuration loading function. """ """ Module that provides configuration loading function. """
import importlib.resources as loader
import json import json
try:
import importlib.resources as loader
except ImportError:
# Try backported to PY<37 `importlib_resources`.
import importlib_resources as loader
from os.path import exists from os.path import exists
from typing import Dict
from .. import resources, SpleeterError from .. import SpleeterError, resources
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
_EMBEDDED_CONFIGURATION_PREFIX: str = "spleeter:"
__email__ = 'spleeter@deezer.com' def load_configuration(descriptor: str) -> Dict:
__author__ = 'Deezer Research' """
__license__ = 'MIT License' Load configuration from the given descriptor. Could be either a
`spleeter:` prefixed embedded configuration name or a file system path
to read configuration from.
_EMBEDDED_CONFIGURATION_PREFIX = 'spleeter:' Parameters:
descriptor (str):
Configuration descriptor to use for lookup.
Returns:
Dict:
Loaded description as dict.
def load_configuration(descriptor): Raises:
""" Load configuration from the given descriptor. Could be ValueError:
either a `spleeter:` prefixed embedded configuration name If required embedded configuration does not exists.
or a file system path to read configuration from. SpleeterError:
If required configuration file does not exists.
:param descriptor: Configuration descriptor to use for lookup.
:returns: Loaded description as dict.
:raise ValueError: If required embedded configuration does not exists.
:raise SpleeterError: If required configuration file does not exists.
""" """
# Embedded configuration reading. # Embedded configuration reading.
if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX): if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX):
name = descriptor[len(_EMBEDDED_CONFIGURATION_PREFIX):] name = descriptor[len(_EMBEDDED_CONFIGURATION_PREFIX) :]
if not loader.is_resource(resources, f'{name}.json'): if not loader.is_resource(resources, f"{name}.json"):
raise SpleeterError(f'No embedded configuration {name} found') raise SpleeterError(f"No embedded configuration {name} found")
with loader.open_text(resources, f'{name}.json') as stream: with loader.open_text(resources, f"{name}.json") as stream:
return json.load(stream) return json.load(stream)
# Standard file reading. # Standard file reading.
if not exists(descriptor): if not exists(descriptor):
raise SpleeterError(f'Configuration file {descriptor} not found') raise SpleeterError(f"Configuration file {descriptor} not found")
with open(descriptor, 'r') as stream: with open(descriptor, "r") as stream:
return json.load(stream) return json.load(stream)

View File

@@ -1,46 +0,0 @@
#!/usr/bin/env python
# coding: utf8
""" Utility functions for creating estimator. """
import tensorflow as tf # pylint: disable=import-error
from ..model import model_fn
from ..model.provider import get_default_model_provider
def get_default_model_dir(model_dir):
"""
Transforms a string like 'spleeter:2stems' into an actual path.
:param model_dir:
:return:
"""
model_provider = get_default_model_provider()
return model_provider.get(model_dir)
def create_estimator(params, MWF):
"""
Initialize tensorflow estimator that will perform separation
Params:
- params: a dictionary of parameters for building the model
Returns:
a tensorflow estimator
"""
# Load model.
params['model_dir'] = get_default_model_dir(params['model_dir'])
params['MWF'] = MWF
# Setup config
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
config = tf.estimator.RunConfig(session_config=session_config)
# Setup estimator
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params['model_dir'],
params=params,
config=config
)
return estimator

View File

@@ -4,58 +4,53 @@
""" Centralized logging facilities for Spleeter. """ """ Centralized logging facilities for Spleeter. """
import logging import logging
import warnings
from os import environ from os import environ
__email__ = 'spleeter@deezer.com' # pyright: reportMissingImports=false
__author__ = 'Deezer Research' # pylint: disable=import-error
__license__ = 'MIT License' from typer import echo
_FORMAT = '%(levelname)s:%(name)s:%(message)s' # pylint: enable=import-error
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
class _LoggerHolder(object): class TyperLoggerHandler(logging.Handler):
""" Logger singleton instance holder. """ """ A custom logger handler that use Typer echo. """
INSTANCE = None def emit(self, record: logging.LogRecord) -> None:
echo(self.format(record))
def get_tensorflow_logger(): formatter = logging.Formatter("%(levelname)s:%(name)s:%(message)s")
handler = TyperLoggerHandler()
handler.setFormatter(formatter)
logger: logging.Logger = logging.getLogger("spleeter")
logger.addHandler(handler)
logger.setLevel(logging.INFO)
def configure_logger(verbose: bool) -> None:
""" """
Configure application logger.
Parameters:
verbose (bool):
`True` to use verbose logger, `False` otherwise.
""" """
# pylint: disable=import-error from tensorflow import get_logger
from tensorflow.compat.v1 import logging from tensorflow.compat.v1 import logging as tf_logging
# pylint: enable=import-error
return logging
tf_logger = get_logger()
def get_logger(): tf_logger.handlers = [handler]
""" Returns library scoped logger. if verbose:
tf_logging.set_verbosity(tf_logging.INFO)
:returns: Library logger. logger.setLevel(logging.DEBUG)
""" else:
if _LoggerHolder.INSTANCE is None: warnings.filterwarnings("ignore")
formatter = logging.Formatter(_FORMAT) tf_logging.set_verbosity(tf_logging.ERROR)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger = logging.getLogger('spleeter')
logger.addHandler(handler)
logger.setLevel(logging.INFO)
_LoggerHolder.INSTANCE = logger
return _LoggerHolder.INSTANCE
def enable_tensorflow_logging():
""" Enable tensorflow logging. """
environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
tf_logger = get_tensorflow_logger()
tf_logger.set_verbosity(tf_logger.INFO)
logger = get_logger()
logger.setLevel(logging.DEBUG)
def enable_logging():
""" Configure default logging. """
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf_logger = get_tensorflow_logger()
tf_logger.set_verbosity(tf_logger.ERROR)

View File

@@ -3,43 +3,54 @@
""" Utility function for tensorflow. """ """ Utility function for tensorflow. """
from typing import Any, Callable, Dict
import pandas as pd
# pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import tensorflow as tf import tensorflow as tf
import pandas as pd
# pylint: enable=import-error # pylint: enable=import-error
__email__ = 'spleeter@deezer.com' __email__ = "spleeter@deezer.com"
__author__ = 'Deezer Research' __author__ = "Deezer Research"
__license__ = 'MIT License' __license__ = "MIT License"
def sync_apply(tensor_dict, func, concat_axis=1): def sync_apply(
""" Return a function that applies synchronously the provided func on the tensor_dict: tf.Tensor, func: Callable, concat_axis: int = 1
) -> Dict[str, tf.Tensor]:
"""
Return a function that applies synchronously the provided func on the
provided dictionnary of tensor. This means that func is applied to the provided dictionnary of tensor. This means that func is applied to the
concatenation of the tensors in tensor_dict. This is useful for performing concatenation of the tensors in tensor_dict. This is useful for
random operation that needs the same drawn value on multiple tensor, such performing random operation that needs the same drawn value on multiple
as a random time-crop on both input data and label (the same crop should be tensor, such as a random time-crop on both input data and label (the
applied to both input data and label, so random crop cannot be applied same crop should be applied to both input data and label, so random
separately on each of them). crop cannot be applied separately on each of them).
IMPORTANT NOTE: all tensor are assumed to be the same shape. Notes:
All tensor are assumed to be the same shape.
Params: Parameters:
- tensor_dict: dictionary (key: strings, values: tf.tensor) tensor_dict (Dict[str, tensorflow.Tensor]):
a dictionary of tensor. A dictionary of tensor.
- func: function func (Callable):
function to be applied to the concatenation of the tensors in Function to be applied to the concatenation of the tensors in
tensor_dict `tensor_dict`.
- concat_axis: int concat_axis (int):
The axis on which to perform the concatenation. The axis on which to perform the concatenation.
Returns: Returns:
processed tensors dictionary with the same name (keys) as input Dict[str, tensorflow.Tensor]:
tensor_dict. Processed tensors dictionary with the same name (keys) as input
tensor_dict.
""" """
if concat_axis not in {0, 1}: if concat_axis not in {0, 1}:
raise NotImplementedError( raise NotImplementedError(
'Function only implemented for concat_axis equal to 0 or 1') "Function only implemented for concat_axis equal to 0 or 1"
)
tensor_list = list(tensor_dict.values()) tensor_list = list(tensor_dict.values())
concat_tensor = tf.concat(tensor_list, concat_axis) concat_tensor = tf.concat(tensor_list, concat_axis)
processed_concat_tensor = func(concat_tensor) processed_concat_tensor = func(concat_tensor)
@@ -47,90 +58,104 @@ def sync_apply(tensor_dict, func, concat_axis=1):
D = tensor_shape[concat_axis] D = tensor_shape[concat_axis]
if concat_axis == 0: if concat_axis == 0:
return { return {
name: processed_concat_tensor[index * D:(index + 1) * D, :, :] name: processed_concat_tensor[index * D : (index + 1) * D, :, :]
for index, name in enumerate(tensor_dict) for index, name in enumerate(tensor_dict)
} }
return { return {
name: processed_concat_tensor[:, index * D:(index + 1) * D, :] name: processed_concat_tensor[:, index * D : (index + 1) * D, :]
for index, name in enumerate(tensor_dict) for index, name in enumerate(tensor_dict)
} }
def from_float32_to_uint8( def from_float32_to_uint8(
tensor, tensor: tf.Tensor,
tensor_key='tensor', tensor_key: str = "tensor",
min_key='min', min_key: str = "min",
max_key='max'): max_key: str = "max",
) -> tf.Tensor:
""" """
:param tensor: Parameters:
:param tensor_key: tensor (tensorflow.Tensor):
:param min_key: tensor_key (str):
:param max_key: min_key (str):
:returns: max_key (str):
Returns:
tensorflow.Tensor:
""" """
tensor_min = tf.reduce_min(tensor) tensor_min = tf.reduce_min(tensor)
tensor_max = tf.reduce_max(tensor) tensor_max = tf.reduce_max(tensor)
return { return {
tensor_key: tf.cast( tensor_key: tf.cast(
(tensor - tensor_min) / (tensor_max - tensor_min + 1e-16) (tensor - tensor_min) / (tensor_max - tensor_min + 1e-16) * 255.9999,
* 255.9999, dtype=tf.uint8), dtype=tf.uint8,
),
min_key: tensor_min, min_key: tensor_min,
max_key: tensor_max max_key: tensor_max,
} }
def from_uint8_to_float32(tensor, tensor_min, tensor_max): def from_uint8_to_float32(
tensor: tf.Tensor, tensor_min: tf.Tensor, tensor_max: tf.Tensor
) -> tf.Tensor:
""" """
:param tensor: Parameters:
:param tensor_min: tensor (tensorflow.Tensor):
:param tensor_max: tensor_min (tensorflow.Tensor):
:returns: tensor_max (tensorflow.Tensor):
Returns:
tensorflow.Tensor:
""" """
return ( return (
tf.cast(tensor, tf.float32) tf.cast(tensor, tf.float32) * (tensor_max - tensor_min) / 255.9999 + tensor_min
* (tensor_max - tensor_min) )
/ 255.9999 + tensor_min)
def pad_and_partition(tensor, segment_len): def pad_and_partition(tensor: tf.Tensor, segment_len: int) -> tf.Tensor:
""" Pad and partition a tensor into segment of len segment_len """
Pad and partition a tensor into segment of len `segment_len`
along the first dimension. The tensor is padded with 0 in order along the first dimension. The tensor is padded with 0 in order
to ensure that the first dimension is a multiple of segment_len. to ensure that the first dimension is a multiple of `segment_len`.
Tensor must be of known fixed rank Tensor must be of known fixed rank
:Example: Examples:
>>> tensor = [[1, 2, 3], [4, 5, 6]] ```python
>>> segment_len = 2 >>> tensor = [[1, 2, 3], [4, 5, 6]]
>>> pad_and_partition(tensor, segment_len) >>> segment_len = 2
[[[1, 2], [4, 5]], [[3, 0], [6, 0]]] >>> pad_and_partition(tensor, segment_len)
[[[1, 2], [4, 5]], [[3, 0], [6, 0]]]
````
:param tensor: Parameters:
:param segment_len: tensor (tensorflow.Tensor):
:returns: segment_len (int):
Returns:
tensorflow.Tensor:
""" """
tensor_size = tf.math.floormod(tf.shape(tensor)[0], segment_len) tensor_size = tf.math.floormod(tf.shape(tensor)[0], segment_len)
pad_size = tf.math.floormod(segment_len - tensor_size, segment_len) pad_size = tf.math.floormod(segment_len - tensor_size, segment_len)
padded = tf.pad( padded = tf.pad(tensor, [[0, pad_size]] + [[0, 0]] * (len(tensor.shape) - 1))
tensor,
[[0, pad_size]] + [[0, 0]] * (len(tensor.shape)-1))
split = (tf.shape(padded)[0] + segment_len - 1) // segment_len split = (tf.shape(padded)[0] + segment_len - 1) // segment_len
return tf.reshape( return tf.reshape(
padded, padded, tf.concat([[split, segment_len], tf.shape(padded)[1:]], axis=0)
tf.concat( )
[[split, segment_len], tf.shape(padded)[1:]],
axis=0))
def pad_and_reshape(instr_spec, frame_length, F): def pad_and_reshape(instr_spec, frame_length, F) -> Any:
""" """
:param instr_spec: Parameters:
:param frame_length: instr_spec:
:param F: frame_length:
:returns: F:
Returns:
Any:
""" """
spec_shape = tf.shape(instr_spec) spec_shape = tf.shape(instr_spec)
extension_row = tf.zeros((spec_shape[0], spec_shape[1], 1, spec_shape[-1])) extension_row = tf.zeros((spec_shape[0], spec_shape[1], 1, spec_shape[-1]))
@@ -138,53 +163,67 @@ def pad_and_reshape(instr_spec, frame_length, F):
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1]) extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
extended_spec = tf.concat([instr_spec, extension], axis=2) extended_spec = tf.concat([instr_spec, extension], axis=2)
old_shape = tf.shape(extended_spec) old_shape = tf.shape(extended_spec)
new_shape = tf.concat([ new_shape = tf.concat([[old_shape[0] * old_shape[1]], old_shape[2:]], axis=0)
[old_shape[0] * old_shape[1]],
old_shape[2:]],
axis=0)
processed_instr_spec = tf.reshape(extended_spec, new_shape) processed_instr_spec = tf.reshape(extended_spec, new_shape)
return processed_instr_spec return processed_instr_spec
def dataset_from_csv(csv_path, **kwargs): def dataset_from_csv(csv_path: str, **kwargs) -> Any:
""" Load dataset from a CSV file using Pandas. kwargs if any are """
Load dataset from a CSV file using Pandas. kwargs if any are
forwarded to the `pandas.read_csv` function. forwarded to the `pandas.read_csv` function.
:param csv_path: Path of the CSV file to load dataset from. Parameters:
:returns: Loaded dataset. csv_path (str):
Path of the CSV file to load dataset from.
Returns:
Any:
Loaded dataset.
""" """
df = pd.read_csv(csv_path, **kwargs) df = pd.read_csv(csv_path, **kwargs)
dataset = ( dataset = tf.data.Dataset.from_tensor_slices({key: df[key].values for key in df})
tf.data.Dataset.from_tensor_slices(
{key: df[key].values for key in df})
)
return dataset return dataset
def check_tensor_shape(tensor_tf, target_shape): def check_tensor_shape(tensor_tf: tf.Tensor, target_shape: Any) -> bool:
""" Return a Tensorflow boolean graph that indicates whether """
Return a Tensorflow boolean graph that indicates whether
sample[features_key] has the specified target shape. Only check sample[features_key] has the specified target shape. Only check
not None entries of target_shape. not None entries of target_shape.
:param tensor_tf: Tensor to check shape for. Parameters:
:param target_shape: Target shape to compare tensor to. tensor_tf (tensorflow.Tensor):
:returns: True if shape is valid, False otherwise (as TF boolean). Tensor to check shape for.
target_shape (Any):
Target shape to compare tensor to.
Returns:
bool:
`True` if shape is valid, `False` otherwise (as TF boolean).
""" """
result = tf.constant(True) result = tf.constant(True)
for i, target_length in enumerate(target_shape): for i, target_length in enumerate(target_shape):
if target_length: if target_length:
result = tf.logical_and( result = tf.logical_and(
result, result, tf.equal(tf.constant(target_length), tf.shape(tensor_tf)[i])
tf.equal(tf.constant(target_length), tf.shape(tensor_tf)[i])) )
return result return result
def set_tensor_shape(tensor, tensor_shape): def set_tensor_shape(tensor: tf.Tensor, tensor_shape: Any) -> tf.Tensor:
""" Set shape for a tensor (not in place, as opposed to tf.set_shape) """
Set shape for a tensor (not in place, as opposed to tf.set_shape)
:param tensor: Tensor to reshape. Parameters:
:param tensor_shape: Shape to apply to the tensor. tensor (tensorflow.Tensor):
:returns: A reshaped tensor. Tensor to reshape.
tensor_shape (Any):
Shape to apply to the tensor.
Returns:
tensorflow.Tensor:
A reshaped tensor.
""" """
# NOTE: That SOUND LIKE IN PLACE HERE ? # NOTE: That SOUND LIKE IN PLACE HERE ?
tensor.set_shape(tensor_shape) tensor.set_shape(tensor_shape)

View File

@@ -7,82 +7,82 @@ __email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research' __author__ = 'Deezer Research'
__license__ = 'MIT License' __license__ = 'MIT License'
import filecmp
import itertools
from os import makedirs from os import makedirs
from os.path import splitext, basename, exists, join from os.path import join
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import pytest import pytest
import numpy as np import numpy as np
import tensorflow as tf from spleeter.__main__ import evaluate
from spleeter.audio.adapter import AudioAdapter
from spleeter.audio.adapter import get_default_audio_adapter BACKENDS = ['tensorflow', 'librosa']
from spleeter.commands import create_argument_parser TEST_CONFIGURATIONS = {el: el for el in BACKENDS}
from spleeter.commands import evaluate
from spleeter.utils.configuration import load_configuration
BACKENDS = ["tensorflow", "librosa"]
TEST_CONFIGURATIONS = {el:el for el in BACKENDS}
res_4stems = { res_4stems = {
"vocals": { 'vocals': {
"SDR": 3.25e-05, 'SDR': 3.25e-05,
"SAR": -11.153575, 'SAR': -11.153575,
"SIR": -1.3849, 'SIR': -1.3849,
"ISR": 2.75e-05 'ISR': 2.75e-05
}, },
"drums": { 'drums': {
"SDR": -0.079505, 'SDR': -0.079505,
"SAR": -15.7073575, 'SAR': -15.7073575,
"SIR": -4.972755, 'SIR': -4.972755,
"ISR": 0.0013575 'ISR': 0.0013575
}, },
"bass":{ 'bass': {
"SDR": 2.5e-06, 'SDR': 2.5e-06,
"SAR": -10.3520575, 'SAR': -10.3520575,
"SIR": -4.272325, 'SIR': -4.272325,
"ISR": 2.5e-06 'ISR': 2.5e-06
}, },
"other":{ 'other': {
"SDR": -1.359175, 'SDR': -1.359175,
"SAR": -14.7076775, 'SAR': -14.7076775,
"SIR": -4.761505, 'SIR': -4.761505,
"ISR": -0.01528 'ISR': -0.01528
} }
} }
def generate_fake_eval_dataset(path): def generate_fake_eval_dataset(path):
""" """
generate fake evaluation dataset generate fake evaluation dataset
""" """
aa = get_default_audio_adapter() aa = AudioAdapter.default()
n_songs = 2 n_songs = 2
fs = 44100 fs = 44100
duration = 3 duration = 3
n_channels = 2 n_channels = 2
rng = np.random.RandomState(seed=0) rng = np.random.RandomState(seed=0)
for song in range(n_songs): for song in range(n_songs):
song_path = join(path, "test", f"song{song}") song_path = join(path, 'test', f'song{song}')
makedirs(song_path, exist_ok=True) makedirs(song_path, exist_ok=True)
for instr in ["mixture", "vocals", "bass", "drums", "other"]: for instr in ['mixture', 'vocals', 'bass', 'drums', 'other']:
filename = join(song_path, f"{instr}.wav") filename = join(song_path, f'{instr}.wav')
data = rng.rand(duration*fs, n_channels)-0.5 data = rng.rand(duration*fs, n_channels)-0.5
aa.save(filename, data, fs) aa.save(filename, data, fs)
@pytest.mark.parametrize('backend', TEST_CONFIGURATIONS) @pytest.mark.parametrize('backend', TEST_CONFIGURATIONS)
def test_evaluate(backend): def test_evaluate(backend):
with TemporaryDirectory() as directory: with TemporaryDirectory() as dataset:
generate_fake_eval_dataset(directory) with TemporaryDirectory() as evaluation:
p = create_argument_parser() generate_fake_eval_dataset(dataset)
arguments = p.parse_args(["evaluate", "-p", "spleeter:4stems", "--mus_dir", directory, "-B", backend]) metrics = evaluate(
params = load_configuration(arguments.configuration) adapter='spleeter.audio.ffmpeg.FFMPEGProcessAudioAdapter',
metrics = evaluate.entrypoint(arguments, params) output_path=evaluation,
for instrument, metric in metrics.items(): stft_backend=backend,
for m, value in metric.items(): params_filename='spleeter:4stems',
assert np.allclose(np.median(value), res_4stems[instrument][m], atol=1e-3) mus_dir=dataset,
mwf=False,
verbose=False)
for instrument, metric in metrics.items():
for m, value in metric.items():
assert np.allclose(
np.median(value),
res_4stems[instrument][m],
atol=1e-3)

View File

@@ -10,6 +10,11 @@ __license__ = 'MIT License'
from os.path import join from os.path import join
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from spleeter import SpleeterError
from spleeter.audio.adapter import AudioAdapter
from spleeter.audio.ffmpeg import FFMPEGProcessAudioAdapter
# pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
from pytest import fixture, raises from pytest import fixture, raises
@@ -17,12 +22,6 @@ import numpy as np
import ffmpeg import ffmpeg
# pylint: enable=import-error # pylint: enable=import-error
from spleeter import SpleeterError
from spleeter.audio.adapter import AudioAdapter
from spleeter.audio.adapter import get_default_audio_adapter
from spleeter.audio.adapter import get_audio_adapter
from spleeter.audio.ffmpeg import FFMPEGProcessAudioAdapter
TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3' TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3'
TEST_OFFSET = 0 TEST_OFFSET = 0
TEST_DURATION = 600. TEST_DURATION = 600.
@@ -32,7 +31,7 @@ TEST_SAMPLE_RATE = 44100
@fixture(scope='session') @fixture(scope='session')
def adapter(): def adapter():
""" Target test audio adapter fixture. """ """ Target test audio adapter fixture. """
return get_default_audio_adapter() return AudioAdapter.default()
@fixture(scope='session') @fixture(scope='session')
@@ -48,7 +47,7 @@ def audio_data(adapter):
def test_default_adapter(adapter): def test_default_adapter(adapter):
""" Test adapter as default adapter. """ """ Test adapter as default adapter. """
assert isinstance(adapter, FFMPEGProcessAudioAdapter) assert isinstance(adapter, FFMPEGProcessAudioAdapter)
assert adapter is AudioAdapter.DEFAULT assert adapter is AudioAdapter._DEFAULT
def test_load(audio_data): def test_load(audio_data):

View File

@@ -5,12 +5,12 @@
from pytest import raises from pytest import raises
from spleeter.model.provider import get_default_model_provider from spleeter.model.provider import ModelProvider
def test_checksum(): def test_checksum():
""" Test archive checksum index retrieval. """ """ Test archive checksum index retrieval. """
provider = get_default_model_provider() provider = ModelProvider.default()
assert provider.checksum('2stems') == \ assert provider.checksum('2stems') == \
'f3a90b39dd2874269e8b05a48a86745df897b848c61f3958efc80a39152bd692' 'f3a90b39dd2874269e8b05a48a86745df897b848c61f3958efc80a39152bd692'
assert provider.checksum('4stems') == \ assert provider.checksum('4stems') == \

View File

@@ -17,7 +17,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from spleeter import SpleeterError from spleeter import SpleeterError
from spleeter.audio.adapter import get_default_audio_adapter from spleeter.audio.adapter import AudioAdapter
from spleeter.separator import Separator from spleeter.separator import Separator
TEST_AUDIO_DESCRIPTORS = ['audio_example.mp3', 'audio_example_mono.mp3'] TEST_AUDIO_DESCRIPTORS = ['audio_example.mp3', 'audio_example_mono.mp3']
@@ -41,7 +41,7 @@ print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__))
@pytest.mark.parametrize('test_file', TEST_AUDIO_DESCRIPTORS) @pytest.mark.parametrize('test_file', TEST_AUDIO_DESCRIPTORS)
def test_separator_backends(test_file): def test_separator_backends(test_file):
adapter = get_default_audio_adapter() adapter = AudioAdapter.default()
waveform, _ = adapter.load(test_file) waveform, _ = adapter.load(test_file)
separator_lib = Separator( separator_lib = Separator(
@@ -64,11 +64,13 @@ def test_separator_backends(test_file):
assert np.allclose(out_tf[instrument], out_lib[instrument], atol=1e-5) assert np.allclose(out_tf[instrument], out_lib[instrument], atol=1e-5)
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) @pytest.mark.parametrize(
'test_file, configuration, backend',
TEST_CONFIGURATIONS)
def test_separate(test_file, configuration, backend): def test_separate(test_file, configuration, backend):
""" Test separation from raw data. """ """ Test separation from raw data. """
instruments = MODEL_TO_INST[configuration] instruments = MODEL_TO_INST[configuration]
adapter = get_default_audio_adapter() adapter = AudioAdapter.default()
waveform, _ = adapter.load(test_file) waveform, _ = adapter.load(test_file)
separator = Separator( separator = Separator(
configuration, stft_backend=backend, multiprocess=False) configuration, stft_backend=backend, multiprocess=False)
@@ -85,7 +87,9 @@ def test_separate(test_file, configuration, backend):
assert not np.allclose(track, prediction[compared]) assert not np.allclose(track, prediction[compared])
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) @pytest.mark.parametrize(
'test_file, configuration, backend',
TEST_CONFIGURATIONS)
def test_separate_to_file(test_file, configuration, backend): def test_separate_to_file(test_file, configuration, backend):
""" Test file based separation. """ """ Test file based separation. """
instruments = MODEL_TO_INST[configuration] instruments = MODEL_TO_INST[configuration]
@@ -102,7 +106,9 @@ def test_separate_to_file(test_file, configuration, backend):
'{}/{}.wav'.format(name, instrument))) '{}/{}.wav'.format(name, instrument)))
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) @pytest.mark.parametrize(
'test_file, configuration, backend',
TEST_CONFIGURATIONS)
def test_filename_format(test_file, configuration, backend): def test_filename_format(test_file, configuration, backend):
""" Test custom filename format. """ """ Test custom filename format. """
instruments = MODEL_TO_INST[configuration] instruments = MODEL_TO_INST[configuration]
@@ -120,7 +126,9 @@ def test_filename_format(test_file, configuration, backend):
'export/{}/{}.wav'.format(name, instrument))) 'export/{}/{}.wav'.format(name, instrument)))
@pytest.mark.parametrize('test_file, configuration', MODELS_AND_TEST_FILES) @pytest.mark.parametrize(
'test_file, configuration',
MODELS_AND_TEST_FILES)
def test_filename_conflict(test_file, configuration): def test_filename_conflict(test_file, configuration):
""" Test error handling with static pattern. """ """ Test error handling with static pattern. """
separator = Separator(configuration, multiprocess=False) separator = Separator(configuration, multiprocess=False)

View File

@@ -7,107 +7,102 @@ __email__ = 'research@deezer.com'
__author__ = 'Deezer Research' __author__ = 'Deezer Research'
__license__ = 'MIT License' __license__ = 'MIT License'
import filecmp import json
import itertools
import os import os
from os import makedirs from os import makedirs
from os.path import splitext, basename, exists, join from os.path import join
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import json
import tensorflow as tf from spleeter.audio.adapter import AudioAdapter
from spleeter.__main__ import spleeter
from typer.testing import CliRunner
from spleeter.audio.adapter import get_default_audio_adapter
from spleeter.commands import create_argument_parser
from spleeter.commands import train
from spleeter.utils.configuration import load_configuration
TRAIN_CONFIG = { TRAIN_CONFIG = {
"mix_name": "mix", 'mix_name': 'mix',
"instrument_list": ["vocals", "other"], 'instrument_list': ['vocals', 'other'],
"sample_rate":44100, 'sample_rate': 44100,
"frame_length":4096, 'frame_length': 4096,
"frame_step":1024, 'frame_step': 1024,
"T":128, 'T': 128,
"F":128, 'F': 128,
"n_channels":2, 'n_channels': 2,
"chunk_duration":4, 'chunk_duration': 4,
"n_chunks_per_song":1, 'n_chunks_per_song': 1,
"separation_exponent":2, 'separation_exponent': 2,
"mask_extension":"zeros", 'mask_extension': 'zeros',
"learning_rate": 1e-4, 'learning_rate': 1e-4,
"batch_size":2, 'batch_size': 2,
"train_max_steps": 10, 'train_max_steps': 10,
"throttle_secs":20, 'throttle_secs': 20,
"save_checkpoints_steps":100, 'save_checkpoints_steps': 100,
"save_summary_steps":5, 'save_summary_steps': 5,
"random_seed":0, 'random_seed': 0,
"model":{ 'model': {
"type":"unet.unet", 'type': 'unet.unet',
"params":{ 'params': {
"conv_activation":"ELU", 'conv_activation': 'ELU',
"deconv_activation":"ELU" 'deconv_activation': 'ELU'
} }
} }
} }
def generate_fake_training_dataset(path, instrument_list=["vocals", "other"]): def generate_fake_training_dataset(path, instrument_list=['vocals', 'other']):
""" """
generates a fake training dataset in path: generates a fake training dataset in path:
- generates audio files - generates audio files
- generates a csv file describing the dataset - generates a csv file describing the dataset
""" """
aa = get_default_audio_adapter() aa = AudioAdapter.default()
n_songs = 2 n_songs = 2
fs = 44100 fs = 44100
duration = 6 duration = 6
n_channels = 2 n_channels = 2
rng = np.random.RandomState(seed=0) rng = np.random.RandomState(seed=0)
dataset_df = pd.DataFrame(columns=["mix_path"]+[f"{instr}_path" for instr in instrument_list]+["duration"]) dataset_df = pd.DataFrame(
columns=['mix_path'] + [
f'{instr}_path' for instr in instrument_list] + ['duration'])
for song in range(n_songs): for song in range(n_songs):
song_path = join(path, "train", f"song{song}") song_path = join(path, 'train', f'song{song}')
makedirs(song_path, exist_ok=True) makedirs(song_path, exist_ok=True)
dataset_df.loc[song, f"duration"] = duration dataset_df.loc[song, f'duration'] = duration
for instr in instrument_list+["mix"]: for instr in instrument_list+['mix']:
filename = join(song_path, f"{instr}.wav") filename = join(song_path, f'{instr}.wav')
data = rng.rand(duration*fs, n_channels)-0.5 data = rng.rand(duration*fs, n_channels)-0.5
aa.save(filename, data, fs) aa.save(filename, data, fs)
dataset_df.loc[song, f"{instr}_path"] = join("train", f"song{song}", f"{instr}.wav") dataset_df.loc[song, f'{instr}_path'] = join(
'train',
dataset_df.to_csv(join(path, "train", "train.csv"), index=False) f'song{song}',
f'{instr}.wav')
dataset_df.to_csv(join(path, 'train', 'train.csv'), index=False)
def test_train(): def test_train():
with TemporaryDirectory() as path: with TemporaryDirectory() as path:
# generate training dataset # generate training dataset
generate_fake_training_dataset(path) generate_fake_training_dataset(path)
# set training command aruments # set training command aruments
p = create_argument_parser() runner = CliRunner()
arguments = p.parse_args(["train", "-p", "useless_config.json", "-d", path]) TRAIN_CONFIG['train_csv'] = join(path, 'train', 'train.csv')
TRAIN_CONFIG["train_csv"] = join(path, "train", "train.csv") TRAIN_CONFIG['validation_csv'] = join(path, 'train', 'train.csv')
TRAIN_CONFIG["validation_csv"] = join(path, "train", "train.csv") TRAIN_CONFIG['model_dir'] = join(path, 'model')
TRAIN_CONFIG["model_dir"] = join(path, "model") TRAIN_CONFIG['training_cache'] = join(path, 'cache', 'training')
TRAIN_CONFIG["training_cache"] = join(path, "cache", "training") TRAIN_CONFIG['validation_cache'] = join(path, 'cache', 'validation')
TRAIN_CONFIG["validation_cache"] = join(path, "cache", "validation") with open('useless_config.json', 'w') as stream:
json.dump(TRAIN_CONFIG, stream)
# execute training # execute training
res = train.entrypoint(arguments, TRAIN_CONFIG) result = runner.invoke(spleeter, [
'train',
'-p', 'useless_config.json',
'-d', path
])
# assert that model checkpoint was created. # assert that model checkpoint was created.
assert os.path.exists(join(path,'model','model.ckpt-10.index')) assert os.path.exists(join(path, 'model', 'model.ckpt-10.index'))
assert os.path.exists(join(path,'model','checkpoint')) assert os.path.exists(join(path, 'model', 'checkpoint'))
assert os.path.exists(join(path,'model','model.ckpt-0.meta')) assert os.path.exists(join(path, 'model', 'model.ckpt-0.meta'))
assert result.exit_code == 0
if __name__=="__main__":
test_train()