mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
6
.github/workflows/conda.yml
vendored
6
.github/workflows/conda.yml
vendored
@@ -1,8 +1,6 @@
|
|||||||
name: conda
|
name: conda
|
||||||
on:
|
on:
|
||||||
push:
|
- workflow_dispatch
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
env:
|
env:
|
||||||
ANACONDA_USERNAME: ${{ secrets.ANACONDA_USERNAME }}
|
ANACONDA_USERNAME: ${{ secrets.ANACONDA_USERNAME }}
|
||||||
ANACONDA_PASSWORD: ${{ secrets.ANACONDA_PASSWORD }}
|
ANACONDA_PASSWORD: ${{ secrets.ANACONDA_PASSWORD }}
|
||||||
@@ -11,7 +9,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python: [3.7, 3.8]
|
python: [3.7, 3.8]
|
||||||
package: [spleeter]
|
package: [spleeter, spleeter-gpu]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
|
|||||||
38
.github/workflows/pypi.yml
vendored
38
.github/workflows/pypi.yml
vendored
@@ -4,40 +4,22 @@ on:
|
|||||||
branches:
|
branches:
|
||||||
- master
|
- master
|
||||||
env:
|
env:
|
||||||
TWINE_USERNAME: ${{ secrets.TWINE_USERNAME }}
|
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
|
||||||
TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD }}
|
|
||||||
jobs:
|
jobs:
|
||||||
package-and-deploy:
|
package-and-deploy:
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
platform: [cpu, gpu]
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- uses: actions/setup-python@v2
|
- uses: actions/setup-python@v2
|
||||||
with:
|
with:
|
||||||
python-version: 3.7
|
python-version: 3.7
|
||||||
- uses: actions/cache@v2
|
- name: Install Poetry
|
||||||
with:
|
run: |
|
||||||
path: ~/.cache/pip
|
pip install poetry
|
||||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }}
|
poetry config virtualenvs.in-project false
|
||||||
restore-keys: |
|
poetry config virtualenvs.path ~/.virtualenvs
|
||||||
${{ runner.os }}-pip-
|
poetry config pypi-token.pypi $PYPI_TOKEN
|
||||||
- uses: actions/cache@v2
|
|
||||||
with:
|
|
||||||
path: ${{ env.GITHUB_WORKSPACE }}/dist
|
|
||||||
key: sdist-${{ matrix.platform }}-${{ hashFiles('**/setup.py') }}
|
|
||||||
restore-keys: |
|
|
||||||
sdist-${{ matrix.platform }}-${{ hashFiles('**/setup.py') }}
|
|
||||||
sdist-${{ matrix.platform }}
|
|
||||||
sdist-
|
|
||||||
- name: Install dependencies
|
|
||||||
run: pip install --upgrade pip setuptools twine
|
|
||||||
- if: ${{ matrix.platform == 'cpu' }}
|
|
||||||
name: Package CPU distribution
|
|
||||||
run: make build
|
|
||||||
- if: ${{ matrix.platform == 'gpu' }}
|
|
||||||
name: Package GPU distribution)
|
|
||||||
run: make build-gpu
|
|
||||||
- name: Deploy to pypi
|
- name: Deploy to pypi
|
||||||
run: make deploy
|
run: |
|
||||||
|
poetry build
|
||||||
|
poetry publish
|
||||||
41
.github/workflows/pytest.yml
vendored
41
.github/workflows/pytest.yml
vendored
@@ -1,41 +0,0 @@
|
|||||||
name: pytest
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- master
|
|
||||||
jobs:
|
|
||||||
tests:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
python-version: [3.6, 3.7, 3.8]
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v2
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
|
||||||
uses: actions/setup-python@v2
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python-version }}
|
|
||||||
- uses: actions/cache@v2
|
|
||||||
id: spleeter-pip-cache
|
|
||||||
with:
|
|
||||||
path: ~/.cache/pip
|
|
||||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-pip-
|
|
||||||
- uses: actions/cache@v2
|
|
||||||
env:
|
|
||||||
model-release: 1
|
|
||||||
id: spleeter-model-cache
|
|
||||||
with:
|
|
||||||
path: ${{ env.GITHUB_WORKSPACE }}/pretrained_models
|
|
||||||
key: models-${{ env.model-release }}
|
|
||||||
restore-keys: |
|
|
||||||
models-${{ env.model-release }}
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
sudo apt-get update && sudo apt-get install -y ffmpeg
|
|
||||||
pip install --upgrade pip setuptools
|
|
||||||
pip install pytest==5.4.3 pytest-xdist==1.32.0 pytest-forked==1.1.3 musdb museval
|
|
||||||
python setup.py install
|
|
||||||
- name: Test with pytest
|
|
||||||
run: make test
|
|
||||||
51
.github/workflows/test.yml
vendored
Normal file
51
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
name: test
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
jobs:
|
||||||
|
tests:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: [3.7, 3.8]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- uses: actions/cache@v2
|
||||||
|
env:
|
||||||
|
model-release: 1
|
||||||
|
id: spleeter-model-cache
|
||||||
|
with:
|
||||||
|
path: ${{ env.GITHUB_WORKSPACE }}/pretrained_models
|
||||||
|
key: models-${{ env.model-release }}
|
||||||
|
restore-keys: |
|
||||||
|
models-${{ env.model-release }}
|
||||||
|
- name: Install ffmpeg
|
||||||
|
run: |
|
||||||
|
sudo apt-get update && sudo apt-get install -y ffmpeg
|
||||||
|
- name: Install Poetry
|
||||||
|
run: |
|
||||||
|
pip install poetry
|
||||||
|
poetry config virtualenvs.in-project false
|
||||||
|
poetry config virtualenvs.path ~/.virtualenvs
|
||||||
|
- name: Cache Poetry virtualenv
|
||||||
|
uses: actions/cache@v1
|
||||||
|
id: cache
|
||||||
|
with:
|
||||||
|
path: ~/.virtualenvs
|
||||||
|
key: poetry-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
|
||||||
|
restore-keys: |
|
||||||
|
poetry-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
|
||||||
|
- name: Install Dependencies
|
||||||
|
run: poetry install
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
- name: Code quality checks
|
||||||
|
run: |
|
||||||
|
poetry run black spleeter --check
|
||||||
|
poetry run isort spleeter --check
|
||||||
|
- name: Test with pytest
|
||||||
|
run: poetry run pytest tests/
|
||||||
22
CHANGELOG.md
22
CHANGELOG.md
@@ -1,5 +1,27 @@
|
|||||||
# Changelog History
|
# Changelog History
|
||||||
|
|
||||||
|
## 2.1.0
|
||||||
|
|
||||||
|
This version introduce design related changes, especially transition to Typer for CLI managment and Poetry as
|
||||||
|
library build backend.
|
||||||
|
|
||||||
|
* `-i` option is now deprecated and replaced by traditional CLI input argument listing
|
||||||
|
* Project is now built using Poetry
|
||||||
|
* Project requires code formatting using Black and iSort
|
||||||
|
* Dedicated GPU package `spleeter-gpu` is not supported anymore, `spleeter` package will support both CPU and GPU hardware
|
||||||
|
|
||||||
|
### API changes:
|
||||||
|
|
||||||
|
* function `get_default_audio_adapter` is now available as `default()` class method within `AudioAdapter` class
|
||||||
|
* function `get_default_model_provider` is now available as `default()` class method within `ModelProvider` class
|
||||||
|
* `STFTBackend` and `Codec` are now string enum
|
||||||
|
* `GithubModelProvider` now use `httpx` with HTTP/2 support
|
||||||
|
* Commands are now located in `__main__` module, wrapped as simple function using Typer options module provide specification for each available option and argument
|
||||||
|
* `types` module provide custom type specification and must be enhanced in future release to provide more robust typing support with MyPy
|
||||||
|
* `utils.logging` module has been cleaned, logger instance is now a module singleton, and a single function is used to configure it with verbose parameter
|
||||||
|
* Added a custom logger handler (see tiangolo/typer#203 discussion)
|
||||||
|
|
||||||
|
|
||||||
## 2.0
|
## 2.0
|
||||||
|
|
||||||
First release, October 9th 2020
|
First release, October 9th 2020
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
include spleeter/resources/*.json
|
|
||||||
include README.md
|
|
||||||
include LICENSE
|
|
||||||
34
Makefile
34
Makefile
@@ -1,34 +0,0 @@
|
|||||||
# =======================================================
|
|
||||||
# Library lifecycle management.
|
|
||||||
#
|
|
||||||
# @author Deezer Research <spleeter@deezer.com>
|
|
||||||
# @licence MIT Licence
|
|
||||||
# =======================================================
|
|
||||||
|
|
||||||
FEEDSTOCK = spleeter-feedstock
|
|
||||||
FEEDSTOCK_REPOSITORY = https://github.com/deezer/$(FEEDSTOCK)
|
|
||||||
FEEDSTOCK_RECIPE = $(FEEDSTOCK)/recipe/spleeter/meta.yaml
|
|
||||||
PYTEST_CMD = pytest -W ignore::FutureWarning -W ignore::DeprecationWarning -vv --forked
|
|
||||||
|
|
||||||
all: clean build test deploy
|
|
||||||
|
|
||||||
clean:
|
|
||||||
rm -Rf *.egg-info
|
|
||||||
rm -Rf dist
|
|
||||||
|
|
||||||
build: clean
|
|
||||||
sed -i "s/project_name = '[^']*'/project_name = 'spleeter'/g" setup.py
|
|
||||||
sed -i "s/tensorflow_dependency = '[^']*'/tensorflow_dependency = 'tensorflow'/g" setup.py
|
|
||||||
python3 setup.py sdist
|
|
||||||
|
|
||||||
build-gpu: clean
|
|
||||||
sed -i "s/project_name = '[^']*'/project_name = 'spleeter-gpu'/g" setup.py
|
|
||||||
sed -i "s/tensorflow_dependency = '[^']*'/tensorflow_dependency = 'tensorflow-gpu'/g" setup.py
|
|
||||||
python3 setup.py sdist
|
|
||||||
|
|
||||||
test:
|
|
||||||
$(PYTEST_CMD) tests/
|
|
||||||
|
|
||||||
deploy:
|
|
||||||
pip install twine
|
|
||||||
twine upload --skip-existing dist/*
|
|
||||||
18
README.md
18
README.md
@@ -2,6 +2,9 @@
|
|||||||
|
|
||||||
[](https://github.com/deezer/spleeter/actions)  [](https://badge.fury.io/py/spleeter) [](https://anaconda.org/conda-forge/spleeter) [](https://hub.docker.com/r/researchdeezer/spleeter) [](https://colab.research.google.com/github/deezer/spleeter/blob/master/spleeter.ipynb) [](https://gitter.im/spleeter/community) [](https://joss.theoj.org/papers/259e5efe669945a343bad6eccb89018b)
|
[](https://github.com/deezer/spleeter/actions)  [](https://badge.fury.io/py/spleeter) [](https://anaconda.org/conda-forge/spleeter) [](https://hub.docker.com/r/researchdeezer/spleeter) [](https://colab.research.google.com/github/deezer/spleeter/blob/master/spleeter.ipynb) [](https://gitter.im/spleeter/community) [](https://joss.theoj.org/papers/259e5efe669945a343bad6eccb89018b)
|
||||||
|
|
||||||
|
> :warning: [Spleeter 2.1.0](https://pypi.org/project/spleeter/) release introduces some breaking changes, including new CLI option naming for input, and the drop
|
||||||
|
> of dedicated GPU package. Please read [CHANGELOG](CHANGELOG.md) for more details.
|
||||||
|
|
||||||
## About
|
## About
|
||||||
|
|
||||||
**Spleeter** is [Deezer](https://www.deezer.com/) source separation library with pretrained models
|
**Spleeter** is [Deezer](https://www.deezer.com/) source separation library with pretrained models
|
||||||
@@ -46,7 +49,7 @@ conda install -c conda-forge spleeter
|
|||||||
# download an example audio file (if you don't have wget, use another tool for downloading)
|
# download an example audio file (if you don't have wget, use another tool for downloading)
|
||||||
wget https://github.com/deezer/spleeter/raw/master/audio_example.mp3
|
wget https://github.com/deezer/spleeter/raw/master/audio_example.mp3
|
||||||
# separate the example audio into two components
|
# separate the example audio into two components
|
||||||
spleeter separate -i audio_example.mp3 -p spleeter:2stems -o output
|
spleeter separate -p spleeter:2stems -o output audio_example.mp3
|
||||||
```
|
```
|
||||||
|
|
||||||
You should get two separated audio files (`vocals.wav` and `accompaniment.wav`) in the `output/audio_example` folder.
|
You should get two separated audio files (`vocals.wav` and `accompaniment.wav`) in the `output/audio_example` folder.
|
||||||
@@ -55,13 +58,18 @@ For a detailed documentation, please check the [repository wiki](https://github.
|
|||||||
|
|
||||||
## Development and Testing
|
## Development and Testing
|
||||||
|
|
||||||
The following set of commands will clone this repository, create a virtual environment provisioned with the dependencies and run the tests (will take a few minutes):
|
This project is managed using [Poetry](https://python-poetry.org/docs/basic-usage/), to run test suite you
|
||||||
|
can execute the following set of commands:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# Clone spleeter repository
|
||||||
git clone https://github.com/Deezer/spleeter && cd spleeter
|
git clone https://github.com/Deezer/spleeter && cd spleeter
|
||||||
python -m venv spleeterenv && source spleeterenv/bin/activate
|
# Install poetry
|
||||||
pip install . && pip install pytest pytest-xdist
|
pip install poetry
|
||||||
make test
|
# Install spleeter dependencies
|
||||||
|
poetry install
|
||||||
|
# Run unit test suite
|
||||||
|
poetry run pytest tests/
|
||||||
```
|
```
|
||||||
|
|
||||||
## Reference
|
## Reference
|
||||||
|
|||||||
52
conda/spleeter-gpu/meta.yaml
Normal file
52
conda/spleeter-gpu/meta.yaml
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
{% set name = "spleeter-gpu" %}
|
||||||
|
{% set version = "2.0.2" %}
|
||||||
|
|
||||||
|
package:
|
||||||
|
name: {{ name|lower }}
|
||||||
|
version: {{ version }}
|
||||||
|
|
||||||
|
source:
|
||||||
|
- url: https://pypi.io/packages/source/{{ name[0] }}/{{ name }}/{{ name }}-{{ version }}.tar.gz
|
||||||
|
sha256: ecd3518a98f9978b9088d1cb2ef98f766401fd9007c2bf72a34e5b5bc5a6fdc3
|
||||||
|
|
||||||
|
build:
|
||||||
|
number: 0
|
||||||
|
script: {{ PYTHON }} -m pip install . -vv
|
||||||
|
skip: True # [osx]
|
||||||
|
entry_points:
|
||||||
|
- spleeter = spleeter.__main__:entrypoint
|
||||||
|
|
||||||
|
requirements:
|
||||||
|
host:
|
||||||
|
- python {{ python }}
|
||||||
|
- pip
|
||||||
|
run:
|
||||||
|
- python {{ python }}
|
||||||
|
- tensorflow-gpu ==2.2.0 # [linux]
|
||||||
|
- tensorflow-gpu ==23.0 # [win]
|
||||||
|
- pandas
|
||||||
|
- ffmpeg-python
|
||||||
|
- norbert
|
||||||
|
- librosa
|
||||||
|
|
||||||
|
test:
|
||||||
|
imports:
|
||||||
|
- spleeter
|
||||||
|
- spleeter.commands
|
||||||
|
- spleeter.model
|
||||||
|
- spleeter.utils
|
||||||
|
- spleeter.separator
|
||||||
|
|
||||||
|
about:
|
||||||
|
home: https://github.com/deezer/spleeter
|
||||||
|
license: MIT
|
||||||
|
license_family: MIT
|
||||||
|
license_file: LICENSE
|
||||||
|
summary: The Deezer source separation library with pretrained models based on tensorflow.
|
||||||
|
doc_url: https://github.com/deezer/spleeter/wiki
|
||||||
|
dev_url: https://github.com/deezer/spleeter
|
||||||
|
|
||||||
|
extra:
|
||||||
|
recipe-maintainers:
|
||||||
|
- Faylixe
|
||||||
|
- romi1502
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
python:
|
|
||||||
- 3.7
|
|
||||||
- 3.8
|
|
||||||
1880
poetry.lock
generated
Normal file
1880
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
83
pyproject.toml
Normal file
83
pyproject.toml
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
[tool.poetry]
|
||||||
|
name = "spleeter"
|
||||||
|
version = "2.1.0"
|
||||||
|
description = "The Deezer source separation library with pretrained models based on tensorflow."
|
||||||
|
authors = ["Deezer Research <spleeter@deezer.com>"]
|
||||||
|
license = "MIT License"
|
||||||
|
readme = "README.md"
|
||||||
|
repository = "https://github.com/deezer/spleeter"
|
||||||
|
homepage = "https://github.com/deezer/spleeter"
|
||||||
|
classifiers = [
|
||||||
|
"Environment :: Console",
|
||||||
|
"Environment :: MacOS X",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"Intended Audience :: Information Technology",
|
||||||
|
"Intended Audience :: Science/Research",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Natural Language :: English",
|
||||||
|
"Operating System :: MacOS",
|
||||||
|
"Operating System :: Microsoft :: Windows",
|
||||||
|
"Operating System :: POSIX :: Linux",
|
||||||
|
"Operating System :: Unix",
|
||||||
|
"Programming Language :: Python",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.6",
|
||||||
|
"Programming Language :: Python :: 3.7",
|
||||||
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3 :: Only",
|
||||||
|
"Programming Language :: Python :: Implementation :: CPython",
|
||||||
|
"Topic :: Artistic Software",
|
||||||
|
"Topic :: Multimedia",
|
||||||
|
"Topic :: Multimedia :: Sound/Audio",
|
||||||
|
"Topic :: Multimedia :: Sound/Audio :: Analysis",
|
||||||
|
"Topic :: Multimedia :: Sound/Audio :: Conversion",
|
||||||
|
"Topic :: Multimedia :: Sound/Audio :: Sound Synthesis",
|
||||||
|
"Topic :: Scientific/Engineering",
|
||||||
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
|
"Topic :: Scientific/Engineering :: Information Analysis",
|
||||||
|
"Topic :: Software Development",
|
||||||
|
"Topic :: Software Development :: Libraries",
|
||||||
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||||
|
"Topic :: Utilities"
|
||||||
|
]
|
||||||
|
packages = [ { include = "spleeter" } ]
|
||||||
|
include = ["LICENSE", "spleeter/resources/*.json"]
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = "^3.7"
|
||||||
|
ffmpeg-python = "0.2.0"
|
||||||
|
norbert = "0.2.1"
|
||||||
|
httpx = {extras = ["http2"], version = "^0.16.1"}
|
||||||
|
typer = "^0.3.2"
|
||||||
|
librosa = "0.8.0"
|
||||||
|
musdb = {version = "0.3.1", optional = true}
|
||||||
|
museval = {version = "0.3.0", optional = true}
|
||||||
|
tensorflow = "2.3.0"
|
||||||
|
pandas = "1.1.2"
|
||||||
|
numpy = "<1.19.0,>=1.16.0"
|
||||||
|
|
||||||
|
[tool.poetry.dev-dependencies]
|
||||||
|
pytest = "^6.2.1"
|
||||||
|
isort = "^5.7.0"
|
||||||
|
black = "^20.8b1"
|
||||||
|
mypy = "^0.790"
|
||||||
|
pytest-forked = "^1.3.0"
|
||||||
|
musdb = "0.3.1"
|
||||||
|
museval = "0.3.0"
|
||||||
|
|
||||||
|
[tool.poetry.scripts]
|
||||||
|
spleeter = 'spleeter.__main__:entrypoint'
|
||||||
|
|
||||||
|
[tool.poetry.extras]
|
||||||
|
evaluation = ["musdb", "museval"]
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
profile = "black"
|
||||||
|
multi_line_output = 3
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
addopts = "-W ignore::FutureWarning -W ignore::DeprecationWarning -vv --forked"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core>=1.0.0"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
102
setup.py
102
setup.py
@@ -1,102 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# coding: utf8
|
|
||||||
|
|
||||||
""" Distribution script. """
|
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from os import path
|
|
||||||
from setuptools import setup
|
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
|
||||||
__author__ = 'Deezer Research'
|
|
||||||
__license__ = 'MIT License'
|
|
||||||
|
|
||||||
# Default project values.
|
|
||||||
project_name = 'spleeter'
|
|
||||||
project_version = '2.0.2'
|
|
||||||
tensorflow_dependency = 'tensorflow'
|
|
||||||
tensorflow_version = '2.3.0'
|
|
||||||
here = path.abspath(path.dirname(__file__))
|
|
||||||
readme_path = path.join(here, 'README.md')
|
|
||||||
with open(readme_path, 'r') as stream:
|
|
||||||
readme = stream.read()
|
|
||||||
|
|
||||||
# Package setup entrypoint.
|
|
||||||
setup(
|
|
||||||
name=project_name,
|
|
||||||
version=project_version,
|
|
||||||
description='''
|
|
||||||
The Deezer source separation library with
|
|
||||||
pretrained models based on tensorflow.
|
|
||||||
''',
|
|
||||||
long_description=readme,
|
|
||||||
long_description_content_type='text/markdown',
|
|
||||||
author='Deezer Research',
|
|
||||||
author_email='spleeter@deezer.com',
|
|
||||||
url='https://github.com/deezer/spleeter',
|
|
||||||
license='MIT License',
|
|
||||||
packages=[
|
|
||||||
'spleeter',
|
|
||||||
'spleeter.audio',
|
|
||||||
'spleeter.commands',
|
|
||||||
'spleeter.model',
|
|
||||||
'spleeter.model.functions',
|
|
||||||
'spleeter.model.provider',
|
|
||||||
'spleeter.resources',
|
|
||||||
'spleeter.utils',
|
|
||||||
],
|
|
||||||
package_data={'spleeter.resources': ['*.json']},
|
|
||||||
python_requires='>=3.6, <3.9',
|
|
||||||
include_package_data=True,
|
|
||||||
install_requires=[
|
|
||||||
'ffmpeg-python==0.2.0',
|
|
||||||
'importlib_resources ; python_version<"3.7"',
|
|
||||||
'norbert==0.2.1',
|
|
||||||
'numpy<1.19.0,>=1.16.0',
|
|
||||||
'pandas==1.1.2',
|
|
||||||
'requests',
|
|
||||||
'scipy==1.4.1',
|
|
||||||
'setuptools>=41.0.0',
|
|
||||||
'librosa==0.8.0',
|
|
||||||
'{}=={}'.format(tensorflow_dependency, tensorflow_version),
|
|
||||||
],
|
|
||||||
extras_require={
|
|
||||||
'evaluation': ['musdb==0.3.1', 'museval==0.3.0']
|
|
||||||
},
|
|
||||||
entry_points={
|
|
||||||
'console_scripts': ['spleeter=spleeter.__main__:entrypoint']
|
|
||||||
},
|
|
||||||
classifiers=[
|
|
||||||
'Environment :: Console',
|
|
||||||
'Environment :: MacOS X',
|
|
||||||
'Intended Audience :: Developers',
|
|
||||||
'Intended Audience :: Information Technology',
|
|
||||||
'Intended Audience :: Science/Research',
|
|
||||||
'License :: OSI Approved :: MIT License',
|
|
||||||
'Natural Language :: English',
|
|
||||||
'Operating System :: MacOS',
|
|
||||||
'Operating System :: Microsoft :: Windows',
|
|
||||||
'Operating System :: POSIX :: Linux',
|
|
||||||
'Operating System :: Unix',
|
|
||||||
'Programming Language :: Python',
|
|
||||||
'Programming Language :: Python :: 3',
|
|
||||||
'Programming Language :: Python :: 3.6',
|
|
||||||
'Programming Language :: Python :: 3.7',
|
|
||||||
'Programming Language :: Python :: 3.8',
|
|
||||||
'Programming Language :: Python :: 3 :: Only',
|
|
||||||
'Programming Language :: Python :: Implementation :: CPython',
|
|
||||||
'Topic :: Artistic Software',
|
|
||||||
'Topic :: Multimedia',
|
|
||||||
'Topic :: Multimedia :: Sound/Audio',
|
|
||||||
'Topic :: Multimedia :: Sound/Audio :: Analysis',
|
|
||||||
'Topic :: Multimedia :: Sound/Audio :: Conversion',
|
|
||||||
'Topic :: Multimedia :: Sound/Audio :: Sound Synthesis',
|
|
||||||
'Topic :: Scientific/Engineering',
|
|
||||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
|
||||||
'Topic :: Scientific/Engineering :: Information Analysis',
|
|
||||||
'Topic :: Software Development',
|
|
||||||
'Topic :: Software Development :: Libraries',
|
|
||||||
'Topic :: Software Development :: Libraries :: Python Modules',
|
|
||||||
'Topic :: Utilities']
|
|
||||||
)
|
|
||||||
@@ -13,9 +13,9 @@
|
|||||||
by providing train, evaluation and source separation action.
|
by providing train, evaluation and source separation action.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
__email__ = "spleeter@deezer.com"
|
||||||
__author__ = 'Deezer Research'
|
__author__ = "Deezer Research"
|
||||||
__license__ = 'MIT License'
|
__license__ = "MIT License"
|
||||||
|
|
||||||
|
|
||||||
class SpleeterError(Exception):
|
class SpleeterError(Exception):
|
||||||
|
|||||||
@@ -5,54 +5,252 @@
|
|||||||
Python oneliner script usage.
|
Python oneliner script usage.
|
||||||
|
|
||||||
USAGE: python -m spleeter {train,evaluate,separate} ...
|
USAGE: python -m spleeter {train,evaluate,separate} ...
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
All critical import involving TF, numpy or Pandas are deported to
|
||||||
|
command function scope to avoid heavy import on CLI evaluation,
|
||||||
|
leading to large bootstraping time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
import json
|
||||||
import warnings
|
from functools import partial
|
||||||
|
from glob import glob
|
||||||
|
from itertools import product
|
||||||
|
from os.path import join
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Container, Dict, List, Optional
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
|
# pylint: disable=import-error
|
||||||
|
from typer import Exit, Typer
|
||||||
|
|
||||||
from . import SpleeterError
|
from . import SpleeterError
|
||||||
from .commands import create_argument_parser
|
from .options import *
|
||||||
from .utils.configuration import load_configuration
|
from .utils.logging import configure_logger, logger
|
||||||
from .utils.logging import (
|
|
||||||
enable_logging,
|
|
||||||
enable_tensorflow_logging,
|
|
||||||
get_logger)
|
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
# pylint: enable=import-error
|
||||||
__author__ = 'Deezer Research'
|
|
||||||
__license__ = 'MIT License'
|
spleeter: Typer = Typer(add_completion=False)
|
||||||
|
""" CLI application. """
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
@spleeter.command()
|
||||||
""" Spleeter runner. Parse provided command line arguments
|
def train(
|
||||||
and run entrypoint for required command (either train,
|
adapter: str = AudioAdapterOption,
|
||||||
evaluate or separate).
|
data: Path = TrainingDataDirectoryOption,
|
||||||
|
params_filename: str = ModelParametersOption,
|
||||||
:param argv: Provided command line arguments.
|
verbose: bool = VerboseOption,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
Train a source separation model
|
||||||
|
"""
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from .audio.adapter import AudioAdapter
|
||||||
|
from .dataset import get_training_dataset, get_validation_dataset
|
||||||
|
from .model import model_fn
|
||||||
|
from .model.provider import ModelProvider
|
||||||
|
from .utils.configuration import load_configuration
|
||||||
|
|
||||||
|
configure_logger(verbose)
|
||||||
|
audio_adapter = AudioAdapter.get(adapter)
|
||||||
|
audio_path = str(data)
|
||||||
|
params = load_configuration(params_filename)
|
||||||
|
session_config = tf.compat.v1.ConfigProto()
|
||||||
|
session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
|
||||||
|
estimator = tf.estimator.Estimator(
|
||||||
|
model_fn=model_fn,
|
||||||
|
model_dir=params["model_dir"],
|
||||||
|
params=params,
|
||||||
|
config=tf.estimator.RunConfig(
|
||||||
|
save_checkpoints_steps=params["save_checkpoints_steps"],
|
||||||
|
tf_random_seed=params["random_seed"],
|
||||||
|
save_summary_steps=params["save_summary_steps"],
|
||||||
|
session_config=session_config,
|
||||||
|
log_step_count_steps=10,
|
||||||
|
keep_checkpoint_max=2,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
|
||||||
|
train_spec = tf.estimator.TrainSpec(
|
||||||
|
input_fn=input_fn, max_steps=params["train_max_steps"]
|
||||||
|
)
|
||||||
|
input_fn = partial(get_validation_dataset, params, audio_adapter, audio_path)
|
||||||
|
evaluation_spec = tf.estimator.EvalSpec(
|
||||||
|
input_fn=input_fn, steps=None, throttle_secs=params["throttle_secs"]
|
||||||
|
)
|
||||||
|
logger.info("Start model training")
|
||||||
|
tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec)
|
||||||
|
ModelProvider.writeProbe(params["model_dir"])
|
||||||
|
logger.info("Model training done")
|
||||||
|
|
||||||
|
|
||||||
|
@spleeter.command()
|
||||||
|
def separate(
|
||||||
|
deprecated_files: Optional[str] = AudioInputOption,
|
||||||
|
files: List[Path] = AudioInputArgument,
|
||||||
|
adapter: str = AudioAdapterOption,
|
||||||
|
bitrate: str = AudioBitrateOption,
|
||||||
|
codec: Codec = AudioCodecOption,
|
||||||
|
duration: float = AudioDurationOption,
|
||||||
|
offset: float = AudioOffsetOption,
|
||||||
|
output_path: Path = AudioOutputOption,
|
||||||
|
stft_backend: STFTBackend = AudioSTFTBackendOption,
|
||||||
|
filename_format: str = FilenameFormatOption,
|
||||||
|
params_filename: str = ModelParametersOption,
|
||||||
|
mwf: bool = MWFOption,
|
||||||
|
verbose: bool = VerboseOption,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Separate audio file(s)
|
||||||
|
"""
|
||||||
|
from .audio.adapter import AudioAdapter
|
||||||
|
from .separator import Separator
|
||||||
|
|
||||||
|
configure_logger(verbose)
|
||||||
|
if deprecated_files is not None:
|
||||||
|
logger.error(
|
||||||
|
"⚠️ -i option is not supported anymore, audio files must be supplied "
|
||||||
|
"using input argument instead (see spleeter separate --help)"
|
||||||
|
)
|
||||||
|
raise Exit(20)
|
||||||
|
audio_adapter: AudioAdapter = AudioAdapter.get(adapter)
|
||||||
|
separator: Separator = Separator(
|
||||||
|
params_filename, MWF=mwf, stft_backend=stft_backend
|
||||||
|
)
|
||||||
|
for filename in files:
|
||||||
|
separator.separate_to_file(
|
||||||
|
str(filename),
|
||||||
|
str(output_path),
|
||||||
|
audio_adapter=audio_adapter,
|
||||||
|
offset=offset,
|
||||||
|
duration=duration,
|
||||||
|
codec=codec,
|
||||||
|
bitrate=bitrate,
|
||||||
|
filename_format=filename_format,
|
||||||
|
synchronous=False,
|
||||||
|
)
|
||||||
|
separator.join()
|
||||||
|
|
||||||
|
|
||||||
|
EVALUATION_SPLIT: str = "test"
|
||||||
|
EVALUATION_METRICS_DIRECTORY: str = "metrics"
|
||||||
|
EVALUATION_INSTRUMENTS: Container[str] = ("vocals", "drums", "bass", "other")
|
||||||
|
EVALUATION_METRICS: Container[str] = ("SDR", "SAR", "SIR", "ISR")
|
||||||
|
EVALUATION_MIXTURE: str = "mixture.wav"
|
||||||
|
EVALUATION_AUDIO_DIRECTORY: str = "audio"
|
||||||
|
|
||||||
|
|
||||||
|
def _compile_metrics(metrics_output_directory) -> Dict:
|
||||||
|
"""
|
||||||
|
Compiles metrics from given directory and returns results as dict.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
metrics_output_directory (str):
|
||||||
|
Directory to get metrics from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict:
|
||||||
|
Compiled metrics as dict.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
songs = glob(join(metrics_output_directory, "test/*.json"))
|
||||||
|
index = pd.MultiIndex.from_tuples(
|
||||||
|
product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS),
|
||||||
|
names=["instrument", "metric"],
|
||||||
|
)
|
||||||
|
pd.DataFrame([], index=["config1", "config2"], columns=index)
|
||||||
|
metrics = {
|
||||||
|
instrument: {k: [] for k in EVALUATION_METRICS}
|
||||||
|
for instrument in EVALUATION_INSTRUMENTS
|
||||||
|
}
|
||||||
|
for song in songs:
|
||||||
|
with open(song, "r") as stream:
|
||||||
|
data = json.load(stream)
|
||||||
|
for target in data["targets"]:
|
||||||
|
instrument = target["name"]
|
||||||
|
for metric in EVALUATION_METRICS:
|
||||||
|
sdr_med = np.median(
|
||||||
|
[
|
||||||
|
frame["metrics"][metric]
|
||||||
|
for frame in target["frames"]
|
||||||
|
if not np.isnan(frame["metrics"][metric])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
metrics[instrument][metric].append(sdr_med)
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
@spleeter.command()
|
||||||
|
def evaluate(
|
||||||
|
adapter: str = AudioAdapterOption,
|
||||||
|
output_path: Path = AudioOutputOption,
|
||||||
|
stft_backend: STFTBackend = AudioSTFTBackendOption,
|
||||||
|
params_filename: str = ModelParametersOption,
|
||||||
|
mus_dir: Path = MUSDBDirectoryOption,
|
||||||
|
mwf: bool = MWFOption,
|
||||||
|
verbose: bool = VerboseOption,
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Evaluate a model on the musDB test dataset
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
configure_logger(verbose)
|
||||||
try:
|
try:
|
||||||
parser = create_argument_parser()
|
import musdb
|
||||||
arguments = parser.parse_args(argv[1:])
|
import museval
|
||||||
enable_logging()
|
except ImportError:
|
||||||
if arguments.verbose:
|
logger.error("Extra dependencies musdb and museval not found")
|
||||||
enable_tensorflow_logging()
|
logger.error("Please install musdb and museval first, abort")
|
||||||
if arguments.command == 'separate':
|
raise Exit(10)
|
||||||
from .commands.separate import entrypoint
|
# Separate musdb sources.
|
||||||
elif arguments.command == 'train':
|
songs = glob(join(mus_dir, EVALUATION_SPLIT, "*/"))
|
||||||
from .commands.train import entrypoint
|
mixtures = [join(song, EVALUATION_MIXTURE) for song in songs]
|
||||||
elif arguments.command == 'evaluate':
|
audio_output_directory = join(output_path, EVALUATION_AUDIO_DIRECTORY)
|
||||||
from .commands.evaluate import entrypoint
|
separate(
|
||||||
params = load_configuration(arguments.configuration)
|
deprecated_files=None,
|
||||||
entrypoint(arguments, params)
|
files=mixtures,
|
||||||
except SpleeterError as e:
|
adapter=adapter,
|
||||||
get_logger().error(e)
|
bitrate="128k",
|
||||||
|
codec=Codec.WAV,
|
||||||
|
duration=600.0,
|
||||||
|
offset=0,
|
||||||
|
output_path=join(audio_output_directory, EVALUATION_SPLIT),
|
||||||
|
stft_backend=stft_backend,
|
||||||
|
filename_format="{foldername}/{instrument}.{codec}",
|
||||||
|
params_filename=params_filename,
|
||||||
|
mwf=mwf,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
# Compute metrics with musdb.
|
||||||
|
metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY)
|
||||||
|
logger.info("Starting musdb evaluation (this could be long) ...")
|
||||||
|
dataset = musdb.DB(root=mus_dir, is_wav=True, subsets=[EVALUATION_SPLIT])
|
||||||
|
museval.eval_mus_dir(
|
||||||
|
dataset=dataset,
|
||||||
|
estimates_dir=audio_output_directory,
|
||||||
|
output_dir=metrics_output_directory,
|
||||||
|
)
|
||||||
|
logger.info("musdb evaluation done")
|
||||||
|
# Compute and pretty print median metrics.
|
||||||
|
metrics = _compile_metrics(metrics_output_directory)
|
||||||
|
for instrument, metric in metrics.items():
|
||||||
|
logger.info(f"{instrument}:")
|
||||||
|
for metric, value in metric.items():
|
||||||
|
logger.info(f"{metric}: {np.median(value):.3f}")
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
def entrypoint():
|
def entrypoint():
|
||||||
""" Command line entrypoint. """
|
""" Application entrypoint. """
|
||||||
warnings.filterwarnings('ignore')
|
try:
|
||||||
main(sys.argv)
|
spleeter()
|
||||||
|
except SpleeterError as e:
|
||||||
|
logger.error(e)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
entrypoint()
|
entrypoint()
|
||||||
|
|||||||
@@ -10,6 +10,43 @@
|
|||||||
- Waveform convertion and transforming functions.
|
- Waveform convertion and transforming functions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
from enum import Enum
|
||||||
__author__ = 'Deezer Research'
|
|
||||||
__license__ = 'MIT License'
|
__email__ = "spleeter@deezer.com"
|
||||||
|
__author__ = "Deezer Research"
|
||||||
|
__license__ = "MIT License"
|
||||||
|
|
||||||
|
|
||||||
|
class Codec(str, Enum):
|
||||||
|
""" Enumeration of supported audio codec. """
|
||||||
|
|
||||||
|
WAV: str = "wav"
|
||||||
|
MP3: str = "mp3"
|
||||||
|
OGG: str = "ogg"
|
||||||
|
M4A: str = "m4a"
|
||||||
|
WMA: str = "wma"
|
||||||
|
FLAC: str = "flac"
|
||||||
|
|
||||||
|
|
||||||
|
class STFTBackend(str, Enum):
|
||||||
|
""" Enumeration of supported STFT backend. """
|
||||||
|
|
||||||
|
AUTO: str = "auto"
|
||||||
|
TENSORFLOW: str = "tensorflow"
|
||||||
|
LIBROSA: str = "librosa"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def resolve(cls: type, backend: str) -> str:
|
||||||
|
# NOTE: import is resolved here to avoid performance issues on command
|
||||||
|
# evaluation.
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
|
# pylint: disable=import-error
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
if backend not in cls.__members__.values():
|
||||||
|
raise ValueError(f"Unsupported backend {backend}")
|
||||||
|
if backend == cls.AUTO:
|
||||||
|
if len(tf.config.list_physical_devices("GPU")):
|
||||||
|
return cls.TENSORFLOW
|
||||||
|
return cls.LIBROSA
|
||||||
|
return backend
|
||||||
|
|||||||
@@ -3,70 +3,101 @@
|
|||||||
|
|
||||||
""" AudioAdapter class defintion. """
|
""" AudioAdapter class defintion. """
|
||||||
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from os.path import exists
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.signal import stft, hann_window
|
from spleeter.audio import Codec
|
||||||
# pylint: enable=import-error
|
|
||||||
|
|
||||||
from .. import SpleeterError
|
from .. import SpleeterError
|
||||||
from ..utils.logging import get_logger
|
from ..types import AudioDescriptor, Signal
|
||||||
|
from ..utils.logging import logger
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
# pylint: enable=import-error
|
||||||
__author__ = 'Deezer Research'
|
|
||||||
__license__ = 'MIT License'
|
|
||||||
|
__email__ = "spleeter@deezer.com"
|
||||||
|
__author__ = "Deezer Research"
|
||||||
|
__license__ = "MIT License"
|
||||||
|
|
||||||
|
|
||||||
class AudioAdapter(ABC):
|
class AudioAdapter(ABC):
|
||||||
""" An abstract class for manipulating audio signal. """
|
""" An abstract class for manipulating audio signal. """
|
||||||
|
|
||||||
# Default audio adapter singleton instance.
|
_DEFAULT: "AudioAdapter" = None
|
||||||
DEFAULT = None
|
""" Default audio adapter singleton instance. """
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load(
|
def load(
|
||||||
self, audio_descriptor, offset, duration,
|
self,
|
||||||
sample_rate, dtype=np.float32):
|
audio_descriptor: AudioDescriptor,
|
||||||
""" Loads the audio file denoted by the given audio descriptor
|
offset: Optional[float] = None,
|
||||||
and returns it data as a waveform. Aims to be implemented
|
duration: Optional[float] = None,
|
||||||
by client.
|
sample_rate: Optional[float] = None,
|
||||||
|
dtype: np.dtype = np.float32,
|
||||||
|
) -> Signal:
|
||||||
|
"""
|
||||||
|
Loads the audio file denoted by the given audio descriptor and
|
||||||
|
returns it data as a waveform. Aims to be implemented by client.
|
||||||
|
|
||||||
:param audio_descriptor: Describe song to load, in case of file
|
Parameters:
|
||||||
based audio adapter, such descriptor would
|
audio_descriptor (AudioDescriptor):
|
||||||
be a file path.
|
Describe song to load, in case of file based audio adapter,
|
||||||
:param offset: Start offset to load from in seconds.
|
such descriptor would be a file path.
|
||||||
:param duration: Duration to load in seconds.
|
offset (Optional[float]):
|
||||||
:param sample_rate: Sample rate to load audio with.
|
Start offset to load from in seconds.
|
||||||
:param dtype: Numpy data type to use, default to float32.
|
duration (Optional[float]):
|
||||||
:returns: Loaded data as (wf, sample_rate) tuple.
|
Duration to load in seconds.
|
||||||
|
sample_rate (Optional[float]):
|
||||||
|
Sample rate to load audio with.
|
||||||
|
dtype (numpy.dtype):
|
||||||
|
(Optional) Numpy data type to use, default to `float32`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Signal:
|
||||||
|
Loaded data as (wf, sample_rate) tuple.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def load_tf_waveform(
|
def load_tf_waveform(
|
||||||
self, audio_descriptor,
|
self,
|
||||||
offset=0.0, duration=1800., sample_rate=44100,
|
audio_descriptor,
|
||||||
dtype=b'float32', waveform_name='waveform'):
|
offset: float = 0.0,
|
||||||
""" Load the audio and convert it to a tensorflow waveform.
|
duration: float = 1800.0,
|
||||||
|
sample_rate: int = 44100,
|
||||||
|
dtype: bytes = b"float32",
|
||||||
|
waveform_name: str = "waveform",
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Load the audio and convert it to a tensorflow waveform.
|
||||||
|
|
||||||
:param audio_descriptor: Describe song to load, in case of file
|
Parameters:
|
||||||
based audio adapter, such descriptor would
|
audio_descriptor ():
|
||||||
be a file path.
|
Describe song to load, in case of file based audio adapter,
|
||||||
:param offset: Start offset to load from in seconds.
|
such descriptor would be a file path.
|
||||||
:param duration: Duration to load in seconds.
|
offset (float):
|
||||||
:param sample_rate: Sample rate to load audio with.
|
Start offset to load from in seconds.
|
||||||
:param dtype: Numpy data type to use, default to float32.
|
duration (float):
|
||||||
:param waveform_name: (Optional) Name of the key in output dict.
|
Duration to load in seconds.
|
||||||
:returns: TF output dict with waveform as
|
sample_rate (float):
|
||||||
(T x chan numpy array) and a boolean that
|
Sample rate to load audio with.
|
||||||
tells whether there were an error while
|
dtype (bytes):
|
||||||
trying to load the waveform.
|
(Optional)data type to use, default to `b'float32'`.
|
||||||
|
waveform_name (str):
|
||||||
|
(Optional) Name of the key in output dict, default to
|
||||||
|
`'waveform'`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]:
|
||||||
|
TF output dict with waveform as `(T x chan numpy array)`
|
||||||
|
and a boolean that tells whether there were an error while
|
||||||
|
trying to load the waveform.
|
||||||
"""
|
"""
|
||||||
# Cast parameters to TF format.
|
# Cast parameters to TF format.
|
||||||
offset = tf.cast(offset, tf.float64)
|
offset = tf.cast(offset, tf.float64)
|
||||||
@@ -74,76 +105,96 @@ class AudioAdapter(ABC):
|
|||||||
|
|
||||||
# Defined safe loading function.
|
# Defined safe loading function.
|
||||||
def safe_load(path, offset, duration, sample_rate, dtype):
|
def safe_load(path, offset, duration, sample_rate, dtype):
|
||||||
logger = get_logger()
|
logger.info(f"Loading audio {path} from {offset} to {offset + duration}")
|
||||||
logger.info(
|
|
||||||
f'Loading audio {path} from {offset} to {offset + duration}')
|
|
||||||
try:
|
try:
|
||||||
(data, _) = self.load(
|
(data, _) = self.load(
|
||||||
path.numpy(),
|
path.numpy(),
|
||||||
offset.numpy(),
|
offset.numpy(),
|
||||||
duration.numpy(),
|
duration.numpy(),
|
||||||
sample_rate.numpy(),
|
sample_rate.numpy(),
|
||||||
dtype=dtype.numpy())
|
dtype=dtype.numpy(),
|
||||||
logger.info('Audio data loaded successfully')
|
)
|
||||||
|
logger.info("Audio data loaded successfully")
|
||||||
return (data, False)
|
return (data, False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(
|
logger.exception("An error occurs while loading audio", exc_info=e)
|
||||||
'An error occurs while loading audio',
|
|
||||||
exc_info=e)
|
|
||||||
return (np.float32(-1.0), True)
|
return (np.float32(-1.0), True)
|
||||||
|
|
||||||
# Execute function and format results.
|
# Execute function and format results.
|
||||||
results = tf.py_function(
|
results = (
|
||||||
safe_load,
|
tf.py_function(
|
||||||
[audio_descriptor, offset, duration, sample_rate, dtype],
|
safe_load,
|
||||||
(tf.float32, tf.bool)),
|
[audio_descriptor, offset, duration, sample_rate, dtype],
|
||||||
|
(tf.float32, tf.bool),
|
||||||
|
),
|
||||||
|
)
|
||||||
waveform, error = results[0]
|
waveform, error = results[0]
|
||||||
return {
|
return {waveform_name: waveform, f"{waveform_name}_error": error}
|
||||||
waveform_name: waveform,
|
|
||||||
f'{waveform_name}_error': error
|
|
||||||
}
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save(
|
def save(
|
||||||
self, path, data, sample_rate,
|
self,
|
||||||
codec=None, bitrate=None):
|
path: Union[Path, str],
|
||||||
""" Save the given audio data to the file denoted by
|
data: np.ndarray,
|
||||||
the given path.
|
sample_rate: float,
|
||||||
|
codec: Codec = None,
|
||||||
|
bitrate: str = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Save the given audio data to the file denoted by the given path.
|
||||||
|
|
||||||
:param path: Path of the audio file to save data in.
|
Parameters:
|
||||||
:param data: Waveform data to write.
|
path (Union[Path, str]):
|
||||||
:param sample_rate: Sample rate to write file in.
|
Path like of the audio file to save data in.
|
||||||
:param codec: (Optional) Writing codec to use.
|
data (numpy.ndarray):
|
||||||
:param bitrate: (Optional) Bitrate of the written audio file.
|
Waveform data to write.
|
||||||
|
sample_rate (float):
|
||||||
|
Sample rate to write file in.
|
||||||
|
codec ():
|
||||||
|
(Optional) Writing codec to use, default to `None`.
|
||||||
|
bitrate (str):
|
||||||
|
(Optional) Bitrate of the written audio file, default to
|
||||||
|
`None`.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default(cls: type) -> "AudioAdapter":
|
||||||
|
"""
|
||||||
|
Builds and returns a default audio adapter instance.
|
||||||
|
|
||||||
def get_default_audio_adapter():
|
Returns:
|
||||||
""" Builds and returns a default audio adapter instance.
|
AudioAdapter:
|
||||||
|
Default adapter instance to use.
|
||||||
|
"""
|
||||||
|
if cls._DEFAULT is None:
|
||||||
|
from .ffmpeg import FFMPEGProcessAudioAdapter
|
||||||
|
|
||||||
:returns: An audio adapter instance.
|
cls._DEFAULT = FFMPEGProcessAudioAdapter()
|
||||||
"""
|
return cls._DEFAULT
|
||||||
if AudioAdapter.DEFAULT is None:
|
|
||||||
from .ffmpeg import FFMPEGProcessAudioAdapter
|
|
||||||
AudioAdapter.DEFAULT = FFMPEGProcessAudioAdapter()
|
|
||||||
return AudioAdapter.DEFAULT
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls: type, descriptor: str) -> "AudioAdapter":
|
||||||
|
"""
|
||||||
|
Load dynamically an AudioAdapter from given class descriptor.
|
||||||
|
|
||||||
def get_audio_adapter(descriptor):
|
Parameters:
|
||||||
""" Load dynamically an AudioAdapter from given class descriptor.
|
descriptor (str):
|
||||||
|
Adapter class descriptor (module.Class)
|
||||||
|
|
||||||
:param descriptor: Adapter class descriptor (module.Class)
|
Returns:
|
||||||
:returns: Created adapter instance.
|
AudioAdapter:
|
||||||
"""
|
Created adapter instance.
|
||||||
if descriptor is None:
|
"""
|
||||||
return get_default_audio_adapter()
|
if not descriptor:
|
||||||
module_path = descriptor.split('.')
|
return cls.default()
|
||||||
adapter_class_name = module_path[-1]
|
module_path: List[str] = descriptor.split(".")
|
||||||
module_path = '.'.join(module_path[:-1])
|
adapter_class_name: str = module_path[-1]
|
||||||
adapter_module = import_module(module_path)
|
module_path: str = ".".join(module_path[:-1])
|
||||||
adapter_class = getattr(adapter_module, adapter_class_name)
|
adapter_module = import_module(module_path)
|
||||||
if not isinstance(adapter_class, AudioAdapter):
|
adapter_class = getattr(adapter_module, adapter_class_name)
|
||||||
raise SpleeterError(
|
if not issubclass(adapter_class, AudioAdapter):
|
||||||
f'{adapter_class_name} is not a valid AudioAdapter class')
|
raise SpleeterError(
|
||||||
return adapter_class()
|
f"{adapter_class_name} is not a valid AudioAdapter class"
|
||||||
|
)
|
||||||
|
return adapter_class()
|
||||||
|
|||||||
@@ -3,39 +3,54 @@
|
|||||||
|
|
||||||
""" This module provides audio data convertion functions. """
|
""" This module provides audio data convertion functions. """
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
# pylint: enable=import-error
|
|
||||||
|
|
||||||
from ..utils.tensor import from_float32_to_uint8, from_uint8_to_float32
|
from ..utils.tensor import from_float32_to_uint8, from_uint8_to_float32
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
# pylint: enable=import-error
|
||||||
__author__ = 'Deezer Research'
|
|
||||||
__license__ = 'MIT License'
|
__email__ = "spleeter@deezer.com"
|
||||||
|
__author__ = "Deezer Research"
|
||||||
|
__license__ = "MIT License"
|
||||||
|
|
||||||
|
|
||||||
def to_n_channels(waveform, n_channels):
|
def to_n_channels(waveform: tf.Tensor, n_channels: int) -> tf.Tensor:
|
||||||
""" Convert a waveform to n_channels by removing or
|
"""
|
||||||
duplicating channels if needed (in tensorflow).
|
Convert a waveform to n_channels by removing or duplicating channels if
|
||||||
|
needed (in tensorflow).
|
||||||
|
|
||||||
:param waveform: Waveform to transform.
|
Parameters:
|
||||||
:param n_channels: Number of channel to reshape waveform in.
|
waveform (tensorflow.Tensor):
|
||||||
:returns: Reshaped waveform.
|
Waveform to transform.
|
||||||
|
n_channels (int):
|
||||||
|
Number of channel to reshape waveform in.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensorflow.Tensor:
|
||||||
|
Reshaped waveform.
|
||||||
"""
|
"""
|
||||||
return tf.cond(
|
return tf.cond(
|
||||||
tf.shape(waveform)[1] >= n_channels,
|
tf.shape(waveform)[1] >= n_channels,
|
||||||
true_fn=lambda: waveform[:, :n_channels],
|
true_fn=lambda: waveform[:, :n_channels],
|
||||||
false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels]
|
false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def to_stereo(waveform):
|
def to_stereo(waveform: np.ndarray) -> np.ndarray:
|
||||||
""" Convert a waveform to stereo by duplicating if mono,
|
"""
|
||||||
or truncating if too many channels.
|
Convert a waveform to stereo by duplicating if mono, or truncating
|
||||||
|
if too many channels.
|
||||||
|
|
||||||
:param waveform: a (N, d) numpy array.
|
Parameters:
|
||||||
:returns: A stereo waveform as a (N, 1) numpy array.
|
waveform (numpy.ndarray):
|
||||||
|
a `(N, d)` numpy array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray:
|
||||||
|
A stereo waveform as a `(N, 1)` numpy array.
|
||||||
"""
|
"""
|
||||||
if waveform.shape[1] == 1:
|
if waveform.shape[1] == 1:
|
||||||
return np.repeat(waveform, 2, axis=-1)
|
return np.repeat(waveform, 2, axis=-1)
|
||||||
@@ -44,45 +59,81 @@ def to_stereo(waveform):
|
|||||||
return waveform
|
return waveform
|
||||||
|
|
||||||
|
|
||||||
def gain_to_db(tensor, espilon=10e-10):
|
def gain_to_db(tensor: tf.Tensor, espilon: float = 10e-10) -> tf.Tensor:
|
||||||
""" Convert from gain to decibel in tensorflow.
|
|
||||||
|
|
||||||
:param tensor: Tensor to convert.
|
|
||||||
:param epsilon: Operation constant.
|
|
||||||
:returns: Converted tensor.
|
|
||||||
"""
|
"""
|
||||||
return 20. / np.log(10) * tf.math.log(tf.maximum(tensor, espilon))
|
Convert from gain to decibel in tensorflow.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
tensor (tensorflow.Tensor):
|
||||||
|
Tensor to convert
|
||||||
|
epsilon (float):
|
||||||
|
Operation constant.
|
||||||
|
|
||||||
def db_to_gain(tensor):
|
Returns:
|
||||||
""" Convert from decibel to gain in tensorflow.
|
tensorflow.Tensor:
|
||||||
|
Converted tensor.
|
||||||
:param tensor_db: Tensor to convert.
|
|
||||||
:returns: Converted tensor.
|
|
||||||
"""
|
"""
|
||||||
return tf.pow(10., (tensor / 20.))
|
return 20.0 / np.log(10) * tf.math.log(tf.maximum(tensor, espilon))
|
||||||
|
|
||||||
|
|
||||||
def spectrogram_to_db_uint(spectrogram, db_range=100., **kwargs):
|
def db_to_gain(tensor: tf.Tensor) -> tf.Tensor:
|
||||||
""" Encodes given spectrogram into uint8 using decibel scale.
|
|
||||||
|
|
||||||
:param spectrogram: Spectrogram to be encoded as TF float tensor.
|
|
||||||
:param db_range: Range in decibel for encoding.
|
|
||||||
:returns: Encoded decibel spectrogram as uint8 tensor.
|
|
||||||
"""
|
"""
|
||||||
db_spectrogram = gain_to_db(spectrogram)
|
Convert from decibel to gain in tensorflow.
|
||||||
max_db_spectrogram = tf.reduce_max(db_spectrogram)
|
|
||||||
db_spectrogram = tf.maximum(db_spectrogram, max_db_spectrogram - db_range)
|
Parameters:
|
||||||
|
tensor (tensorflow.Tensor):
|
||||||
|
Tensor to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensorflow.Tensor:
|
||||||
|
Converted tensor.
|
||||||
|
"""
|
||||||
|
return tf.pow(10.0, (tensor / 20.0))
|
||||||
|
|
||||||
|
|
||||||
|
def spectrogram_to_db_uint(
|
||||||
|
spectrogram: tf.Tensor, db_range: float = 100.0, **kwargs
|
||||||
|
) -> tf.Tensor:
|
||||||
|
"""
|
||||||
|
Encodes given spectrogram into uint8 using decibel scale.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
spectrogram (tensorflow.Tensor):
|
||||||
|
Spectrogram to be encoded as TF float tensor.
|
||||||
|
db_range (float):
|
||||||
|
Range in decibel for encoding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensorflow.Tensor:
|
||||||
|
Encoded decibel spectrogram as `uint8` tensor.
|
||||||
|
"""
|
||||||
|
db_spectrogram: tf.Tensor = gain_to_db(spectrogram)
|
||||||
|
max_db_spectrogram: tf.Tensor = tf.reduce_max(db_spectrogram)
|
||||||
|
db_spectrogram: tf.Tensor = tf.maximum(
|
||||||
|
db_spectrogram, max_db_spectrogram - db_range
|
||||||
|
)
|
||||||
return from_float32_to_uint8(db_spectrogram, **kwargs)
|
return from_float32_to_uint8(db_spectrogram, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def db_uint_spectrogram_to_gain(db_uint_spectrogram, min_db, max_db):
|
def db_uint_spectrogram_to_gain(
|
||||||
""" Decode spectrogram from uint8 decibel scale.
|
db_uint_spectrogram: tf.Tensor, min_db: tf.Tensor, max_db: tf.Tensor
|
||||||
|
) -> tf.Tensor:
|
||||||
:param db_uint_spectrogram: Decibel pectrogram to decode.
|
|
||||||
:param min_db: Lower bound limit for decoding.
|
|
||||||
:param max_db: Upper bound limit for decoding.
|
|
||||||
:returns: Decoded spectrogram as float2 tensor.
|
|
||||||
"""
|
"""
|
||||||
db_spectrogram = from_uint8_to_float32(db_uint_spectrogram, min_db, max_db)
|
Decode spectrogram from uint8 decibel scale.
|
||||||
|
|
||||||
|
Paramters:
|
||||||
|
db_uint_spectrogram (tensorflow.Tensor):
|
||||||
|
Decibel spectrogram to decode.
|
||||||
|
min_db (tensorflow.Tensor):
|
||||||
|
Lower bound limit for decoding.
|
||||||
|
max_db (tensorflow.Tensor):
|
||||||
|
Upper bound limit for decoding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensorflow.Tensor:
|
||||||
|
Decoded spectrogram as `float32` tensor.
|
||||||
|
"""
|
||||||
|
db_spectrogram: tf.Tensor = from_uint8_to_float32(
|
||||||
|
db_uint_spectrogram, min_db, max_db
|
||||||
|
)
|
||||||
return db_to_gain(db_spectrogram)
|
return db_to_gain(db_spectrogram)
|
||||||
|
|||||||
@@ -8,143 +8,178 @@
|
|||||||
used within this library.
|
used within this library.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import datetime as dt
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
import ffmpeg
|
import ffmpeg
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from .. import SpleeterError
|
||||||
|
from ..types import Signal
|
||||||
|
from ..utils.logging import logger
|
||||||
|
from . import Codec
|
||||||
|
from .adapter import AudioAdapter
|
||||||
|
|
||||||
# pylint: enable=import-error
|
# pylint: enable=import-error
|
||||||
|
|
||||||
from .adapter import AudioAdapter
|
__email__ = "spleeter@deezer.com"
|
||||||
from .. import SpleeterError
|
__author__ = "Deezer Research"
|
||||||
from ..utils.logging import get_logger
|
__license__ = "MIT License"
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
|
||||||
__author__ = 'Deezer Research'
|
|
||||||
__license__ = 'MIT License'
|
|
||||||
|
|
||||||
|
|
||||||
def _check_ffmpeg_install():
|
|
||||||
""" Ensure FFMPEG binaries are available.
|
|
||||||
|
|
||||||
:raise SpleeterError: If ffmpeg or ffprobe is not found.
|
|
||||||
"""
|
|
||||||
for binary in ('ffmpeg', 'ffprobe'):
|
|
||||||
if shutil.which(binary) is None:
|
|
||||||
raise SpleeterError('{} binary not found'.format(binary))
|
|
||||||
|
|
||||||
|
|
||||||
def _to_ffmpeg_time(n):
|
|
||||||
""" Format number of seconds to time expected by FFMPEG.
|
|
||||||
:param n: Time in seconds to format.
|
|
||||||
:returns: Formatted time in FFMPEG format.
|
|
||||||
"""
|
|
||||||
m, s = divmod(n, 60)
|
|
||||||
h, m = divmod(m, 60)
|
|
||||||
return '%d:%02d:%09.6f' % (h, m, s)
|
|
||||||
|
|
||||||
|
|
||||||
def _to_ffmpeg_codec(codec):
|
|
||||||
ffmpeg_codecs = {
|
|
||||||
'm4a': 'aac',
|
|
||||||
'ogg': 'libvorbis',
|
|
||||||
'wma': 'wmav2',
|
|
||||||
}
|
|
||||||
return ffmpeg_codecs.get(codec) or codec
|
|
||||||
|
|
||||||
|
|
||||||
class FFMPEGProcessAudioAdapter(AudioAdapter):
|
class FFMPEGProcessAudioAdapter(AudioAdapter):
|
||||||
""" An AudioAdapter implementation that use FFMPEG binary through
|
"""
|
||||||
|
An AudioAdapter implementation that use FFMPEG binary through
|
||||||
subprocess in order to perform I/O operation for audio processing.
|
subprocess in order to perform I/O operation for audio processing.
|
||||||
|
|
||||||
When created, FFMPEG binary path will be checked and expended,
|
When created, FFMPEG binary path will be checked and expended,
|
||||||
raising exception if not found. Such path could be infered using
|
raising exception if not found. Such path could be infered using
|
||||||
FFMPEG_PATH environment variable.
|
`FFMPEG_PATH` environment variable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
SUPPORTED_CODECS: Dict[Codec, str] = {
|
||||||
|
Codec.M4A: "aac",
|
||||||
|
Codec.OGG: "libvorbis",
|
||||||
|
Codec.WMA: "wmav2",
|
||||||
|
}
|
||||||
|
""" FFMPEG codec name mapping. """
|
||||||
|
|
||||||
|
def __init__(_) -> None:
|
||||||
|
"""
|
||||||
|
Default constructor, ensure FFMPEG binaries are available.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SpleeterError:
|
||||||
|
If ffmpeg or ffprobe is not found.
|
||||||
|
"""
|
||||||
|
for binary in ("ffmpeg", "ffprobe"):
|
||||||
|
if shutil.which(binary) is None:
|
||||||
|
raise SpleeterError("{} binary not found".format(binary))
|
||||||
|
|
||||||
def load(
|
def load(
|
||||||
self, path, offset=None, duration=None,
|
_,
|
||||||
sample_rate=None, dtype=np.float32):
|
path: Union[Path, str],
|
||||||
""" Loads the audio file denoted by the given path
|
offset: Optional[float] = None,
|
||||||
|
duration: Optional[float] = None,
|
||||||
|
sample_rate: Optional[float] = None,
|
||||||
|
dtype: np.dtype = np.float32,
|
||||||
|
) -> Signal:
|
||||||
|
"""
|
||||||
|
Loads the audio file denoted by the given path
|
||||||
and returns it data as a waveform.
|
and returns it data as a waveform.
|
||||||
|
|
||||||
:param path: Path of the audio file to load data from.
|
Parameters:
|
||||||
:param offset: (Optional) Start offset to load from in seconds.
|
path (Union[Path, str]:
|
||||||
:param duration: (Optional) Duration to load in seconds.
|
Path of the audio file to load data from.
|
||||||
:param sample_rate: (Optional) Sample rate to load audio with.
|
offset (Optional[float]):
|
||||||
:param dtype: (Optional) Numpy data type to use, default to float32.
|
Start offset to load from in seconds.
|
||||||
:returns: Loaded data a (waveform, sample_rate) tuple.
|
duration (Optional[float]):
|
||||||
:raise SpleeterError: If any error occurs while loading audio.
|
Duration to load in seconds.
|
||||||
|
sample_rate (Optional[float]):
|
||||||
|
Sample rate to load audio with.
|
||||||
|
dtype (numpy.dtype):
|
||||||
|
(Optional) Numpy data type to use, default to `float32`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Signal:
|
||||||
|
Loaded data a (waveform, sample_rate) tuple.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SpleeterError:
|
||||||
|
If any error occurs while loading audio.
|
||||||
"""
|
"""
|
||||||
_check_ffmpeg_install()
|
if isinstance(path, Path):
|
||||||
|
path = str(path)
|
||||||
if not isinstance(path, str):
|
if not isinstance(path, str):
|
||||||
path = path.decode()
|
path = path.decode()
|
||||||
try:
|
try:
|
||||||
probe = ffmpeg.probe(path)
|
probe = ffmpeg.probe(path)
|
||||||
except ffmpeg._run.Error as e:
|
except ffmpeg._run.Error as e:
|
||||||
raise SpleeterError(
|
raise SpleeterError(
|
||||||
'An error occurs with ffprobe (see ffprobe output below)\n\n{}'
|
"An error occurs with ffprobe (see ffprobe output below)\n\n{}".format(
|
||||||
.format(e.stderr.decode()))
|
e.stderr.decode()
|
||||||
if 'streams' not in probe or len(probe['streams']) == 0:
|
)
|
||||||
raise SpleeterError('No stream was found with ffprobe')
|
)
|
||||||
|
if "streams" not in probe or len(probe["streams"]) == 0:
|
||||||
|
raise SpleeterError("No stream was found with ffprobe")
|
||||||
metadata = next(
|
metadata = next(
|
||||||
stream
|
stream for stream in probe["streams"] if stream["codec_type"] == "audio"
|
||||||
for stream in probe['streams']
|
)
|
||||||
if stream['codec_type'] == 'audio')
|
n_channels = metadata["channels"]
|
||||||
n_channels = metadata['channels']
|
|
||||||
if sample_rate is None:
|
if sample_rate is None:
|
||||||
sample_rate = metadata['sample_rate']
|
sample_rate = metadata["sample_rate"]
|
||||||
output_kwargs = {'format': 'f32le', 'ar': sample_rate}
|
output_kwargs = {"format": "f32le", "ar": sample_rate}
|
||||||
if duration is not None:
|
if duration is not None:
|
||||||
output_kwargs['t'] = _to_ffmpeg_time(duration)
|
output_kwargs["t"] = str(dt.timedelta(seconds=duration))
|
||||||
if offset is not None:
|
if offset is not None:
|
||||||
output_kwargs['ss'] = _to_ffmpeg_time(offset)
|
output_kwargs["ss"] = str(dt.timedelta(seconds=offset))
|
||||||
process = (
|
process = (
|
||||||
ffmpeg
|
ffmpeg.input(path)
|
||||||
.input(path)
|
.output("pipe:", **output_kwargs)
|
||||||
.output('pipe:', **output_kwargs)
|
.run_async(pipe_stdout=True, pipe_stderr=True)
|
||||||
.run_async(pipe_stdout=True, pipe_stderr=True))
|
)
|
||||||
buffer, _ = process.communicate()
|
buffer, _ = process.communicate()
|
||||||
waveform = np.frombuffer(buffer, dtype='<f4').reshape(-1, n_channels)
|
waveform = np.frombuffer(buffer, dtype="<f4").reshape(-1, n_channels)
|
||||||
if not waveform.dtype == np.dtype(dtype):
|
if not waveform.dtype == np.dtype(dtype):
|
||||||
waveform = waveform.astype(dtype)
|
waveform = waveform.astype(dtype)
|
||||||
return (waveform, sample_rate)
|
return (waveform, sample_rate)
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
self, path, data, sample_rate,
|
self,
|
||||||
codec=None, bitrate=None):
|
path: Union[Path, str],
|
||||||
""" Write waveform data to the file denoted by the given path
|
data: np.ndarray,
|
||||||
using FFMPEG process.
|
sample_rate: float,
|
||||||
|
codec: Codec = None,
|
||||||
:param path: Path of the audio file to save data in.
|
bitrate: str = None,
|
||||||
:param data: Waveform data to write.
|
) -> None:
|
||||||
:param sample_rate: Sample rate to write file in.
|
|
||||||
:param codec: (Optional) Writing codec to use.
|
|
||||||
:param bitrate: (Optional) Bitrate of the written audio file.
|
|
||||||
:raise IOError: If any error occurs while using FFMPEG to write data.
|
|
||||||
"""
|
"""
|
||||||
_check_ffmpeg_install()
|
Write waveform data to the file denoted by the given path using
|
||||||
|
FFMPEG process.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
path (Union[Path, str]):
|
||||||
|
Path like of the audio file to save data in.
|
||||||
|
data (numpy.ndarray):
|
||||||
|
Waveform data to write.
|
||||||
|
sample_rate (float):
|
||||||
|
Sample rate to write file in.
|
||||||
|
codec ():
|
||||||
|
(Optional) Writing codec to use, default to `None`.
|
||||||
|
bitrate (str):
|
||||||
|
(Optional) Bitrate of the written audio file, default to
|
||||||
|
`None`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
IOError:
|
||||||
|
If any error occurs while using FFMPEG to write data.
|
||||||
|
"""
|
||||||
|
if isinstance(path, Path):
|
||||||
|
path = str(path)
|
||||||
directory = os.path.dirname(path)
|
directory = os.path.dirname(path)
|
||||||
if not os.path.exists(directory):
|
if not os.path.exists(directory):
|
||||||
raise SpleeterError(f'output directory does not exists: {directory}')
|
raise SpleeterError(f"output directory does not exists: {directory}")
|
||||||
get_logger().debug('Writing file %s', path)
|
logger.debug(f"Writing file {path}")
|
||||||
input_kwargs = {'ar': sample_rate, 'ac': data.shape[1]}
|
input_kwargs = {"ar": sample_rate, "ac": data.shape[1]}
|
||||||
output_kwargs = {'ar': sample_rate, 'strict': '-2'}
|
output_kwargs = {"ar": sample_rate, "strict": "-2"}
|
||||||
if bitrate:
|
if bitrate:
|
||||||
output_kwargs['audio_bitrate'] = bitrate
|
output_kwargs["audio_bitrate"] = bitrate
|
||||||
if codec is not None and codec != 'wav':
|
if codec is not None and codec != "wav":
|
||||||
output_kwargs['codec'] = _to_ffmpeg_codec(codec)
|
output_kwargs["codec"] = self.SUPPORTED_CODECS.get(codec, codec)
|
||||||
process = (
|
process = (
|
||||||
ffmpeg
|
ffmpeg.input("pipe:", format="f32le", **input_kwargs)
|
||||||
.input('pipe:', format='f32le', **input_kwargs)
|
|
||||||
.output(path, **output_kwargs)
|
.output(path, **output_kwargs)
|
||||||
.overwrite_output()
|
.overwrite_output()
|
||||||
.run_async(pipe_stdin=True, pipe_stderr=True, quiet=True))
|
.run_async(pipe_stdin=True, pipe_stderr=True, quiet=True)
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
process.stdin.write(data.astype('<f4').tobytes())
|
process.stdin.write(data.astype("<f4").tobytes())
|
||||||
process.stdin.close()
|
process.stdin.close()
|
||||||
process.wait()
|
process.wait()
|
||||||
except IOError:
|
except IOError:
|
||||||
raise SpleeterError(f'FFMPEG error: {process.stderr.read()}')
|
raise SpleeterError(f"FFMPEG error: {process.stderr.read()}")
|
||||||
get_logger().info('File %s written succesfully', path)
|
logger.info(f"File {path} written succesfully")
|
||||||
|
|||||||
@@ -1,128 +1,176 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# coding: utf8
|
# coding: utf8
|
||||||
|
|
||||||
""" Spectrogram specific data augmentation """
|
""" Spectrogram specific data augmentation. """
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow.signal import hann_window, stft
|
||||||
|
|
||||||
from tensorflow.signal import stft, hann_window
|
|
||||||
# pylint: enable=import-error
|
# pylint: enable=import-error
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
__email__ = "spleeter@deezer.com"
|
||||||
__author__ = 'Deezer Research'
|
__author__ = "Deezer Research"
|
||||||
__license__ = 'MIT License'
|
__license__ = "MIT License"
|
||||||
|
|
||||||
|
|
||||||
def compute_spectrogram_tf(
|
def compute_spectrogram_tf(
|
||||||
waveform,
|
waveform: tf.Tensor,
|
||||||
frame_length=2048, frame_step=512,
|
frame_length: int = 2048,
|
||||||
spec_exponent=1., window_exponent=1.):
|
frame_step: int = 512,
|
||||||
""" Compute magnitude / power spectrogram from waveform as
|
spec_exponent: float = 1.0,
|
||||||
a n_samples x n_channels tensor.
|
window_exponent: float = 1.0,
|
||||||
|
) -> tf.Tensor:
|
||||||
:param waveform: Input waveform as (times x number of channels)
|
|
||||||
tensor.
|
|
||||||
:param frame_length: Length of a STFT frame to use.
|
|
||||||
:param frame_step: HOP between successive frames.
|
|
||||||
:param spec_exponent: Exponent of the spectrogram (usually 1 for
|
|
||||||
magnitude spectrogram, or 2 for power spectrogram).
|
|
||||||
:param window_exponent: Exponent applied to the Hann windowing function
|
|
||||||
(may be useful for making perfect STFT/iSTFT
|
|
||||||
reconstruction).
|
|
||||||
:returns: Computed magnitude / power spectrogram as a
|
|
||||||
(T x F x n_channels) tensor.
|
|
||||||
"""
|
"""
|
||||||
stft_tensor = tf.transpose(
|
Compute magnitude / power spectrogram from waveform as a
|
||||||
|
`n_samples x n_channels` tensor.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
waveform (tensorflow.Tensor):
|
||||||
|
Input waveform as `(times x number of channels)` tensor.
|
||||||
|
frame_length (int):
|
||||||
|
Length of a STFT frame to use.
|
||||||
|
frame_step (int):
|
||||||
|
HOP between successive frames.
|
||||||
|
spec_exponent (float):
|
||||||
|
Exponent of the spectrogram (usually 1 for magnitude
|
||||||
|
spectrogram, or 2 for power spectrogram).
|
||||||
|
window_exponent (float):
|
||||||
|
Exponent applied to the Hann windowing function (may be
|
||||||
|
useful for making perfect STFT/iSTFT reconstruction).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensorflow.Tensor:
|
||||||
|
Computed magnitude / power spectrogram as a
|
||||||
|
`(T x F x n_channels)` tensor.
|
||||||
|
"""
|
||||||
|
stft_tensor: tf.Tensor = tf.transpose(
|
||||||
stft(
|
stft(
|
||||||
tf.transpose(waveform),
|
tf.transpose(waveform),
|
||||||
frame_length,
|
frame_length,
|
||||||
frame_step,
|
frame_step,
|
||||||
window_fn=lambda f, dtype: hann_window(
|
window_fn=lambda f, dtype: hann_window(
|
||||||
f,
|
f, periodic=True, dtype=waveform.dtype
|
||||||
periodic=True,
|
)
|
||||||
dtype=waveform.dtype) ** window_exponent),
|
** window_exponent,
|
||||||
perm=[1, 2, 0])
|
),
|
||||||
|
perm=[1, 2, 0],
|
||||||
|
)
|
||||||
return tf.abs(stft_tensor) ** spec_exponent
|
return tf.abs(stft_tensor) ** spec_exponent
|
||||||
|
|
||||||
|
|
||||||
def time_stretch(
|
def time_stretch(
|
||||||
spectrogram,
|
spectrogram: tf.Tensor,
|
||||||
factor=1.0,
|
factor: float = 1.0,
|
||||||
method=tf.image.ResizeMethod.BILINEAR):
|
method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR,
|
||||||
""" Time stretch a spectrogram preserving shape in tensorflow. Note that
|
) -> tf.Tensor:
|
||||||
|
"""
|
||||||
|
Time stretch a spectrogram preserving shape in tensorflow. Note that
|
||||||
this is an approximation in the frequency domain.
|
this is an approximation in the frequency domain.
|
||||||
|
|
||||||
:param spectrogram: Input spectrogram to be time stretched as tensor.
|
Parameters:
|
||||||
:param factor: (Optional) Time stretch factor, must be >0, default to 1.
|
spectrogram (tensorflow.Tensor):
|
||||||
:param mehtod: (Optional) Interpolation method, default to BILINEAR.
|
Input spectrogram to be time stretched as tensor.
|
||||||
:returns: Time stretched spectrogram as tensor with same shape.
|
factor (float):
|
||||||
|
(Optional) Time stretch factor, must be > 0, default to `1`.
|
||||||
|
method (tensorflow.image.ResizeMethod):
|
||||||
|
(Optional) Interpolation method, default to `BILINEAR`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensorflow.Tensor:
|
||||||
|
Time stretched spectrogram as tensor with same shape.
|
||||||
"""
|
"""
|
||||||
T = tf.shape(spectrogram)[0]
|
T = tf.shape(spectrogram)[0]
|
||||||
T_ts = tf.cast(tf.cast(T, tf.float32) * factor, tf.int32)[0]
|
T_ts = tf.cast(tf.cast(T, tf.float32) * factor, tf.int32)[0]
|
||||||
F = tf.shape(spectrogram)[1]
|
F = tf.shape(spectrogram)[1]
|
||||||
ts_spec = tf.image.resize_images(
|
ts_spec = tf.image.resize_images(
|
||||||
spectrogram,
|
spectrogram, [T_ts, F], method=method, align_corners=True
|
||||||
[T_ts, F],
|
)
|
||||||
method=method,
|
|
||||||
align_corners=True)
|
|
||||||
return tf.image.resize_image_with_crop_or_pad(ts_spec, T, F)
|
return tf.image.resize_image_with_crop_or_pad(ts_spec, T, F)
|
||||||
|
|
||||||
|
|
||||||
def random_time_stretch(spectrogram, factor_min=0.9, factor_max=1.1, **kwargs):
|
def random_time_stretch(
|
||||||
""" Time stretch a spectrogram preserving shape with random ratio in
|
spectrogram: tf.Tensor, factor_min: float = 0.9, factor_max: float = 1.1, **kwargs
|
||||||
tensorflow. Applies time_stretch to spectrogram with a random ratio drawn
|
) -> tf.Tensor:
|
||||||
uniformly in [factor_min, factor_max].
|
|
||||||
|
|
||||||
:param spectrogram: Input spectrogram to be time stretched as tensor.
|
|
||||||
:param factor_min: (Optional) Min time stretch factor, default to 0.9.
|
|
||||||
:param factor_max: (Optional) Max time stretch factor, default to 1.1.
|
|
||||||
:returns: Randomly time stretched spectrogram as tensor with same shape.
|
|
||||||
"""
|
"""
|
||||||
factor = tf.random_uniform(
|
Time stretch a spectrogram preserving shape with random ratio in
|
||||||
shape=(1,),
|
tensorflow. Applies time_stretch to spectrogram with a random ratio
|
||||||
seed=0) * (factor_max - factor_min) + factor_min
|
drawn uniformly in `[factor_min, factor_max]`.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
spectrogram (tensorflow.Tensor):
|
||||||
|
Input spectrogram to be time stretched as tensor.
|
||||||
|
factor_min (float):
|
||||||
|
(Optional) Min time stretch factor, default to `0.9`.
|
||||||
|
factor_max (float):
|
||||||
|
(Optional) Max time stretch factor, default to `1.1`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensorflow.Tensor:
|
||||||
|
Randomly time stretched spectrogram as tensor with same shape.
|
||||||
|
"""
|
||||||
|
factor = (
|
||||||
|
tf.random_uniform(shape=(1,), seed=0) * (factor_max - factor_min) + factor_min
|
||||||
|
)
|
||||||
return time_stretch(spectrogram, factor=factor, **kwargs)
|
return time_stretch(spectrogram, factor=factor, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def pitch_shift(
|
def pitch_shift(
|
||||||
spectrogram,
|
spectrogram: tf.Tensor,
|
||||||
semitone_shift=0.0,
|
semitone_shift: float = 0.0,
|
||||||
method=tf.image.ResizeMethod.BILINEAR):
|
method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR,
|
||||||
""" Pitch shift a spectrogram preserving shape in tensorflow. Note that
|
) -> tf.Tensor:
|
||||||
|
"""
|
||||||
|
Pitch shift a spectrogram preserving shape in tensorflow. Note that
|
||||||
this is an approximation in the frequency domain.
|
this is an approximation in the frequency domain.
|
||||||
|
|
||||||
:param spectrogram: Input spectrogram to be pitch shifted as tensor.
|
Parameters:
|
||||||
:param semitone_shift: (Optional) Pitch shift in semitone, default to 0.0.
|
spectrogram (tensorflow.Tensor):
|
||||||
:param mehtod: (Optional) Interpolation method, default to BILINEAR.
|
Input spectrogram to be pitch shifted as tensor.
|
||||||
:returns: Pitch shifted spectrogram (same shape as spectrogram).
|
semitone_shift (float):
|
||||||
|
(Optional) Pitch shift in semitone, default to `0.0`.
|
||||||
|
method (tensorflow.image.ResizeMethod):
|
||||||
|
(Optional) Interpolation method, default to `BILINEAR`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensorflow.Tensor:
|
||||||
|
Pitch shifted spectrogram (same shape as spectrogram).
|
||||||
"""
|
"""
|
||||||
factor = 2 ** (semitone_shift / 12.)
|
factor = 2 ** (semitone_shift / 12.0)
|
||||||
T = tf.shape(spectrogram)[0]
|
T = tf.shape(spectrogram)[0]
|
||||||
F = tf.shape(spectrogram)[1]
|
F = tf.shape(spectrogram)[1]
|
||||||
F_ps = tf.cast(tf.cast(F, tf.float32) * factor, tf.int32)[0]
|
F_ps = tf.cast(tf.cast(F, tf.float32) * factor, tf.int32)[0]
|
||||||
ps_spec = tf.image.resize_images(
|
ps_spec = tf.image.resize_images(
|
||||||
spectrogram,
|
spectrogram, [T, F_ps], method=method, align_corners=True
|
||||||
[T, F_ps],
|
)
|
||||||
method=method,
|
|
||||||
align_corners=True)
|
|
||||||
paddings = [[0, 0], [0, tf.maximum(0, F - F_ps)], [0, 0]]
|
paddings = [[0, 0], [0, tf.maximum(0, F - F_ps)], [0, 0]]
|
||||||
return tf.pad(ps_spec[:, :F, :], paddings, 'CONSTANT')
|
return tf.pad(ps_spec[:, :F, :], paddings, "CONSTANT")
|
||||||
|
|
||||||
|
|
||||||
def random_pitch_shift(spectrogram, shift_min=-1., shift_max=1., **kwargs):
|
def random_pitch_shift(
|
||||||
""" Pitch shift a spectrogram preserving shape with random ratio in
|
spectrogram: tf.Tensor, shift_min: float = -1.0, shift_max: float = 1.0, **kwargs
|
||||||
tensorflow. Applies pitch_shift to spectrogram with a random shift
|
) -> tf.Tensor:
|
||||||
amount (expressed in semitones) drawn uniformly in [shift_min, shift_max].
|
|
||||||
|
|
||||||
:param spectrogram: Input spectrogram to be pitch shifted as tensor.
|
|
||||||
|
|
||||||
:param shift_min: (Optional) Min pitch shift in semitone, default to -1.
|
|
||||||
:param shift_max: (Optional) Max pitch shift in semitone, default to 1.
|
|
||||||
:returns: Randomly pitch shifted spectrogram (same shape as spectrogram).
|
|
||||||
"""
|
"""
|
||||||
semitone_shift = tf.random_uniform(
|
Pitch shift a spectrogram preserving shape with random ratio in
|
||||||
shape=(1,),
|
tensorflow. Applies pitch_shift to spectrogram with a random shift
|
||||||
seed=0) * (shift_max - shift_min) + shift_min
|
amount (expressed in semitones) drawn uniformly in
|
||||||
|
`[shift_min, shift_max]`.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
spectrogram (tensorflow.Tensor):
|
||||||
|
Input spectrogram to be pitch shifted as tensor.
|
||||||
|
shift_min (float):
|
||||||
|
(Optional) Min pitch shift in semitone, default to -1.
|
||||||
|
shift_max (float):
|
||||||
|
(Optional) Max pitch shift in semitone, default to 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensorflow.Tensor:
|
||||||
|
Randomly pitch shifted spectrogram (same shape as spectrogram).
|
||||||
|
"""
|
||||||
|
semitone_shift = (
|
||||||
|
tf.random_uniform(shape=(1,), seed=0) * (shift_max - shift_min) + shift_min
|
||||||
|
)
|
||||||
return pitch_shift(spectrogram, semitone_shift=semitone_shift, **kwargs)
|
return pitch_shift(spectrogram, semitone_shift=semitone_shift, **kwargs)
|
||||||
|
|||||||
@@ -1,209 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# coding: utf8
|
|
||||||
|
|
||||||
""" This modules provides spleeter command as well as CLI parsing methods. """
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from argparse import ArgumentParser
|
|
||||||
from tempfile import gettempdir
|
|
||||||
from os.path import exists, join
|
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
|
||||||
__author__ = 'Deezer Research'
|
|
||||||
__license__ = 'MIT License'
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# -i opt specification (separate).
|
|
||||||
OPT_INPUT = {
|
|
||||||
'dest': 'inputs',
|
|
||||||
'nargs': '+',
|
|
||||||
'help': 'List of input audio filenames',
|
|
||||||
'required': True
|
|
||||||
}
|
|
||||||
|
|
||||||
# -o opt specification (evaluate and separate).
|
|
||||||
OPT_OUTPUT = {
|
|
||||||
'dest': 'output_path',
|
|
||||||
'default': join(gettempdir(), 'separated_audio'),
|
|
||||||
'help': 'Path of the output directory to write audio files in'
|
|
||||||
}
|
|
||||||
|
|
||||||
# -f opt specification (separate).
|
|
||||||
OPT_FORMAT = {
|
|
||||||
'dest': 'filename_format',
|
|
||||||
'default': '{filename}/{instrument}.{codec}',
|
|
||||||
'help': (
|
|
||||||
'Template string that will be formatted to generated'
|
|
||||||
'output filename. Such template should be Python formattable'
|
|
||||||
'string, and could use {filename}, {instrument}, and {codec}'
|
|
||||||
'variables.'
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
# -p opt specification (train, evaluate and separate).
|
|
||||||
OPT_PARAMS = {
|
|
||||||
'dest': 'configuration',
|
|
||||||
'default': 'spleeter:2stems',
|
|
||||||
'type': str,
|
|
||||||
'action': 'store',
|
|
||||||
'help': 'JSON filename that contains params'
|
|
||||||
}
|
|
||||||
|
|
||||||
# -s opt specification (separate).
|
|
||||||
OPT_OFFSET = {
|
|
||||||
'dest': 'offset',
|
|
||||||
'type': float,
|
|
||||||
'default': 0.,
|
|
||||||
'help': 'Set the starting offset to separate audio from.'
|
|
||||||
}
|
|
||||||
|
|
||||||
# -d opt specification (separate).
|
|
||||||
OPT_DURATION = {
|
|
||||||
'dest': 'duration',
|
|
||||||
'type': float,
|
|
||||||
'default': 600.,
|
|
||||||
'help': (
|
|
||||||
'Set a maximum duration for processing audio '
|
|
||||||
'(only separate offset + duration first seconds of '
|
|
||||||
'the input file)')
|
|
||||||
}
|
|
||||||
|
|
||||||
# -w opt specification (separate)
|
|
||||||
OPT_STFT_BACKEND = {
|
|
||||||
'dest': 'stft_backend',
|
|
||||||
'type': str,
|
|
||||||
'choices' : ["tensorflow", "librosa", "auto"],
|
|
||||||
'default': "auto",
|
|
||||||
'help': 'Who should be in charge of computing the stfts. Librosa is faster than tensorflow on CPU and uses'
|
|
||||||
' less memory. "auto" will use tensorflow when GPU acceleration is available and librosa when not.'
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# -c opt specification (separate).
|
|
||||||
OPT_CODEC = {
|
|
||||||
'dest': 'codec',
|
|
||||||
'choices': ('wav', 'mp3', 'ogg', 'm4a', 'wma', 'flac'),
|
|
||||||
'default': 'wav',
|
|
||||||
'help': 'Audio codec to be used for the separated output'
|
|
||||||
}
|
|
||||||
|
|
||||||
# -b opt specification (separate).
|
|
||||||
OPT_BITRATE = {
|
|
||||||
'dest': 'bitrate',
|
|
||||||
'default': '128k',
|
|
||||||
'help': 'Audio bitrate to be used for the separated output'
|
|
||||||
}
|
|
||||||
|
|
||||||
# -m opt specification (evaluate and separate).
|
|
||||||
OPT_MWF = {
|
|
||||||
'dest': 'MWF',
|
|
||||||
'action': 'store_const',
|
|
||||||
'const': True,
|
|
||||||
'default': False,
|
|
||||||
'help': 'Whether to use multichannel Wiener filtering for separation',
|
|
||||||
}
|
|
||||||
|
|
||||||
# --mus_dir opt specification (evaluate).
|
|
||||||
OPT_MUSDB = {
|
|
||||||
'dest': 'mus_dir',
|
|
||||||
'type': str,
|
|
||||||
'required': True,
|
|
||||||
'help': 'Path to folder with musDB'
|
|
||||||
}
|
|
||||||
|
|
||||||
# -d opt specification (train).
|
|
||||||
OPT_DATA = {
|
|
||||||
'dest': 'audio_path',
|
|
||||||
'type': str,
|
|
||||||
'required': True,
|
|
||||||
'help': 'Path of the folder containing audio data for training'
|
|
||||||
}
|
|
||||||
|
|
||||||
# -a opt specification (train, evaluate and separate).
|
|
||||||
OPT_ADAPTER = {
|
|
||||||
'dest': 'audio_adapter',
|
|
||||||
'type': str,
|
|
||||||
'help': 'Name of the audio adapter to use for audio I/O'
|
|
||||||
}
|
|
||||||
|
|
||||||
# -a opt specification (train, evaluate and separate).
|
|
||||||
OPT_VERBOSE = {
|
|
||||||
'action': 'store_true',
|
|
||||||
'help': 'Shows verbose logs'
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _add_common_options(parser):
|
|
||||||
""" Add common option to the given parser.
|
|
||||||
|
|
||||||
:param parser: Parser to add common opt to.
|
|
||||||
"""
|
|
||||||
parser.add_argument('-a', '--adapter', **OPT_ADAPTER)
|
|
||||||
parser.add_argument('-p', '--params_filename', **OPT_PARAMS)
|
|
||||||
parser.add_argument('--verbose', **OPT_VERBOSE)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_train_parser(parser_factory):
|
|
||||||
""" Creates an argparser for training command
|
|
||||||
|
|
||||||
:param parser_factory: Factory to use to create parser instance.
|
|
||||||
:returns: Created and configured parser.
|
|
||||||
"""
|
|
||||||
parser = parser_factory('train', help='Train a source separation model')
|
|
||||||
_add_common_options(parser)
|
|
||||||
parser.add_argument('-d', '--data', **OPT_DATA)
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def _create_evaluate_parser(parser_factory):
|
|
||||||
""" Creates an argparser for evaluation command
|
|
||||||
|
|
||||||
:param parser_factory: Factory to use to create parser instance.
|
|
||||||
:returns: Created and configured parser.
|
|
||||||
"""
|
|
||||||
parser = parser_factory(
|
|
||||||
'evaluate',
|
|
||||||
help='Evaluate a model on the musDB test dataset')
|
|
||||||
_add_common_options(parser)
|
|
||||||
parser.add_argument('-o', '--output_path', **OPT_OUTPUT)
|
|
||||||
parser.add_argument('--mus_dir', **OPT_MUSDB)
|
|
||||||
parser.add_argument('-m', '--mwf', **OPT_MWF)
|
|
||||||
parser.add_argument('-B', '--stft-backend', **OPT_STFT_BACKEND)
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def _create_separate_parser(parser_factory):
|
|
||||||
""" Creates an argparser for separation command
|
|
||||||
|
|
||||||
:param parser_factory: Factory to use to create parser instance.
|
|
||||||
:returns: Created and configured parser.
|
|
||||||
"""
|
|
||||||
parser = parser_factory('separate', help='Separate audio files')
|
|
||||||
_add_common_options(parser)
|
|
||||||
parser.add_argument('-i', '--inputs', **OPT_INPUT)
|
|
||||||
parser.add_argument('-o', '--output_path', **OPT_OUTPUT)
|
|
||||||
parser.add_argument('-f', '--filename_format', **OPT_FORMAT)
|
|
||||||
parser.add_argument('-d', '--duration', **OPT_DURATION)
|
|
||||||
parser.add_argument('-s', '--offset', **OPT_OFFSET)
|
|
||||||
parser.add_argument('-c', '--codec', **OPT_CODEC)
|
|
||||||
parser.add_argument('-b', '--birate', **OPT_BITRATE)
|
|
||||||
parser.add_argument('-m', '--mwf', **OPT_MWF)
|
|
||||||
parser.add_argument('-B', '--stft-backend', **OPT_STFT_BACKEND)
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def create_argument_parser():
|
|
||||||
""" Creates overall command line parser for Spleeter.
|
|
||||||
|
|
||||||
:returns: Created argument parser.
|
|
||||||
"""
|
|
||||||
parser = ArgumentParser(prog='spleeter')
|
|
||||||
subparsers = parser.add_subparsers()
|
|
||||||
subparsers.dest = 'command'
|
|
||||||
subparsers.required = True
|
|
||||||
_create_separate_parser(subparsers.add_parser)
|
|
||||||
_create_train_parser(subparsers.add_parser)
|
|
||||||
_create_evaluate_parser(subparsers.add_parser)
|
|
||||||
return parser
|
|
||||||
@@ -1,167 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# coding: utf8
|
|
||||||
|
|
||||||
"""
|
|
||||||
Entrypoint provider for performing model evaluation.
|
|
||||||
|
|
||||||
Evaluation is performed against musDB dataset.
|
|
||||||
|
|
||||||
USAGE: python -m spleeter evaluate \
|
|
||||||
-p /path/to/params \
|
|
||||||
-o /path/to/output/dir \
|
|
||||||
[-m] \
|
|
||||||
--mus_dir /path/to/musdb dataset
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
|
|
||||||
from argparse import Namespace
|
|
||||||
from itertools import product
|
|
||||||
from glob import glob
|
|
||||||
from os.path import join, exists
|
|
||||||
|
|
||||||
# pylint: disable=import-error
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
# pylint: enable=import-error
|
|
||||||
|
|
||||||
from .separate import entrypoint as separate_entrypoint
|
|
||||||
from ..utils.logging import get_logger
|
|
||||||
|
|
||||||
try:
|
|
||||||
import musdb
|
|
||||||
import museval
|
|
||||||
except ImportError:
|
|
||||||
logger = get_logger()
|
|
||||||
logger.error('Extra dependencies musdb and museval not found')
|
|
||||||
logger.error('Please install musdb and museval first, abort')
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
|
||||||
__author__ = 'Deezer Research'
|
|
||||||
__license__ = 'MIT License'
|
|
||||||
|
|
||||||
_SPLIT = 'test'
|
|
||||||
_MIXTURE = 'mixture.wav'
|
|
||||||
_AUDIO_DIRECTORY = 'audio'
|
|
||||||
_METRICS_DIRECTORY = 'metrics'
|
|
||||||
_INSTRUMENTS = ('vocals', 'drums', 'bass', 'other')
|
|
||||||
_METRICS = ('SDR', 'SAR', 'SIR', 'ISR')
|
|
||||||
|
|
||||||
|
|
||||||
def _separate_evaluation_dataset(arguments, musdb_root_directory, params):
|
|
||||||
""" Performs audio separation on the musdb dataset from
|
|
||||||
the given directory and params.
|
|
||||||
|
|
||||||
:param arguments: Entrypoint arguments.
|
|
||||||
:param musdb_root_directory: Directory to retrieve dataset from.
|
|
||||||
:param params: Spleeter configuration to apply to separation.
|
|
||||||
:returns: Separation output directory path.
|
|
||||||
"""
|
|
||||||
songs = glob(join(musdb_root_directory, _SPLIT, '*/'))
|
|
||||||
mixtures = [join(song, _MIXTURE) for song in songs]
|
|
||||||
audio_output_directory = join(
|
|
||||||
arguments.output_path,
|
|
||||||
_AUDIO_DIRECTORY)
|
|
||||||
separate_entrypoint(
|
|
||||||
Namespace(
|
|
||||||
audio_adapter=arguments.audio_adapter,
|
|
||||||
configuration=arguments.configuration,
|
|
||||||
inputs=mixtures,
|
|
||||||
output_path=join(audio_output_directory, _SPLIT),
|
|
||||||
filename_format='{foldername}/{instrument}.{codec}',
|
|
||||||
codec='wav',
|
|
||||||
duration=600.,
|
|
||||||
offset=0.,
|
|
||||||
bitrate='128k',
|
|
||||||
MWF=arguments.MWF,
|
|
||||||
verbose=arguments.verbose,
|
|
||||||
stft_backend=arguments.stft_backend),
|
|
||||||
params)
|
|
||||||
return audio_output_directory
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_musdb_metrics(
|
|
||||||
arguments,
|
|
||||||
musdb_root_directory,
|
|
||||||
audio_output_directory):
|
|
||||||
""" Generates musdb metrics fro previsouly computed audio estimation.
|
|
||||||
|
|
||||||
:param arguments: Entrypoint arguments.
|
|
||||||
:param audio_output_directory: Directory to get audio estimation from.
|
|
||||||
:returns: Path of generated metrics directory.
|
|
||||||
"""
|
|
||||||
metrics_output_directory = join(
|
|
||||||
arguments.output_path,
|
|
||||||
_METRICS_DIRECTORY)
|
|
||||||
get_logger().info('Starting musdb evaluation (this could be long) ...')
|
|
||||||
dataset = musdb.DB(
|
|
||||||
root=musdb_root_directory,
|
|
||||||
is_wav=True,
|
|
||||||
subsets=[_SPLIT])
|
|
||||||
museval.eval_mus_dir(
|
|
||||||
dataset=dataset,
|
|
||||||
estimates_dir=audio_output_directory,
|
|
||||||
output_dir=metrics_output_directory)
|
|
||||||
get_logger().info('musdb evaluation done')
|
|
||||||
return metrics_output_directory
|
|
||||||
|
|
||||||
|
|
||||||
def _compile_metrics(metrics_output_directory):
|
|
||||||
""" Compiles metrics from given directory and returns
|
|
||||||
results as dict.
|
|
||||||
|
|
||||||
:param metrics_output_directory: Directory to get metrics from.
|
|
||||||
:returns: Compiled metrics as dict.
|
|
||||||
"""
|
|
||||||
songs = glob(join(metrics_output_directory, 'test/*.json'))
|
|
||||||
index = pd.MultiIndex.from_tuples(
|
|
||||||
product(_INSTRUMENTS, _METRICS),
|
|
||||||
names=['instrument', 'metric'])
|
|
||||||
pd.DataFrame([], index=['config1', 'config2'], columns=index)
|
|
||||||
metrics = {
|
|
||||||
instrument: {k: [] for k in _METRICS}
|
|
||||||
for instrument in _INSTRUMENTS}
|
|
||||||
for song in songs:
|
|
||||||
with open(song, 'r') as stream:
|
|
||||||
data = json.load(stream)
|
|
||||||
for target in data['targets']:
|
|
||||||
instrument = target['name']
|
|
||||||
for metric in _METRICS:
|
|
||||||
sdr_med = np.median([
|
|
||||||
frame['metrics'][metric]
|
|
||||||
for frame in target['frames']
|
|
||||||
if not np.isnan(frame['metrics'][metric])])
|
|
||||||
metrics[instrument][metric].append(sdr_med)
|
|
||||||
return metrics
|
|
||||||
|
|
||||||
|
|
||||||
def entrypoint(arguments, params):
|
|
||||||
""" Command entrypoint.
|
|
||||||
|
|
||||||
:param arguments: Command line parsed argument as argparse.Namespace.
|
|
||||||
:param params: Deserialized JSON configuration file provided in CLI args.
|
|
||||||
"""
|
|
||||||
# Parse and check musdb directory.
|
|
||||||
musdb_root_directory = arguments.mus_dir
|
|
||||||
if not exists(musdb_root_directory):
|
|
||||||
raise IOError(f'musdb directory {musdb_root_directory} not found')
|
|
||||||
# Separate musdb sources.
|
|
||||||
audio_output_directory = _separate_evaluation_dataset(
|
|
||||||
arguments,
|
|
||||||
musdb_root_directory,
|
|
||||||
params)
|
|
||||||
# Compute metrics with musdb.
|
|
||||||
metrics_output_directory = _compute_musdb_metrics(
|
|
||||||
arguments,
|
|
||||||
musdb_root_directory,
|
|
||||||
audio_output_directory)
|
|
||||||
# Compute and pretty print median metrics.
|
|
||||||
metrics = _compile_metrics(metrics_output_directory)
|
|
||||||
for instrument, metric in metrics.items():
|
|
||||||
get_logger().info('%s:', instrument)
|
|
||||||
for metric, value in metric.items():
|
|
||||||
get_logger().info('%s: %s', metric, f'{np.median(value):.3f}')
|
|
||||||
|
|
||||||
return metrics
|
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# coding: utf8
|
|
||||||
|
|
||||||
"""
|
|
||||||
Entrypoint provider for performing source separation.
|
|
||||||
|
|
||||||
USAGE: python -m spleeter separate \
|
|
||||||
-p /path/to/params \
|
|
||||||
-i inputfile1 inputfile2 ... inputfilen
|
|
||||||
-o /path/to/output/dir \
|
|
||||||
-i /path/to/audio1.wav /path/to/audio2.mp3
|
|
||||||
"""
|
|
||||||
|
|
||||||
from ..audio.adapter import get_audio_adapter
|
|
||||||
from ..separator import Separator
|
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
|
||||||
__author__ = 'Deezer Research'
|
|
||||||
__license__ = 'MIT License'
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def entrypoint(arguments, params):
|
|
||||||
""" Command entrypoint.
|
|
||||||
|
|
||||||
:param arguments: Command line parsed argument as argparse.Namespace.
|
|
||||||
:param params: Deserialized JSON configuration file provided in CLI args.
|
|
||||||
"""
|
|
||||||
# TODO: check with output naming.
|
|
||||||
audio_adapter = get_audio_adapter(arguments.audio_adapter)
|
|
||||||
separator = Separator(
|
|
||||||
arguments.configuration,
|
|
||||||
MWF=arguments.MWF,
|
|
||||||
stft_backend=arguments.stft_backend)
|
|
||||||
for filename in arguments.inputs:
|
|
||||||
separator.separate_to_file(
|
|
||||||
filename,
|
|
||||||
arguments.output_path,
|
|
||||||
audio_adapter=audio_adapter,
|
|
||||||
offset=arguments.offset,
|
|
||||||
duration=arguments.duration,
|
|
||||||
codec=arguments.codec,
|
|
||||||
bitrate=arguments.bitrate,
|
|
||||||
filename_format=arguments.filename_format,
|
|
||||||
synchronous=False
|
|
||||||
)
|
|
||||||
separator.join()
|
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# coding: utf8
|
|
||||||
|
|
||||||
"""
|
|
||||||
Entrypoint provider for performing model training.
|
|
||||||
|
|
||||||
USAGE: python -m spleeter train -p /path/to/params
|
|
||||||
"""
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
# pylint: disable=import-error
|
|
||||||
import tensorflow as tf
|
|
||||||
# pylint: enable=import-error
|
|
||||||
|
|
||||||
from ..audio.adapter import get_audio_adapter
|
|
||||||
from ..dataset import get_training_dataset, get_validation_dataset
|
|
||||||
from ..model import model_fn
|
|
||||||
from ..model.provider import ModelProvider
|
|
||||||
from ..utils.logging import get_logger
|
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
|
||||||
__author__ = 'Deezer Research'
|
|
||||||
__license__ = 'MIT License'
|
|
||||||
|
|
||||||
|
|
||||||
def _create_estimator(params):
|
|
||||||
""" Creates estimator.
|
|
||||||
|
|
||||||
:param params: TF params to build estimator from.
|
|
||||||
:returns: Built estimator.
|
|
||||||
"""
|
|
||||||
session_config = tf.compat.v1.ConfigProto()
|
|
||||||
session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
|
|
||||||
estimator = tf.estimator.Estimator(
|
|
||||||
model_fn=model_fn,
|
|
||||||
model_dir=params['model_dir'],
|
|
||||||
params=params,
|
|
||||||
config=tf.estimator.RunConfig(
|
|
||||||
save_checkpoints_steps=params['save_checkpoints_steps'],
|
|
||||||
tf_random_seed=params['random_seed'],
|
|
||||||
save_summary_steps=params['save_summary_steps'],
|
|
||||||
session_config=session_config,
|
|
||||||
log_step_count_steps=10,
|
|
||||||
keep_checkpoint_max=2))
|
|
||||||
return estimator
|
|
||||||
|
|
||||||
|
|
||||||
def _create_train_spec(params, audio_adapter, audio_path):
|
|
||||||
""" Creates train spec.
|
|
||||||
|
|
||||||
:param params: TF params to build spec from.
|
|
||||||
:returns: Built train spec.
|
|
||||||
"""
|
|
||||||
input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
|
|
||||||
train_spec = tf.estimator.TrainSpec(
|
|
||||||
input_fn=input_fn,
|
|
||||||
max_steps=params['train_max_steps'])
|
|
||||||
return train_spec
|
|
||||||
|
|
||||||
|
|
||||||
def _create_evaluation_spec(params, audio_adapter, audio_path):
|
|
||||||
""" Setup eval spec evaluating ever n seconds
|
|
||||||
|
|
||||||
:param params: TF params to build spec from.
|
|
||||||
:returns: Built evaluation spec.
|
|
||||||
"""
|
|
||||||
input_fn = partial(
|
|
||||||
get_validation_dataset,
|
|
||||||
params,
|
|
||||||
audio_adapter,
|
|
||||||
audio_path)
|
|
||||||
evaluation_spec = tf.estimator.EvalSpec(
|
|
||||||
input_fn=input_fn,
|
|
||||||
steps=None,
|
|
||||||
throttle_secs=params['throttle_secs'])
|
|
||||||
return evaluation_spec
|
|
||||||
|
|
||||||
|
|
||||||
def entrypoint(arguments, params):
|
|
||||||
""" Command entrypoint.
|
|
||||||
|
|
||||||
:param arguments: Command line parsed argument as argparse.Namespace.
|
|
||||||
:param params: Deserialized JSON configuration file provided in CLI args.
|
|
||||||
"""
|
|
||||||
audio_adapter = get_audio_adapter(arguments.audio_adapter)
|
|
||||||
audio_path = arguments.audio_path
|
|
||||||
estimator = _create_estimator(params)
|
|
||||||
train_spec = _create_train_spec(params, audio_adapter, audio_path)
|
|
||||||
evaluation_spec = _create_evaluation_spec(
|
|
||||||
params,
|
|
||||||
audio_adapter,
|
|
||||||
audio_path)
|
|
||||||
get_logger().info('Start model training')
|
|
||||||
tf.estimator.train_and_evaluate(
|
|
||||||
estimator,
|
|
||||||
train_spec,
|
|
||||||
evaluation_spec)
|
|
||||||
ModelProvider.writeProbe(params['model_dir'])
|
|
||||||
get_logger().info('Model training done')
|
|
||||||
@@ -14,87 +14,110 @@
|
|||||||
(ground truth)
|
(ground truth)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
|
||||||
import os
|
import os
|
||||||
from os.path import exists, join, sep as SEPARATOR
|
import time
|
||||||
|
from os.path import exists
|
||||||
|
from os.path import sep as SEPARATOR
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
# pylint: enable=import-error
|
|
||||||
|
|
||||||
from .audio.convertor import (
|
from .audio.adapter import AudioAdapter
|
||||||
db_uint_spectrogram_to_gain,
|
from .audio.convertor import db_uint_spectrogram_to_gain, spectrogram_to_db_uint
|
||||||
spectrogram_to_db_uint)
|
|
||||||
from .audio.spectrogram import (
|
from .audio.spectrogram import (
|
||||||
compute_spectrogram_tf,
|
compute_spectrogram_tf,
|
||||||
random_pitch_shift,
|
random_pitch_shift,
|
||||||
random_time_stretch)
|
random_time_stretch,
|
||||||
from .utils.logging import get_logger
|
)
|
||||||
|
from .utils.logging import logger
|
||||||
from .utils.tensor import (
|
from .utils.tensor import (
|
||||||
check_tensor_shape,
|
check_tensor_shape,
|
||||||
dataset_from_csv,
|
dataset_from_csv,
|
||||||
set_tensor_shape,
|
set_tensor_shape,
|
||||||
sync_apply)
|
sync_apply,
|
||||||
|
)
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
# pylint: enable=import-error
|
||||||
__author__ = 'Deezer Research'
|
|
||||||
__license__ = 'MIT License'
|
__email__ = "spleeter@deezer.com"
|
||||||
|
__author__ = "Deezer Research"
|
||||||
|
__license__ = "MIT License"
|
||||||
|
|
||||||
# Default audio parameters to use.
|
# Default audio parameters to use.
|
||||||
DEFAULT_AUDIO_PARAMS = {
|
DEFAULT_AUDIO_PARAMS: Dict = {
|
||||||
'instrument_list': ('vocals', 'accompaniment'),
|
"instrument_list": ("vocals", "accompaniment"),
|
||||||
'mix_name': 'mix',
|
"mix_name": "mix",
|
||||||
'sample_rate': 44100,
|
"sample_rate": 44100,
|
||||||
'frame_length': 4096,
|
"frame_length": 4096,
|
||||||
'frame_step': 1024,
|
"frame_step": 1024,
|
||||||
'T': 512,
|
"T": 512,
|
||||||
'F': 1024
|
"F": 1024,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_training_dataset(audio_params, audio_adapter, audio_path):
|
def get_training_dataset(
|
||||||
""" Builds training dataset.
|
audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Builds training dataset.
|
||||||
|
|
||||||
:param audio_params: Audio parameters.
|
Parameters:
|
||||||
:param audio_adapter: Adapter to load audio from.
|
audio_params (Dict):
|
||||||
:param audio_path: Path of directory containing audio.
|
Audio parameters.
|
||||||
:returns: Built dataset.
|
audio_adapter (AudioAdapter):
|
||||||
|
Adapter to load audio from.
|
||||||
|
audio_path (str):
|
||||||
|
Path of directory containing audio.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any:
|
||||||
|
Built dataset.
|
||||||
"""
|
"""
|
||||||
builder = DatasetBuilder(
|
builder = DatasetBuilder(
|
||||||
audio_params,
|
audio_params,
|
||||||
audio_adapter,
|
audio_adapter,
|
||||||
audio_path,
|
audio_path,
|
||||||
chunk_duration=audio_params.get('chunk_duration', 20.0),
|
chunk_duration=audio_params.get("chunk_duration", 20.0),
|
||||||
random_seed=audio_params.get('random_seed', 0))
|
random_seed=audio_params.get("random_seed", 0),
|
||||||
|
)
|
||||||
return builder.build(
|
return builder.build(
|
||||||
audio_params.get('train_csv'),
|
audio_params.get("train_csv"),
|
||||||
cache_directory=audio_params.get('training_cache'),
|
cache_directory=audio_params.get("training_cache"),
|
||||||
batch_size=audio_params.get('batch_size'),
|
batch_size=audio_params.get("batch_size"),
|
||||||
n_chunks_per_song=audio_params.get('n_chunks_per_song', 2),
|
n_chunks_per_song=audio_params.get("n_chunks_per_song", 2),
|
||||||
random_data_augmentation=False,
|
random_data_augmentation=False,
|
||||||
convert_to_uint=True,
|
convert_to_uint=True,
|
||||||
wait_for_cache=False)
|
wait_for_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_validation_dataset(audio_params, audio_adapter, audio_path):
|
def get_validation_dataset(
|
||||||
""" Builds validation dataset.
|
audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Builds validation dataset.
|
||||||
|
|
||||||
:param audio_params: Audio parameters.
|
Parameters:
|
||||||
:param audio_adapter: Adapter to load audio from.
|
audio_params (Dict):
|
||||||
:param audio_path: Path of directory containing audio.
|
Audio parameters.
|
||||||
:returns: Built dataset.
|
audio_adapter (AudioAdapter):
|
||||||
|
Adapter to load audio from.
|
||||||
|
audio_path (str):
|
||||||
|
Path of directory containing audio.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any:
|
||||||
|
Built dataset.
|
||||||
"""
|
"""
|
||||||
builder = DatasetBuilder(
|
builder = DatasetBuilder(
|
||||||
audio_params,
|
audio_params, audio_adapter, audio_path, chunk_duration=12.0
|
||||||
audio_adapter,
|
)
|
||||||
audio_path,
|
|
||||||
chunk_duration=12.0)
|
|
||||||
return builder.build(
|
return builder.build(
|
||||||
audio_params.get('validation_csv'),
|
audio_params.get("validation_csv"),
|
||||||
batch_size=audio_params.get('batch_size'),
|
batch_size=audio_params.get("batch_size"),
|
||||||
cache_directory=audio_params.get('validation_cache'),
|
cache_directory=audio_params.get("validation_cache"),
|
||||||
convert_to_uint=True,
|
convert_to_uint=True,
|
||||||
infinite_generator=False,
|
infinite_generator=False,
|
||||||
n_chunks_per_song=1,
|
n_chunks_per_song=1,
|
||||||
@@ -108,127 +131,175 @@ def get_validation_dataset(audio_params, audio_adapter, audio_path):
|
|||||||
class InstrumentDatasetBuilder(object):
|
class InstrumentDatasetBuilder(object):
|
||||||
""" Instrument based filter and mapper provider. """
|
""" Instrument based filter and mapper provider. """
|
||||||
|
|
||||||
def __init__(self, parent, instrument):
|
def __init__(self, parent, instrument) -> None:
|
||||||
""" Default constructor.
|
"""
|
||||||
|
Default constructor.
|
||||||
|
|
||||||
:param parent: Parent dataset builder.
|
Parameters:
|
||||||
:param instrument: Target instrument.
|
parent:
|
||||||
|
Parent dataset builder.
|
||||||
|
instrument:
|
||||||
|
Target instrument.
|
||||||
"""
|
"""
|
||||||
self._parent = parent
|
self._parent = parent
|
||||||
self._instrument = instrument
|
self._instrument = instrument
|
||||||
self._spectrogram_key = f'{instrument}_spectrogram'
|
self._spectrogram_key = f"{instrument}_spectrogram"
|
||||||
self._min_spectrogram_key = f'min_{instrument}_spectrogram'
|
self._min_spectrogram_key = f"min_{instrument}_spectrogram"
|
||||||
self._max_spectrogram_key = f'max_{instrument}_spectrogram'
|
self._max_spectrogram_key = f"max_{instrument}_spectrogram"
|
||||||
|
|
||||||
def load_waveform(self, sample):
|
def load_waveform(self, sample):
|
||||||
""" Load waveform for given sample. """
|
""" Load waveform for given sample. """
|
||||||
return dict(sample, **self._parent._audio_adapter.load_tf_waveform(
|
return dict(
|
||||||
sample[f'{self._instrument}_path'],
|
sample,
|
||||||
offset=sample['start'],
|
**self._parent._audio_adapter.load_tf_waveform(
|
||||||
duration=self._parent._chunk_duration,
|
sample[f"{self._instrument}_path"],
|
||||||
sample_rate=self._parent._sample_rate,
|
offset=sample["start"],
|
||||||
waveform_name='waveform'))
|
duration=self._parent._chunk_duration,
|
||||||
|
sample_rate=self._parent._sample_rate,
|
||||||
|
waveform_name="waveform",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def compute_spectrogram(self, sample):
|
def compute_spectrogram(self, sample):
|
||||||
""" Compute spectrogram of the given sample. """
|
""" Compute spectrogram of the given sample. """
|
||||||
return dict(sample, **{
|
return dict(
|
||||||
self._spectrogram_key: compute_spectrogram_tf(
|
sample,
|
||||||
sample['waveform'],
|
**{
|
||||||
frame_length=self._parent._frame_length,
|
self._spectrogram_key: compute_spectrogram_tf(
|
||||||
frame_step=self._parent._frame_step,
|
sample["waveform"],
|
||||||
spec_exponent=1.,
|
frame_length=self._parent._frame_length,
|
||||||
window_exponent=1.)})
|
frame_step=self._parent._frame_step,
|
||||||
|
spec_exponent=1.0,
|
||||||
|
window_exponent=1.0,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def filter_frequencies(self, sample):
|
def filter_frequencies(self, sample):
|
||||||
""" """
|
""" """
|
||||||
return dict(sample, **{
|
return dict(
|
||||||
self._spectrogram_key:
|
sample,
|
||||||
sample[self._spectrogram_key][:, :self._parent._F, :]})
|
**{
|
||||||
|
self._spectrogram_key: sample[self._spectrogram_key][
|
||||||
|
:, : self._parent._F, :
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def convert_to_uint(self, sample):
|
def convert_to_uint(self, sample):
|
||||||
""" Convert given sample from float to unit. """
|
""" Convert given sample from float to unit. """
|
||||||
return dict(sample, **spectrogram_to_db_uint(
|
return dict(
|
||||||
sample[self._spectrogram_key],
|
sample,
|
||||||
tensor_key=self._spectrogram_key,
|
**spectrogram_to_db_uint(
|
||||||
min_key=self._min_spectrogram_key,
|
sample[self._spectrogram_key],
|
||||||
max_key=self._max_spectrogram_key))
|
tensor_key=self._spectrogram_key,
|
||||||
|
min_key=self._min_spectrogram_key,
|
||||||
|
max_key=self._max_spectrogram_key,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def filter_infinity(self, sample):
|
def filter_infinity(self, sample):
|
||||||
""" Filter infinity sample. """
|
""" Filter infinity sample. """
|
||||||
return tf.logical_not(
|
return tf.logical_not(tf.math.is_inf(sample[self._min_spectrogram_key]))
|
||||||
tf.math.is_inf(
|
|
||||||
sample[self._min_spectrogram_key]))
|
|
||||||
|
|
||||||
def convert_to_float32(self, sample):
|
def convert_to_float32(self, sample):
|
||||||
""" Convert given sample from unit to float. """
|
""" Convert given sample from unit to float. """
|
||||||
return dict(sample, **{
|
return dict(
|
||||||
self._spectrogram_key: db_uint_spectrogram_to_gain(
|
sample,
|
||||||
sample[self._spectrogram_key],
|
**{
|
||||||
sample[self._min_spectrogram_key],
|
self._spectrogram_key: db_uint_spectrogram_to_gain(
|
||||||
sample[self._max_spectrogram_key])})
|
sample[self._spectrogram_key],
|
||||||
|
sample[self._min_spectrogram_key],
|
||||||
|
sample[self._max_spectrogram_key],
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def time_crop(self, sample):
|
def time_crop(self, sample):
|
||||||
""" """
|
""" """
|
||||||
|
|
||||||
def start(sample):
|
def start(sample):
|
||||||
""" mid_segment_start """
|
""" mid_segment_start """
|
||||||
return tf.cast(
|
return tf.cast(
|
||||||
tf.maximum(
|
tf.maximum(
|
||||||
tf.shape(sample[self._spectrogram_key])[0]
|
tf.shape(sample[self._spectrogram_key])[0] / 2
|
||||||
/ 2 - self._parent._T / 2, 0),
|
- self._parent._T / 2,
|
||||||
tf.int32)
|
0,
|
||||||
return dict(sample, **{
|
),
|
||||||
self._spectrogram_key: sample[self._spectrogram_key][
|
tf.int32,
|
||||||
start(sample):start(sample) + self._parent._T, :, :]})
|
)
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
sample,
|
||||||
|
**{
|
||||||
|
self._spectrogram_key: sample[self._spectrogram_key][
|
||||||
|
start(sample) : start(sample) + self._parent._T, :, :
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def filter_shape(self, sample):
|
def filter_shape(self, sample):
|
||||||
""" Filter badly shaped sample. """
|
""" Filter badly shaped sample. """
|
||||||
return check_tensor_shape(
|
return check_tensor_shape(
|
||||||
sample[self._spectrogram_key], (
|
sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2)
|
||||||
self._parent._T, self._parent._F, 2))
|
)
|
||||||
|
|
||||||
def reshape_spectrogram(self, sample):
|
def reshape_spectrogram(self, sample):
|
||||||
""" """
|
""" Reshape given sample. """
|
||||||
return dict(sample, **{
|
return dict(
|
||||||
self._spectrogram_key: set_tensor_shape(
|
sample,
|
||||||
sample[self._spectrogram_key],
|
**{
|
||||||
(self._parent._T, self._parent._F, 2))})
|
self._spectrogram_key: set_tensor_shape(
|
||||||
|
sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2)
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DatasetBuilder(object):
|
class DatasetBuilder(object):
|
||||||
"""
|
"""
|
||||||
|
TO BE DOCUMENTED.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Margin at beginning and end of songs in seconds.
|
MARGIN: float = 0.5
|
||||||
MARGIN = 0.5
|
""" Margin at beginning and end of songs in seconds. """
|
||||||
|
|
||||||
# Wait period for cache (in seconds).
|
WAIT_PERIOD: int = 60
|
||||||
WAIT_PERIOD = 60
|
""" Wait period for cache (in seconds). """
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
audio_params, audio_adapter, audio_path,
|
audio_params: Dict,
|
||||||
random_seed=0, chunk_duration=20.0):
|
audio_adapter: AudioAdapter,
|
||||||
""" Default constructor.
|
audio_path: str,
|
||||||
|
random_seed: int = 0,
|
||||||
|
chunk_duration: float = 20.0,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Default constructor.
|
||||||
|
|
||||||
NOTE: Probably need for AudioAdapter.
|
NOTE: Probably need for AudioAdapter.
|
||||||
|
|
||||||
:param audio_params: Audio parameters to use.
|
Parameters:
|
||||||
:param audio_adapter: Audio adapter to use.
|
audio_params (Dict):
|
||||||
:param audio_path:
|
Audio parameters to use.
|
||||||
:param random_seed:
|
audio_adapter (AudioAdapter):
|
||||||
:param chunk_duration:
|
Audio adapter to use.
|
||||||
|
audio_path (str):
|
||||||
|
random_seed (int):
|
||||||
|
chunk_duration (float):
|
||||||
"""
|
"""
|
||||||
# Length of segment in frames (if fs=22050 and
|
# Length of segment in frames (if fs=22050 and
|
||||||
# frame_step=512, then T=512 corresponds to 11.89s)
|
# frame_step=512, then T=512 corresponds to 11.89s)
|
||||||
self._T = audio_params['T']
|
self._T = audio_params["T"]
|
||||||
# Number of frequency bins to be used (should
|
# Number of frequency bins to be used (should
|
||||||
# be less than frame_length/2 + 1)
|
# be less than frame_length/2 + 1)
|
||||||
self._F = audio_params['F']
|
self._F = audio_params["F"]
|
||||||
self._sample_rate = audio_params['sample_rate']
|
self._sample_rate = audio_params["sample_rate"]
|
||||||
self._frame_length = audio_params['frame_length']
|
self._frame_length = audio_params["frame_length"]
|
||||||
self._frame_step = audio_params['frame_step']
|
self._frame_step = audio_params["frame_step"]
|
||||||
self._mix_name = audio_params['mix_name']
|
self._mix_name = audio_params["mix_name"]
|
||||||
self._instruments = [self._mix_name] + audio_params['instrument_list']
|
self._instruments = [self._mix_name] + audio_params["instrument_list"]
|
||||||
self._instrument_builders = None
|
self._instrument_builders = None
|
||||||
self._chunk_duration = chunk_duration
|
self._chunk_duration = chunk_duration
|
||||||
self._audio_adapter = audio_adapter
|
self._audio_adapter = audio_adapter
|
||||||
@@ -238,130 +309,202 @@ class DatasetBuilder(object):
|
|||||||
|
|
||||||
def expand_path(self, sample):
|
def expand_path(self, sample):
|
||||||
""" Expands audio paths for the given sample. """
|
""" Expands audio paths for the given sample. """
|
||||||
return dict(sample, **{f'{instrument}_path': tf.strings.join(
|
return dict(
|
||||||
(self._audio_path, sample[f'{instrument}_path']), SEPARATOR)
|
sample,
|
||||||
for instrument in self._instruments})
|
**{
|
||||||
|
f"{instrument}_path": tf.strings.join(
|
||||||
|
(self._audio_path, sample[f"{instrument}_path"]), SEPARATOR
|
||||||
|
)
|
||||||
|
for instrument in self._instruments
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def filter_error(self, sample):
|
def filter_error(self, sample):
|
||||||
""" Filter errored sample. """
|
""" Filter errored sample. """
|
||||||
return tf.logical_not(sample['waveform_error'])
|
return tf.logical_not(sample["waveform_error"])
|
||||||
|
|
||||||
def filter_waveform(self, sample):
|
def filter_waveform(self, sample):
|
||||||
""" Filter waveform from sample. """
|
""" Filter waveform from sample. """
|
||||||
return {k: v for k, v in sample.items() if not k == 'waveform'}
|
return {k: v for k, v in sample.items() if not k == "waveform"}
|
||||||
|
|
||||||
def harmonize_spectrogram(self, sample):
|
def harmonize_spectrogram(self, sample):
|
||||||
""" Ensure same size for vocals and mix spectrograms. """
|
""" Ensure same size for vocals and mix spectrograms. """
|
||||||
|
|
||||||
def _reduce(sample):
|
def _reduce(sample):
|
||||||
return tf.reduce_min([
|
return tf.reduce_min(
|
||||||
tf.shape(sample[f'{instrument}_spectrogram'])[0]
|
[
|
||||||
for instrument in self._instruments])
|
tf.shape(sample[f"{instrument}_spectrogram"])[0]
|
||||||
return dict(sample, **{
|
for instrument in self._instruments
|
||||||
f'{instrument}_spectrogram':
|
]
|
||||||
sample[f'{instrument}_spectrogram'][:_reduce(sample), :, :]
|
)
|
||||||
for instrument in self._instruments})
|
|
||||||
|
return dict(
|
||||||
|
sample,
|
||||||
|
**{
|
||||||
|
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"][
|
||||||
|
: _reduce(sample), :, :
|
||||||
|
]
|
||||||
|
for instrument in self._instruments
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def filter_short_segments(self, sample):
|
def filter_short_segments(self, sample):
|
||||||
""" Filter out too short segment. """
|
""" Filter out too short segment. """
|
||||||
return tf.reduce_any([
|
return tf.reduce_any(
|
||||||
tf.shape(sample[f'{instrument}_spectrogram'])[0] >= self._T
|
[
|
||||||
for instrument in self._instruments])
|
tf.shape(sample[f"{instrument}_spectrogram"])[0] >= self._T
|
||||||
|
for instrument in self._instruments
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def random_time_crop(self, sample):
|
def random_time_crop(self, sample):
|
||||||
""" Random time crop of 11.88s. """
|
""" Random time crop of 11.88s. """
|
||||||
return dict(sample, **sync_apply({
|
return dict(
|
||||||
f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram']
|
sample,
|
||||||
for instrument in self._instruments},
|
**sync_apply(
|
||||||
lambda x: tf.image.random_crop(
|
{
|
||||||
x, (self._T, len(self._instruments) * self._F, 2),
|
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
|
||||||
seed=self._random_seed)))
|
for instrument in self._instruments
|
||||||
|
},
|
||||||
|
lambda x: tf.image.random_crop(
|
||||||
|
x,
|
||||||
|
(self._T, len(self._instruments) * self._F, 2),
|
||||||
|
seed=self._random_seed,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def random_time_stretch(self, sample):
|
def random_time_stretch(self, sample):
|
||||||
""" Randomly time stretch the given sample. """
|
""" Randomly time stretch the given sample. """
|
||||||
return dict(sample, **sync_apply({
|
return dict(
|
||||||
f'{instrument}_spectrogram':
|
sample,
|
||||||
sample[f'{instrument}_spectrogram']
|
**sync_apply(
|
||||||
for instrument in self._instruments},
|
{
|
||||||
lambda x: random_time_stretch(
|
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
|
||||||
x, factor_min=0.9, factor_max=1.1)))
|
for instrument in self._instruments
|
||||||
|
},
|
||||||
|
lambda x: random_time_stretch(x, factor_min=0.9, factor_max=1.1),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def random_pitch_shift(self, sample):
|
def random_pitch_shift(self, sample):
|
||||||
""" Randomly pitch shift the given sample. """
|
""" Randomly pitch shift the given sample. """
|
||||||
return dict(sample, **sync_apply({
|
return dict(
|
||||||
f'{instrument}_spectrogram':
|
sample,
|
||||||
sample[f'{instrument}_spectrogram']
|
**sync_apply(
|
||||||
for instrument in self._instruments},
|
{
|
||||||
lambda x: random_pitch_shift(
|
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
|
||||||
x, shift_min=-1.0, shift_max=1.0), concat_axis=0))
|
for instrument in self._instruments
|
||||||
|
},
|
||||||
|
lambda x: random_pitch_shift(x, shift_min=-1.0, shift_max=1.0),
|
||||||
|
concat_axis=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def map_features(self, sample):
|
def map_features(self, sample):
|
||||||
""" Select features and annotation of the given sample. """
|
""" Select features and annotation of the given sample. """
|
||||||
input_ = {
|
input_ = {
|
||||||
f'{self._mix_name}_spectrogram':
|
f"{self._mix_name}_spectrogram": sample[f"{self._mix_name}_spectrogram"]
|
||||||
sample[f'{self._mix_name}_spectrogram']}
|
}
|
||||||
output = {
|
output = {
|
||||||
f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram']
|
f"{instrument}_spectrogram": sample[f"{instrument}_spectrogram"]
|
||||||
for instrument in self._audio_params['instrument_list']}
|
for instrument in self._audio_params["instrument_list"]
|
||||||
|
}
|
||||||
return (input_, output)
|
return (input_, output)
|
||||||
|
|
||||||
def compute_segments(self, dataset, n_chunks_per_song):
|
def compute_segments(self, dataset: Any, n_chunks_per_song: int) -> Any:
|
||||||
""" Computes segments for each song of the dataset.
|
"""
|
||||||
|
Computes segments for each song of the dataset.
|
||||||
|
|
||||||
:param dataset: Dataset to compute segments for.
|
Parameters:
|
||||||
:param n_chunks_per_song: Number of segment per song to compute.
|
dataset (Any):
|
||||||
:returns: Segmented dataset.
|
Dataset to compute segments for.
|
||||||
|
n_chunks_per_song (int):
|
||||||
|
Number of segment per song to compute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any:
|
||||||
|
Segmented dataset.
|
||||||
"""
|
"""
|
||||||
if n_chunks_per_song <= 0:
|
if n_chunks_per_song <= 0:
|
||||||
raise ValueError('n_chunks_per_song must be positif')
|
raise ValueError("n_chunks_per_song must be positif")
|
||||||
datasets = []
|
datasets = []
|
||||||
for k in range(n_chunks_per_song):
|
for k in range(n_chunks_per_song):
|
||||||
if n_chunks_per_song > 1:
|
if n_chunks_per_song > 1:
|
||||||
datasets.append(
|
datasets.append(
|
||||||
dataset.map(lambda sample: dict(sample, start=tf.maximum(
|
dataset.map(
|
||||||
k * (
|
lambda sample: dict(
|
||||||
sample['duration'] - self._chunk_duration - 2
|
sample,
|
||||||
* self.MARGIN) / (n_chunks_per_song - 1)
|
start=tf.maximum(
|
||||||
+ self.MARGIN, 0))))
|
k
|
||||||
|
* (
|
||||||
|
sample["duration"]
|
||||||
|
- self._chunk_duration
|
||||||
|
- 2 * self.MARGIN
|
||||||
|
)
|
||||||
|
/ (n_chunks_per_song - 1)
|
||||||
|
+ self.MARGIN,
|
||||||
|
0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
elif n_chunks_per_song == 1: # Take central segment.
|
elif n_chunks_per_song == 1: # Take central segment.
|
||||||
datasets.append(
|
datasets.append(
|
||||||
dataset.map(lambda sample: dict(sample, start=tf.maximum(
|
dataset.map(
|
||||||
sample['duration'] / 2 - self._chunk_duration / 2,
|
lambda sample: dict(
|
||||||
0))))
|
sample,
|
||||||
|
start=tf.maximum(
|
||||||
|
sample["duration"] / 2 - self._chunk_duration / 2, 0
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
dataset = datasets[-1]
|
dataset = datasets[-1]
|
||||||
for d in datasets[:-1]:
|
for d in datasets[:-1]:
|
||||||
dataset = dataset.concatenate(d)
|
dataset = dataset.concatenate(d)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def instruments(self):
|
def instruments(self) -> Any:
|
||||||
""" Instrument dataset builder generator.
|
"""
|
||||||
|
Instrument dataset builder generator.
|
||||||
|
|
||||||
:yield InstrumentBuilder instance.
|
Yields:
|
||||||
|
Any:
|
||||||
|
InstrumentBuilder instance.
|
||||||
"""
|
"""
|
||||||
if self._instrument_builders is None:
|
if self._instrument_builders is None:
|
||||||
self._instrument_builders = []
|
self._instrument_builders = []
|
||||||
for instrument in self._instruments:
|
for instrument in self._instruments:
|
||||||
self._instrument_builders.append(
|
self._instrument_builders.append(
|
||||||
InstrumentDatasetBuilder(self, instrument))
|
InstrumentDatasetBuilder(self, instrument)
|
||||||
|
)
|
||||||
for builder in self._instrument_builders:
|
for builder in self._instrument_builders:
|
||||||
yield builder
|
yield builder
|
||||||
|
|
||||||
def cache(self, dataset, cache, wait):
|
def cache(self, dataset: Any, cache: str, wait: bool) -> Any:
|
||||||
""" Cache the given dataset if cache is enabled. Eventually waits for
|
"""
|
||||||
cache to be available (useful if another process is already computing
|
Cache the given dataset if cache is enabled. Eventually waits for
|
||||||
cache) if provided wait flag is True.
|
cache to be available (useful if another process is already
|
||||||
|
computing cache) if provided wait flag is `True`.
|
||||||
|
|
||||||
:param dataset: Dataset to be cached if cache is required.
|
Parameters:
|
||||||
:param cache: Path of cache directory to be used, None if no cache.
|
dataset (Any):
|
||||||
:param wait: If caching is enabled, True is cache should be waited.
|
Dataset to be cached if cache is required.
|
||||||
:returns: Cached dataset if needed, original dataset otherwise.
|
cache (str):
|
||||||
|
Path of cache directory to be used, None if no cache.
|
||||||
|
wait (bool):
|
||||||
|
If caching is enabled, True is cache should be waited.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any:
|
||||||
|
Cached dataset if needed, original dataset otherwise.
|
||||||
"""
|
"""
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
if wait:
|
if wait:
|
||||||
while not exists(f'{cache}.index'):
|
while not exists(f"{cache}.index"):
|
||||||
get_logger().info(
|
logger.info(f"Cache not available, wait {self.WAIT_PERIOD}")
|
||||||
'Cache not available, wait %s',
|
|
||||||
self.WAIT_PERIOD)
|
|
||||||
time.sleep(self.WAIT_PERIOD)
|
time.sleep(self.WAIT_PERIOD)
|
||||||
cache_path = os.path.split(cache)[0]
|
cache_path = os.path.split(cache)[0]
|
||||||
os.makedirs(cache_path, exist_ok=True)
|
os.makedirs(cache_path, exist_ok=True)
|
||||||
@@ -369,11 +512,19 @@ class DatasetBuilder(object):
|
|||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
def build(
|
def build(
|
||||||
self, csv_path,
|
self,
|
||||||
batch_size=8, shuffle=True, convert_to_uint=True,
|
csv_path: str,
|
||||||
random_data_augmentation=False, random_time_crop=True,
|
batch_size: int = 8,
|
||||||
infinite_generator=True, cache_directory=None,
|
shuffle: bool = True,
|
||||||
wait_for_cache=False, num_parallel_calls=4, n_chunks_per_song=2,):
|
convert_to_uint: bool = True,
|
||||||
|
random_data_augmentation: bool = False,
|
||||||
|
random_time_crop: bool = True,
|
||||||
|
infinite_generator: bool = True,
|
||||||
|
cache_directory: Optional[str] = None,
|
||||||
|
wait_for_cache: bool = False,
|
||||||
|
num_parallel_calls: int = 4,
|
||||||
|
n_chunks_per_song: float = 2,
|
||||||
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
TO BE DOCUMENTED.
|
TO BE DOCUMENTED.
|
||||||
"""
|
"""
|
||||||
@@ -385,7 +536,8 @@ class DatasetBuilder(object):
|
|||||||
buffer_size=200000,
|
buffer_size=200000,
|
||||||
seed=self._random_seed,
|
seed=self._random_seed,
|
||||||
# useless since it is cached :
|
# useless since it is cached :
|
||||||
reshuffle_each_iteration=True)
|
reshuffle_each_iteration=True,
|
||||||
|
)
|
||||||
# Expand audio path.
|
# Expand audio path.
|
||||||
dataset = dataset.map(self.expand_path)
|
dataset = dataset.map(self.expand_path)
|
||||||
# Load waveform, compute spectrogram, and filtering error,
|
# Load waveform, compute spectrogram, and filtering error,
|
||||||
@@ -393,11 +545,11 @@ class DatasetBuilder(object):
|
|||||||
N = num_parallel_calls
|
N = num_parallel_calls
|
||||||
for instrument in self.instruments:
|
for instrument in self.instruments:
|
||||||
dataset = (
|
dataset = (
|
||||||
dataset
|
dataset.map(instrument.load_waveform, num_parallel_calls=N)
|
||||||
.map(instrument.load_waveform, num_parallel_calls=N)
|
|
||||||
.filter(self.filter_error)
|
.filter(self.filter_error)
|
||||||
.map(instrument.compute_spectrogram, num_parallel_calls=N)
|
.map(instrument.compute_spectrogram, num_parallel_calls=N)
|
||||||
.map(instrument.filter_frequencies))
|
.map(instrument.filter_frequencies)
|
||||||
|
)
|
||||||
dataset = dataset.map(self.filter_waveform)
|
dataset = dataset.map(self.filter_waveform)
|
||||||
# Convert to uint before caching in order to save space.
|
# Convert to uint before caching in order to save space.
|
||||||
if convert_to_uint:
|
if convert_to_uint:
|
||||||
@@ -428,26 +580,25 @@ class DatasetBuilder(object):
|
|||||||
# after croping but before converting back to float.
|
# after croping but before converting back to float.
|
||||||
if shuffle:
|
if shuffle:
|
||||||
dataset = dataset.shuffle(
|
dataset = dataset.shuffle(
|
||||||
buffer_size=256, seed=self._random_seed,
|
buffer_size=256, seed=self._random_seed, reshuffle_each_iteration=True
|
||||||
reshuffle_each_iteration=True)
|
)
|
||||||
# Convert back to float32
|
# Convert back to float32
|
||||||
if convert_to_uint:
|
if convert_to_uint:
|
||||||
for instrument in self.instruments:
|
for instrument in self.instruments:
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
instrument.convert_to_float32, num_parallel_calls=N)
|
instrument.convert_to_float32, num_parallel_calls=N
|
||||||
|
)
|
||||||
M = 8 # Parallel call post caching.
|
M = 8 # Parallel call post caching.
|
||||||
# Must be applied with the same factor on mix and vocals.
|
# Must be applied with the same factor on mix and vocals.
|
||||||
if random_data_augmentation:
|
if random_data_augmentation:
|
||||||
dataset = (
|
dataset = dataset.map(self.random_time_stretch, num_parallel_calls=M).map(
|
||||||
dataset
|
self.random_pitch_shift, num_parallel_calls=M
|
||||||
.map(self.random_time_stretch, num_parallel_calls=M)
|
)
|
||||||
.map(self.random_pitch_shift, num_parallel_calls=M))
|
|
||||||
# Filter by shape (remove badly shaped tensors).
|
# Filter by shape (remove badly shaped tensors).
|
||||||
for instrument in self.instruments:
|
for instrument in self.instruments:
|
||||||
dataset = (
|
dataset = dataset.filter(instrument.filter_shape).map(
|
||||||
dataset
|
instrument.reshape_spectrogram
|
||||||
.filter(instrument.filter_shape)
|
)
|
||||||
.map(instrument.reshape_spectrogram))
|
|
||||||
# Select features and annotation.
|
# Select features and annotation.
|
||||||
dataset = dataset.map(self.map_features)
|
dataset = dataset.map(self.map_features)
|
||||||
# Make batch (done after selection to avoid
|
# Make batch (done after selection to avoid
|
||||||
|
|||||||
@@ -5,17 +5,19 @@
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow.signal import hann_window, inverse_stft, stft
|
||||||
from tensorflow.signal import stft, inverse_stft, hann_window
|
|
||||||
# pylint: enable=import-error
|
|
||||||
|
|
||||||
from ..utils.tensor import pad_and_partition, pad_and_reshape
|
from ..utils.tensor import pad_and_partition, pad_and_reshape
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
# pylint: enable=import-error
|
||||||
__author__ = 'Deezer Research'
|
|
||||||
__license__ = 'MIT License'
|
|
||||||
|
__email__ = "spleeter@deezer.com"
|
||||||
|
__author__ = "Deezer Research"
|
||||||
|
__license__ = "MIT License"
|
||||||
|
|
||||||
|
|
||||||
placeholder = tf.compat.v1.placeholder
|
placeholder = tf.compat.v1.placeholder
|
||||||
@@ -23,29 +25,28 @@ placeholder = tf.compat.v1.placeholder
|
|||||||
|
|
||||||
def get_model_function(model_type):
|
def get_model_function(model_type):
|
||||||
"""
|
"""
|
||||||
Get tensorflow function of the model to be applied to the input tensor.
|
Get tensorflow function of the model to be applied to the input tensor.
|
||||||
For instance "unet.softmax_unet" will return the softmax_unet function
|
For instance "unet.softmax_unet" will return the softmax_unet function
|
||||||
in the "unet.py" submodule of the current module (spleeter.model).
|
in the "unet.py" submodule of the current module (spleeter.model).
|
||||||
|
|
||||||
Params:
|
Params:
|
||||||
- model_type: str
|
- model_type: str
|
||||||
the relative module path to the model function.
|
the relative module path to the model function.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tensorflow function to be applied to the input tensor to get the
|
A tensorflow function to be applied to the input tensor to get the
|
||||||
multitrack output.
|
multitrack output.
|
||||||
"""
|
"""
|
||||||
relative_path_to_module = '.'.join(model_type.split('.')[:-1])
|
relative_path_to_module = ".".join(model_type.split(".")[:-1])
|
||||||
model_name = model_type.split('.')[-1]
|
model_name = model_type.split(".")[-1]
|
||||||
main_module = '.'.join((__name__, 'functions'))
|
main_module = ".".join((__name__, "functions"))
|
||||||
path_to_module = f'{main_module}.{relative_path_to_module}'
|
path_to_module = f"{main_module}.{relative_path_to_module}"
|
||||||
module = importlib.import_module(path_to_module)
|
module = importlib.import_module(path_to_module)
|
||||||
model_function = getattr(module, model_name)
|
model_function = getattr(module, model_name)
|
||||||
return model_function
|
return model_function
|
||||||
|
|
||||||
|
|
||||||
class InputProvider(object):
|
class InputProvider(object):
|
||||||
|
|
||||||
def __init__(self, params):
|
def __init__(self, params):
|
||||||
self.params = params
|
self.params = params
|
||||||
|
|
||||||
@@ -61,16 +62,16 @@ class InputProvider(object):
|
|||||||
|
|
||||||
|
|
||||||
class WaveformInputProvider(InputProvider):
|
class WaveformInputProvider(InputProvider):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_names(self):
|
def input_names(self):
|
||||||
return ["audio_id", "waveform"]
|
return ["audio_id", "waveform"]
|
||||||
|
|
||||||
def get_input_dict_placeholders(self):
|
def get_input_dict_placeholders(self):
|
||||||
shape = (None, self.params['n_channels'])
|
shape = (None, self.params["n_channels"])
|
||||||
features = {
|
features = {
|
||||||
'waveform': placeholder(tf.float32, shape=shape, name="waveform"),
|
"waveform": placeholder(tf.float32, shape=shape, name="waveform"),
|
||||||
'audio_id': placeholder(tf.string, name="audio_id")}
|
"audio_id": placeholder(tf.string, name="audio_id"),
|
||||||
|
}
|
||||||
return features
|
return features
|
||||||
|
|
||||||
def get_feed_dict(self, features, waveform, audio_id):
|
def get_feed_dict(self, features, waveform, audio_id):
|
||||||
@@ -78,7 +79,6 @@ class WaveformInputProvider(InputProvider):
|
|||||||
|
|
||||||
|
|
||||||
class SpectralInputProvider(InputProvider):
|
class SpectralInputProvider(InputProvider):
|
||||||
|
|
||||||
def __init__(self, params):
|
def __init__(self, params):
|
||||||
super().__init__(params)
|
super().__init__(params)
|
||||||
self.stft_input_name = "{}_stft".format(self.params["mix_name"])
|
self.stft_input_name = "{}_stft".format(self.params["mix_name"])
|
||||||
@@ -89,11 +89,17 @@ class SpectralInputProvider(InputProvider):
|
|||||||
|
|
||||||
def get_input_dict_placeholders(self):
|
def get_input_dict_placeholders(self):
|
||||||
features = {
|
features = {
|
||||||
self.stft_input_name: placeholder(tf.complex64,
|
self.stft_input_name: placeholder(
|
||||||
shape=(None, self.params["frame_length"]//2+1,
|
tf.complex64,
|
||||||
self.params['n_channels']),
|
shape=(
|
||||||
name=self.stft_input_name),
|
None,
|
||||||
'audio_id': placeholder(tf.string, name="audio_id")}
|
self.params["frame_length"] // 2 + 1,
|
||||||
|
self.params["n_channels"],
|
||||||
|
),
|
||||||
|
name=self.stft_input_name,
|
||||||
|
),
|
||||||
|
"audio_id": placeholder(tf.string, name="audio_id"),
|
||||||
|
}
|
||||||
return features
|
return features
|
||||||
|
|
||||||
def get_feed_dict(self, features, stft, audio_id):
|
def get_feed_dict(self, features, stft, audio_id):
|
||||||
@@ -101,11 +107,13 @@ class SpectralInputProvider(InputProvider):
|
|||||||
|
|
||||||
|
|
||||||
class InputProviderFactory(object):
|
class InputProviderFactory(object):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get(params):
|
def get(params):
|
||||||
stft_backend = params["stft_backend"]
|
stft_backend = params["stft_backend"]
|
||||||
assert stft_backend in ("tensorflow", "librosa"), "Unexpected backend {}".format(stft_backend)
|
assert stft_backend in (
|
||||||
|
"tensorflow",
|
||||||
|
"librosa",
|
||||||
|
), "Unexpected backend {}".format(stft_backend)
|
||||||
if stft_backend == "tensorflow":
|
if stft_backend == "tensorflow":
|
||||||
return WaveformInputProvider(params)
|
return WaveformInputProvider(params)
|
||||||
else:
|
else:
|
||||||
@@ -113,7 +121,7 @@ class InputProviderFactory(object):
|
|||||||
|
|
||||||
|
|
||||||
class EstimatorSpecBuilder(object):
|
class EstimatorSpecBuilder(object):
|
||||||
""" A builder class that allows to builds a multitrack unet model
|
"""A builder class that allows to builds a multitrack unet model
|
||||||
estimator. The built model estimator has a different behaviour when
|
estimator. The built model estimator has a different behaviour when
|
||||||
used in a train/eval mode and in predict mode.
|
used in a train/eval mode and in predict mode.
|
||||||
|
|
||||||
@@ -137,22 +145,22 @@ class EstimatorSpecBuilder(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Supported model functions.
|
# Supported model functions.
|
||||||
DEFAULT_MODEL = 'unet.unet'
|
DEFAULT_MODEL = "unet.unet"
|
||||||
|
|
||||||
# Supported loss functions.
|
# Supported loss functions.
|
||||||
L1_MASK = 'L1_mask'
|
L1_MASK = "L1_mask"
|
||||||
WEIGHTED_L1_MASK = 'weighted_L1_mask'
|
WEIGHTED_L1_MASK = "weighted_L1_mask"
|
||||||
|
|
||||||
# Supported optimizers.
|
# Supported optimizers.
|
||||||
ADADELTA = 'Adadelta'
|
ADADELTA = "Adadelta"
|
||||||
SGD = 'SGD'
|
SGD = "SGD"
|
||||||
|
|
||||||
# Math constants.
|
# Math constants.
|
||||||
WINDOW_COMPENSATION_FACTOR = 2./3.
|
WINDOW_COMPENSATION_FACTOR = 2.0 / 3.0
|
||||||
EPSILON = 1e-10
|
EPSILON = 1e-10
|
||||||
|
|
||||||
def __init__(self, features, params):
|
def __init__(self, features, params):
|
||||||
""" Default constructor. Depending on built model
|
"""Default constructor. Depending on built model
|
||||||
usage, the provided features should be different:
|
usage, the provided features should be different:
|
||||||
|
|
||||||
* In train/eval mode: features is a dictionary with a
|
* In train/eval mode: features is a dictionary with a
|
||||||
@@ -169,20 +177,20 @@ class EstimatorSpecBuilder(object):
|
|||||||
self._features = features
|
self._features = features
|
||||||
self._params = params
|
self._params = params
|
||||||
# Get instrument name.
|
# Get instrument name.
|
||||||
self._mix_name = params['mix_name']
|
self._mix_name = params["mix_name"]
|
||||||
self._instruments = params['instrument_list']
|
self._instruments = params["instrument_list"]
|
||||||
# Get STFT/signals parameters
|
# Get STFT/signals parameters
|
||||||
self._n_channels = params['n_channels']
|
self._n_channels = params["n_channels"]
|
||||||
self._T = params['T']
|
self._T = params["T"]
|
||||||
self._F = params['F']
|
self._F = params["F"]
|
||||||
self._frame_length = params['frame_length']
|
self._frame_length = params["frame_length"]
|
||||||
self._frame_step = params['frame_step']
|
self._frame_step = params["frame_step"]
|
||||||
|
|
||||||
def include_stft_computations(self):
|
def include_stft_computations(self):
|
||||||
return self._params["stft_backend"] == "tensorflow"
|
return self._params["stft_backend"] == "tensorflow"
|
||||||
|
|
||||||
def _build_model_outputs(self):
|
def _build_model_outputs(self):
|
||||||
""" Created a batch_sizexTxFxn_channels input tensor containing
|
"""Created a batch_sizexTxFxn_channels input tensor containing
|
||||||
mix magnitude spectrogram, then an output dict from it according
|
mix magnitude spectrogram, then an output dict from it according
|
||||||
to the selected model in internal parameters.
|
to the selected model in internal parameters.
|
||||||
|
|
||||||
@@ -191,22 +199,21 @@ class EstimatorSpecBuilder(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
input_tensor = self.spectrogram_feature
|
input_tensor = self.spectrogram_feature
|
||||||
model = self._params.get('model', None)
|
model = self._params.get("model", None)
|
||||||
if model is not None:
|
if model is not None:
|
||||||
model_type = model.get('type', self.DEFAULT_MODEL)
|
model_type = model.get("type", self.DEFAULT_MODEL)
|
||||||
else:
|
else:
|
||||||
model_type = self.DEFAULT_MODEL
|
model_type = self.DEFAULT_MODEL
|
||||||
try:
|
try:
|
||||||
apply_model = get_model_function(model_type)
|
apply_model = get_model_function(model_type)
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
raise ValueError(f'No model function {model_type} found')
|
raise ValueError(f"No model function {model_type} found")
|
||||||
self._model_outputs = apply_model(
|
self._model_outputs = apply_model(
|
||||||
input_tensor,
|
input_tensor, self._instruments, self._params["model"]["params"]
|
||||||
self._instruments,
|
)
|
||||||
self._params['model']['params'])
|
|
||||||
|
|
||||||
def _build_loss(self, labels):
|
def _build_loss(self, labels):
|
||||||
""" Construct tensorflow loss and metrics
|
"""Construct tensorflow loss and metrics
|
||||||
|
|
||||||
:param output_dict: dictionary of network outputs (key: instrument
|
:param output_dict: dictionary of network outputs (key: instrument
|
||||||
name, value: estimated spectrogram of the instrument)
|
name, value: estimated spectrogram of the instrument)
|
||||||
@@ -215,7 +222,7 @@ class EstimatorSpecBuilder(object):
|
|||||||
:returns: tensorflow (loss, metrics) tuple.
|
:returns: tensorflow (loss, metrics) tuple.
|
||||||
"""
|
"""
|
||||||
output_dict = self.model_outputs
|
output_dict = self.model_outputs
|
||||||
loss_type = self._params.get('loss_type', self.L1_MASK)
|
loss_type = self._params.get("loss_type", self.L1_MASK)
|
||||||
if loss_type == self.L1_MASK:
|
if loss_type == self.L1_MASK:
|
||||||
losses = {
|
losses = {
|
||||||
name: tf.reduce_mean(tf.abs(output - labels[name]))
|
name: tf.reduce_mean(tf.abs(output - labels[name]))
|
||||||
@@ -224,11 +231,9 @@ class EstimatorSpecBuilder(object):
|
|||||||
elif loss_type == self.WEIGHTED_L1_MASK:
|
elif loss_type == self.WEIGHTED_L1_MASK:
|
||||||
losses = {
|
losses = {
|
||||||
name: tf.reduce_mean(
|
name: tf.reduce_mean(
|
||||||
tf.reduce_mean(
|
tf.reduce_mean(labels[name], axis=[1, 2, 3], keep_dims=True)
|
||||||
labels[name],
|
* tf.abs(output - labels[name])
|
||||||
axis=[1, 2, 3],
|
)
|
||||||
keep_dims=True) *
|
|
||||||
tf.abs(output - labels[name]))
|
|
||||||
for name, output in output_dict.items()
|
for name, output in output_dict.items()
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
@@ -236,20 +241,20 @@ class EstimatorSpecBuilder(object):
|
|||||||
loss = tf.reduce_sum(list(losses.values()))
|
loss = tf.reduce_sum(list(losses.values()))
|
||||||
# Add metrics for monitoring each instrument.
|
# Add metrics for monitoring each instrument.
|
||||||
metrics = {k: tf.compat.v1.metrics.mean(v) for k, v in losses.items()}
|
metrics = {k: tf.compat.v1.metrics.mean(v) for k, v in losses.items()}
|
||||||
metrics['absolute_difference'] = tf.compat.v1.metrics.mean(loss)
|
metrics["absolute_difference"] = tf.compat.v1.metrics.mean(loss)
|
||||||
return loss, metrics
|
return loss, metrics
|
||||||
|
|
||||||
def _build_optimizer(self):
|
def _build_optimizer(self):
|
||||||
""" Builds an optimizer instance from internal parameter values.
|
"""Builds an optimizer instance from internal parameter values.
|
||||||
|
|
||||||
Default to AdamOptimizer if not specified.
|
Default to AdamOptimizer if not specified.
|
||||||
|
|
||||||
:returns: Optimizer instance from internal configuration.
|
:returns: Optimizer instance from internal configuration.
|
||||||
"""
|
"""
|
||||||
name = self._params.get('optimizer')
|
name = self._params.get("optimizer")
|
||||||
if name == self.ADADELTA:
|
if name == self.ADADELTA:
|
||||||
return tf.compat.v1.train.AdadeltaOptimizer()
|
return tf.compat.v1.train.AdadeltaOptimizer()
|
||||||
rate = self._params['learning_rate']
|
rate = self._params["learning_rate"]
|
||||||
if name == self.SGD:
|
if name == self.SGD:
|
||||||
return tf.compat.v1.train.GradientDescentOptimizer(rate)
|
return tf.compat.v1.train.GradientDescentOptimizer(rate)
|
||||||
return tf.compat.v1.train.AdamOptimizer(rate)
|
return tf.compat.v1.train.AdamOptimizer(rate)
|
||||||
@@ -260,15 +265,15 @@ class EstimatorSpecBuilder(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def stft_name(self):
|
def stft_name(self):
|
||||||
return f'{self._mix_name}_stft'
|
return f"{self._mix_name}_stft"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def spectrogram_name(self):
|
def spectrogram_name(self):
|
||||||
return f'{self._mix_name}_spectrogram'
|
return f"{self._mix_name}_spectrogram"
|
||||||
|
|
||||||
def _build_stft_feature(self):
|
def _build_stft_feature(self):
|
||||||
""" Compute STFT of waveform and slice the STFT in segment
|
"""Compute STFT of waveform and slice the STFT in segment
|
||||||
with the right length to feed the network.
|
with the right length to feed the network.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
stft_name = self.stft_name
|
stft_name = self.stft_name
|
||||||
@@ -276,25 +281,30 @@ class EstimatorSpecBuilder(object):
|
|||||||
|
|
||||||
if stft_name not in self._features:
|
if stft_name not in self._features:
|
||||||
# pad input with a frame of zeros
|
# pad input with a frame of zeros
|
||||||
waveform = tf.concat([
|
waveform = tf.concat(
|
||||||
tf.zeros((self._frame_length, self._n_channels)),
|
[
|
||||||
self._features['waveform']
|
tf.zeros((self._frame_length, self._n_channels)),
|
||||||
],
|
self._features["waveform"],
|
||||||
0
|
],
|
||||||
)
|
0,
|
||||||
|
)
|
||||||
stft_feature = tf.transpose(
|
stft_feature = tf.transpose(
|
||||||
stft(
|
stft(
|
||||||
tf.transpose(waveform),
|
tf.transpose(waveform),
|
||||||
self._frame_length,
|
self._frame_length,
|
||||||
self._frame_step,
|
self._frame_step,
|
||||||
window_fn=lambda frame_length, dtype: (
|
window_fn=lambda frame_length, dtype: (
|
||||||
hann_window(frame_length, periodic=True, dtype=dtype)),
|
hann_window(frame_length, periodic=True, dtype=dtype)
|
||||||
pad_end=True),
|
),
|
||||||
perm=[1, 2, 0])
|
pad_end=True,
|
||||||
self._features[f'{self._mix_name}_stft'] = stft_feature
|
),
|
||||||
|
perm=[1, 2, 0],
|
||||||
|
)
|
||||||
|
self._features[f"{self._mix_name}_stft"] = stft_feature
|
||||||
if spec_name not in self._features:
|
if spec_name not in self._features:
|
||||||
self._features[spec_name] = tf.abs(
|
self._features[spec_name] = tf.abs(
|
||||||
pad_and_partition(self._features[stft_name], self._T))[:, :, :self._F, :]
|
pad_and_partition(self._features[stft_name], self._T)
|
||||||
|
)[:, :, : self._F, :]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_outputs(self):
|
def model_outputs(self):
|
||||||
@@ -333,25 +343,29 @@ class EstimatorSpecBuilder(object):
|
|||||||
return self._masked_stfts
|
return self._masked_stfts
|
||||||
|
|
||||||
def _inverse_stft(self, stft_t, time_crop=None):
|
def _inverse_stft(self, stft_t, time_crop=None):
|
||||||
""" Inverse and reshape the given STFT
|
"""Inverse and reshape the given STFT
|
||||||
|
|
||||||
:param stft_t: input STFT
|
:param stft_t: input STFT
|
||||||
:returns: inverse STFT (waveform)
|
:returns: inverse STFT (waveform)
|
||||||
"""
|
"""
|
||||||
inversed = inverse_stft(
|
inversed = (
|
||||||
tf.transpose(stft_t, perm=[2, 0, 1]),
|
inverse_stft(
|
||||||
self._frame_length,
|
tf.transpose(stft_t, perm=[2, 0, 1]),
|
||||||
self._frame_step,
|
self._frame_length,
|
||||||
window_fn=lambda frame_length, dtype: (
|
self._frame_step,
|
||||||
hann_window(frame_length, periodic=True, dtype=dtype))
|
window_fn=lambda frame_length, dtype: (
|
||||||
) * self.WINDOW_COMPENSATION_FACTOR
|
hann_window(frame_length, periodic=True, dtype=dtype)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
* self.WINDOW_COMPENSATION_FACTOR
|
||||||
|
)
|
||||||
reshaped = tf.transpose(inversed)
|
reshaped = tf.transpose(inversed)
|
||||||
if time_crop is None:
|
if time_crop is None:
|
||||||
time_crop = tf.shape(self._features['waveform'])[0]
|
time_crop = tf.shape(self._features["waveform"])[0]
|
||||||
return reshaped[self._frame_length:self._frame_length+time_crop, :]
|
return reshaped[self._frame_length : self._frame_length + time_crop, :]
|
||||||
|
|
||||||
def _build_mwf_output_waveform(self):
|
def _build_mwf_output_waveform(self):
|
||||||
""" Perform separation with multichannel Wiener Filtering using Norbert.
|
"""Perform separation with multichannel Wiener Filtering using Norbert.
|
||||||
Note: multichannel Wiener Filtering is not coded in Tensorflow and thus
|
Note: multichannel Wiener Filtering is not coded in Tensorflow and thus
|
||||||
may be quite slow.
|
may be quite slow.
|
||||||
|
|
||||||
@@ -359,36 +373,42 @@ class EstimatorSpecBuilder(object):
|
|||||||
value: estimated waveform of the instrument)
|
value: estimated waveform of the instrument)
|
||||||
"""
|
"""
|
||||||
import norbert # pylint: disable=import-error
|
import norbert # pylint: disable=import-error
|
||||||
|
|
||||||
output_dict = self.model_outputs
|
output_dict = self.model_outputs
|
||||||
x = self.stft_feature
|
x = self.stft_feature
|
||||||
v = tf.stack(
|
v = tf.stack(
|
||||||
[
|
[
|
||||||
pad_and_reshape(
|
pad_and_reshape(
|
||||||
output_dict[f'{instrument}_spectrogram'],
|
output_dict[f"{instrument}_spectrogram"],
|
||||||
self._frame_length,
|
self._frame_length,
|
||||||
self._F)[:tf.shape(x)[0], ...]
|
self._F,
|
||||||
|
)[: tf.shape(x)[0], ...]
|
||||||
for instrument in self._instruments
|
for instrument in self._instruments
|
||||||
],
|
],
|
||||||
axis=3)
|
axis=3,
|
||||||
|
)
|
||||||
input_args = [v, x]
|
input_args = [v, x]
|
||||||
stft_function = tf.py_function(
|
stft_function = (
|
||||||
lambda v, x: norbert.wiener(v.numpy(), x.numpy()),
|
tf.py_function(
|
||||||
input_args,
|
lambda v, x: norbert.wiener(v.numpy(), x.numpy()),
|
||||||
tf.complex64),
|
input_args,
|
||||||
|
tf.complex64,
|
||||||
|
),
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
instrument: self._inverse_stft(stft_function[0][:, :, :, k])
|
instrument: self._inverse_stft(stft_function[0][:, :, :, k])
|
||||||
for k, instrument in enumerate(self._instruments)
|
for k, instrument in enumerate(self._instruments)
|
||||||
}
|
}
|
||||||
|
|
||||||
def _extend_mask(self, mask):
|
def _extend_mask(self, mask):
|
||||||
""" Extend mask, from reduced number of frequency bin to the number of
|
"""Extend mask, from reduced number of frequency bin to the number of
|
||||||
frequency bin in the STFT.
|
frequency bin in the STFT.
|
||||||
|
|
||||||
:param mask: restricted mask
|
:param mask: restricted mask
|
||||||
:returns: extended mask
|
:returns: extended mask
|
||||||
:raise ValueError: If invalid mask_extension parameter is set.
|
:raise ValueError: If invalid mask_extension parameter is set.
|
||||||
"""
|
"""
|
||||||
extension = self._params['mask_extension']
|
extension = self._params["mask_extension"]
|
||||||
# Extend with average
|
# Extend with average
|
||||||
# (dispatch according to energy in the processed band)
|
# (dispatch according to energy in the processed band)
|
||||||
if extension == "average":
|
if extension == "average":
|
||||||
@@ -397,13 +417,9 @@ class EstimatorSpecBuilder(object):
|
|||||||
# (avoid extension artifacts but not conservative separation)
|
# (avoid extension artifacts but not conservative separation)
|
||||||
elif extension == "zeros":
|
elif extension == "zeros":
|
||||||
mask_shape = tf.shape(mask)
|
mask_shape = tf.shape(mask)
|
||||||
extension_row = tf.zeros((
|
extension_row = tf.zeros((mask_shape[0], mask_shape[1], 1, mask_shape[-1]))
|
||||||
mask_shape[0],
|
|
||||||
mask_shape[1],
|
|
||||||
1,
|
|
||||||
mask_shape[-1]))
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Invalid mask_extension parameter {extension}')
|
raise ValueError(f"Invalid mask_extension parameter {extension}")
|
||||||
n_extra_row = self._frame_length // 2 + 1 - self._F
|
n_extra_row = self._frame_length // 2 + 1 - self._F
|
||||||
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
|
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
|
||||||
return tf.concat([mask, extension], axis=2)
|
return tf.concat([mask, extension], axis=2)
|
||||||
@@ -415,29 +431,31 @@ class EstimatorSpecBuilder(object):
|
|||||||
"""
|
"""
|
||||||
output_dict = self.model_outputs
|
output_dict = self.model_outputs
|
||||||
stft_feature = self.stft_feature
|
stft_feature = self.stft_feature
|
||||||
separation_exponent = self._params['separation_exponent']
|
separation_exponent = self._params["separation_exponent"]
|
||||||
output_sum = tf.reduce_sum(
|
output_sum = (
|
||||||
[e ** separation_exponent for e in output_dict.values()],
|
tf.reduce_sum(
|
||||||
axis=0
|
[e ** separation_exponent for e in output_dict.values()], axis=0
|
||||||
) + self.EPSILON
|
)
|
||||||
|
+ self.EPSILON
|
||||||
|
)
|
||||||
out = {}
|
out = {}
|
||||||
for instrument in self._instruments:
|
for instrument in self._instruments:
|
||||||
output = output_dict[f'{instrument}_spectrogram']
|
output = output_dict[f"{instrument}_spectrogram"]
|
||||||
# Compute mask with the model.
|
# Compute mask with the model.
|
||||||
instrument_mask = (output ** separation_exponent
|
instrument_mask = (
|
||||||
+ (self.EPSILON / len(output_dict))) / output_sum
|
output ** separation_exponent + (self.EPSILON / len(output_dict))
|
||||||
|
) / output_sum
|
||||||
# Extend mask;
|
# Extend mask;
|
||||||
instrument_mask = self._extend_mask(instrument_mask)
|
instrument_mask = self._extend_mask(instrument_mask)
|
||||||
# Stack back mask.
|
# Stack back mask.
|
||||||
old_shape = tf.shape(instrument_mask)
|
old_shape = tf.shape(instrument_mask)
|
||||||
new_shape = tf.concat(
|
new_shape = tf.concat(
|
||||||
[[old_shape[0] * old_shape[1]], old_shape[2:]],
|
[[old_shape[0] * old_shape[1]], old_shape[2:]], axis=0
|
||||||
axis=0)
|
)
|
||||||
instrument_mask = tf.reshape(instrument_mask, new_shape)
|
instrument_mask = tf.reshape(instrument_mask, new_shape)
|
||||||
# Remove padded part (for mask having the same size as STFT);
|
# Remove padded part (for mask having the same size as STFT);
|
||||||
|
|
||||||
instrument_mask = instrument_mask[
|
instrument_mask = instrument_mask[: tf.shape(stft_feature)[0], ...]
|
||||||
:tf.shape(stft_feature)[0], ...]
|
|
||||||
out[instrument] = instrument_mask
|
out[instrument] = instrument_mask
|
||||||
self._masks = out
|
self._masks = out
|
||||||
|
|
||||||
@@ -449,7 +467,7 @@ class EstimatorSpecBuilder(object):
|
|||||||
self._masked_stfts = out
|
self._masked_stfts = out
|
||||||
|
|
||||||
def _build_manual_output_waveform(self, masked_stft):
|
def _build_manual_output_waveform(self, masked_stft):
|
||||||
""" Perform ratio mask separation
|
"""Perform ratio mask separation
|
||||||
|
|
||||||
:param output_dict: dictionary of estimated spectrogram (key: instrument
|
:param output_dict: dictionary of estimated spectrogram (key: instrument
|
||||||
name, value: estimated spectrogram of the instrument)
|
name, value: estimated spectrogram of the instrument)
|
||||||
@@ -463,14 +481,14 @@ class EstimatorSpecBuilder(object):
|
|||||||
return output_waveform
|
return output_waveform
|
||||||
|
|
||||||
def _build_output_waveform(self, masked_stft):
|
def _build_output_waveform(self, masked_stft):
|
||||||
""" Build output waveform from given output dict in order to be used in
|
"""Build output waveform from given output dict in order to be used in
|
||||||
prediction context. Regarding of the configuration building method will
|
prediction context. Regarding of the configuration building method will
|
||||||
be using MWF.
|
be using MWF.
|
||||||
|
|
||||||
:returns: Built output waveform.
|
:returns: Built output waveform.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self._params.get('MWF', False):
|
if self._params.get("MWF", False):
|
||||||
output_waveform = self._build_mwf_output_waveform()
|
output_waveform = self._build_mwf_output_waveform()
|
||||||
else:
|
else:
|
||||||
output_waveform = self._build_manual_output_waveform(masked_stft)
|
output_waveform = self._build_manual_output_waveform(masked_stft)
|
||||||
@@ -482,11 +500,11 @@ class EstimatorSpecBuilder(object):
|
|||||||
else:
|
else:
|
||||||
self._outputs = self.masked_stfts
|
self._outputs = self.masked_stfts
|
||||||
|
|
||||||
if 'audio_id' in self._features:
|
if "audio_id" in self._features:
|
||||||
self._outputs['audio_id'] = self._features['audio_id']
|
self._outputs["audio_id"] = self._features["audio_id"]
|
||||||
|
|
||||||
def build_predict_model(self):
|
def build_predict_model(self):
|
||||||
""" Builder interface for creating model instance that aims to perform
|
"""Builder interface for creating model instance that aims to perform
|
||||||
prediction / inference over given track. The output of such estimator
|
prediction / inference over given track. The output of such estimator
|
||||||
will be a dictionary with a "<instrument>" key per separated instrument
|
will be a dictionary with a "<instrument>" key per separated instrument
|
||||||
, associated to the estimated separated waveform of the instrument.
|
, associated to the estimated separated waveform of the instrument.
|
||||||
@@ -495,11 +513,11 @@ class EstimatorSpecBuilder(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
return tf.estimator.EstimatorSpec(
|
return tf.estimator.EstimatorSpec(
|
||||||
tf.estimator.ModeKeys.PREDICT,
|
tf.estimator.ModeKeys.PREDICT, predictions=self.outputs
|
||||||
predictions=self.outputs)
|
)
|
||||||
|
|
||||||
def build_evaluation_model(self, labels):
|
def build_evaluation_model(self, labels):
|
||||||
""" Builder interface for creating model instance that aims to perform
|
"""Builder interface for creating model instance that aims to perform
|
||||||
model evaluation. The output of such estimator will be a dictionary
|
model evaluation. The output of such estimator will be a dictionary
|
||||||
with a key "<instrument>_spectrogram" per separated instrument,
|
with a key "<instrument>_spectrogram" per separated instrument,
|
||||||
associated to the estimated separated instrument magnitude spectrogram.
|
associated to the estimated separated instrument magnitude spectrogram.
|
||||||
@@ -509,12 +527,11 @@ class EstimatorSpecBuilder(object):
|
|||||||
"""
|
"""
|
||||||
loss, metrics = self._build_loss(labels)
|
loss, metrics = self._build_loss(labels)
|
||||||
return tf.estimator.EstimatorSpec(
|
return tf.estimator.EstimatorSpec(
|
||||||
tf.estimator.ModeKeys.EVAL,
|
tf.estimator.ModeKeys.EVAL, loss=loss, eval_metric_ops=metrics
|
||||||
loss=loss,
|
)
|
||||||
eval_metric_ops=metrics)
|
|
||||||
|
|
||||||
def build_train_model(self, labels):
|
def build_train_model(self, labels):
|
||||||
""" Builder interface for creating model instance that aims to perform
|
"""Builder interface for creating model instance that aims to perform
|
||||||
model training. The output of such estimator will be a dictionary
|
model training. The output of such estimator will be a dictionary
|
||||||
with a key "<instrument>_spectrogram" per separated instrument,
|
with a key "<instrument>_spectrogram" per separated instrument,
|
||||||
associated to the estimated separated instrument magnitude spectrogram.
|
associated to the estimated separated instrument magnitude spectrogram.
|
||||||
@@ -525,8 +542,8 @@ class EstimatorSpecBuilder(object):
|
|||||||
loss, metrics = self._build_loss(labels)
|
loss, metrics = self._build_loss(labels)
|
||||||
optimizer = self._build_optimizer()
|
optimizer = self._build_optimizer()
|
||||||
train_operation = optimizer.minimize(
|
train_operation = optimizer.minimize(
|
||||||
loss=loss,
|
loss=loss, global_step=tf.compat.v1.train.get_global_step()
|
||||||
global_step=tf.compat.v1.train.get_global_step())
|
)
|
||||||
return tf.estimator.EstimatorSpec(
|
return tf.estimator.EstimatorSpec(
|
||||||
mode=tf.estimator.ModeKeys.TRAIN,
|
mode=tf.estimator.ModeKeys.TRAIN,
|
||||||
loss=loss,
|
loss=loss,
|
||||||
@@ -553,4 +570,4 @@ def model_fn(features, labels, mode, params, config):
|
|||||||
return builder.build_evaluation_model(labels)
|
return builder.build_evaluation_model(labels)
|
||||||
elif mode == tf.estimator.ModeKeys.TRAIN:
|
elif mode == tf.estimator.ModeKeys.TRAIN:
|
||||||
return builder.build_train_model(labels)
|
return builder.build_train_model(labels)
|
||||||
raise ValueError(f'Unknown mode {mode}')
|
raise ValueError(f"Unknown mode {mode}")
|
||||||
|
|||||||
@@ -3,25 +3,45 @@
|
|||||||
|
|
||||||
""" This package provide model functions. """
|
""" This package provide model functions. """
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
from typing import Callable, Dict, Iterable, Optional
|
||||||
__author__ = 'Deezer Research'
|
|
||||||
__license__ = 'MIT License'
|
# pyright: reportMissingImports=false
|
||||||
|
# pylint: disable=import-error
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
# pylint: enable=import-error
|
||||||
|
|
||||||
|
__email__ = "spleeter@deezer.com"
|
||||||
|
__author__ = "Deezer Research"
|
||||||
|
__license__ = "MIT License"
|
||||||
|
|
||||||
|
|
||||||
def apply(function, input_tensor, instruments, params={}):
|
def apply(
|
||||||
""" Apply given function to the input tensor.
|
function: Callable,
|
||||||
|
input_tensor: tf.Tensor,
|
||||||
:param function: Function to be applied to tensor.
|
instruments: Iterable[str],
|
||||||
:param input_tensor: Tensor to apply blstm to.
|
params: Optional[Dict] = None,
|
||||||
:param instruments: Iterable that provides a collection of instruments.
|
) -> Dict:
|
||||||
:param params: (Optional) dict of BLSTM parameters.
|
|
||||||
:returns: Created output tensor dict.
|
|
||||||
"""
|
"""
|
||||||
output_dict = {}
|
Apply given function to the input tensor.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
function:
|
||||||
|
Function to be applied to tensor.
|
||||||
|
input_tensor (tensorflow.Tensor):
|
||||||
|
Tensor to apply blstm to.
|
||||||
|
instruments (Iterable[str]):
|
||||||
|
Iterable that provides a collection of instruments.
|
||||||
|
params:
|
||||||
|
(Optional) dict of BLSTM parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created output tensor dict.
|
||||||
|
"""
|
||||||
|
output_dict: Dict = {}
|
||||||
for instrument in instruments:
|
for instrument in instruments:
|
||||||
out_name = f'{instrument}_spectrogram'
|
out_name = f"{instrument}_spectrogram"
|
||||||
output_dict[out_name] = function(
|
output_dict[out_name] = function(
|
||||||
input_tensor,
|
input_tensor, output_name=out_name, params=params or {}
|
||||||
output_name=out_name,
|
)
|
||||||
params=params)
|
|
||||||
return output_dict
|
return output_dict
|
||||||
|
|||||||
@@ -20,7 +20,11 @@
|
|||||||
selection (LSTM layer dropout rate, regularization strength).
|
selection (LSTM layer dropout rate, regularization strength).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
|
import tensorflow as tf
|
||||||
from tensorflow.compat.v1.keras.initializers import he_uniform
|
from tensorflow.compat.v1.keras.initializers import he_uniform
|
||||||
from tensorflow.compat.v1.keras.layers import CuDNNLSTM
|
from tensorflow.compat.v1.keras.layers import CuDNNLSTM
|
||||||
from tensorflow.keras.layers import (
|
from tensorflow.keras.layers import (
|
||||||
@@ -28,34 +32,48 @@ from tensorflow.keras.layers import (
|
|||||||
Dense,
|
Dense,
|
||||||
Flatten,
|
Flatten,
|
||||||
Reshape,
|
Reshape,
|
||||||
TimeDistributed)
|
TimeDistributed,
|
||||||
# pylint: enable=import-error
|
)
|
||||||
|
|
||||||
from . import apply
|
from . import apply
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
# pylint: enable=import-error
|
||||||
__author__ = 'Deezer Research'
|
|
||||||
__license__ = 'MIT License'
|
__email__ = "spleeter@deezer.com"
|
||||||
|
__author__ = "Deezer Research"
|
||||||
|
__license__ = "MIT License"
|
||||||
|
|
||||||
|
|
||||||
def apply_blstm(input_tensor, output_name='output', params={}):
|
def apply_blstm(
|
||||||
""" Apply BLSTM to the given input_tensor.
|
input_tensor: tf.Tensor, output_name: str = "output", params: Optional[Dict] = None
|
||||||
|
) -> tf.Tensor:
|
||||||
:param input_tensor: Input of the model.
|
|
||||||
:param output_name: (Optional) name of the output, default to 'output'.
|
|
||||||
:param params: (Optional) dict of BLSTM parameters.
|
|
||||||
:returns: Output tensor.
|
|
||||||
"""
|
"""
|
||||||
units = params.get('lstm_units', 250)
|
Apply BLSTM to the given input_tensor.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
input_tensor (tensorflow.Tensor):
|
||||||
|
Input of the model.
|
||||||
|
output_name (str):
|
||||||
|
(Optional) name of the output, default to 'output'.
|
||||||
|
params (Optional[Dict]):
|
||||||
|
(Optional) dict of BLSTM parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensorflow.Tensor:
|
||||||
|
Output tensor.
|
||||||
|
"""
|
||||||
|
if params is None:
|
||||||
|
params = {}
|
||||||
|
units: int = params.get("lstm_units", 250)
|
||||||
kernel_initializer = he_uniform(seed=50)
|
kernel_initializer = he_uniform(seed=50)
|
||||||
flatten_input = TimeDistributed(Flatten())((input_tensor))
|
flatten_input = TimeDistributed(Flatten())((input_tensor))
|
||||||
|
|
||||||
def create_bidirectional():
|
def create_bidirectional():
|
||||||
return Bidirectional(
|
return Bidirectional(
|
||||||
CuDNNLSTM(
|
CuDNNLSTM(
|
||||||
units,
|
units, kernel_initializer=kernel_initializer, return_sequences=True
|
||||||
kernel_initializer=kernel_initializer,
|
)
|
||||||
return_sequences=True))
|
)
|
||||||
|
|
||||||
l1 = create_bidirectional()((flatten_input))
|
l1 = create_bidirectional()((flatten_input))
|
||||||
l2 = create_bidirectional()((l1))
|
l2 = create_bidirectional()((l1))
|
||||||
@@ -63,14 +81,18 @@ def apply_blstm(input_tensor, output_name='output', params={}):
|
|||||||
dense = TimeDistributed(
|
dense = TimeDistributed(
|
||||||
Dense(
|
Dense(
|
||||||
int(flatten_input.shape[2]),
|
int(flatten_input.shape[2]),
|
||||||
activation='relu',
|
activation="relu",
|
||||||
kernel_initializer=kernel_initializer))((l3))
|
kernel_initializer=kernel_initializer,
|
||||||
output = TimeDistributed(
|
)
|
||||||
Reshape(input_tensor.shape[2:]),
|
)((l3))
|
||||||
name=output_name)(dense)
|
output: tf.Tensor = TimeDistributed(
|
||||||
|
Reshape(input_tensor.shape[2:]), name=output_name
|
||||||
|
)(dense)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def blstm(input_tensor, output_name='output', params={}):
|
def blstm(
|
||||||
|
input_tensor: tf.Tensor, output_name: str = "output", params: Optional[Dict] = None
|
||||||
|
) -> tf.Tensor:
|
||||||
""" Model function applier. """
|
""" Model function applier. """
|
||||||
return apply(apply_blstm, input_tensor, output_name, params)
|
return apply(apply_blstm, input_tensor, output_name, params)
|
||||||
|
|||||||
@@ -2,92 +2,109 @@
|
|||||||
# coding: utf8
|
# coding: utf8
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This module contains building functions for U-net source
|
This module contains building functions for U-net source
|
||||||
separation models in a similar way as in A. Jansson et al. "Singing
|
separation models in a similar way as in A. Jansson et al. :
|
||||||
voice separation with deep u-net convolutional networks", ISMIR 2017.
|
|
||||||
Each instrument is modeled by a single U-net convolutional
|
"Singing voice separation with deep u-net convolutional networks",
|
||||||
/ deconvolutional network that take a mix spectrogram as input and the
|
ISMIR 2017
|
||||||
estimated sound spectrogram as output.
|
|
||||||
|
Each instrument is modeled by a single U-net
|
||||||
|
convolutional / deconvolutional network that take a mix spectrogram
|
||||||
|
as input and the estimated sound spectrogram as output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import Any, Dict, Iterable, Optional
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow.compat.v1 import logging
|
||||||
|
from tensorflow.compat.v1.keras.initializers import he_uniform
|
||||||
from tensorflow.keras.layers import (
|
from tensorflow.keras.layers import (
|
||||||
|
ELU,
|
||||||
BatchNormalization,
|
BatchNormalization,
|
||||||
Concatenate,
|
Concatenate,
|
||||||
Conv2D,
|
Conv2D,
|
||||||
Conv2DTranspose,
|
Conv2DTranspose,
|
||||||
Dropout,
|
Dropout,
|
||||||
ELU,
|
|
||||||
LeakyReLU,
|
LeakyReLU,
|
||||||
Multiply,
|
Multiply,
|
||||||
ReLU,
|
ReLU,
|
||||||
Softmax)
|
Softmax,
|
||||||
from tensorflow.compat.v1 import logging
|
)
|
||||||
from tensorflow.compat.v1.keras.initializers import he_uniform
|
|
||||||
# pylint: enable=import-error
|
|
||||||
|
|
||||||
from . import apply
|
from . import apply
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
# pylint: enable=import-error
|
||||||
__author__ = 'Deezer Research'
|
|
||||||
__license__ = 'MIT License'
|
__email__ = "spleeter@deezer.com"
|
||||||
|
__author__ = "Deezer Research"
|
||||||
|
__license__ = "MIT License"
|
||||||
|
|
||||||
|
|
||||||
def _get_conv_activation_layer(params):
|
def _get_conv_activation_layer(params: Dict) -> Any:
|
||||||
"""
|
"""
|
||||||
|
> To be documented.
|
||||||
|
|
||||||
:param params:
|
Parameters:
|
||||||
:returns: Required Activation function.
|
params (Dict):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any:
|
||||||
|
Required Activation function.
|
||||||
"""
|
"""
|
||||||
conv_activation = params.get('conv_activation')
|
conv_activation: str = params.get("conv_activation")
|
||||||
if conv_activation == 'ReLU':
|
if conv_activation == "ReLU":
|
||||||
return ReLU()
|
return ReLU()
|
||||||
elif conv_activation == 'ELU':
|
elif conv_activation == "ELU":
|
||||||
return ELU()
|
return ELU()
|
||||||
return LeakyReLU(0.2)
|
return LeakyReLU(0.2)
|
||||||
|
|
||||||
|
|
||||||
def _get_deconv_activation_layer(params):
|
def _get_deconv_activation_layer(params: Dict) -> Any:
|
||||||
"""
|
"""
|
||||||
|
> To be documented.
|
||||||
|
|
||||||
:param params:
|
Parameters:
|
||||||
:returns: Required Activation function.
|
params (Dict):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any:
|
||||||
|
Required Activation function.
|
||||||
"""
|
"""
|
||||||
deconv_activation = params.get('deconv_activation')
|
deconv_activation: str = params.get("deconv_activation")
|
||||||
if deconv_activation == 'LeakyReLU':
|
if deconv_activation == "LeakyReLU":
|
||||||
return LeakyReLU(0.2)
|
return LeakyReLU(0.2)
|
||||||
elif deconv_activation == 'ELU':
|
elif deconv_activation == "ELU":
|
||||||
return ELU()
|
return ELU()
|
||||||
return ReLU()
|
return ReLU()
|
||||||
|
|
||||||
|
|
||||||
def apply_unet(
|
def apply_unet(
|
||||||
input_tensor,
|
input_tensor: tf.Tensor,
|
||||||
output_name='output',
|
output_name: str = "output",
|
||||||
params={},
|
params: Optional[Dict] = None,
|
||||||
output_mask_logit=False):
|
output_mask_logit: bool = False,
|
||||||
""" Apply a convolutionnal U-net to model a single instrument (one U-net
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Apply a convolutionnal U-net to model a single instrument (one U-net
|
||||||
is used for each instrument).
|
is used for each instrument).
|
||||||
|
|
||||||
:param input_tensor:
|
Parameters:
|
||||||
:param output_name: (Optional) , default to 'output'
|
input_tensor (tensorflow.Tensor):
|
||||||
:param params: (Optional) , default to empty dict.
|
output_name (str):
|
||||||
:param output_mask_logit: (Optional) , default to False.
|
params (Optional[Dict]):
|
||||||
|
output_mask_logit (bool):
|
||||||
"""
|
"""
|
||||||
logging.info(f'Apply unet for {output_name}')
|
logging.info(f"Apply unet for {output_name}")
|
||||||
conv_n_filters = params.get('conv_n_filters', [16, 32, 64, 128, 256, 512])
|
conv_n_filters = params.get("conv_n_filters", [16, 32, 64, 128, 256, 512])
|
||||||
conv_activation_layer = _get_conv_activation_layer(params)
|
conv_activation_layer = _get_conv_activation_layer(params)
|
||||||
deconv_activation_layer = _get_deconv_activation_layer(params)
|
deconv_activation_layer = _get_deconv_activation_layer(params)
|
||||||
kernel_initializer = he_uniform(seed=50)
|
kernel_initializer = he_uniform(seed=50)
|
||||||
conv2d_factory = partial(
|
conv2d_factory = partial(
|
||||||
Conv2D,
|
Conv2D, strides=(2, 2), padding="same", kernel_initializer=kernel_initializer
|
||||||
strides=(2, 2),
|
)
|
||||||
padding='same',
|
|
||||||
kernel_initializer=kernel_initializer)
|
|
||||||
# First layer.
|
# First layer.
|
||||||
conv1 = conv2d_factory(conv_n_filters[0], (5, 5))(input_tensor)
|
conv1 = conv2d_factory(conv_n_filters[0], (5, 5))(input_tensor)
|
||||||
batch1 = BatchNormalization(axis=-1)(conv1)
|
batch1 = BatchNormalization(axis=-1)(conv1)
|
||||||
@@ -117,8 +134,9 @@ def apply_unet(
|
|||||||
conv2d_transpose_factory = partial(
|
conv2d_transpose_factory = partial(
|
||||||
Conv2DTranspose,
|
Conv2DTranspose,
|
||||||
strides=(2, 2),
|
strides=(2, 2),
|
||||||
padding='same',
|
padding="same",
|
||||||
kernel_initializer=kernel_initializer)
|
kernel_initializer=kernel_initializer,
|
||||||
|
)
|
||||||
#
|
#
|
||||||
up1 = conv2d_transpose_factory(conv_n_filters[4], (5, 5))((conv6))
|
up1 = conv2d_transpose_factory(conv_n_filters[4], (5, 5))((conv6))
|
||||||
up1 = deconv_activation_layer(up1)
|
up1 = deconv_activation_layer(up1)
|
||||||
@@ -157,46 +175,60 @@ def apply_unet(
|
|||||||
2,
|
2,
|
||||||
(4, 4),
|
(4, 4),
|
||||||
dilation_rate=(2, 2),
|
dilation_rate=(2, 2),
|
||||||
activation='sigmoid',
|
activation="sigmoid",
|
||||||
padding='same',
|
padding="same",
|
||||||
kernel_initializer=kernel_initializer)((batch12))
|
kernel_initializer=kernel_initializer,
|
||||||
|
)((batch12))
|
||||||
output = Multiply(name=output_name)([up7, input_tensor])
|
output = Multiply(name=output_name)([up7, input_tensor])
|
||||||
return output
|
return output
|
||||||
return Conv2D(
|
return Conv2D(
|
||||||
2,
|
2,
|
||||||
(4, 4),
|
(4, 4),
|
||||||
dilation_rate=(2, 2),
|
dilation_rate=(2, 2),
|
||||||
padding='same',
|
padding="same",
|
||||||
kernel_initializer=kernel_initializer)((batch12))
|
kernel_initializer=kernel_initializer,
|
||||||
|
)((batch12))
|
||||||
|
|
||||||
|
|
||||||
def unet(input_tensor, instruments, params={}):
|
def unet(
|
||||||
|
input_tensor: tf.Tensor, instruments: Iterable[str], params: Optional[Dict] = None
|
||||||
|
) -> Dict:
|
||||||
""" Model function applier. """
|
""" Model function applier. """
|
||||||
return apply(apply_unet, input_tensor, instruments, params)
|
return apply(apply_unet, input_tensor, instruments, params)
|
||||||
|
|
||||||
|
|
||||||
def softmax_unet(input_tensor, instruments, params={}):
|
def softmax_unet(
|
||||||
""" Apply softmax to multitrack unet in order to have mask suming to one.
|
input_tensor: tf.Tensor, instruments: Iterable[str], params: Optional[Dict] = None
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Apply softmax to multitrack unet in order to have mask suming to one.
|
||||||
|
|
||||||
:param input_tensor: Tensor to apply blstm to.
|
Parameters:
|
||||||
:param instruments: Iterable that provides a collection of instruments.
|
input_tensor (tensorflow.Tensor):
|
||||||
:param params: (Optional) dict of BLSTM parameters.
|
Tensor to apply blstm to.
|
||||||
:returns: Created output tensor dict.
|
instruments (Iterable[str]):
|
||||||
|
Iterable that provides a collection of instruments.
|
||||||
|
params (Optional[Dict]):
|
||||||
|
(Optional) dict of BLSTM parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict:
|
||||||
|
Created output tensor dict.
|
||||||
"""
|
"""
|
||||||
logit_mask_list = []
|
logit_mask_list = []
|
||||||
for instrument in instruments:
|
for instrument in instruments:
|
||||||
out_name = f'{instrument}_spectrogram'
|
out_name = f"{instrument}_spectrogram"
|
||||||
logit_mask_list.append(
|
logit_mask_list.append(
|
||||||
apply_unet(
|
apply_unet(
|
||||||
input_tensor,
|
input_tensor,
|
||||||
output_name=out_name,
|
output_name=out_name,
|
||||||
params=params,
|
params=params,
|
||||||
output_mask_logit=True))
|
output_mask_logit=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
masks = Softmax(axis=4)(tf.stack(logit_mask_list, axis=4))
|
masks = Softmax(axis=4)(tf.stack(logit_mask_list, axis=4))
|
||||||
output_dict = {}
|
output_dict = {}
|
||||||
for i, instrument in enumerate(instruments):
|
for i, instrument in enumerate(instruments):
|
||||||
out_name = f'{instrument}_spectrogram'
|
out_name = f"{instrument}_spectrogram"
|
||||||
output_dict[out_name] = Multiply(name=out_name)([
|
output_dict[out_name] = Multiply(name=out_name)([masks[..., i], input_tensor])
|
||||||
masks[..., i],
|
|
||||||
input_tensor])
|
|
||||||
return output_dict
|
return output_dict
|
||||||
|
|||||||
@@ -5,77 +5,91 @@
|
|||||||
This package provides tools for downloading model from network
|
This package provides tools for downloading model from network
|
||||||
using remote storage abstraction.
|
using remote storage abstraction.
|
||||||
|
|
||||||
:Example:
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
>>> provider = MyProviderImplementation()
|
>>> provider = MyProviderImplementation()
|
||||||
>>> provider.get('/path/to/local/storage', params)
|
>>> provider.get('/path/to/local/storage', params)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from os import environ, makedirs
|
from os import environ, makedirs
|
||||||
from os.path import exists, isabs, join, sep
|
from os.path import exists, isabs, join, sep
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
__email__ = "spleeter@deezer.com"
|
||||||
__author__ = 'Deezer Research'
|
__author__ = "Deezer Research"
|
||||||
__license__ = 'MIT License'
|
__license__ = "MIT License"
|
||||||
|
|
||||||
|
|
||||||
class ModelProvider(ABC):
|
class ModelProvider(ABC):
|
||||||
"""
|
"""
|
||||||
A ModelProvider manages model files on disk and
|
A ModelProvider manages model files on disk and
|
||||||
file download is not available.
|
file download is not available.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_MODEL_PATH = environ.get('MODEL_PATH', 'pretrained_models')
|
DEFAULT_MODEL_PATH: str = environ.get("MODEL_PATH", "pretrained_models")
|
||||||
MODEL_PROBE_PATH = '.probe'
|
MODEL_PROBE_PATH: str = ".probe"
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def download(self, name, path):
|
def download(_, name: str, path: str) -> None:
|
||||||
""" Download model denoted by the given name to disk.
|
"""
|
||||||
|
Download model denoted by the given name to disk.
|
||||||
|
|
||||||
:param name: Name of the model to download.
|
Parameters:
|
||||||
:param path: Path of the directory to save model into.
|
name (str):
|
||||||
|
Name of the model to download.
|
||||||
|
path (str):
|
||||||
|
Path of the directory to save model into.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def writeProbe(directory):
|
def writeProbe(directory: str) -> None:
|
||||||
""" Write a model probe file into the given directory.
|
|
||||||
|
|
||||||
:param directory: Directory to write probe into.
|
|
||||||
"""
|
"""
|
||||||
probe = join(directory, ModelProvider.MODEL_PROBE_PATH)
|
Write a model probe file into the given directory.
|
||||||
with open(probe, 'w') as stream:
|
|
||||||
stream.write('OK')
|
|
||||||
|
|
||||||
def get(self, model_directory):
|
Parameters:
|
||||||
""" Ensures required model is available at given location.
|
directory (str):
|
||||||
|
Directory to write probe into.
|
||||||
|
"""
|
||||||
|
probe: str = join(directory, ModelProvider.MODEL_PROBE_PATH)
|
||||||
|
with open(probe, "w") as stream:
|
||||||
|
stream.write("OK")
|
||||||
|
|
||||||
:param model_directory: Expected model_directory to be available.
|
def get(self, model_directory: str) -> str:
|
||||||
:raise IOError: If model can not be retrieved.
|
"""
|
||||||
|
Ensures required model is available at given location.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model_directory (str):
|
||||||
|
Expected model_directory to be available.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
IOError:
|
||||||
|
If model can not be retrieved.
|
||||||
"""
|
"""
|
||||||
# Expend model directory if needed.
|
# Expend model directory if needed.
|
||||||
if not isabs(model_directory):
|
if not isabs(model_directory):
|
||||||
model_directory = join(self.DEFAULT_MODEL_PATH, model_directory)
|
model_directory = join(self.DEFAULT_MODEL_PATH, model_directory)
|
||||||
# Download it if not exists.
|
# Download it if not exists.
|
||||||
model_probe = join(model_directory, self.MODEL_PROBE_PATH)
|
model_probe: str = join(model_directory, self.MODEL_PROBE_PATH)
|
||||||
if not exists(model_probe):
|
if not exists(model_probe):
|
||||||
if not exists(model_directory):
|
if not exists(model_directory):
|
||||||
makedirs(model_directory)
|
makedirs(model_directory)
|
||||||
self.download(
|
self.download(model_directory.split(sep)[-1], model_directory)
|
||||||
model_directory.split(sep)[-1],
|
|
||||||
model_directory)
|
|
||||||
self.writeProbe(model_directory)
|
self.writeProbe(model_directory)
|
||||||
return model_directory
|
return model_directory
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default(_: type) -> "ModelProvider":
|
||||||
|
"""
|
||||||
|
Builds and returns a default model provider.
|
||||||
|
|
||||||
def get_default_model_provider():
|
Returns:
|
||||||
""" Builds and returns a default model provider.
|
ModelProvider:
|
||||||
|
A default model provider instance to use.
|
||||||
|
"""
|
||||||
|
from .github import GithubModelProvider
|
||||||
|
|
||||||
:returns: A default model provider instance to use.
|
return GithubModelProvider.from_environ()
|
||||||
"""
|
|
||||||
from .github import GithubModelProvider
|
|
||||||
host = environ.get('GITHUB_HOST', 'https://github.com')
|
|
||||||
repository = environ.get('GITHUB_REPOSITORY', 'deezer/spleeter')
|
|
||||||
release = environ.get('GITHUB_RELEASE', GithubModelProvider.LATEST_RELEASE)
|
|
||||||
return GithubModelProvider(host, repository, release)
|
|
||||||
|
|||||||
@@ -4,41 +4,48 @@
|
|||||||
"""
|
"""
|
||||||
A ModelProvider backed by Github Release feature.
|
A ModelProvider backed by Github Release feature.
|
||||||
|
|
||||||
:Example:
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
>>> from spleeter.model.provider import github
|
>>> from spleeter.model.provider import github
|
||||||
>>> provider = github.GithubModelProvider(
|
>>> provider = github.GithubModelProvider(
|
||||||
'github.com',
|
'github.com',
|
||||||
'Deezer/spleeter',
|
'Deezer/spleeter',
|
||||||
'latest')
|
'latest')
|
||||||
>>> provider.download('2stems', '/path/to/local/storage')
|
>>> provider.download('2stems', '/path/to/local/storage')
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import tarfile
|
|
||||||
import os
|
import os
|
||||||
|
import tarfile
|
||||||
|
from os import environ
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import requests
|
# pyright: reportMissingImports=false
|
||||||
|
# pylint: disable=import-error
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from ...utils.logging import logger
|
||||||
from . import ModelProvider
|
from . import ModelProvider
|
||||||
from ...utils.logging import get_logger
|
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
# pylint: enable=import-error
|
||||||
__author__ = 'Deezer Research'
|
|
||||||
__license__ = 'MIT License'
|
__email__ = "spleeter@deezer.com"
|
||||||
|
__author__ = "Deezer Research"
|
||||||
|
__license__ = "MIT License"
|
||||||
|
|
||||||
|
|
||||||
def compute_file_checksum(path):
|
def compute_file_checksum(path):
|
||||||
""" Computes given path file sha256.
|
"""Computes given path file sha256.
|
||||||
|
|
||||||
:param path: Path of the file to compute checksum for.
|
:param path: Path of the file to compute checksum for.
|
||||||
:returns: File checksum.
|
:returns: File checksum.
|
||||||
"""
|
"""
|
||||||
sha256 = hashlib.sha256()
|
sha256 = hashlib.sha256()
|
||||||
with open(path, 'rb') as stream:
|
with open(path, "rb") as stream:
|
||||||
for chunk in iter(lambda: stream.read(4096), b''):
|
for chunk in iter(lambda: stream.read(4096), b""):
|
||||||
sha256.update(chunk)
|
sha256.update(chunk)
|
||||||
return sha256.hexdigest()
|
return sha256.hexdigest()
|
||||||
|
|
||||||
@@ -46,69 +53,104 @@ def compute_file_checksum(path):
|
|||||||
class GithubModelProvider(ModelProvider):
|
class GithubModelProvider(ModelProvider):
|
||||||
""" A ModelProvider implementation backed on Github for remote storage. """
|
""" A ModelProvider implementation backed on Github for remote storage. """
|
||||||
|
|
||||||
LATEST_RELEASE = 'v1.4.0'
|
DEFAULT_HOST: str = "https://github.com"
|
||||||
RELEASE_PATH = 'releases/download'
|
DEFAULT_REPOSITORY: str = "deezer/spleeter"
|
||||||
CHECKSUM_INDEX = 'checksum.json'
|
|
||||||
|
|
||||||
def __init__(self, host, repository, release):
|
CHECKSUM_INDEX: str = "checksum.json"
|
||||||
""" Default constructor.
|
LATEST_RELEASE: str = "v1.4.0"
|
||||||
|
RELEASE_PATH: str = "releases/download"
|
||||||
|
|
||||||
:param host: Host to the Github instance to reach.
|
def __init__(self, host: str, repository: str, release: str) -> None:
|
||||||
:param repository: Repository path within target Github.
|
"""Default constructor.
|
||||||
:param release: Release name to get models from.
|
|
||||||
|
Parameters:
|
||||||
|
host (str):
|
||||||
|
Host to the Github instance to reach.
|
||||||
|
repository (str):
|
||||||
|
Repository path within target Github.
|
||||||
|
release (str):
|
||||||
|
Release name to get models from.
|
||||||
"""
|
"""
|
||||||
self._host = host
|
self._host: str = host
|
||||||
self._repository = repository
|
self._repository: str = repository
|
||||||
self._release = release
|
self._release: str = release
|
||||||
|
|
||||||
def checksum(self, name):
|
@classmethod
|
||||||
""" Downloads and returns reference checksum for the given model name.
|
def from_environ(cls: type) -> "GithubModelProvider":
|
||||||
|
|
||||||
:param name: Name of the model to get checksum for.
|
|
||||||
:returns: Checksum of the required model.
|
|
||||||
:raise ValueError: If the given model name is not indexed.
|
|
||||||
"""
|
"""
|
||||||
url = '{}/{}/{}/{}/{}'.format(
|
Factory method that creates provider from envvars.
|
||||||
self._host,
|
|
||||||
self._repository,
|
Returns:
|
||||||
self.RELEASE_PATH,
|
GithubModelProvider:
|
||||||
self._release,
|
Created instance.
|
||||||
self.CHECKSUM_INDEX)
|
"""
|
||||||
response = requests.get(url)
|
return cls(
|
||||||
|
environ.get("GITHUB_HOST", cls.DEFAULT_HOST),
|
||||||
|
environ.get("GITHUB_REPOSITORY", cls.DEFAULT_REPOSITORY),
|
||||||
|
environ.get("GITHUB_RELEASE", cls.LATEST_RELEASE),
|
||||||
|
)
|
||||||
|
|
||||||
|
def checksum(self, name: str) -> str:
|
||||||
|
"""
|
||||||
|
Downloads and returns reference checksum for the given model name.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
name (str):
|
||||||
|
Name of the model to get checksum for.
|
||||||
|
Returns:
|
||||||
|
str:
|
||||||
|
Checksum of the required model.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError:
|
||||||
|
If the given model name is not indexed.
|
||||||
|
"""
|
||||||
|
url: str = "/".join(
|
||||||
|
(
|
||||||
|
self._host,
|
||||||
|
self._repository,
|
||||||
|
self.RELEASE_PATH,
|
||||||
|
self._release,
|
||||||
|
self.CHECKSUM_INDEX,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
response: httpx.Response = httpx.get(url)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
index = response.json()
|
index: Dict = response.json()
|
||||||
if name not in index:
|
if name not in index:
|
||||||
raise ValueError('No checksum for model {}'.format(name))
|
raise ValueError(f"No checksum for model {name}")
|
||||||
return index[name]
|
return index[name]
|
||||||
|
|
||||||
def download(self, name, path):
|
def download(self, name: str, path: str) -> None:
|
||||||
""" Download model denoted by the given name to disk.
|
|
||||||
|
|
||||||
:param name: Name of the model to download.
|
|
||||||
:param path: Path of the directory to save model into.
|
|
||||||
"""
|
"""
|
||||||
url = '{}/{}/{}/{}/{}.tar.gz'.format(
|
Download model denoted by the given name to disk.
|
||||||
self._host,
|
|
||||||
self._repository,
|
Parameters:
|
||||||
self.RELEASE_PATH,
|
name (str):
|
||||||
self._release,
|
Name of the model to download.
|
||||||
name)
|
path (str):
|
||||||
get_logger().info('Downloading model archive %s', url)
|
Path of the directory to save model into.
|
||||||
with requests.get(url, stream=True) as response:
|
"""
|
||||||
response.raise_for_status()
|
url: str = "/".join(
|
||||||
archive = NamedTemporaryFile(delete=False)
|
(self._host, self._repository, self.RELEASE_PATH, self._release, name)
|
||||||
try:
|
)
|
||||||
with archive as stream:
|
url = f"{url}.tar.gz"
|
||||||
# Note: check for chunk size parameters ?
|
logger.info(f"Downloading model archive {url}")
|
||||||
for chunk in response.iter_content(chunk_size=8192):
|
with httpx.Client(http2=True) as client:
|
||||||
if chunk:
|
with client.stream("GET", url) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
archive = NamedTemporaryFile(delete=False)
|
||||||
|
try:
|
||||||
|
with archive as stream:
|
||||||
|
for chunk in response.iter_raw():
|
||||||
stream.write(chunk)
|
stream.write(chunk)
|
||||||
get_logger().info('Validating archive checksum')
|
logger.info("Validating archive checksum")
|
||||||
if compute_file_checksum(archive.name) != self.checksum(name):
|
checksum: str = compute_file_checksum(archive.name)
|
||||||
raise IOError('Downloaded file is corrupted, please retry')
|
if checksum != self.checksum(name):
|
||||||
get_logger().info('Extracting downloaded %s archive', name)
|
raise IOError("Downloaded file is corrupted, please retry")
|
||||||
with tarfile.open(name=archive.name) as tar:
|
logger.info(f"Extracting downloaded {name} archive")
|
||||||
tar.extractall(path=path)
|
with tarfile.open(name=archive.name) as tar:
|
||||||
finally:
|
tar.extractall(path=path)
|
||||||
os.unlink(archive.name)
|
finally:
|
||||||
get_logger().info('%s model file(s) extracted', name)
|
os.unlink(archive.name)
|
||||||
|
logger.info(f"{name} model file(s) extracted")
|
||||||
|
|||||||
128
spleeter/options.py
Normal file
128
spleeter/options.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
""" This modules provides spleeter command as well as CLI parsing methods. """
|
||||||
|
|
||||||
|
from os.path import join
|
||||||
|
from tempfile import gettempdir
|
||||||
|
|
||||||
|
from typer import Argument, Option
|
||||||
|
from typer.models import ArgumentInfo, OptionInfo
|
||||||
|
|
||||||
|
from .audio import Codec, STFTBackend
|
||||||
|
|
||||||
|
__email__ = "spleeter@deezer.com"
|
||||||
|
__author__ = "Deezer Research"
|
||||||
|
__license__ = "MIT License"
|
||||||
|
|
||||||
|
AudioInputArgument: ArgumentInfo = Argument(
|
||||||
|
...,
|
||||||
|
help="List of input audio file path",
|
||||||
|
exists=True,
|
||||||
|
file_okay=True,
|
||||||
|
dir_okay=False,
|
||||||
|
readable=True,
|
||||||
|
resolve_path=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
AudioInputOption: OptionInfo = Option(
|
||||||
|
None, "--inputs", "-i", help="(DEPRECATED) placeholder for deprecated input option"
|
||||||
|
)
|
||||||
|
|
||||||
|
AudioAdapterOption: OptionInfo = Option(
|
||||||
|
"spleeter.audio.ffmpeg.FFMPEGProcessAudioAdapter",
|
||||||
|
"--adapter",
|
||||||
|
"-a",
|
||||||
|
help="Name of the audio adapter to use for audio I/O",
|
||||||
|
)
|
||||||
|
|
||||||
|
AudioOutputOption: OptionInfo = Option(
|
||||||
|
join(gettempdir(), "separated_audio"),
|
||||||
|
"--output_path",
|
||||||
|
"-o",
|
||||||
|
help="Path of the output directory to write audio files in",
|
||||||
|
)
|
||||||
|
|
||||||
|
AudioOffsetOption: OptionInfo = Option(
|
||||||
|
0.0, "--offset", "-s", help="Set the starting offset to separate audio from"
|
||||||
|
)
|
||||||
|
|
||||||
|
AudioDurationOption: OptionInfo = Option(
|
||||||
|
600.0,
|
||||||
|
"--duration",
|
||||||
|
"-d",
|
||||||
|
help=(
|
||||||
|
"Set a maximum duration for processing audio "
|
||||||
|
"(only separate offset + duration first seconds of "
|
||||||
|
"the input file)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
AudioSTFTBackendOption: OptionInfo = Option(
|
||||||
|
STFTBackend.AUTO,
|
||||||
|
"--stft-backend",
|
||||||
|
"-B",
|
||||||
|
case_sensitive=False,
|
||||||
|
help=(
|
||||||
|
"Who should be in charge of computing the stfts. Librosa is faster "
|
||||||
|
'than tensorflow on CPU and uses less memory. "auto" will use '
|
||||||
|
"tensorflow when GPU acceleration is available and librosa when not"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
AudioCodecOption: OptionInfo = Option(
|
||||||
|
Codec.WAV, "--codec", "-c", help="Audio codec to be used for the separated output"
|
||||||
|
)
|
||||||
|
|
||||||
|
AudioBitrateOption: OptionInfo = Option(
|
||||||
|
"128k", "--bitrate", "-b", help="Audio bitrate to be used for the separated output"
|
||||||
|
)
|
||||||
|
|
||||||
|
FilenameFormatOption: OptionInfo = Option(
|
||||||
|
"{filename}/{instrument}.{codec}",
|
||||||
|
"--filename_format",
|
||||||
|
"-f",
|
||||||
|
help=(
|
||||||
|
"Template string that will be formatted to generated"
|
||||||
|
"output filename. Such template should be Python formattable"
|
||||||
|
"string, and could use {filename}, {instrument}, and {codec}"
|
||||||
|
"variables"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
ModelParametersOption: OptionInfo = Option(
|
||||||
|
"spleeter:2stems",
|
||||||
|
"--params_filename",
|
||||||
|
"-p",
|
||||||
|
help="JSON filename that contains params",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MWFOption: OptionInfo = Option(
|
||||||
|
False, "--mwf", help="Whether to use multichannel Wiener filtering for separation"
|
||||||
|
)
|
||||||
|
|
||||||
|
MUSDBDirectoryOption: OptionInfo = Option(
|
||||||
|
...,
|
||||||
|
"--mus_dir",
|
||||||
|
exists=True,
|
||||||
|
dir_okay=True,
|
||||||
|
file_okay=False,
|
||||||
|
readable=True,
|
||||||
|
resolve_path=True,
|
||||||
|
help="Path to musDB dataset directory",
|
||||||
|
)
|
||||||
|
|
||||||
|
TrainingDataDirectoryOption: OptionInfo = Option(
|
||||||
|
...,
|
||||||
|
"--data",
|
||||||
|
"-d",
|
||||||
|
exists=True,
|
||||||
|
dir_okay=True,
|
||||||
|
file_okay=False,
|
||||||
|
readable=True,
|
||||||
|
resolve_path=True,
|
||||||
|
help="Path of the folder containing audio data for training",
|
||||||
|
)
|
||||||
|
|
||||||
|
VerboseOption: OptionInfo = Option(False, "--verbose", help="Enable verbose logs")
|
||||||
0
spleeter/py.typed
Normal file
0
spleeter/py.typed
Normal file
@@ -3,6 +3,6 @@
|
|||||||
|
|
||||||
""" Packages that provides static resources file for the library. """
|
""" Packages that provides static resources file for the library. """
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
__email__ = "spleeter@deezer.com"
|
||||||
__author__ = 'Deezer Research'
|
__author__ = "Deezer Research"
|
||||||
__license__ = 'MIT License'
|
__license__ = "MIT License"
|
||||||
|
|||||||
@@ -4,60 +4,63 @@
|
|||||||
"""
|
"""
|
||||||
Module that provides a class wrapper for source separation.
|
Module that provides a class wrapper for source separation.
|
||||||
|
|
||||||
:Example:
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
>>> from spleeter.separator import Separator
|
>>> from spleeter.separator import Separator
|
||||||
>>> separator = Separator('spleeter:2stems')
|
>>> separator = Separator('spleeter:2stems')
|
||||||
>>> separator.separate(waveform, lambda instrument, data: ...)
|
>>> separator.separate(waveform, lambda instrument, data: ...)
|
||||||
>>> separator.separate_to_file(...)
|
>>> separator.separate_to_file(...)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import atexit
|
import atexit
|
||||||
import os
|
import os
|
||||||
import logging
|
|
||||||
|
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from os.path import basename, join, splitext, dirname
|
from os.path import basename, dirname, join, splitext
|
||||||
from time import time
|
from typing import Dict, Generator, Optional
|
||||||
from typing import Container, NoReturn
|
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
|
# pylint: disable=import-error
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from librosa.core import istft, stft
|
||||||
from librosa.core import stft, istft
|
|
||||||
from scipy.signal.windows import hann
|
from scipy.signal.windows import hann
|
||||||
|
|
||||||
|
from spleeter.model.provider import ModelProvider
|
||||||
|
|
||||||
from . import SpleeterError
|
from . import SpleeterError
|
||||||
from .audio.adapter import get_default_audio_adapter
|
from .audio import Codec, STFTBackend
|
||||||
|
from .audio.adapter import AudioAdapter
|
||||||
from .audio.convertor import to_stereo
|
from .audio.convertor import to_stereo
|
||||||
|
from .model import EstimatorSpecBuilder, InputProviderFactory, model_fn
|
||||||
|
from .model.provider import ModelProvider
|
||||||
|
from .types import AudioDescriptor
|
||||||
from .utils.configuration import load_configuration
|
from .utils.configuration import load_configuration
|
||||||
from .utils.estimator import create_estimator, get_default_model_dir
|
|
||||||
from .model import EstimatorSpecBuilder, InputProviderFactory
|
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
# pylint: enable=import-error
|
||||||
__author__ = 'Deezer Research'
|
|
||||||
__license__ = 'MIT License'
|
|
||||||
|
|
||||||
SUPPORTED_BACKEND: Container[str] = ('auto', 'tensorflow', 'librosa')
|
__email__ = "spleeter@deezer.com"
|
||||||
""" """
|
__author__ = "Deezer Research"
|
||||||
|
__license__ = "MIT License"
|
||||||
|
|
||||||
|
|
||||||
class DataGenerator():
|
class DataGenerator(object):
|
||||||
"""
|
"""
|
||||||
Generator object that store a sample and generate it once while called.
|
Generator object that store a sample and generate it once while called.
|
||||||
Used to feed a tensorflow estimator without knowing the whole data at
|
Used to feed a tensorflow estimator without knowing the whole data at
|
||||||
build time.
|
build time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
""" Default constructor. """
|
""" Default constructor. """
|
||||||
self._current_data = None
|
self._current_data = None
|
||||||
|
|
||||||
def update_data(self, data):
|
def update_data(self, data) -> None:
|
||||||
""" Replace internal data. """
|
""" Replace internal data. """
|
||||||
self._current_data = data
|
self._current_data = data
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self) -> Generator:
|
||||||
""" Generation process. """
|
""" Generation process. """
|
||||||
buffer = self._current_data
|
buffer = self._current_data
|
||||||
while buffer:
|
while buffer:
|
||||||
@@ -65,34 +68,52 @@ class DataGenerator():
|
|||||||
buffer = self._current_data
|
buffer = self._current_data
|
||||||
|
|
||||||
|
|
||||||
def get_backend(backend: str) -> str:
|
def create_estimator(params, MWF):
|
||||||
"""
|
"""
|
||||||
|
Initialize tensorflow estimator that will perform separation
|
||||||
|
|
||||||
|
Params:
|
||||||
|
- params: a dictionary of parameters for building the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a tensorflow estimator
|
||||||
"""
|
"""
|
||||||
if backend not in SUPPORTED_BACKEND:
|
# Load model.
|
||||||
raise ValueError(f'Unsupported backend {backend}')
|
provider: ModelProvider = ModelProvider.default()
|
||||||
if backend == 'auto':
|
params["model_dir"] = provider.get(params["model_dir"])
|
||||||
if len(tf.config.list_physical_devices('GPU')):
|
params["MWF"] = MWF
|
||||||
return 'tensorflow'
|
# Setup config
|
||||||
return 'librosa'
|
session_config = tf.compat.v1.ConfigProto()
|
||||||
return backend
|
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
|
||||||
|
config = tf.estimator.RunConfig(session_config=session_config)
|
||||||
|
# Setup estimator
|
||||||
|
estimator = tf.estimator.Estimator(
|
||||||
|
model_fn=model_fn, model_dir=params["model_dir"], params=params, config=config
|
||||||
|
)
|
||||||
|
return estimator
|
||||||
|
|
||||||
|
|
||||||
class Separator(object):
|
class Separator(object):
|
||||||
""" A wrapper class for performing separation. """
|
""" A wrapper class for performing separation. """
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params_descriptor,
|
params_descriptor: str,
|
||||||
MWF: bool = False,
|
MWF: bool = False,
|
||||||
stft_backend: str = 'auto',
|
stft_backend: STFTBackend = STFTBackend.AUTO,
|
||||||
multiprocess: bool = True):
|
multiprocess: bool = True,
|
||||||
""" Default constructor.
|
) -> None:
|
||||||
|
"""
|
||||||
|
Default constructor.
|
||||||
|
|
||||||
:param params_descriptor: Descriptor for TF params to be used.
|
Parameters:
|
||||||
:param MWF: (Optional) True if MWF should be used, False otherwise.
|
params_descriptor (str):
|
||||||
|
Descriptor for TF params to be used.
|
||||||
|
MWF (bool):
|
||||||
|
(Optional) `True` if MWF should be used, `False` otherwise.
|
||||||
"""
|
"""
|
||||||
self._params = load_configuration(params_descriptor)
|
self._params = load_configuration(params_descriptor)
|
||||||
self._sample_rate = self._params['sample_rate']
|
self._sample_rate = self._params["sample_rate"]
|
||||||
self._MWF = MWF
|
self._MWF = MWF
|
||||||
self._tf_graph = tf.Graph()
|
self._tf_graph = tf.Graph()
|
||||||
self._prediction_generator = None
|
self._prediction_generator = None
|
||||||
@@ -106,19 +127,21 @@ class Separator(object):
|
|||||||
else:
|
else:
|
||||||
self._pool = None
|
self._pool = None
|
||||||
self._tasks = []
|
self._tasks = []
|
||||||
self._params['stft_backend'] = get_backend(stft_backend)
|
self._params["stft_backend"] = STFTBackend.resolve(stft_backend)
|
||||||
self._data_generator = DataGenerator()
|
self._data_generator = DataGenerator()
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self) -> None:
|
||||||
""" """
|
|
||||||
if self._session:
|
if self._session:
|
||||||
self._session.close()
|
self._session.close()
|
||||||
|
|
||||||
def _get_prediction_generator(self):
|
def _get_prediction_generator(self) -> Generator:
|
||||||
""" Lazy loading access method for internal prediction generator
|
"""
|
||||||
|
Lazy loading access method for internal prediction generator
|
||||||
returned by the predict method of a tensorflow estimator.
|
returned by the predict method of a tensorflow estimator.
|
||||||
|
|
||||||
:returns: generator of prediction.
|
Returns:
|
||||||
|
Generator:
|
||||||
|
Generator of prediction.
|
||||||
"""
|
"""
|
||||||
if self._prediction_generator is None:
|
if self._prediction_generator is None:
|
||||||
estimator = create_estimator(self._params, self._MWF)
|
estimator = create_estimator(self._params, self._MWF)
|
||||||
@@ -126,82 +149,74 @@ class Separator(object):
|
|||||||
def get_dataset():
|
def get_dataset():
|
||||||
return tf.data.Dataset.from_generator(
|
return tf.data.Dataset.from_generator(
|
||||||
self._data_generator,
|
self._data_generator,
|
||||||
output_types={
|
output_types={"waveform": tf.float32, "audio_id": tf.string},
|
||||||
'waveform': tf.float32,
|
output_shapes={"waveform": (None, 2), "audio_id": ()},
|
||||||
'audio_id': tf.string},
|
)
|
||||||
output_shapes={
|
|
||||||
'waveform': (None, 2),
|
|
||||||
'audio_id': ()})
|
|
||||||
|
|
||||||
self._prediction_generator = estimator.predict(
|
self._prediction_generator = estimator.predict(
|
||||||
get_dataset,
|
get_dataset, yield_single_examples=False
|
||||||
yield_single_examples=False)
|
)
|
||||||
return self._prediction_generator
|
return self._prediction_generator
|
||||||
|
|
||||||
def join(self, timeout: int = 200) -> NoReturn:
|
def join(self, timeout: int = 200) -> None:
|
||||||
""" Wait for all pending tasks to be finished.
|
"""
|
||||||
|
Wait for all pending tasks to be finished.
|
||||||
|
|
||||||
:param timeout: (Optional) task waiting timeout.
|
Parameters:
|
||||||
|
timeout (int):
|
||||||
|
(Optional) task waiting timeout.
|
||||||
"""
|
"""
|
||||||
while len(self._tasks) > 0:
|
while len(self._tasks) > 0:
|
||||||
task = self._tasks.pop()
|
task = self._tasks.pop()
|
||||||
task.get()
|
task.get()
|
||||||
task.wait(timeout=timeout)
|
task.wait(timeout=timeout)
|
||||||
|
|
||||||
def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
|
def _stft(
|
||||||
""" Performs source separation over the given waveform with tensorflow
|
self, data: np.ndarray, inverse: bool = False, length: Optional[int] = None
|
||||||
backend.
|
) -> np.ndarray:
|
||||||
|
|
||||||
:param waveform: Waveform to apply separation on.
|
|
||||||
:returns: Separated waveforms.
|
|
||||||
"""
|
"""
|
||||||
if not waveform.shape[-1] == 2:
|
Single entrypoint for both stft and istft. This computes stft and
|
||||||
waveform = to_stereo(waveform)
|
|
||||||
prediction_generator = self._get_prediction_generator()
|
|
||||||
# NOTE: update data in generator before performing separation.
|
|
||||||
self._data_generator.update_data({
|
|
||||||
'waveform': waveform,
|
|
||||||
'audio_id': np.array(audio_descriptor)})
|
|
||||||
# NOTE: perform separation.
|
|
||||||
prediction = next(prediction_generator)
|
|
||||||
prediction.pop('audio_id')
|
|
||||||
return prediction
|
|
||||||
|
|
||||||
def _stft(self, data, inverse: bool = False, length=None):
|
|
||||||
""" Single entrypoint for both stft and istft. This computes stft and
|
|
||||||
istft with librosa on stereo data. The two channels are processed
|
istft with librosa on stereo data. The two channels are processed
|
||||||
separately and are concatenated together in the result. The expected
|
separately and are concatenated together in the result. The
|
||||||
input formats are: (n_samples, 2) for stft and (T, F, 2) for istft.
|
expected input formats are: (n_samples, 2) for stft and (T, F, 2)
|
||||||
|
for istft.
|
||||||
|
|
||||||
:param data: np.array with either the waveform or the complex
|
Parameters:
|
||||||
spectrogram depending on the parameter inverse
|
data (numpy.array):
|
||||||
:param inverse: should a stft or an istft be computed.
|
Array with either the waveform or the complex spectrogram
|
||||||
:returns: Stereo data as numpy array for the transform.
|
depending on the parameter inverse
|
||||||
The channels are stored in the last dimension.
|
inverse (bool):
|
||||||
|
(Optional) Should a stft or an istft be computed.
|
||||||
|
length (Optional[int]):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray:
|
||||||
|
Stereo data as numpy array for the transform. The channels
|
||||||
|
are stored in the last dimension.
|
||||||
"""
|
"""
|
||||||
assert not (inverse and length is None)
|
assert not (inverse and length is None)
|
||||||
data = np.asfortranarray(data)
|
data = np.asfortranarray(data)
|
||||||
N = self._params['frame_length']
|
N = self._params["frame_length"]
|
||||||
H = self._params['frame_step']
|
H = self._params["frame_step"]
|
||||||
win = hann(N, sym=False)
|
win = hann(N, sym=False)
|
||||||
fstft = istft if inverse else stft
|
fstft = istft if inverse else stft
|
||||||
win_len_arg = {
|
win_len_arg = {"win_length": None, "length": None} if inverse else {"n_fft": N}
|
||||||
'win_length': None,
|
|
||||||
'length': None} if inverse else {'n_fft': N}
|
|
||||||
n_channels = data.shape[-1]
|
n_channels = data.shape[-1]
|
||||||
out = []
|
out = []
|
||||||
for c in range(n_channels):
|
for c in range(n_channels):
|
||||||
d = np.concatenate(
|
d = (
|
||||||
(np.zeros((N, )), data[:, c], np.zeros((N, )))
|
np.concatenate((np.zeros((N,)), data[:, c], np.zeros((N,))))
|
||||||
) if not inverse else data[:, :, c].T
|
if not inverse
|
||||||
|
else data[:, :, c].T
|
||||||
|
)
|
||||||
s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg)
|
s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg)
|
||||||
if inverse:
|
if inverse:
|
||||||
s = s[N:N+length]
|
s = s[N : N + length]
|
||||||
s = np.expand_dims(s.T, 2-inverse)
|
s = np.expand_dims(s.T, 2 - inverse)
|
||||||
out.append(s)
|
out.append(s)
|
||||||
if len(out) == 1:
|
if len(out) == 1:
|
||||||
return out[0]
|
return out[0]
|
||||||
return np.concatenate(out, axis=2-inverse)
|
return np.concatenate(out, axis=2 - inverse)
|
||||||
|
|
||||||
def _get_input_provider(self):
|
def _get_input_provider(self):
|
||||||
if self._input_provider is None:
|
if self._input_provider is None:
|
||||||
@@ -216,22 +231,29 @@ class Separator(object):
|
|||||||
|
|
||||||
def _get_builder(self):
|
def _get_builder(self):
|
||||||
if self._builder is None:
|
if self._builder is None:
|
||||||
self._builder = EstimatorSpecBuilder(
|
self._builder = EstimatorSpecBuilder(self._get_features(), self._params)
|
||||||
self._get_features(),
|
|
||||||
self._params)
|
|
||||||
return self._builder
|
return self._builder
|
||||||
|
|
||||||
def _get_session(self):
|
def _get_session(self):
|
||||||
if self._session is None:
|
if self._session is None:
|
||||||
saver = tf.compat.v1.train.Saver()
|
saver = tf.compat.v1.train.Saver()
|
||||||
latest_checkpoint = tf.train.latest_checkpoint(
|
provider = ModelProvider.default()
|
||||||
get_default_model_dir(self._params['model_dir']))
|
model_directory: str = provider.get(self._params["model_dir"])
|
||||||
|
latest_checkpoint = tf.train.latest_checkpoint(model_directory)
|
||||||
self._session = tf.compat.v1.Session()
|
self._session = tf.compat.v1.Session()
|
||||||
saver.restore(self._session, latest_checkpoint)
|
saver.restore(self._session, latest_checkpoint)
|
||||||
return self._session
|
return self._session
|
||||||
|
|
||||||
def _separate_librosa(self, waveform: np.ndarray, audio_id):
|
def _separate_librosa(
|
||||||
""" Performs separation with librosa backend for STFT.
|
self, waveform: np.ndarray, audio_descriptor: AudioDescriptor
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Performs separation with librosa backend for STFT.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
waveform (numpy.ndarray):
|
||||||
|
Waveform to be separated (as a numpy array)
|
||||||
|
audio_descriptor (AudioDescriptor):
|
||||||
"""
|
"""
|
||||||
with self._tf_graph.as_default():
|
with self._tf_graph.as_default():
|
||||||
out = {}
|
out = {}
|
||||||
@@ -248,65 +270,115 @@ class Separator(object):
|
|||||||
outputs = sess.run(
|
outputs = sess.run(
|
||||||
outputs,
|
outputs,
|
||||||
feed_dict=self._get_input_provider().get_feed_dict(
|
feed_dict=self._get_input_provider().get_feed_dict(
|
||||||
features,
|
features, stft, audio_descriptor
|
||||||
stft,
|
),
|
||||||
audio_id))
|
)
|
||||||
for inst in self._get_builder().instruments:
|
for inst in self._get_builder().instruments:
|
||||||
out[inst] = self._stft(
|
out[inst] = self._stft(
|
||||||
outputs[inst],
|
outputs[inst], inverse=True, length=waveform.shape[0]
|
||||||
inverse=True,
|
)
|
||||||
length=waveform.shape[0])
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def separate(self, waveform: np.ndarray, audio_descriptor=''):
|
def _separate_tensorflow(
|
||||||
""" Performs separation on a waveform.
|
self, waveform: np.ndarray, audio_descriptor: AudioDescriptor
|
||||||
|
) -> Dict:
|
||||||
:param waveform: Waveform to be separated (as a numpy array)
|
|
||||||
:param audio_descriptor: (Optional) string describing the waveform
|
|
||||||
(e.g. filename).
|
|
||||||
"""
|
"""
|
||||||
if self._params['stft_backend'] == 'tensorflow':
|
Performs source separation over the given waveform with tensorflow
|
||||||
|
backend.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
waveform (numpy.ndarray):
|
||||||
|
Waveform to be separated (as a numpy array)
|
||||||
|
audio_descriptor (AudioDescriptor):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Separated waveforms.
|
||||||
|
"""
|
||||||
|
if not waveform.shape[-1] == 2:
|
||||||
|
waveform = to_stereo(waveform)
|
||||||
|
prediction_generator = self._get_prediction_generator()
|
||||||
|
# NOTE: update data in generator before performing separation.
|
||||||
|
self._data_generator.update_data(
|
||||||
|
{"waveform": waveform, "audio_id": np.array(audio_descriptor)}
|
||||||
|
)
|
||||||
|
# NOTE: perform separation.
|
||||||
|
prediction = next(prediction_generator)
|
||||||
|
prediction.pop("audio_id")
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
def separate(
|
||||||
|
self, waveform: np.ndarray, audio_descriptor: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Performs separation on a waveform.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
waveform (numpy.ndarray):
|
||||||
|
Waveform to be separated (as a numpy array)
|
||||||
|
audio_descriptor (str):
|
||||||
|
(Optional) string describing the waveform (e.g. filename).
|
||||||
|
"""
|
||||||
|
backend: str = self._params["stft_backend"]
|
||||||
|
if backend == STFTBackend.TENSORFLOW:
|
||||||
return self._separate_tensorflow(waveform, audio_descriptor)
|
return self._separate_tensorflow(waveform, audio_descriptor)
|
||||||
else:
|
elif backend == STFTBackend.LIBROSA:
|
||||||
return self._separate_librosa(waveform, audio_descriptor)
|
return self._separate_librosa(waveform, audio_descriptor)
|
||||||
|
raise ValueError(f"Unsupported STFT backend {backend}")
|
||||||
|
|
||||||
def separate_to_file(
|
def separate_to_file(
|
||||||
self,
|
self,
|
||||||
audio_descriptor,
|
audio_descriptor: AudioDescriptor,
|
||||||
destination,
|
destination: str,
|
||||||
audio_adapter=get_default_audio_adapter(),
|
audio_adapter: Optional[AudioAdapter] = None,
|
||||||
offset=0,
|
offset: int = 0,
|
||||||
duration=600.,
|
duration: float = 600.0,
|
||||||
codec='wav',
|
codec: Codec = Codec.WAV,
|
||||||
bitrate='128k',
|
bitrate: str = "128k",
|
||||||
filename_format='{filename}/{instrument}.{codec}',
|
filename_format: str = "{filename}/{instrument}.{codec}",
|
||||||
synchronous=True):
|
synchronous: bool = True,
|
||||||
""" Performs source separation and export result to file using
|
) -> None:
|
||||||
|
"""
|
||||||
|
Performs source separation and export result to file using
|
||||||
given audio adapter.
|
given audio adapter.
|
||||||
|
|
||||||
Filename format should be a Python formattable string that could use
|
Filename format should be a Python formattable string that could
|
||||||
following parameters : {instrument}, {filename}, {foldername} and
|
use following parameters :
|
||||||
{codec}.
|
|
||||||
|
|
||||||
:param audio_descriptor: Describe song to separate, used by audio
|
- {instrument}
|
||||||
adapter to retrieve and load audio data,
|
- {filename}
|
||||||
in case of file based audio adapter, such
|
- {foldername}
|
||||||
descriptor would be a file path.
|
- {codec}.
|
||||||
:param destination: Target directory to write output to.
|
|
||||||
:param audio_adapter: (Optional) Audio adapter to use for I/O.
|
Parameters:
|
||||||
:param offset: (Optional) Offset of loaded song.
|
audio_descriptor (AudioDescriptor):
|
||||||
:param duration: (Optional) Duration of loaded song
|
Describe song to separate, used by audio adapter to
|
||||||
(default: 600s).
|
retrieve and load audio data, in case of file based
|
||||||
:param codec: (Optional) Export codec.
|
audio adapter, such descriptor would be a file path.
|
||||||
:param bitrate: (Optional) Export bitrate.
|
destination (str):
|
||||||
:param filename_format: (Optional) Filename format.
|
Target directory to write output to.
|
||||||
:param synchronous: (Optional) True is should by synchronous.
|
audio_adapter (Optional[AudioAdapter]):
|
||||||
|
(Optional) Audio adapter to use for I/O.
|
||||||
|
offset (int):
|
||||||
|
(Optional) Offset of loaded song.
|
||||||
|
duration (float):
|
||||||
|
(Optional) Duration of loaded song (default: 600s).
|
||||||
|
codec (Codec):
|
||||||
|
(Optional) Export codec.
|
||||||
|
bitrate (str):
|
||||||
|
(Optional) Export bitrate.
|
||||||
|
filename_format (str):
|
||||||
|
(Optional) Filename format.
|
||||||
|
synchronous (bool):
|
||||||
|
(Optional) True is should by synchronous.
|
||||||
"""
|
"""
|
||||||
waveform, sample_rate = audio_adapter.load(
|
if audio_adapter is None:
|
||||||
|
audio_adapter = AudioAdapter.default()
|
||||||
|
waveform, _ = audio_adapter.load(
|
||||||
audio_descriptor,
|
audio_descriptor,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
sample_rate=self._sample_rate)
|
sample_rate=self._sample_rate,
|
||||||
|
)
|
||||||
sources = self.separate(waveform, audio_descriptor)
|
sources = self.separate(waveform, audio_descriptor)
|
||||||
self.save_to_file(
|
self.save_to_file(
|
||||||
sources,
|
sources,
|
||||||
@@ -316,69 +388,78 @@ class Separator(object):
|
|||||||
codec,
|
codec,
|
||||||
audio_adapter,
|
audio_adapter,
|
||||||
bitrate,
|
bitrate,
|
||||||
synchronous)
|
synchronous,
|
||||||
|
)
|
||||||
|
|
||||||
def save_to_file(
|
def save_to_file(
|
||||||
self,
|
self,
|
||||||
sources,
|
sources: Dict,
|
||||||
audio_descriptor,
|
audio_descriptor: AudioDescriptor,
|
||||||
destination,
|
destination: str,
|
||||||
filename_format='{filename}/{instrument}.{codec}',
|
filename_format: str = "{filename}/{instrument}.{codec}",
|
||||||
codec='wav',
|
codec: Codec = Codec.WAV,
|
||||||
audio_adapter=get_default_audio_adapter(),
|
audio_adapter: Optional[AudioAdapter] = None,
|
||||||
bitrate='128k',
|
bitrate: str = "128k",
|
||||||
synchronous=True):
|
synchronous: bool = True,
|
||||||
""" Export dictionary of sources to files.
|
) -> None:
|
||||||
|
|
||||||
:param sources: Dictionary of sources to be exported. The
|
|
||||||
keys are the name of the instruments, and
|
|
||||||
the values are Nx2 numpy arrays containing
|
|
||||||
the corresponding intrument waveform, as
|
|
||||||
returned by the separate method
|
|
||||||
:param audio_descriptor: Describe song to separate, used by audio
|
|
||||||
adapter to retrieve and load audio data,
|
|
||||||
in case of file based audio adapter, such
|
|
||||||
descriptor would be a file path.
|
|
||||||
:param destination: Target directory to write output to.
|
|
||||||
:param filename_format: (Optional) Filename format.
|
|
||||||
:param codec: (Optional) Export codec.
|
|
||||||
:param audio_adapter: (Optional) Audio adapter to use for I/O.
|
|
||||||
:param bitrate: (Optional) Export bitrate.
|
|
||||||
:param synchronous: (Optional) True is should by synchronous.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
Export dictionary of sources to files.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
sources (Dict):
|
||||||
|
Dictionary of sources to be exported. The keys are the name
|
||||||
|
of the instruments, and the values are `N x 2` numpy arrays
|
||||||
|
containing the corresponding intrument waveform, as
|
||||||
|
returned by the separate method
|
||||||
|
audio_descriptor (AudioDescriptor):
|
||||||
|
Describe song to separate, used by audio adapter to
|
||||||
|
retrieve and load audio data, in case of file based audio
|
||||||
|
adapter, such descriptor would be a file path.
|
||||||
|
destination (str):
|
||||||
|
Target directory to write output to.
|
||||||
|
filename_format (str):
|
||||||
|
(Optional) Filename format.
|
||||||
|
codec (Codec):
|
||||||
|
(Optional) Export codec.
|
||||||
|
audio_adapter (Optional[AudioAdapter]):
|
||||||
|
(Optional) Audio adapter to use for I/O.
|
||||||
|
bitrate (str):
|
||||||
|
(Optional) Export bitrate.
|
||||||
|
synchronous (bool):
|
||||||
|
(Optional) True is should by synchronous.
|
||||||
|
"""
|
||||||
|
if audio_adapter is None:
|
||||||
|
audio_adapter = AudioAdapter.default()
|
||||||
foldername = basename(dirname(audio_descriptor))
|
foldername = basename(dirname(audio_descriptor))
|
||||||
filename = splitext(basename(audio_descriptor))[0]
|
filename = splitext(basename(audio_descriptor))[0]
|
||||||
generated = []
|
generated = []
|
||||||
for instrument, data in sources.items():
|
for instrument, data in sources.items():
|
||||||
path = join(destination, filename_format.format(
|
path = join(
|
||||||
filename=filename,
|
destination,
|
||||||
instrument=instrument,
|
filename_format.format(
|
||||||
foldername=foldername,
|
filename=filename,
|
||||||
codec=codec,
|
instrument=instrument,
|
||||||
))
|
foldername=foldername,
|
||||||
|
codec=codec,
|
||||||
|
),
|
||||||
|
)
|
||||||
directory = os.path.dirname(path)
|
directory = os.path.dirname(path)
|
||||||
if not os.path.exists(directory):
|
if not os.path.exists(directory):
|
||||||
os.makedirs(directory)
|
os.makedirs(directory)
|
||||||
if path in generated:
|
if path in generated:
|
||||||
raise SpleeterError((
|
raise SpleeterError(
|
||||||
f'Separated source path conflict : {path},'
|
(
|
||||||
'please check your filename format'))
|
f"Separated source path conflict : {path},"
|
||||||
|
"please check your filename format"
|
||||||
|
)
|
||||||
|
)
|
||||||
generated.append(path)
|
generated.append(path)
|
||||||
if self._pool:
|
if self._pool:
|
||||||
task = self._pool.apply_async(audio_adapter.save, (
|
task = self._pool.apply_async(
|
||||||
path,
|
audio_adapter.save, (path, data, self._sample_rate, codec, bitrate)
|
||||||
data,
|
)
|
||||||
self._sample_rate,
|
|
||||||
codec,
|
|
||||||
bitrate))
|
|
||||||
self._tasks.append(task)
|
self._tasks.append(task)
|
||||||
else:
|
else:
|
||||||
audio_adapter.save(
|
audio_adapter.save(path, data, self._sample_rate, codec, bitrate)
|
||||||
path,
|
|
||||||
data,
|
|
||||||
self._sample_rate,
|
|
||||||
codec,
|
|
||||||
bitrate)
|
|
||||||
if synchronous and self._pool:
|
if synchronous and self._pool:
|
||||||
self.join()
|
self.join()
|
||||||
|
|||||||
15
spleeter/types.py
Normal file
15
spleeter/types.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
""" Custom types definition. """
|
||||||
|
|
||||||
|
from typing import Any, Tuple
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
|
# pylint: disable=import-error
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# pylint: enable=import-error
|
||||||
|
|
||||||
|
AudioDescriptor: type = Any
|
||||||
|
Signal: type = Tuple[np.ndarray, float]
|
||||||
@@ -3,6 +3,6 @@
|
|||||||
|
|
||||||
""" This package provides utility function and classes. """
|
""" This package provides utility function and classes. """
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
__email__ = "spleeter@deezer.com"
|
||||||
__author__ = 'Deezer Research'
|
__author__ = "Deezer Research"
|
||||||
__license__ = 'MIT License'
|
__license__ = "MIT License"
|
||||||
|
|||||||
@@ -3,45 +3,49 @@
|
|||||||
|
|
||||||
""" Module that provides configuration loading function. """
|
""" Module that provides configuration loading function. """
|
||||||
|
|
||||||
|
import importlib.resources as loader
|
||||||
import json
|
import json
|
||||||
|
|
||||||
try:
|
|
||||||
import importlib.resources as loader
|
|
||||||
except ImportError:
|
|
||||||
# Try backported to PY<37 `importlib_resources`.
|
|
||||||
import importlib_resources as loader
|
|
||||||
|
|
||||||
from os.path import exists
|
from os.path import exists
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
from .. import resources, SpleeterError
|
from .. import SpleeterError, resources
|
||||||
|
|
||||||
|
__email__ = "spleeter@deezer.com"
|
||||||
|
__author__ = "Deezer Research"
|
||||||
|
__license__ = "MIT License"
|
||||||
|
|
||||||
|
_EMBEDDED_CONFIGURATION_PREFIX: str = "spleeter:"
|
||||||
|
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
def load_configuration(descriptor: str) -> Dict:
|
||||||
__author__ = 'Deezer Research'
|
"""
|
||||||
__license__ = 'MIT License'
|
Load configuration from the given descriptor. Could be either a
|
||||||
|
`spleeter:` prefixed embedded configuration name or a file system path
|
||||||
|
to read configuration from.
|
||||||
|
|
||||||
_EMBEDDED_CONFIGURATION_PREFIX = 'spleeter:'
|
Parameters:
|
||||||
|
descriptor (str):
|
||||||
|
Configuration descriptor to use for lookup.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict:
|
||||||
|
Loaded description as dict.
|
||||||
|
|
||||||
def load_configuration(descriptor):
|
Raises:
|
||||||
""" Load configuration from the given descriptor. Could be
|
ValueError:
|
||||||
either a `spleeter:` prefixed embedded configuration name
|
If required embedded configuration does not exists.
|
||||||
or a file system path to read configuration from.
|
SpleeterError:
|
||||||
|
If required configuration file does not exists.
|
||||||
:param descriptor: Configuration descriptor to use for lookup.
|
|
||||||
:returns: Loaded description as dict.
|
|
||||||
:raise ValueError: If required embedded configuration does not exists.
|
|
||||||
:raise SpleeterError: If required configuration file does not exists.
|
|
||||||
"""
|
"""
|
||||||
# Embedded configuration reading.
|
# Embedded configuration reading.
|
||||||
if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX):
|
if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX):
|
||||||
name = descriptor[len(_EMBEDDED_CONFIGURATION_PREFIX):]
|
name = descriptor[len(_EMBEDDED_CONFIGURATION_PREFIX) :]
|
||||||
if not loader.is_resource(resources, f'{name}.json'):
|
if not loader.is_resource(resources, f"{name}.json"):
|
||||||
raise SpleeterError(f'No embedded configuration {name} found')
|
raise SpleeterError(f"No embedded configuration {name} found")
|
||||||
with loader.open_text(resources, f'{name}.json') as stream:
|
with loader.open_text(resources, f"{name}.json") as stream:
|
||||||
return json.load(stream)
|
return json.load(stream)
|
||||||
# Standard file reading.
|
# Standard file reading.
|
||||||
if not exists(descriptor):
|
if not exists(descriptor):
|
||||||
raise SpleeterError(f'Configuration file {descriptor} not found')
|
raise SpleeterError(f"Configuration file {descriptor} not found")
|
||||||
with open(descriptor, 'r') as stream:
|
with open(descriptor, "r") as stream:
|
||||||
return json.load(stream)
|
return json.load(stream)
|
||||||
|
|||||||
@@ -1,46 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# coding: utf8
|
|
||||||
|
|
||||||
""" Utility functions for creating estimator. """
|
|
||||||
|
|
||||||
import tensorflow as tf # pylint: disable=import-error
|
|
||||||
|
|
||||||
from ..model import model_fn
|
|
||||||
from ..model.provider import get_default_model_provider
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_model_dir(model_dir):
|
|
||||||
"""
|
|
||||||
Transforms a string like 'spleeter:2stems' into an actual path.
|
|
||||||
:param model_dir:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
model_provider = get_default_model_provider()
|
|
||||||
return model_provider.get(model_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def create_estimator(params, MWF):
|
|
||||||
"""
|
|
||||||
Initialize tensorflow estimator that will perform separation
|
|
||||||
|
|
||||||
Params:
|
|
||||||
- params: a dictionary of parameters for building the model
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
a tensorflow estimator
|
|
||||||
"""
|
|
||||||
# Load model.
|
|
||||||
params['model_dir'] = get_default_model_dir(params['model_dir'])
|
|
||||||
params['MWF'] = MWF
|
|
||||||
# Setup config
|
|
||||||
session_config = tf.compat.v1.ConfigProto()
|
|
||||||
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
|
|
||||||
config = tf.estimator.RunConfig(session_config=session_config)
|
|
||||||
# Setup estimator
|
|
||||||
estimator = tf.estimator.Estimator(
|
|
||||||
model_fn=model_fn,
|
|
||||||
model_dir=params['model_dir'],
|
|
||||||
params=params,
|
|
||||||
config=config
|
|
||||||
)
|
|
||||||
return estimator
|
|
||||||
@@ -4,58 +4,53 @@
|
|||||||
""" Centralized logging facilities for Spleeter. """
|
""" Centralized logging facilities for Spleeter. """
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
from os import environ
|
from os import environ
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
# pyright: reportMissingImports=false
|
||||||
__author__ = 'Deezer Research'
|
# pylint: disable=import-error
|
||||||
__license__ = 'MIT License'
|
from typer import echo
|
||||||
|
|
||||||
_FORMAT = '%(levelname)s:%(name)s:%(message)s'
|
# pylint: enable=import-error
|
||||||
|
|
||||||
|
__email__ = "spleeter@deezer.com"
|
||||||
|
__author__ = "Deezer Research"
|
||||||
|
__license__ = "MIT License"
|
||||||
|
|
||||||
|
environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
|
|
||||||
|
|
||||||
class _LoggerHolder(object):
|
class TyperLoggerHandler(logging.Handler):
|
||||||
""" Logger singleton instance holder. """
|
""" A custom logger handler that use Typer echo. """
|
||||||
|
|
||||||
INSTANCE = None
|
def emit(self, record: logging.LogRecord) -> None:
|
||||||
|
echo(self.format(record))
|
||||||
|
|
||||||
|
|
||||||
def get_tensorflow_logger():
|
formatter = logging.Formatter("%(levelname)s:%(name)s:%(message)s")
|
||||||
|
handler = TyperLoggerHandler()
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger: logging.Logger = logging.getLogger("spleeter")
|
||||||
|
logger.addHandler(handler)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_logger(verbose: bool) -> None:
|
||||||
"""
|
"""
|
||||||
|
Configure application logger.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
verbose (bool):
|
||||||
|
`True` to use verbose logger, `False` otherwise.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=import-error
|
from tensorflow import get_logger
|
||||||
from tensorflow.compat.v1 import logging
|
from tensorflow.compat.v1 import logging as tf_logging
|
||||||
# pylint: enable=import-error
|
|
||||||
return logging
|
|
||||||
|
|
||||||
|
tf_logger = get_logger()
|
||||||
def get_logger():
|
tf_logger.handlers = [handler]
|
||||||
""" Returns library scoped logger.
|
if verbose:
|
||||||
|
tf_logging.set_verbosity(tf_logging.INFO)
|
||||||
:returns: Library logger.
|
logger.setLevel(logging.DEBUG)
|
||||||
"""
|
else:
|
||||||
if _LoggerHolder.INSTANCE is None:
|
warnings.filterwarnings("ignore")
|
||||||
formatter = logging.Formatter(_FORMAT)
|
tf_logging.set_verbosity(tf_logging.ERROR)
|
||||||
handler = logging.StreamHandler()
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
logger = logging.getLogger('spleeter')
|
|
||||||
logger.addHandler(handler)
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
_LoggerHolder.INSTANCE = logger
|
|
||||||
return _LoggerHolder.INSTANCE
|
|
||||||
|
|
||||||
|
|
||||||
def enable_tensorflow_logging():
|
|
||||||
""" Enable tensorflow logging. """
|
|
||||||
environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
|
||||||
tf_logger = get_tensorflow_logger()
|
|
||||||
tf_logger.set_verbosity(tf_logger.INFO)
|
|
||||||
logger = get_logger()
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
|
|
||||||
def enable_logging():
|
|
||||||
""" Configure default logging. """
|
|
||||||
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|
||||||
tf_logger = get_tensorflow_logger()
|
|
||||||
tf_logger.set_verbosity(tf_logger.ERROR)
|
|
||||||
|
|||||||
@@ -3,43 +3,54 @@
|
|||||||
|
|
||||||
""" Utility function for tensorflow. """
|
""" Utility function for tensorflow. """
|
||||||
|
|
||||||
|
from typing import Any, Callable, Dict
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import pandas as pd
|
|
||||||
# pylint: enable=import-error
|
# pylint: enable=import-error
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
__email__ = "spleeter@deezer.com"
|
||||||
__author__ = 'Deezer Research'
|
__author__ = "Deezer Research"
|
||||||
__license__ = 'MIT License'
|
__license__ = "MIT License"
|
||||||
|
|
||||||
|
|
||||||
def sync_apply(tensor_dict, func, concat_axis=1):
|
def sync_apply(
|
||||||
""" Return a function that applies synchronously the provided func on the
|
tensor_dict: tf.Tensor, func: Callable, concat_axis: int = 1
|
||||||
|
) -> Dict[str, tf.Tensor]:
|
||||||
|
"""
|
||||||
|
Return a function that applies synchronously the provided func on the
|
||||||
provided dictionnary of tensor. This means that func is applied to the
|
provided dictionnary of tensor. This means that func is applied to the
|
||||||
concatenation of the tensors in tensor_dict. This is useful for performing
|
concatenation of the tensors in tensor_dict. This is useful for
|
||||||
random operation that needs the same drawn value on multiple tensor, such
|
performing random operation that needs the same drawn value on multiple
|
||||||
as a random time-crop on both input data and label (the same crop should be
|
tensor, such as a random time-crop on both input data and label (the
|
||||||
applied to both input data and label, so random crop cannot be applied
|
same crop should be applied to both input data and label, so random
|
||||||
separately on each of them).
|
crop cannot be applied separately on each of them).
|
||||||
|
|
||||||
IMPORTANT NOTE: all tensor are assumed to be the same shape.
|
Notes:
|
||||||
|
All tensor are assumed to be the same shape.
|
||||||
|
|
||||||
Params:
|
Parameters:
|
||||||
- tensor_dict: dictionary (key: strings, values: tf.tensor)
|
tensor_dict (Dict[str, tensorflow.Tensor]):
|
||||||
a dictionary of tensor.
|
A dictionary of tensor.
|
||||||
- func: function
|
func (Callable):
|
||||||
function to be applied to the concatenation of the tensors in
|
Function to be applied to the concatenation of the tensors in
|
||||||
tensor_dict
|
`tensor_dict`.
|
||||||
- concat_axis: int
|
concat_axis (int):
|
||||||
The axis on which to perform the concatenation.
|
The axis on which to perform the concatenation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
processed tensors dictionary with the same name (keys) as input
|
Dict[str, tensorflow.Tensor]:
|
||||||
tensor_dict.
|
Processed tensors dictionary with the same name (keys) as input
|
||||||
|
tensor_dict.
|
||||||
"""
|
"""
|
||||||
if concat_axis not in {0, 1}:
|
if concat_axis not in {0, 1}:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'Function only implemented for concat_axis equal to 0 or 1')
|
"Function only implemented for concat_axis equal to 0 or 1"
|
||||||
|
)
|
||||||
tensor_list = list(tensor_dict.values())
|
tensor_list = list(tensor_dict.values())
|
||||||
concat_tensor = tf.concat(tensor_list, concat_axis)
|
concat_tensor = tf.concat(tensor_list, concat_axis)
|
||||||
processed_concat_tensor = func(concat_tensor)
|
processed_concat_tensor = func(concat_tensor)
|
||||||
@@ -47,90 +58,104 @@ def sync_apply(tensor_dict, func, concat_axis=1):
|
|||||||
D = tensor_shape[concat_axis]
|
D = tensor_shape[concat_axis]
|
||||||
if concat_axis == 0:
|
if concat_axis == 0:
|
||||||
return {
|
return {
|
||||||
name: processed_concat_tensor[index * D:(index + 1) * D, :, :]
|
name: processed_concat_tensor[index * D : (index + 1) * D, :, :]
|
||||||
for index, name in enumerate(tensor_dict)
|
for index, name in enumerate(tensor_dict)
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
name: processed_concat_tensor[:, index * D:(index + 1) * D, :]
|
name: processed_concat_tensor[:, index * D : (index + 1) * D, :]
|
||||||
for index, name in enumerate(tensor_dict)
|
for index, name in enumerate(tensor_dict)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def from_float32_to_uint8(
|
def from_float32_to_uint8(
|
||||||
tensor,
|
tensor: tf.Tensor,
|
||||||
tensor_key='tensor',
|
tensor_key: str = "tensor",
|
||||||
min_key='min',
|
min_key: str = "min",
|
||||||
max_key='max'):
|
max_key: str = "max",
|
||||||
|
) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param tensor:
|
Parameters:
|
||||||
:param tensor_key:
|
tensor (tensorflow.Tensor):
|
||||||
:param min_key:
|
tensor_key (str):
|
||||||
:param max_key:
|
min_key (str):
|
||||||
:returns:
|
max_key (str):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensorflow.Tensor:
|
||||||
"""
|
"""
|
||||||
tensor_min = tf.reduce_min(tensor)
|
tensor_min = tf.reduce_min(tensor)
|
||||||
tensor_max = tf.reduce_max(tensor)
|
tensor_max = tf.reduce_max(tensor)
|
||||||
return {
|
return {
|
||||||
tensor_key: tf.cast(
|
tensor_key: tf.cast(
|
||||||
(tensor - tensor_min) / (tensor_max - tensor_min + 1e-16)
|
(tensor - tensor_min) / (tensor_max - tensor_min + 1e-16) * 255.9999,
|
||||||
* 255.9999, dtype=tf.uint8),
|
dtype=tf.uint8,
|
||||||
|
),
|
||||||
min_key: tensor_min,
|
min_key: tensor_min,
|
||||||
max_key: tensor_max
|
max_key: tensor_max,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def from_uint8_to_float32(tensor, tensor_min, tensor_max):
|
def from_uint8_to_float32(
|
||||||
|
tensor: tf.Tensor, tensor_min: tf.Tensor, tensor_max: tf.Tensor
|
||||||
|
) -> tf.Tensor:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param tensor:
|
Parameters:
|
||||||
:param tensor_min:
|
tensor (tensorflow.Tensor):
|
||||||
:param tensor_max:
|
tensor_min (tensorflow.Tensor):
|
||||||
:returns:
|
tensor_max (tensorflow.Tensor):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensorflow.Tensor:
|
||||||
"""
|
"""
|
||||||
return (
|
return (
|
||||||
tf.cast(tensor, tf.float32)
|
tf.cast(tensor, tf.float32) * (tensor_max - tensor_min) / 255.9999 + tensor_min
|
||||||
* (tensor_max - tensor_min)
|
)
|
||||||
/ 255.9999 + tensor_min)
|
|
||||||
|
|
||||||
|
|
||||||
def pad_and_partition(tensor, segment_len):
|
def pad_and_partition(tensor: tf.Tensor, segment_len: int) -> tf.Tensor:
|
||||||
""" Pad and partition a tensor into segment of len segment_len
|
"""
|
||||||
|
Pad and partition a tensor into segment of len `segment_len`
|
||||||
along the first dimension. The tensor is padded with 0 in order
|
along the first dimension. The tensor is padded with 0 in order
|
||||||
to ensure that the first dimension is a multiple of segment_len.
|
to ensure that the first dimension is a multiple of `segment_len`.
|
||||||
|
|
||||||
Tensor must be of known fixed rank
|
Tensor must be of known fixed rank
|
||||||
|
|
||||||
:Example:
|
Examples:
|
||||||
|
|
||||||
>>> tensor = [[1, 2, 3], [4, 5, 6]]
|
```python
|
||||||
>>> segment_len = 2
|
>>> tensor = [[1, 2, 3], [4, 5, 6]]
|
||||||
>>> pad_and_partition(tensor, segment_len)
|
>>> segment_len = 2
|
||||||
[[[1, 2], [4, 5]], [[3, 0], [6, 0]]]
|
>>> pad_and_partition(tensor, segment_len)
|
||||||
|
[[[1, 2], [4, 5]], [[3, 0], [6, 0]]]
|
||||||
|
````
|
||||||
|
|
||||||
:param tensor:
|
Parameters:
|
||||||
:param segment_len:
|
tensor (tensorflow.Tensor):
|
||||||
:returns:
|
segment_len (int):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensorflow.Tensor:
|
||||||
"""
|
"""
|
||||||
tensor_size = tf.math.floormod(tf.shape(tensor)[0], segment_len)
|
tensor_size = tf.math.floormod(tf.shape(tensor)[0], segment_len)
|
||||||
pad_size = tf.math.floormod(segment_len - tensor_size, segment_len)
|
pad_size = tf.math.floormod(segment_len - tensor_size, segment_len)
|
||||||
padded = tf.pad(
|
padded = tf.pad(tensor, [[0, pad_size]] + [[0, 0]] * (len(tensor.shape) - 1))
|
||||||
tensor,
|
|
||||||
[[0, pad_size]] + [[0, 0]] * (len(tensor.shape)-1))
|
|
||||||
split = (tf.shape(padded)[0] + segment_len - 1) // segment_len
|
split = (tf.shape(padded)[0] + segment_len - 1) // segment_len
|
||||||
return tf.reshape(
|
return tf.reshape(
|
||||||
padded,
|
padded, tf.concat([[split, segment_len], tf.shape(padded)[1:]], axis=0)
|
||||||
tf.concat(
|
)
|
||||||
[[split, segment_len], tf.shape(padded)[1:]],
|
|
||||||
axis=0))
|
|
||||||
|
|
||||||
|
|
||||||
def pad_and_reshape(instr_spec, frame_length, F):
|
def pad_and_reshape(instr_spec, frame_length, F) -> Any:
|
||||||
"""
|
"""
|
||||||
:param instr_spec:
|
Parameters:
|
||||||
:param frame_length:
|
instr_spec:
|
||||||
:param F:
|
frame_length:
|
||||||
:returns:
|
F:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any:
|
||||||
"""
|
"""
|
||||||
spec_shape = tf.shape(instr_spec)
|
spec_shape = tf.shape(instr_spec)
|
||||||
extension_row = tf.zeros((spec_shape[0], spec_shape[1], 1, spec_shape[-1]))
|
extension_row = tf.zeros((spec_shape[0], spec_shape[1], 1, spec_shape[-1]))
|
||||||
@@ -138,53 +163,67 @@ def pad_and_reshape(instr_spec, frame_length, F):
|
|||||||
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
|
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
|
||||||
extended_spec = tf.concat([instr_spec, extension], axis=2)
|
extended_spec = tf.concat([instr_spec, extension], axis=2)
|
||||||
old_shape = tf.shape(extended_spec)
|
old_shape = tf.shape(extended_spec)
|
||||||
new_shape = tf.concat([
|
new_shape = tf.concat([[old_shape[0] * old_shape[1]], old_shape[2:]], axis=0)
|
||||||
[old_shape[0] * old_shape[1]],
|
|
||||||
old_shape[2:]],
|
|
||||||
axis=0)
|
|
||||||
processed_instr_spec = tf.reshape(extended_spec, new_shape)
|
processed_instr_spec = tf.reshape(extended_spec, new_shape)
|
||||||
return processed_instr_spec
|
return processed_instr_spec
|
||||||
|
|
||||||
|
|
||||||
def dataset_from_csv(csv_path, **kwargs):
|
def dataset_from_csv(csv_path: str, **kwargs) -> Any:
|
||||||
""" Load dataset from a CSV file using Pandas. kwargs if any are
|
"""
|
||||||
|
Load dataset from a CSV file using Pandas. kwargs if any are
|
||||||
forwarded to the `pandas.read_csv` function.
|
forwarded to the `pandas.read_csv` function.
|
||||||
|
|
||||||
:param csv_path: Path of the CSV file to load dataset from.
|
Parameters:
|
||||||
:returns: Loaded dataset.
|
csv_path (str):
|
||||||
|
Path of the CSV file to load dataset from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any:
|
||||||
|
Loaded dataset.
|
||||||
"""
|
"""
|
||||||
df = pd.read_csv(csv_path, **kwargs)
|
df = pd.read_csv(csv_path, **kwargs)
|
||||||
dataset = (
|
dataset = tf.data.Dataset.from_tensor_slices({key: df[key].values for key in df})
|
||||||
tf.data.Dataset.from_tensor_slices(
|
|
||||||
{key: df[key].values for key in df})
|
|
||||||
)
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def check_tensor_shape(tensor_tf, target_shape):
|
def check_tensor_shape(tensor_tf: tf.Tensor, target_shape: Any) -> bool:
|
||||||
""" Return a Tensorflow boolean graph that indicates whether
|
"""
|
||||||
|
Return a Tensorflow boolean graph that indicates whether
|
||||||
sample[features_key] has the specified target shape. Only check
|
sample[features_key] has the specified target shape. Only check
|
||||||
not None entries of target_shape.
|
not None entries of target_shape.
|
||||||
|
|
||||||
:param tensor_tf: Tensor to check shape for.
|
Parameters:
|
||||||
:param target_shape: Target shape to compare tensor to.
|
tensor_tf (tensorflow.Tensor):
|
||||||
:returns: True if shape is valid, False otherwise (as TF boolean).
|
Tensor to check shape for.
|
||||||
|
target_shape (Any):
|
||||||
|
Target shape to compare tensor to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool:
|
||||||
|
`True` if shape is valid, `False` otherwise (as TF boolean).
|
||||||
"""
|
"""
|
||||||
result = tf.constant(True)
|
result = tf.constant(True)
|
||||||
for i, target_length in enumerate(target_shape):
|
for i, target_length in enumerate(target_shape):
|
||||||
if target_length:
|
if target_length:
|
||||||
result = tf.logical_and(
|
result = tf.logical_and(
|
||||||
result,
|
result, tf.equal(tf.constant(target_length), tf.shape(tensor_tf)[i])
|
||||||
tf.equal(tf.constant(target_length), tf.shape(tensor_tf)[i]))
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def set_tensor_shape(tensor, tensor_shape):
|
def set_tensor_shape(tensor: tf.Tensor, tensor_shape: Any) -> tf.Tensor:
|
||||||
""" Set shape for a tensor (not in place, as opposed to tf.set_shape)
|
"""
|
||||||
|
Set shape for a tensor (not in place, as opposed to tf.set_shape)
|
||||||
|
|
||||||
:param tensor: Tensor to reshape.
|
Parameters:
|
||||||
:param tensor_shape: Shape to apply to the tensor.
|
tensor (tensorflow.Tensor):
|
||||||
:returns: A reshaped tensor.
|
Tensor to reshape.
|
||||||
|
tensor_shape (Any):
|
||||||
|
Shape to apply to the tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensorflow.Tensor:
|
||||||
|
A reshaped tensor.
|
||||||
"""
|
"""
|
||||||
# NOTE: That SOUND LIKE IN PLACE HERE ?
|
# NOTE: That SOUND LIKE IN PLACE HERE ?
|
||||||
tensor.set_shape(tensor_shape)
|
tensor.set_shape(tensor_shape)
|
||||||
|
|||||||
@@ -7,82 +7,82 @@ __email__ = 'spleeter@deezer.com'
|
|||||||
__author__ = 'Deezer Research'
|
__author__ = 'Deezer Research'
|
||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
import filecmp
|
|
||||||
import itertools
|
|
||||||
from os import makedirs
|
from os import makedirs
|
||||||
from os.path import splitext, basename, exists, join
|
from os.path import join
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import tensorflow as tf
|
from spleeter.__main__ import evaluate
|
||||||
|
from spleeter.audio.adapter import AudioAdapter
|
||||||
|
|
||||||
from spleeter.audio.adapter import get_default_audio_adapter
|
BACKENDS = ['tensorflow', 'librosa']
|
||||||
from spleeter.commands import create_argument_parser
|
TEST_CONFIGURATIONS = {el: el for el in BACKENDS}
|
||||||
|
|
||||||
from spleeter.commands import evaluate
|
|
||||||
|
|
||||||
from spleeter.utils.configuration import load_configuration
|
|
||||||
|
|
||||||
BACKENDS = ["tensorflow", "librosa"]
|
|
||||||
TEST_CONFIGURATIONS = {el:el for el in BACKENDS}
|
|
||||||
|
|
||||||
res_4stems = {
|
res_4stems = {
|
||||||
"vocals": {
|
'vocals': {
|
||||||
"SDR": 3.25e-05,
|
'SDR': 3.25e-05,
|
||||||
"SAR": -11.153575,
|
'SAR': -11.153575,
|
||||||
"SIR": -1.3849,
|
'SIR': -1.3849,
|
||||||
"ISR": 2.75e-05
|
'ISR': 2.75e-05
|
||||||
},
|
},
|
||||||
"drums": {
|
'drums': {
|
||||||
"SDR": -0.079505,
|
'SDR': -0.079505,
|
||||||
"SAR": -15.7073575,
|
'SAR': -15.7073575,
|
||||||
"SIR": -4.972755,
|
'SIR': -4.972755,
|
||||||
"ISR": 0.0013575
|
'ISR': 0.0013575
|
||||||
},
|
},
|
||||||
"bass":{
|
'bass': {
|
||||||
"SDR": 2.5e-06,
|
'SDR': 2.5e-06,
|
||||||
"SAR": -10.3520575,
|
'SAR': -10.3520575,
|
||||||
"SIR": -4.272325,
|
'SIR': -4.272325,
|
||||||
"ISR": 2.5e-06
|
'ISR': 2.5e-06
|
||||||
},
|
},
|
||||||
"other":{
|
'other': {
|
||||||
"SDR": -1.359175,
|
'SDR': -1.359175,
|
||||||
"SAR": -14.7076775,
|
'SAR': -14.7076775,
|
||||||
"SIR": -4.761505,
|
'SIR': -4.761505,
|
||||||
"ISR": -0.01528
|
'ISR': -0.01528
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def generate_fake_eval_dataset(path):
|
def generate_fake_eval_dataset(path):
|
||||||
"""
|
"""
|
||||||
generate fake evaluation dataset
|
generate fake evaluation dataset
|
||||||
"""
|
"""
|
||||||
aa = get_default_audio_adapter()
|
aa = AudioAdapter.default()
|
||||||
n_songs = 2
|
n_songs = 2
|
||||||
fs = 44100
|
fs = 44100
|
||||||
duration = 3
|
duration = 3
|
||||||
n_channels = 2
|
n_channels = 2
|
||||||
rng = np.random.RandomState(seed=0)
|
rng = np.random.RandomState(seed=0)
|
||||||
for song in range(n_songs):
|
for song in range(n_songs):
|
||||||
song_path = join(path, "test", f"song{song}")
|
song_path = join(path, 'test', f'song{song}')
|
||||||
makedirs(song_path, exist_ok=True)
|
makedirs(song_path, exist_ok=True)
|
||||||
for instr in ["mixture", "vocals", "bass", "drums", "other"]:
|
for instr in ['mixture', 'vocals', 'bass', 'drums', 'other']:
|
||||||
filename = join(song_path, f"{instr}.wav")
|
filename = join(song_path, f'{instr}.wav')
|
||||||
data = rng.rand(duration*fs, n_channels)-0.5
|
data = rng.rand(duration*fs, n_channels)-0.5
|
||||||
aa.save(filename, data, fs)
|
aa.save(filename, data, fs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('backend', TEST_CONFIGURATIONS)
|
@pytest.mark.parametrize('backend', TEST_CONFIGURATIONS)
|
||||||
def test_evaluate(backend):
|
def test_evaluate(backend):
|
||||||
with TemporaryDirectory() as directory:
|
with TemporaryDirectory() as dataset:
|
||||||
generate_fake_eval_dataset(directory)
|
with TemporaryDirectory() as evaluation:
|
||||||
p = create_argument_parser()
|
generate_fake_eval_dataset(dataset)
|
||||||
arguments = p.parse_args(["evaluate", "-p", "spleeter:4stems", "--mus_dir", directory, "-B", backend])
|
metrics = evaluate(
|
||||||
params = load_configuration(arguments.configuration)
|
adapter='spleeter.audio.ffmpeg.FFMPEGProcessAudioAdapter',
|
||||||
metrics = evaluate.entrypoint(arguments, params)
|
output_path=evaluation,
|
||||||
for instrument, metric in metrics.items():
|
stft_backend=backend,
|
||||||
for m, value in metric.items():
|
params_filename='spleeter:4stems',
|
||||||
assert np.allclose(np.median(value), res_4stems[instrument][m], atol=1e-3)
|
mus_dir=dataset,
|
||||||
|
mwf=False,
|
||||||
|
verbose=False)
|
||||||
|
for instrument, metric in metrics.items():
|
||||||
|
for m, value in metric.items():
|
||||||
|
assert np.allclose(
|
||||||
|
np.median(value),
|
||||||
|
res_4stems[instrument][m],
|
||||||
|
atol=1e-3)
|
||||||
|
|||||||
@@ -10,6 +10,11 @@ __license__ = 'MIT License'
|
|||||||
from os.path import join
|
from os.path import join
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
|
from spleeter import SpleeterError
|
||||||
|
from spleeter.audio.adapter import AudioAdapter
|
||||||
|
from spleeter.audio.ffmpeg import FFMPEGProcessAudioAdapter
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
from pytest import fixture, raises
|
from pytest import fixture, raises
|
||||||
|
|
||||||
@@ -17,12 +22,6 @@ import numpy as np
|
|||||||
import ffmpeg
|
import ffmpeg
|
||||||
# pylint: enable=import-error
|
# pylint: enable=import-error
|
||||||
|
|
||||||
from spleeter import SpleeterError
|
|
||||||
from spleeter.audio.adapter import AudioAdapter
|
|
||||||
from spleeter.audio.adapter import get_default_audio_adapter
|
|
||||||
from spleeter.audio.adapter import get_audio_adapter
|
|
||||||
from spleeter.audio.ffmpeg import FFMPEGProcessAudioAdapter
|
|
||||||
|
|
||||||
TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3'
|
TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3'
|
||||||
TEST_OFFSET = 0
|
TEST_OFFSET = 0
|
||||||
TEST_DURATION = 600.
|
TEST_DURATION = 600.
|
||||||
@@ -32,7 +31,7 @@ TEST_SAMPLE_RATE = 44100
|
|||||||
@fixture(scope='session')
|
@fixture(scope='session')
|
||||||
def adapter():
|
def adapter():
|
||||||
""" Target test audio adapter fixture. """
|
""" Target test audio adapter fixture. """
|
||||||
return get_default_audio_adapter()
|
return AudioAdapter.default()
|
||||||
|
|
||||||
|
|
||||||
@fixture(scope='session')
|
@fixture(scope='session')
|
||||||
@@ -48,7 +47,7 @@ def audio_data(adapter):
|
|||||||
def test_default_adapter(adapter):
|
def test_default_adapter(adapter):
|
||||||
""" Test adapter as default adapter. """
|
""" Test adapter as default adapter. """
|
||||||
assert isinstance(adapter, FFMPEGProcessAudioAdapter)
|
assert isinstance(adapter, FFMPEGProcessAudioAdapter)
|
||||||
assert adapter is AudioAdapter.DEFAULT
|
assert adapter is AudioAdapter._DEFAULT
|
||||||
|
|
||||||
|
|
||||||
def test_load(audio_data):
|
def test_load(audio_data):
|
||||||
|
|||||||
@@ -5,12 +5,12 @@
|
|||||||
|
|
||||||
from pytest import raises
|
from pytest import raises
|
||||||
|
|
||||||
from spleeter.model.provider import get_default_model_provider
|
from spleeter.model.provider import ModelProvider
|
||||||
|
|
||||||
|
|
||||||
def test_checksum():
|
def test_checksum():
|
||||||
""" Test archive checksum index retrieval. """
|
""" Test archive checksum index retrieval. """
|
||||||
provider = get_default_model_provider()
|
provider = ModelProvider.default()
|
||||||
assert provider.checksum('2stems') == \
|
assert provider.checksum('2stems') == \
|
||||||
'f3a90b39dd2874269e8b05a48a86745df897b848c61f3958efc80a39152bd692'
|
'f3a90b39dd2874269e8b05a48a86745df897b848c61f3958efc80a39152bd692'
|
||||||
assert provider.checksum('4stems') == \
|
assert provider.checksum('4stems') == \
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import numpy as np
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from spleeter import SpleeterError
|
from spleeter import SpleeterError
|
||||||
from spleeter.audio.adapter import get_default_audio_adapter
|
from spleeter.audio.adapter import AudioAdapter
|
||||||
from spleeter.separator import Separator
|
from spleeter.separator import Separator
|
||||||
|
|
||||||
TEST_AUDIO_DESCRIPTORS = ['audio_example.mp3', 'audio_example_mono.mp3']
|
TEST_AUDIO_DESCRIPTORS = ['audio_example.mp3', 'audio_example_mono.mp3']
|
||||||
@@ -41,7 +41,7 @@ print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__))
|
|||||||
|
|
||||||
@pytest.mark.parametrize('test_file', TEST_AUDIO_DESCRIPTORS)
|
@pytest.mark.parametrize('test_file', TEST_AUDIO_DESCRIPTORS)
|
||||||
def test_separator_backends(test_file):
|
def test_separator_backends(test_file):
|
||||||
adapter = get_default_audio_adapter()
|
adapter = AudioAdapter.default()
|
||||||
waveform, _ = adapter.load(test_file)
|
waveform, _ = adapter.load(test_file)
|
||||||
|
|
||||||
separator_lib = Separator(
|
separator_lib = Separator(
|
||||||
@@ -64,11 +64,13 @@ def test_separator_backends(test_file):
|
|||||||
assert np.allclose(out_tf[instrument], out_lib[instrument], atol=1e-5)
|
assert np.allclose(out_tf[instrument], out_lib[instrument], atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
|
@pytest.mark.parametrize(
|
||||||
|
'test_file, configuration, backend',
|
||||||
|
TEST_CONFIGURATIONS)
|
||||||
def test_separate(test_file, configuration, backend):
|
def test_separate(test_file, configuration, backend):
|
||||||
""" Test separation from raw data. """
|
""" Test separation from raw data. """
|
||||||
instruments = MODEL_TO_INST[configuration]
|
instruments = MODEL_TO_INST[configuration]
|
||||||
adapter = get_default_audio_adapter()
|
adapter = AudioAdapter.default()
|
||||||
waveform, _ = adapter.load(test_file)
|
waveform, _ = adapter.load(test_file)
|
||||||
separator = Separator(
|
separator = Separator(
|
||||||
configuration, stft_backend=backend, multiprocess=False)
|
configuration, stft_backend=backend, multiprocess=False)
|
||||||
@@ -85,7 +87,9 @@ def test_separate(test_file, configuration, backend):
|
|||||||
assert not np.allclose(track, prediction[compared])
|
assert not np.allclose(track, prediction[compared])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
|
@pytest.mark.parametrize(
|
||||||
|
'test_file, configuration, backend',
|
||||||
|
TEST_CONFIGURATIONS)
|
||||||
def test_separate_to_file(test_file, configuration, backend):
|
def test_separate_to_file(test_file, configuration, backend):
|
||||||
""" Test file based separation. """
|
""" Test file based separation. """
|
||||||
instruments = MODEL_TO_INST[configuration]
|
instruments = MODEL_TO_INST[configuration]
|
||||||
@@ -102,7 +106,9 @@ def test_separate_to_file(test_file, configuration, backend):
|
|||||||
'{}/{}.wav'.format(name, instrument)))
|
'{}/{}.wav'.format(name, instrument)))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
|
@pytest.mark.parametrize(
|
||||||
|
'test_file, configuration, backend',
|
||||||
|
TEST_CONFIGURATIONS)
|
||||||
def test_filename_format(test_file, configuration, backend):
|
def test_filename_format(test_file, configuration, backend):
|
||||||
""" Test custom filename format. """
|
""" Test custom filename format. """
|
||||||
instruments = MODEL_TO_INST[configuration]
|
instruments = MODEL_TO_INST[configuration]
|
||||||
@@ -120,7 +126,9 @@ def test_filename_format(test_file, configuration, backend):
|
|||||||
'export/{}/{}.wav'.format(name, instrument)))
|
'export/{}/{}.wav'.format(name, instrument)))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('test_file, configuration', MODELS_AND_TEST_FILES)
|
@pytest.mark.parametrize(
|
||||||
|
'test_file, configuration',
|
||||||
|
MODELS_AND_TEST_FILES)
|
||||||
def test_filename_conflict(test_file, configuration):
|
def test_filename_conflict(test_file, configuration):
|
||||||
""" Test error handling with static pattern. """
|
""" Test error handling with static pattern. """
|
||||||
separator = Separator(configuration, multiprocess=False)
|
separator = Separator(configuration, multiprocess=False)
|
||||||
|
|||||||
@@ -7,107 +7,102 @@ __email__ = 'research@deezer.com'
|
|||||||
__author__ = 'Deezer Research'
|
__author__ = 'Deezer Research'
|
||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
import filecmp
|
import json
|
||||||
import itertools
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from os import makedirs
|
from os import makedirs
|
||||||
from os.path import splitext, basename, exists, join
|
from os.path import join
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import json
|
|
||||||
|
|
||||||
import tensorflow as tf
|
from spleeter.audio.adapter import AudioAdapter
|
||||||
|
from spleeter.__main__ import spleeter
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
from spleeter.audio.adapter import get_default_audio_adapter
|
|
||||||
from spleeter.commands import create_argument_parser
|
|
||||||
|
|
||||||
from spleeter.commands import train
|
|
||||||
|
|
||||||
from spleeter.utils.configuration import load_configuration
|
|
||||||
|
|
||||||
TRAIN_CONFIG = {
|
TRAIN_CONFIG = {
|
||||||
"mix_name": "mix",
|
'mix_name': 'mix',
|
||||||
"instrument_list": ["vocals", "other"],
|
'instrument_list': ['vocals', 'other'],
|
||||||
"sample_rate":44100,
|
'sample_rate': 44100,
|
||||||
"frame_length":4096,
|
'frame_length': 4096,
|
||||||
"frame_step":1024,
|
'frame_step': 1024,
|
||||||
"T":128,
|
'T': 128,
|
||||||
"F":128,
|
'F': 128,
|
||||||
"n_channels":2,
|
'n_channels': 2,
|
||||||
"chunk_duration":4,
|
'chunk_duration': 4,
|
||||||
"n_chunks_per_song":1,
|
'n_chunks_per_song': 1,
|
||||||
"separation_exponent":2,
|
'separation_exponent': 2,
|
||||||
"mask_extension":"zeros",
|
'mask_extension': 'zeros',
|
||||||
"learning_rate": 1e-4,
|
'learning_rate': 1e-4,
|
||||||
"batch_size":2,
|
'batch_size': 2,
|
||||||
"train_max_steps": 10,
|
'train_max_steps': 10,
|
||||||
"throttle_secs":20,
|
'throttle_secs': 20,
|
||||||
"save_checkpoints_steps":100,
|
'save_checkpoints_steps': 100,
|
||||||
"save_summary_steps":5,
|
'save_summary_steps': 5,
|
||||||
"random_seed":0,
|
'random_seed': 0,
|
||||||
"model":{
|
'model': {
|
||||||
"type":"unet.unet",
|
'type': 'unet.unet',
|
||||||
"params":{
|
'params': {
|
||||||
"conv_activation":"ELU",
|
'conv_activation': 'ELU',
|
||||||
"deconv_activation":"ELU"
|
'deconv_activation': 'ELU'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def generate_fake_training_dataset(path, instrument_list=["vocals", "other"]):
|
def generate_fake_training_dataset(path, instrument_list=['vocals', 'other']):
|
||||||
"""
|
"""
|
||||||
generates a fake training dataset in path:
|
generates a fake training dataset in path:
|
||||||
- generates audio files
|
- generates audio files
|
||||||
- generates a csv file describing the dataset
|
- generates a csv file describing the dataset
|
||||||
"""
|
"""
|
||||||
aa = get_default_audio_adapter()
|
aa = AudioAdapter.default()
|
||||||
n_songs = 2
|
n_songs = 2
|
||||||
fs = 44100
|
fs = 44100
|
||||||
duration = 6
|
duration = 6
|
||||||
n_channels = 2
|
n_channels = 2
|
||||||
rng = np.random.RandomState(seed=0)
|
rng = np.random.RandomState(seed=0)
|
||||||
dataset_df = pd.DataFrame(columns=["mix_path"]+[f"{instr}_path" for instr in instrument_list]+["duration"])
|
dataset_df = pd.DataFrame(
|
||||||
|
columns=['mix_path'] + [
|
||||||
|
f'{instr}_path' for instr in instrument_list] + ['duration'])
|
||||||
for song in range(n_songs):
|
for song in range(n_songs):
|
||||||
song_path = join(path, "train", f"song{song}")
|
song_path = join(path, 'train', f'song{song}')
|
||||||
makedirs(song_path, exist_ok=True)
|
makedirs(song_path, exist_ok=True)
|
||||||
dataset_df.loc[song, f"duration"] = duration
|
dataset_df.loc[song, f'duration'] = duration
|
||||||
for instr in instrument_list+["mix"]:
|
for instr in instrument_list+['mix']:
|
||||||
filename = join(song_path, f"{instr}.wav")
|
filename = join(song_path, f'{instr}.wav')
|
||||||
data = rng.rand(duration*fs, n_channels)-0.5
|
data = rng.rand(duration*fs, n_channels)-0.5
|
||||||
aa.save(filename, data, fs)
|
aa.save(filename, data, fs)
|
||||||
dataset_df.loc[song, f"{instr}_path"] = join("train", f"song{song}", f"{instr}.wav")
|
dataset_df.loc[song, f'{instr}_path'] = join(
|
||||||
|
'train',
|
||||||
dataset_df.to_csv(join(path, "train", "train.csv"), index=False)
|
f'song{song}',
|
||||||
|
f'{instr}.wav')
|
||||||
|
dataset_df.to_csv(join(path, 'train', 'train.csv'), index=False)
|
||||||
|
|
||||||
|
|
||||||
def test_train():
|
def test_train():
|
||||||
|
|
||||||
|
|
||||||
with TemporaryDirectory() as path:
|
with TemporaryDirectory() as path:
|
||||||
|
|
||||||
# generate training dataset
|
# generate training dataset
|
||||||
generate_fake_training_dataset(path)
|
generate_fake_training_dataset(path)
|
||||||
|
|
||||||
# set training command aruments
|
# set training command aruments
|
||||||
p = create_argument_parser()
|
runner = CliRunner()
|
||||||
arguments = p.parse_args(["train", "-p", "useless_config.json", "-d", path])
|
TRAIN_CONFIG['train_csv'] = join(path, 'train', 'train.csv')
|
||||||
TRAIN_CONFIG["train_csv"] = join(path, "train", "train.csv")
|
TRAIN_CONFIG['validation_csv'] = join(path, 'train', 'train.csv')
|
||||||
TRAIN_CONFIG["validation_csv"] = join(path, "train", "train.csv")
|
TRAIN_CONFIG['model_dir'] = join(path, 'model')
|
||||||
TRAIN_CONFIG["model_dir"] = join(path, "model")
|
TRAIN_CONFIG['training_cache'] = join(path, 'cache', 'training')
|
||||||
TRAIN_CONFIG["training_cache"] = join(path, "cache", "training")
|
TRAIN_CONFIG['validation_cache'] = join(path, 'cache', 'validation')
|
||||||
TRAIN_CONFIG["validation_cache"] = join(path, "cache", "validation")
|
with open('useless_config.json', 'w') as stream:
|
||||||
|
json.dump(TRAIN_CONFIG, stream)
|
||||||
# execute training
|
# execute training
|
||||||
res = train.entrypoint(arguments, TRAIN_CONFIG)
|
result = runner.invoke(spleeter, [
|
||||||
|
'train',
|
||||||
|
'-p', 'useless_config.json',
|
||||||
|
'-d', path
|
||||||
|
])
|
||||||
# assert that model checkpoint was created.
|
# assert that model checkpoint was created.
|
||||||
assert os.path.exists(join(path,'model','model.ckpt-10.index'))
|
assert os.path.exists(join(path, 'model', 'model.ckpt-10.index'))
|
||||||
assert os.path.exists(join(path,'model','checkpoint'))
|
assert os.path.exists(join(path, 'model', 'checkpoint'))
|
||||||
assert os.path.exists(join(path,'model','model.ckpt-0.meta'))
|
assert os.path.exists(join(path, 'model', 'model.ckpt-0.meta'))
|
||||||
|
assert result.exit_code == 0
|
||||||
if __name__=="__main__":
|
|
||||||
test_train()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user