mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-06 04:32:43 +00:00
Initial commit from private spleeter
This commit is contained in:
112
.gitignore
vendored
Normal file
112
.gitignore
vendored
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# celery beat schedule file
|
||||||
|
celerybeat-schedule
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.vscode
|
||||||
|
.DS_Store
|
||||||
|
__pycache__
|
||||||
|
**/reporting
|
||||||
|
|
||||||
|
pretrained_models
|
||||||
|
docs/build
|
||||||
|
.vscode
|
||||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2019-present, Deezer SA.
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
3
MANIFEST.in
Normal file
3
MANIFEST.in
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
include src/resources/*.json
|
||||||
|
include README.md
|
||||||
|
include LICENSE
|
||||||
30
Makefile
Normal file
30
Makefile
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# =======================================================
|
||||||
|
# Build script for distribution packaging.
|
||||||
|
#
|
||||||
|
# @author Deezer Research <research@deezer.com>
|
||||||
|
# @licence MIT Licence
|
||||||
|
# =======================================================
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -Rf *.egg-info
|
||||||
|
rm -Rf dist
|
||||||
|
|
||||||
|
build:
|
||||||
|
@echo "=== Build CPU bdist package"
|
||||||
|
@python3 setup.py sdist
|
||||||
|
@echo "=== CPU version checksum"
|
||||||
|
@openssl sha256 dist/*.tar.gz
|
||||||
|
|
||||||
|
build-gpu:
|
||||||
|
@echo "=== Build GPU bdist package"
|
||||||
|
@python3 setup.py sdist --target gpu
|
||||||
|
@echo "=== GPU version checksum"
|
||||||
|
@openssl sha256 dist/*.tar.gz
|
||||||
|
|
||||||
|
upload:
|
||||||
|
twine upload dist/*
|
||||||
|
|
||||||
|
test-upload:
|
||||||
|
twine upload --repository-url https://test.pypi.org/legacy/ dist/*
|
||||||
|
|
||||||
|
all: clean build build-gpu upload
|
||||||
64
README.md
64
README.md
@@ -1,6 +1,64 @@
|
|||||||
# spleeter
|
<img src="https://github.com/deezer/spleeter/raw/master/images/spleeter_logo.png" height="80" />
|
||||||
|
|
||||||
<img src=images/spleeter_logo.png height=100>
|
[](https://badge.fury.io/py/spleeter) 
|
||||||
|
|
||||||
|
## About
|
||||||
|
|
||||||
spleeter will be made available soon!
|
**Spleeter** is the [Deezer](https://www.deezer.com/) source separation library with pretrained models
|
||||||
|
written in [Python](https://www.python.org/) and uses [Tensorflow](tensorflow.org/). It makes it easy
|
||||||
|
to train source separation model (assuming you have a dataset of isolated sources), and provides
|
||||||
|
already trained state of the art model for performing various flavour of separation :
|
||||||
|
|
||||||
|
* Vocals (singing voice) / accompaniment separation ([2 stems](https://github.com/deezer/spleeter/wiki/2.-Getting-started#using-2stems-model))
|
||||||
|
* Vocals / drums / bass / other separation ([4 stems](https://github.com/deezer/spleeter/wiki/2.-Getting-started#using-4stems-model))
|
||||||
|
* Vocals / drums / bass / piano / other separation ([5 stems](https://github.com/deezer/spleeter/wiki/2.-Getting-started#using-5stems-model))
|
||||||
|
|
||||||
|
2 stems and 4 stems models have state of the art performances on the
|
||||||
|
[musdb](https://sigsep.github.io/datasets/musdb.html) dataset. It is also very fast as
|
||||||
|
it can perform separation of audio files to 4 stems 100x faster than real-time when run on a *GPU*.
|
||||||
|
We designed it so you can use it straight from [command line](https://github.com/deezer/spleeter/wiki/2.-Getting-started#usage)
|
||||||
|
as well as directly in your own development pipeline as a
|
||||||
|
[Python library](https://github.com/deezer/spleeter/wiki/4.-API-Reference#separator)
|
||||||
|
|
||||||
|
**Spleeter** can be installed with [Conda](https://github.com/deezer/spleeter/wiki/1.-Installation#using-conda),
|
||||||
|
with [pip](https://github.com/deezer/spleeter/wiki/1.-Installation#using-pip) or be used with
|
||||||
|
[Docker](https://github.com/deezer/spleeter/wiki/2.-Getting-started#using-docker-image).
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
Want to try it out ? Just clone the repository and install a
|
||||||
|
[Conda](https://github.com/deezer/spleeter/wiki/1.-Installation#using-conda)
|
||||||
|
environment to start separating audio file as follows:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ git clone https://github.com/Deezer/spleeter
|
||||||
|
$ conda env create -f spleeter/conda/spleeter-cpu.yaml
|
||||||
|
$ conda activate spleeter-cpu
|
||||||
|
$ spleeter separate -i spleeter/audio_example.mp3 -p spleeter:2stems -o output
|
||||||
|
```
|
||||||
|
You should get two separated audio files (`vocals.wav` and `accompaniment.wav`)
|
||||||
|
in the `output/audio_example` folder.
|
||||||
|
|
||||||
|
For a more detailed documentation, please check the [repository wiki](https://github.com/deezer/spleeter/wiki)
|
||||||
|
|
||||||
|
## Reference
|
||||||
|
If you use **Spleeter** in your work, please cite:
|
||||||
|
|
||||||
|
```
|
||||||
|
@misc{spleeter2019,
|
||||||
|
title={Spleeter: A Fast And State-of-the Art Music Source Separation Tool With Pre-trained Models},
|
||||||
|
author={Romain Hennequin and Anis Khlif and Felix Voituret and Manuel Moussallam},
|
||||||
|
howpublished={Late-Breaking/Demo ISMIR 2019},
|
||||||
|
month={November},
|
||||||
|
year={2019}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
The code of **Spleeter** is MIT-licensed.
|
||||||
|
|
||||||
|
## Note
|
||||||
|
This repository include a demo audio file `audio_example.mp3` which is an excerpt
|
||||||
|
from Slow Motion Dream by Steven M Bryant (c) copyright 2011 Licensed under a Creative
|
||||||
|
Commons Attribution (3.0) license. http://dig.ccmixter.org/files/stevieb357/34740
|
||||||
|
Ft: CSoul,Alex Beroza & Robert Siekawitch
|
||||||
|
|||||||
BIN
audio_example.mp3
Normal file
BIN
audio_example.mp3
Normal file
Binary file not shown.
18
conda/spleeter-cpu.yaml
Normal file
18
conda/spleeter-cpu.yaml
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
name: spleeter-cpu
|
||||||
|
|
||||||
|
channels:
|
||||||
|
- conda-forge
|
||||||
|
- anaconda
|
||||||
|
|
||||||
|
dependencies:
|
||||||
|
- python=3.7
|
||||||
|
- tensorflow=1.14.0
|
||||||
|
- ffmpeg
|
||||||
|
- pandas==0.25.1
|
||||||
|
- requests
|
||||||
|
- pip
|
||||||
|
- pip:
|
||||||
|
- museval==0.3.0
|
||||||
|
- musdb==0.3.1
|
||||||
|
- norbert==0.2.1
|
||||||
|
- spleeter
|
||||||
19
conda/spleeter-gpu.yaml
Normal file
19
conda/spleeter-gpu.yaml
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
name: spleeter-gpu
|
||||||
|
|
||||||
|
channels:
|
||||||
|
- conda-forge
|
||||||
|
- anaconda
|
||||||
|
|
||||||
|
dependencies:
|
||||||
|
- python=3.7
|
||||||
|
- tensorflow-gpu=1.14.0
|
||||||
|
- ffmpeg
|
||||||
|
- pandas==0.25.1
|
||||||
|
- requests
|
||||||
|
- pip
|
||||||
|
- pip:
|
||||||
|
- museval==0.3.0
|
||||||
|
- musdb==0.3.1
|
||||||
|
- norbert==0.2.1
|
||||||
|
- spleeter
|
||||||
|
|
||||||
28
configs/2stems/base_config.json
Normal file
28
configs/2stems/base_config.json
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
{
|
||||||
|
"train_csv": "path/to/train.csv",
|
||||||
|
"validation_csv": "path/to/test.csv",
|
||||||
|
"model_dir": "2stems",
|
||||||
|
"mix_name": "mix",
|
||||||
|
"instrument_list": ["vocals", "accompaniment"],
|
||||||
|
"sample_rate":44100,
|
||||||
|
"frame_length":4096,
|
||||||
|
"frame_step":1024,
|
||||||
|
"T":512,
|
||||||
|
"F":1024,
|
||||||
|
"n_channels":2,
|
||||||
|
"separation_exponent":2,
|
||||||
|
"mask_extension":"zeros",
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"batch_size":4,
|
||||||
|
"training_cache":"training_cache",
|
||||||
|
"validation_cache":"validation_cache",
|
||||||
|
"train_max_steps": 1000000,
|
||||||
|
"throttle_secs":300,
|
||||||
|
"random_seed":0,
|
||||||
|
"save_checkpoints_steps":150,
|
||||||
|
"save_summary_steps":5,
|
||||||
|
"model":{
|
||||||
|
"type":"unet.unet",
|
||||||
|
"params":{}
|
||||||
|
}
|
||||||
|
}
|
||||||
31
configs/4stems/base_config.json
Normal file
31
configs/4stems/base_config.json
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
{
|
||||||
|
"train_csv": "path/to/train.csv",
|
||||||
|
"validation_csv": "path/to/test.csv",
|
||||||
|
"model_dir": "4stems",
|
||||||
|
"mix_name": "mix",
|
||||||
|
"instrument_list": ["vocals", "drums", "bass", "other"],
|
||||||
|
"sample_rate":44100,
|
||||||
|
"frame_length":4096,
|
||||||
|
"frame_step":1024,
|
||||||
|
"T":512,
|
||||||
|
"F":1024,
|
||||||
|
"n_channels":2,
|
||||||
|
"separation_exponent":2,
|
||||||
|
"mask_extension":"zeros",
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"batch_size":4,
|
||||||
|
"training_cache":"training_cache",
|
||||||
|
"validation_cache":"validation_cache",
|
||||||
|
"train_max_steps": 1500000,
|
||||||
|
"throttle_secs":600,
|
||||||
|
"random_seed":3,
|
||||||
|
"save_checkpoints_steps":300,
|
||||||
|
"save_summary_steps":5,
|
||||||
|
"model":{
|
||||||
|
"type":"unet.unet",
|
||||||
|
"params":{
|
||||||
|
"conv_activation":"ELU",
|
||||||
|
"deconv_activation":"ELU"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
31
configs/5stems/base_config.json
Normal file
31
configs/5stems/base_config.json
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
{
|
||||||
|
"train_csv": "path/to/train.csv",
|
||||||
|
"validation_csv": "path/to/test.csv",
|
||||||
|
"model_dir": "5stems",
|
||||||
|
"mix_name": "mix",
|
||||||
|
"instrument_list": ["vocals", "piano", "drums", "bass", "other"],
|
||||||
|
"sample_rate":44100,
|
||||||
|
"frame_length":4096,
|
||||||
|
"frame_step":1024,
|
||||||
|
"T":512,
|
||||||
|
"F":1024,
|
||||||
|
"n_channels":2,
|
||||||
|
"separation_exponent":2,
|
||||||
|
"mask_extension":"zeros",
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"batch_size":4,
|
||||||
|
"training_cache":"training_cache",
|
||||||
|
"validation_cache":"validation_cache",
|
||||||
|
"train_max_steps": 2500000,
|
||||||
|
"throttle_secs":600,
|
||||||
|
"random_seed":8,
|
||||||
|
"save_checkpoints_steps":300,
|
||||||
|
"save_summary_steps":5,
|
||||||
|
"model":{
|
||||||
|
"type":"unet.softmax_unet",
|
||||||
|
"params":{
|
||||||
|
"conv_activation":"ELU",
|
||||||
|
"deconv_activation":"ELU"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
32
configs/musdb_config.json
Normal file
32
configs/musdb_config.json
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
{
|
||||||
|
"train_csv": "configs/musdb_train.csv",
|
||||||
|
"validation_csv": "configs/musdb_validation.csv",
|
||||||
|
"model_dir": "musdb_model",
|
||||||
|
"mix_name": "mix",
|
||||||
|
"instrument_list": ["vocals", "drums", "bass", "other"],
|
||||||
|
"sample_rate":44100,
|
||||||
|
"frame_length":4096,
|
||||||
|
"frame_step":1024,
|
||||||
|
"T":512,
|
||||||
|
"F":1024,
|
||||||
|
"n_channels":2,
|
||||||
|
"n_chunks_per_song":1,
|
||||||
|
"separation_exponent":2,
|
||||||
|
"mask_extension":"zeros",
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"batch_size":4,
|
||||||
|
"training_cache":"cache/training",
|
||||||
|
"validation_cache":"cache/validation",
|
||||||
|
"train_max_steps": 100000,
|
||||||
|
"throttle_secs":600,
|
||||||
|
"random_seed":3,
|
||||||
|
"save_checkpoints_steps":300,
|
||||||
|
"save_summary_steps":5,
|
||||||
|
"model":{
|
||||||
|
"type":"unet.unet",
|
||||||
|
"params":{
|
||||||
|
"conv_activation":"ELU",
|
||||||
|
"deconv_activation":"ELU"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
87
configs/musdb_train.csv
Normal file
87
configs/musdb_train.csv
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
mix_path,vocals_path,drums_path,bass_path,other_path,duration
|
||||||
|
train/A Classic Education - NightOwl/mixture.wav,train/A Classic Education - NightOwl/vocals.wav,train/A Classic Education - NightOwl/drums.wav,train/A Classic Education - NightOwl/bass.wav,train/A Classic Education - NightOwl/other.wav,171.247166
|
||||||
|
train/ANiMAL - Clinic A/mixture.wav,train/ANiMAL - Clinic A/vocals.wav,train/ANiMAL - Clinic A/drums.wav,train/ANiMAL - Clinic A/bass.wav,train/ANiMAL - Clinic A/other.wav,237.865215
|
||||||
|
train/ANiMAL - Easy Tiger/mixture.wav,train/ANiMAL - Easy Tiger/vocals.wav,train/ANiMAL - Easy Tiger/drums.wav,train/ANiMAL - Easy Tiger/bass.wav,train/ANiMAL - Easy Tiger/other.wav,205.473379
|
||||||
|
train/Actions - Devil's Words/mixture.wav,train/Actions - Devil's Words/vocals.wav,train/Actions - Devil's Words/drums.wav,train/Actions - Devil's Words/bass.wav,train/Actions - Devil's Words/other.wav,196.626576
|
||||||
|
train/Actions - South Of The Water/mixture.wav,train/Actions - South Of The Water/vocals.wav,train/Actions - South Of The Water/drums.wav,train/Actions - South Of The Water/bass.wav,train/Actions - South Of The Water/other.wav,176.610975
|
||||||
|
train/Aimee Norwich - Child/mixture.wav,train/Aimee Norwich - Child/vocals.wav,train/Aimee Norwich - Child/drums.wav,train/Aimee Norwich - Child/bass.wav,train/Aimee Norwich - Child/other.wav,189.080091
|
||||||
|
train/Alexander Ross - Velvet Curtain/mixture.wav,train/Alexander Ross - Velvet Curtain/vocals.wav,train/Alexander Ross - Velvet Curtain/drums.wav,train/Alexander Ross - Velvet Curtain/bass.wav,train/Alexander Ross - Velvet Curtain/other.wav,514.298776
|
||||||
|
train/Angela Thomas Wade - Milk Cow Blues/mixture.wav,train/Angela Thomas Wade - Milk Cow Blues/vocals.wav,train/Angela Thomas Wade - Milk Cow Blues/drums.wav,train/Angela Thomas Wade - Milk Cow Blues/bass.wav,train/Angela Thomas Wade - Milk Cow Blues/other.wav,210.906848
|
||||||
|
train/Atlantis Bound - It Was My Fault For Waiting/mixture.wav,train/Atlantis Bound - It Was My Fault For Waiting/vocals.wav,train/Atlantis Bound - It Was My Fault For Waiting/drums.wav,train/Atlantis Bound - It Was My Fault For Waiting/bass.wav,train/Atlantis Bound - It Was My Fault For Waiting/other.wav,268.051156
|
||||||
|
train/Auctioneer - Our Future Faces/mixture.wav,train/Auctioneer - Our Future Faces/vocals.wav,train/Auctioneer - Our Future Faces/drums.wav,train/Auctioneer - Our Future Faces/bass.wav,train/Auctioneer - Our Future Faces/other.wav,207.702494
|
||||||
|
train/AvaLuna - Waterduct/mixture.wav,train/AvaLuna - Waterduct/vocals.wav,train/AvaLuna - Waterduct/drums.wav,train/AvaLuna - Waterduct/bass.wav,train/AvaLuna - Waterduct/other.wav,259.111474
|
||||||
|
train/BigTroubles - Phantom/mixture.wav,train/BigTroubles - Phantom/vocals.wav,train/BigTroubles - Phantom/drums.wav,train/BigTroubles - Phantom/bass.wav,train/BigTroubles - Phantom/other.wav,146.750113
|
||||||
|
train/Bill Chudziak - Children Of No-one/mixture.wav,train/Bill Chudziak - Children Of No-one/vocals.wav,train/Bill Chudziak - Children Of No-one/drums.wav,train/Bill Chudziak - Children Of No-one/bass.wav,train/Bill Chudziak - Children Of No-one/other.wav,230.736689
|
||||||
|
train/Black Bloc - If You Want Success/mixture.wav,train/Black Bloc - If You Want Success/vocals.wav,train/Black Bloc - If You Want Success/drums.wav,train/Black Bloc - If You Want Success/bass.wav,train/Black Bloc - If You Want Success/other.wav,398.547302
|
||||||
|
train/Celestial Shore - Die For Us/mixture.wav,train/Celestial Shore - Die For Us/vocals.wav,train/Celestial Shore - Die For Us/drums.wav,train/Celestial Shore - Die For Us/bass.wav,train/Celestial Shore - Die For Us/other.wav,278.476916
|
||||||
|
train/Chris Durban - Celebrate/mixture.wav,train/Chris Durban - Celebrate/vocals.wav,train/Chris Durban - Celebrate/drums.wav,train/Chris Durban - Celebrate/bass.wav,train/Chris Durban - Celebrate/other.wav,301.603991
|
||||||
|
train/Clara Berry And Wooldog - Air Traffic/mixture.wav,train/Clara Berry And Wooldog - Air Traffic/vocals.wav,train/Clara Berry And Wooldog - Air Traffic/drums.wav,train/Clara Berry And Wooldog - Air Traffic/bass.wav,train/Clara Berry And Wooldog - Air Traffic/other.wav,173.267302
|
||||||
|
train/Clara Berry And Wooldog - Stella/mixture.wav,train/Clara Berry And Wooldog - Stella/vocals.wav,train/Clara Berry And Wooldog - Stella/drums.wav,train/Clara Berry And Wooldog - Stella/bass.wav,train/Clara Berry And Wooldog - Stella/other.wav,195.558458
|
||||||
|
train/Cnoc An Tursa - Bannockburn/mixture.wav,train/Cnoc An Tursa - Bannockburn/vocals.wav,train/Cnoc An Tursa - Bannockburn/drums.wav,train/Cnoc An Tursa - Bannockburn/bass.wav,train/Cnoc An Tursa - Bannockburn/other.wav,294.521905
|
||||||
|
train/Creepoid - OldTree/mixture.wav,train/Creepoid - OldTree/vocals.wav,train/Creepoid - OldTree/drums.wav,train/Creepoid - OldTree/bass.wav,train/Creepoid - OldTree/other.wav,302.02195
|
||||||
|
train/Dark Ride - Burning Bridges/mixture.wav,train/Dark Ride - Burning Bridges/vocals.wav,train/Dark Ride - Burning Bridges/drums.wav,train/Dark Ride - Burning Bridges/bass.wav,train/Dark Ride - Burning Bridges/other.wav,232.663946
|
||||||
|
train/Dreamers Of The Ghetto - Heavy Love/mixture.wav,train/Dreamers Of The Ghetto - Heavy Love/vocals.wav,train/Dreamers Of The Ghetto - Heavy Love/drums.wav,train/Dreamers Of The Ghetto - Heavy Love/bass.wav,train/Dreamers Of The Ghetto - Heavy Love/other.wav,294.800544
|
||||||
|
train/Drumtracks - Ghost Bitch/mixture.wav,train/Drumtracks - Ghost Bitch/vocals.wav,train/Drumtracks - Ghost Bitch/drums.wav,train/Drumtracks - Ghost Bitch/bass.wav,train/Drumtracks - Ghost Bitch/other.wav,356.913923
|
||||||
|
train/Faces On Film - Waiting For Ga/mixture.wav,train/Faces On Film - Waiting For Ga/vocals.wav,train/Faces On Film - Waiting For Ga/drums.wav,train/Faces On Film - Waiting For Ga/bass.wav,train/Faces On Film - Waiting For Ga/other.wav,257.439637
|
||||||
|
train/Fergessen - Back From The Start/mixture.wav,train/Fergessen - Back From The Start/vocals.wav,train/Fergessen - Back From The Start/drums.wav,train/Fergessen - Back From The Start/bass.wav,train/Fergessen - Back From The Start/other.wav,168.553651
|
||||||
|
train/Fergessen - The Wind/mixture.wav,train/Fergessen - The Wind/vocals.wav,train/Fergessen - The Wind/drums.wav,train/Fergessen - The Wind/bass.wav,train/Fergessen - The Wind/other.wav,191.820045
|
||||||
|
train/Flags - 54/mixture.wav,train/Flags - 54/vocals.wav,train/Flags - 54/drums.wav,train/Flags - 54/bass.wav,train/Flags - 54/other.wav,315.164444
|
||||||
|
train/Giselle - Moss/mixture.wav,train/Giselle - Moss/vocals.wav,train/Giselle - Moss/drums.wav,train/Giselle - Moss/bass.wav,train/Giselle - Moss/other.wav,201.711746
|
||||||
|
train/Grants - PunchDrunk/mixture.wav,train/Grants - PunchDrunk/vocals.wav,train/Grants - PunchDrunk/drums.wav,train/Grants - PunchDrunk/bass.wav,train/Grants - PunchDrunk/other.wav,204.405261
|
||||||
|
train/Helado Negro - Mitad Del Mundo/mixture.wav,train/Helado Negro - Mitad Del Mundo/vocals.wav,train/Helado Negro - Mitad Del Mundo/drums.wav,train/Helado Negro - Mitad Del Mundo/bass.wav,train/Helado Negro - Mitad Del Mundo/other.wav,181.672925
|
||||||
|
train/Hezekiah Jones - Borrowed Heart/mixture.wav,train/Hezekiah Jones - Borrowed Heart/vocals.wav,train/Hezekiah Jones - Borrowed Heart/drums.wav,train/Hezekiah Jones - Borrowed Heart/bass.wav,train/Hezekiah Jones - Borrowed Heart/other.wav,241.394649
|
||||||
|
train/Hollow Ground - Left Blind/mixture.wav,train/Hollow Ground - Left Blind/vocals.wav,train/Hollow Ground - Left Blind/drums.wav,train/Hollow Ground - Left Blind/bass.wav,train/Hollow Ground - Left Blind/other.wav,159.103129
|
||||||
|
train/Hop Along - Sister Cities/mixture.wav,train/Hop Along - Sister Cities/vocals.wav,train/Hop Along - Sister Cities/drums.wav,train/Hop Along - Sister Cities/bass.wav,train/Hop Along - Sister Cities/other.wav,283.237007
|
||||||
|
train/Invisible Familiars - Disturbing Wildlife/mixture.wav,train/Invisible Familiars - Disturbing Wildlife/vocals.wav,train/Invisible Familiars - Disturbing Wildlife/drums.wav,train/Invisible Familiars - Disturbing Wildlife/bass.wav,train/Invisible Familiars - Disturbing Wildlife/other.wav,218.499773
|
||||||
|
train/James May - All Souls Moon/mixture.wav,train/James May - All Souls Moon/vocals.wav,train/James May - All Souls Moon/drums.wav,train/James May - All Souls Moon/bass.wav,train/James May - All Souls Moon/other.wav,220.844989
|
||||||
|
train/James May - Dont Let Go/mixture.wav,train/James May - Dont Let Go/vocals.wav,train/James May - Dont Let Go/drums.wav,train/James May - Dont Let Go/bass.wav,train/James May - Dont Let Go/other.wav,241.951927
|
||||||
|
train/James May - If You Say/mixture.wav,train/James May - If You Say/vocals.wav,train/James May - If You Say/drums.wav,train/James May - If You Say/bass.wav,train/James May - If You Say/other.wav,258.321995
|
||||||
|
train/Jay Menon - Through My Eyes/mixture.wav,train/Jay Menon - Through My Eyes/vocals.wav,train/Jay Menon - Through My Eyes/drums.wav,train/Jay Menon - Through My Eyes/bass.wav,train/Jay Menon - Through My Eyes/other.wav,253.167166
|
||||||
|
train/Johnny Lokke - Whisper To A Scream/mixture.wav,train/Johnny Lokke - Whisper To A Scream/vocals.wav,train/Johnny Lokke - Whisper To A Scream/drums.wav,train/Johnny Lokke - Whisper To A Scream/bass.wav,train/Johnny Lokke - Whisper To A Scream/other.wav,255.326621
|
||||||
|
"train/Jokers, Jacks & Kings - Sea Of Leaves/mixture.wav","train/Jokers, Jacks & Kings - Sea Of Leaves/vocals.wav","train/Jokers, Jacks & Kings - Sea Of Leaves/drums.wav","train/Jokers, Jacks & Kings - Sea Of Leaves/bass.wav","train/Jokers, Jacks & Kings - Sea Of Leaves/other.wav",191.471746
|
||||||
|
train/Leaf - Come Around/mixture.wav,train/Leaf - Come Around/vocals.wav,train/Leaf - Come Around/drums.wav,train/Leaf - Come Around/bass.wav,train/Leaf - Come Around/other.wav,264.382404
|
||||||
|
train/Leaf - Wicked/mixture.wav,train/Leaf - Wicked/vocals.wav,train/Leaf - Wicked/drums.wav,train/Leaf - Wicked/bass.wav,train/Leaf - Wicked/other.wav,190.635828
|
||||||
|
train/Lushlife - Toynbee Suite/mixture.wav,train/Lushlife - Toynbee Suite/vocals.wav,train/Lushlife - Toynbee Suite/drums.wav,train/Lushlife - Toynbee Suite/bass.wav,train/Lushlife - Toynbee Suite/other.wav,628.378413
|
||||||
|
train/Matthew Entwistle - Dont You Ever/mixture.wav,train/Matthew Entwistle - Dont You Ever/vocals.wav,train/Matthew Entwistle - Dont You Ever/drums.wav,train/Matthew Entwistle - Dont You Ever/bass.wav,train/Matthew Entwistle - Dont You Ever/other.wav,113.824218
|
||||||
|
train/Meaxic - You Listen/mixture.wav,train/Meaxic - You Listen/vocals.wav,train/Meaxic - You Listen/drums.wav,train/Meaxic - You Listen/bass.wav,train/Meaxic - You Listen/other.wav,412.525714
|
||||||
|
train/Music Delta - 80s Rock/mixture.wav,train/Music Delta - 80s Rock/vocals.wav,train/Music Delta - 80s Rock/drums.wav,train/Music Delta - 80s Rock/bass.wav,train/Music Delta - 80s Rock/other.wav,36.733968
|
||||||
|
train/Music Delta - Beatles/mixture.wav,train/Music Delta - Beatles/vocals.wav,train/Music Delta - Beatles/drums.wav,train/Music Delta - Beatles/bass.wav,train/Music Delta - Beatles/other.wav,36.176689
|
||||||
|
train/Music Delta - Britpop/mixture.wav,train/Music Delta - Britpop/vocals.wav,train/Music Delta - Britpop/drums.wav,train/Music Delta - Britpop/bass.wav,train/Music Delta - Britpop/other.wav,36.594649
|
||||||
|
train/Music Delta - Country1/mixture.wav,train/Music Delta - Country1/vocals.wav,train/Music Delta - Country1/drums.wav,train/Music Delta - Country1/bass.wav,train/Music Delta - Country1/other.wav,34.551293
|
||||||
|
train/Music Delta - Country2/mixture.wav,train/Music Delta - Country2/vocals.wav,train/Music Delta - Country2/drums.wav,train/Music Delta - Country2/bass.wav,train/Music Delta - Country2/other.wav,17.275646
|
||||||
|
train/Music Delta - Disco/mixture.wav,train/Music Delta - Disco/vocals.wav,train/Music Delta - Disco/drums.wav,train/Music Delta - Disco/bass.wav,train/Music Delta - Disco/other.wav,124.598277
|
||||||
|
train/Music Delta - Gospel/mixture.wav,train/Music Delta - Gospel/vocals.wav,train/Music Delta - Gospel/drums.wav,train/Music Delta - Gospel/bass.wav,train/Music Delta - Gospel/other.wav,75.557732
|
||||||
|
train/Music Delta - Grunge/mixture.wav,train/Music Delta - Grunge/vocals.wav,train/Music Delta - Grunge/drums.wav,train/Music Delta - Grunge/bass.wav,train/Music Delta - Grunge/other.wav,41.656599
|
||||||
|
train/Music Delta - Hendrix/mixture.wav,train/Music Delta - Hendrix/vocals.wav,train/Music Delta - Hendrix/drums.wav,train/Music Delta - Hendrix/bass.wav,train/Music Delta - Hendrix/other.wav,19.644082
|
||||||
|
train/Music Delta - Punk/mixture.wav,train/Music Delta - Punk/vocals.wav,train/Music Delta - Punk/drums.wav,train/Music Delta - Punk/bass.wav,train/Music Delta - Punk/other.wav,28.583764
|
||||||
|
train/Music Delta - Reggae/mixture.wav,train/Music Delta - Reggae/vocals.wav,train/Music Delta - Reggae/drums.wav,train/Music Delta - Reggae/bass.wav,train/Music Delta - Reggae/other.wav,17.275646
|
||||||
|
train/Music Delta - Rock/mixture.wav,train/Music Delta - Rock/vocals.wav,train/Music Delta - Rock/drums.wav,train/Music Delta - Rock/bass.wav,train/Music Delta - Rock/other.wav,12.910295
|
||||||
|
train/Music Delta - Rockabilly/mixture.wav,train/Music Delta - Rockabilly/vocals.wav,train/Music Delta - Rockabilly/drums.wav,train/Music Delta - Rockabilly/bass.wav,train/Music Delta - Rockabilly/other.wav,25.75093
|
||||||
|
train/Night Panther - Fire/mixture.wav,train/Night Panther - Fire/vocals.wav,train/Night Panther - Fire/drums.wav,train/Night Panther - Fire/bass.wav,train/Night Panther - Fire/other.wav,212.810884
|
||||||
|
train/North To Alaska - All The Same/mixture.wav,train/North To Alaska - All The Same/vocals.wav,train/North To Alaska - All The Same/drums.wav,train/North To Alaska - All The Same/bass.wav,train/North To Alaska - All The Same/other.wav,247.965896
|
||||||
|
train/Patrick Talbot - Set Me Free/mixture.wav,train/Patrick Talbot - Set Me Free/vocals.wav,train/Patrick Talbot - Set Me Free/drums.wav,train/Patrick Talbot - Set Me Free/bass.wav,train/Patrick Talbot - Set Me Free/other.wav,289.785034
|
||||||
|
train/Phre The Eon - Everybody's Falling Apart/mixture.wav,train/Phre The Eon - Everybody's Falling Apart/vocals.wav,train/Phre The Eon - Everybody's Falling Apart/drums.wav,train/Phre The Eon - Everybody's Falling Apart/bass.wav,train/Phre The Eon - Everybody's Falling Apart/other.wav,224.235102
|
||||||
|
train/Port St Willow - Stay Even/mixture.wav,train/Port St Willow - Stay Even/vocals.wav,train/Port St Willow - Stay Even/drums.wav,train/Port St Willow - Stay Even/bass.wav,train/Port St Willow - Stay Even/other.wav,316.836281
|
||||||
|
train/Remember December - C U Next Time/mixture.wav,train/Remember December - C U Next Time/vocals.wav,train/Remember December - C U Next Time/drums.wav,train/Remember December - C U Next Time/bass.wav,train/Remember December - C U Next Time/other.wav,242.532426
|
||||||
|
train/Secret Mountains - High Horse/mixture.wav,train/Secret Mountains - High Horse/vocals.wav,train/Secret Mountains - High Horse/drums.wav,train/Secret Mountains - High Horse/bass.wav,train/Secret Mountains - High Horse/other.wav,355.311746
|
||||||
|
train/Skelpolu - Together Alone/mixture.wav,train/Skelpolu - Together Alone/vocals.wav,train/Skelpolu - Together Alone/drums.wav,train/Skelpolu - Together Alone/bass.wav,train/Skelpolu - Together Alone/other.wav,325.822404
|
||||||
|
train/Snowmine - Curfews/mixture.wav,train/Snowmine - Curfews/vocals.wav,train/Snowmine - Curfews/drums.wav,train/Snowmine - Curfews/bass.wav,train/Snowmine - Curfews/other.wav,275.017143
|
||||||
|
train/Spike Mullings - Mike's Sulking/mixture.wav,train/Spike Mullings - Mike's Sulking/vocals.wav,train/Spike Mullings - Mike's Sulking/drums.wav,train/Spike Mullings - Mike's Sulking/bass.wav,train/Spike Mullings - Mike's Sulking/other.wav,256.696599
|
||||||
|
train/St Vitus - Word Gets Around/mixture.wav,train/St Vitus - Word Gets Around/vocals.wav,train/St Vitus - Word Gets Around/drums.wav,train/St Vitus - Word Gets Around/bass.wav,train/St Vitus - Word Gets Around/other.wav,247.013878
|
||||||
|
train/Steven Clark - Bounty/mixture.wav,train/Steven Clark - Bounty/vocals.wav,train/Steven Clark - Bounty/drums.wav,train/Steven Clark - Bounty/bass.wav,train/Steven Clark - Bounty/other.wav,289.274195
|
||||||
|
train/Strand Of Oaks - Spacestation/mixture.wav,train/Strand Of Oaks - Spacestation/vocals.wav,train/Strand Of Oaks - Spacestation/drums.wav,train/Strand Of Oaks - Spacestation/bass.wav,train/Strand Of Oaks - Spacestation/other.wav,243.670204
|
||||||
|
train/Sweet Lights - You Let Me Down/mixture.wav,train/Sweet Lights - You Let Me Down/vocals.wav,train/Sweet Lights - You Let Me Down/drums.wav,train/Sweet Lights - You Let Me Down/bass.wav,train/Sweet Lights - You Let Me Down/other.wav,391.790295
|
||||||
|
train/Swinging Steaks - Lost My Way/mixture.wav,train/Swinging Steaks - Lost My Way/vocals.wav,train/Swinging Steaks - Lost My Way/drums.wav,train/Swinging Steaks - Lost My Way/bass.wav,train/Swinging Steaks - Lost My Way/other.wav,309.963175
|
||||||
|
train/The Districts - Vermont/mixture.wav,train/The Districts - Vermont/vocals.wav,train/The Districts - Vermont/drums.wav,train/The Districts - Vermont/bass.wav,train/The Districts - Vermont/other.wav,227.973515
|
||||||
|
train/The Long Wait - Back Home To Blue/mixture.wav,train/The Long Wait - Back Home To Blue/vocals.wav,train/The Long Wait - Back Home To Blue/drums.wav,train/The Long Wait - Back Home To Blue/bass.wav,train/The Long Wait - Back Home To Blue/other.wav,260.458231
|
||||||
|
train/The Scarlet Brand - Les Fleurs Du Mal/mixture.wav,train/The Scarlet Brand - Les Fleurs Du Mal/vocals.wav,train/The Scarlet Brand - Les Fleurs Du Mal/drums.wav,train/The Scarlet Brand - Les Fleurs Du Mal/bass.wav,train/The Scarlet Brand - Les Fleurs Du Mal/other.wav,303.438367
|
||||||
|
train/The So So Glos - Emergency/mixture.wav,train/The So So Glos - Emergency/vocals.wav,train/The So So Glos - Emergency/drums.wav,train/The So So Glos - Emergency/bass.wav,train/The So So Glos - Emergency/other.wav,166.812154
|
||||||
|
train/The Wrong'Uns - Rothko/mixture.wav,train/The Wrong'Uns - Rothko/vocals.wav,train/The Wrong'Uns - Rothko/drums.wav,train/The Wrong'Uns - Rothko/bass.wav,train/The Wrong'Uns - Rothko/other.wav,202.152925
|
||||||
|
train/Tim Taler - Stalker/mixture.wav,train/Tim Taler - Stalker/vocals.wav,train/Tim Taler - Stalker/drums.wav,train/Tim Taler - Stalker/bass.wav,train/Tim Taler - Stalker/other.wav,237.633016
|
||||||
|
train/Titanium - Haunted Age/mixture.wav,train/Titanium - Haunted Age/vocals.wav,train/Titanium - Haunted Age/drums.wav,train/Titanium - Haunted Age/bass.wav,train/Titanium - Haunted Age/other.wav,248.105215
|
||||||
|
train/Traffic Experiment - Once More (With Feeling)/mixture.wav,train/Traffic Experiment - Once More (With Feeling)/vocals.wav,train/Traffic Experiment - Once More (With Feeling)/drums.wav,train/Traffic Experiment - Once More (With Feeling)/bass.wav,train/Traffic Experiment - Once More (With Feeling)/other.wav,435.07229
|
||||||
|
train/Triviul - Dorothy/mixture.wav,train/Triviul - Dorothy/vocals.wav,train/Triviul - Dorothy/drums.wav,train/Triviul - Dorothy/bass.wav,train/Triviul - Dorothy/other.wav,187.361814
|
||||||
|
train/Voelund - Comfort Lives In Belief/mixture.wav,train/Voelund - Comfort Lives In Belief/vocals.wav,train/Voelund - Comfort Lives In Belief/drums.wav,train/Voelund - Comfort Lives In Belief/bass.wav,train/Voelund - Comfort Lives In Belief/other.wav,209.90839
|
||||||
|
train/Wall Of Death - Femme/mixture.wav,train/Wall Of Death - Femme/vocals.wav,train/Wall Of Death - Femme/drums.wav,train/Wall Of Death - Femme/bass.wav,train/Wall Of Death - Femme/other.wav,238.933333
|
||||||
|
train/Young Griffo - Blood To Bone/mixture.wav,train/Young Griffo - Blood To Bone/vocals.wav,train/Young Griffo - Blood To Bone/drums.wav,train/Young Griffo - Blood To Bone/bass.wav,train/Young Griffo - Blood To Bone/other.wav,254.397823
|
||||||
|
train/Young Griffo - Facade/mixture.wav,train/Young Griffo - Facade/vocals.wav,train/Young Griffo - Facade/drums.wav,train/Young Griffo - Facade/bass.wav,train/Young Griffo - Facade/other.wav,167.857052
|
||||||
|
15
configs/musdb_validation.csv
Normal file
15
configs/musdb_validation.csv
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
mix_path,vocals_path,drums_path,bass_path,other_path,duration
|
||||||
|
train/ANiMAL - Rockshow/mixture.wav,train/ANiMAL - Rockshow/vocals.wav,train/ANiMAL - Rockshow/drums.wav,train/ANiMAL - Rockshow/bass.wav,train/ANiMAL - Rockshow/other.wav,165.511837
|
||||||
|
train/Actions - One Minute Smile/mixture.wav,train/Actions - One Minute Smile/vocals.wav,train/Actions - One Minute Smile/drums.wav,train/Actions - One Minute Smile/bass.wav,train/Actions - One Minute Smile/other.wav,163.375601
|
||||||
|
train/Alexander Ross - Goodbye Bolero/mixture.wav,train/Alexander Ross - Goodbye Bolero/vocals.wav,train/Alexander Ross - Goodbye Bolero/drums.wav,train/Alexander Ross - Goodbye Bolero/bass.wav,train/Alexander Ross - Goodbye Bolero/other.wav,418.632562
|
||||||
|
train/Clara Berry And Wooldog - Waltz For My Victims/mixture.wav,train/Clara Berry And Wooldog - Waltz For My Victims/vocals.wav,train/Clara Berry And Wooldog - Waltz For My Victims/drums.wav,train/Clara Berry And Wooldog - Waltz For My Victims/bass.wav,train/Clara Berry And Wooldog - Waltz For My Victims/other.wav,175.240998
|
||||||
|
train/Fergessen - Nos Palpitants/mixture.wav,train/Fergessen - Nos Palpitants/vocals.wav,train/Fergessen - Nos Palpitants/drums.wav,train/Fergessen - Nos Palpitants/bass.wav,train/Fergessen - Nos Palpitants/other.wav,198.228753
|
||||||
|
train/James May - On The Line/mixture.wav,train/James May - On The Line/vocals.wav,train/James May - On The Line/drums.wav,train/James May - On The Line/bass.wav,train/James May - On The Line/other.wav,256.09288
|
||||||
|
train/Johnny Lokke - Promises & Lies/mixture.wav,train/Johnny Lokke - Promises & Lies/vocals.wav,train/Johnny Lokke - Promises & Lies/drums.wav,train/Johnny Lokke - Promises & Lies/bass.wav,train/Johnny Lokke - Promises & Lies/other.wav,285.814422
|
||||||
|
train/Leaf - Summerghost/mixture.wav,train/Leaf - Summerghost/vocals.wav,train/Leaf - Summerghost/drums.wav,train/Leaf - Summerghost/bass.wav,train/Leaf - Summerghost/other.wav,231.804807
|
||||||
|
train/Meaxic - Take A Step/mixture.wav,train/Meaxic - Take A Step/vocals.wav,train/Meaxic - Take A Step/drums.wav,train/Meaxic - Take A Step/bass.wav,train/Meaxic - Take A Step/other.wav,282.517188
|
||||||
|
train/Patrick Talbot - A Reason To Leave/mixture.wav,train/Patrick Talbot - A Reason To Leave/vocals.wav,train/Patrick Talbot - A Reason To Leave/drums.wav,train/Patrick Talbot - A Reason To Leave/bass.wav,train/Patrick Talbot - A Reason To Leave/other.wav,259.552653
|
||||||
|
train/Skelpolu - Human Mistakes/mixture.wav,train/Skelpolu - Human Mistakes/vocals.wav,train/Skelpolu - Human Mistakes/drums.wav,train/Skelpolu - Human Mistakes/bass.wav,train/Skelpolu - Human Mistakes/other.wav,324.498866
|
||||||
|
train/Traffic Experiment - Sirens/mixture.wav,train/Traffic Experiment - Sirens/vocals.wav,train/Traffic Experiment - Sirens/drums.wav,train/Traffic Experiment - Sirens/bass.wav,train/Traffic Experiment - Sirens/other.wav,421.279637
|
||||||
|
train/Triviul - Angelsaint/mixture.wav,train/Triviul - Angelsaint/vocals.wav,train/Triviul - Angelsaint/drums.wav,train/Triviul - Angelsaint/bass.wav,train/Triviul - Angelsaint/other.wav,236.704218
|
||||||
|
train/Young Griffo - Pennies/mixture.wav,train/Young Griffo - Pennies/vocals.wav,train/Young Griffo - Pennies/drums.wav,train/Young Griffo - Pennies/bass.wav,train/Young Griffo - Pennies/other.wav,277.803537
|
||||||
|
24
docker/cpu.Dockerfile
Normal file
24
docker/cpu.Dockerfile
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
FROM continuumio/miniconda3:4.7.10
|
||||||
|
|
||||||
|
# install tensorflow
|
||||||
|
RUN conda install -y tensorflow==1.14.0
|
||||||
|
|
||||||
|
# install ffmpeg for audio loading/writing
|
||||||
|
RUN conda install -y -c conda-forge ffmpeg
|
||||||
|
|
||||||
|
# install extra python libraries
|
||||||
|
RUN conda install -y -c anaconda pandas==0.25.1
|
||||||
|
RUN conda install -y -c conda-forge libsndfile
|
||||||
|
|
||||||
|
# install ipython
|
||||||
|
RUN conda install -y ipython
|
||||||
|
|
||||||
|
WORKDIR /workspace/
|
||||||
|
COPY ./ spleeter/
|
||||||
|
|
||||||
|
RUN mkdir /cache/
|
||||||
|
|
||||||
|
WORKDIR /workspace/spleeter
|
||||||
|
RUN pip install .
|
||||||
|
|
||||||
|
ENTRYPOINT ["python", "-m", "spleeter"]
|
||||||
35
docker/gpu.Dockerfile
Normal file
35
docker/gpu.Dockerfile
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
FROM nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04
|
||||||
|
|
||||||
|
# set work directory
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
# install anaconda
|
||||||
|
ENV PATH /opt/conda/bin:$PATH
|
||||||
|
COPY docker/install_miniconda.sh .
|
||||||
|
RUN bash ./install_miniconda.sh && rm install_miniconda.sh
|
||||||
|
|
||||||
|
RUN conda update -n base -c defaults conda
|
||||||
|
|
||||||
|
# install tensorflow for GPU
|
||||||
|
RUN conda install -y tensorflow-gpu==1.14.0
|
||||||
|
|
||||||
|
# install ffmpeg for audio loading/writing
|
||||||
|
RUN conda install -y -c conda-forge ffmpeg
|
||||||
|
|
||||||
|
# install extra libs
|
||||||
|
RUN conda install -y -c anaconda pandas==0.25.1
|
||||||
|
RUN conda install -y -c conda-forge libsndfile
|
||||||
|
|
||||||
|
# install ipython
|
||||||
|
RUN conda install -y ipython
|
||||||
|
|
||||||
|
RUN mkdir /cache/
|
||||||
|
|
||||||
|
# clone inside image github repository
|
||||||
|
COPY ./ spleeter/
|
||||||
|
|
||||||
|
WORKDIR /workspace/spleeter
|
||||||
|
RUN pip install .
|
||||||
|
|
||||||
|
|
||||||
|
ENTRYPOINT ["python", "-m", "spleeter"]
|
||||||
13
docker/install_miniconda.sh
Normal file
13
docker/install_miniconda.sh
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
apt-get update --fix-missing && \
|
||||||
|
apt-get install -y wget bzip2 ca-certificates curl git && \
|
||||||
|
apt-get clean && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-4.6.14-Linux-x86_64.sh -O ~/miniconda.sh && \
|
||||||
|
/bin/bash ~/miniconda.sh -b -p /opt/conda && \
|
||||||
|
rm ~/miniconda.sh && \
|
||||||
|
/opt/conda/bin/conda clean -tipsy && \
|
||||||
|
ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \
|
||||||
|
echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \
|
||||||
|
echo "conda activate base" >> ~/.bashrc
|
||||||
108
setup.py
Normal file
108
setup.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
""" Distribution script. """
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from os import path
|
||||||
|
from setuptools import setup
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
# Default project values.
|
||||||
|
project_name = 'spleeter'
|
||||||
|
project_version = '1.4.0'
|
||||||
|
device_target = 'cpu'
|
||||||
|
tensorflow_dependency = 'tensorflow'
|
||||||
|
tensorflow_version = '1.14.0'
|
||||||
|
here = path.abspath(path.dirname(__file__))
|
||||||
|
readme_path = path.join(here, 'README.md')
|
||||||
|
with open(readme_path, 'r') as stream:
|
||||||
|
readme = stream.read()
|
||||||
|
|
||||||
|
# Check if GPU target is specified.
|
||||||
|
if '--target' in sys.argv:
|
||||||
|
target_index = sys.argv.index('--target') + 1
|
||||||
|
target = sys.argv[target_index].lower()
|
||||||
|
sys.argv.remove('--target')
|
||||||
|
sys.argv.pop(target_index)
|
||||||
|
|
||||||
|
# GPU target compatibility check.
|
||||||
|
if device_target == 'gpu':
|
||||||
|
project_name = '{}-gpu'.format(project_name)
|
||||||
|
tensorflow_dependency = 'tensorflow-gpu'
|
||||||
|
|
||||||
|
# Package setup entrypoint.
|
||||||
|
setup(
|
||||||
|
name=project_name,
|
||||||
|
version=project_version,
|
||||||
|
description='''
|
||||||
|
The Deezer source separation library with
|
||||||
|
pretrained models based on tensorflow.
|
||||||
|
''',
|
||||||
|
long_description=readme,
|
||||||
|
long_description_content_type='text/markdown',
|
||||||
|
author='Deezer Research',
|
||||||
|
author_email='research@deezer.com',
|
||||||
|
url='https://github.com/deezer/spleeter',
|
||||||
|
license='MIT License',
|
||||||
|
packages=[
|
||||||
|
'spleeter',
|
||||||
|
'spleeter.commands',
|
||||||
|
'spleeter.model',
|
||||||
|
'spleeter.model.functions',
|
||||||
|
'spleeter.model.provider',
|
||||||
|
'spleeter.resources',
|
||||||
|
'spleeter.utils',
|
||||||
|
'spleeter.utils.audio',
|
||||||
|
],
|
||||||
|
package_data={'spleeter.resources': ['*.json']},
|
||||||
|
python_requires='>=3.6, <3.8',
|
||||||
|
include_package_data=True,
|
||||||
|
install_requires=[
|
||||||
|
'importlib_resources ; python_version<"3.7"',
|
||||||
|
'musdb==0.3.1',
|
||||||
|
'museval==0.3.0',
|
||||||
|
'norbert==0.2.1',
|
||||||
|
'pandas==0.25.1',
|
||||||
|
'requests',
|
||||||
|
'{}=={}'.format(tensorflow_dependency, tensorflow_version),
|
||||||
|
],
|
||||||
|
entry_points={
|
||||||
|
'console_scripts': ['spleeter=spleeter.__main__:entrypoint']
|
||||||
|
},
|
||||||
|
classifiers=[
|
||||||
|
'Environment :: Console',
|
||||||
|
'Environment :: MacOS X',
|
||||||
|
'Intended Audience :: Developers',
|
||||||
|
'Intended Audience :: Information Technology',
|
||||||
|
'Intended Audience :: Science/Research',
|
||||||
|
'License :: OSI Approved :: MIT License',
|
||||||
|
'Natural Language :: English',
|
||||||
|
'Operating System :: MacOS',
|
||||||
|
'Operating System :: Microsoft :: Windows',
|
||||||
|
'Operating System :: POSIX :: Linux',
|
||||||
|
'Operating System :: Unix',
|
||||||
|
'Programming Language :: Python',
|
||||||
|
'Programming Language :: Python :: 3',
|
||||||
|
'Programming Language :: Python :: 3.6',
|
||||||
|
'Programming Language :: Python :: 3.7',
|
||||||
|
'Programming Language :: Python :: 3 :: Only',
|
||||||
|
'Programming Language :: Python :: Implementation :: CPython',
|
||||||
|
'Topic :: Artistic Software',
|
||||||
|
'Topic :: Multimedia',
|
||||||
|
'Topic :: Multimedia :: Sound/Audio',
|
||||||
|
'Topic :: Multimedia :: Sound/Audio :: Analysis',
|
||||||
|
'Topic :: Multimedia :: Sound/Audio :: Conversion',
|
||||||
|
'Topic :: Multimedia :: Sound/Audio :: Sound Synthesis',
|
||||||
|
'Topic :: Scientific/Engineering',
|
||||||
|
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||||
|
'Topic :: Scientific/Engineering :: Information Analysis',
|
||||||
|
'Topic :: Software Development',
|
||||||
|
'Topic :: Software Development :: Libraries',
|
||||||
|
'Topic :: Software Development :: Libraries :: Python Modules',
|
||||||
|
'Topic :: Utilities']
|
||||||
|
)
|
||||||
18
spleeter/__init__.py
Normal file
18
spleeter/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
"""
|
||||||
|
Spleeter is the Deezer source separation library with pretrained models.
|
||||||
|
The library is based on Tensorflow:
|
||||||
|
|
||||||
|
- It provides already trained model for performing separation.
|
||||||
|
- It makes it easy to train source separation model with tensorflow
|
||||||
|
(provided you have a dataset of isolated sources).
|
||||||
|
|
||||||
|
This module allows to interact easily from command line with Spleeter
|
||||||
|
by providing train, evaluation and source separation action.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
52
spleeter/__main__.py
Normal file
52
spleeter/__main__.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
"""
|
||||||
|
Python oneliner script usage.
|
||||||
|
|
||||||
|
USAGE: python -m spleeter {train,evaluate,separate} ...
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from .commands import create_argument_parser
|
||||||
|
from .utils.configuration import load_configuration
|
||||||
|
from .utils.logging import enable_logging, enable_verbose_logging
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
""" Spleeter runner. Parse provided command line arguments
|
||||||
|
and run entrypoint for required command (either train,
|
||||||
|
evaluate or separate).
|
||||||
|
|
||||||
|
:param argv: Provided command line arguments.
|
||||||
|
"""
|
||||||
|
parser = create_argument_parser()
|
||||||
|
arguments = parser.parse_args(argv[1:])
|
||||||
|
if arguments.verbose:
|
||||||
|
enable_verbose_logging()
|
||||||
|
else:
|
||||||
|
enable_logging()
|
||||||
|
if arguments.command == 'separate':
|
||||||
|
from .commands.separate import entrypoint
|
||||||
|
elif arguments.command == 'train':
|
||||||
|
from .commands.train import entrypoint
|
||||||
|
elif arguments.command == 'evaluate':
|
||||||
|
from .commands.evaluate import entrypoint
|
||||||
|
params = load_configuration(arguments.params_filename)
|
||||||
|
entrypoint(arguments, params)
|
||||||
|
|
||||||
|
|
||||||
|
def entrypoint():
|
||||||
|
""" Command line entrypoint. """
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
main(sys.argv)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
entrypoint()
|
||||||
182
spleeter/commands/__init__.py
Normal file
182
spleeter/commands/__init__.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
""" This modules provides spleeter command as well as CLI parsing methods. """
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from tempfile import gettempdir
|
||||||
|
from os.path import exists, join
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
# -i opt specification.
|
||||||
|
OPT_INPUT = {
|
||||||
|
'dest': 'audio_filenames',
|
||||||
|
'nargs': '+',
|
||||||
|
'help': 'List of input audio filenames',
|
||||||
|
'required': True
|
||||||
|
}
|
||||||
|
|
||||||
|
# -o opt specification.
|
||||||
|
OPT_OUTPUT = {
|
||||||
|
'dest': 'output_path',
|
||||||
|
'default': join(gettempdir(), 'separated_audio'),
|
||||||
|
'help': 'Path of the output directory to write audio files in'
|
||||||
|
}
|
||||||
|
|
||||||
|
# -p opt specification.
|
||||||
|
OPT_PARAMS = {
|
||||||
|
'dest': 'params_filename',
|
||||||
|
'default': 'spleeter:2stems',
|
||||||
|
'type': str,
|
||||||
|
'action': 'store',
|
||||||
|
'help': 'JSON filename that contains params'
|
||||||
|
}
|
||||||
|
|
||||||
|
# -n opt specification.
|
||||||
|
OPT_OUTPUT_NAMING = {
|
||||||
|
'dest': 'output_naming',
|
||||||
|
'default': 'filename',
|
||||||
|
'choices': ('directory', 'filename'),
|
||||||
|
'help': (
|
||||||
|
'Choice for naming the output base path: '
|
||||||
|
'"filename" (use the input filename, i.e '
|
||||||
|
'/path/to/audio/mix.wav will be separated to '
|
||||||
|
'<output_path>/mix/<instument1>.wav, '
|
||||||
|
'<output_path>/mix/<instument2>.wav...) or '
|
||||||
|
'"directory" (use the name of the input last level'
|
||||||
|
' directory, for instance /path/to/audio/mix.wav '
|
||||||
|
'will be separated to <output_path>/audio/<instument1>.wav'
|
||||||
|
', <output_path>/audio/<instument2>.wav)')
|
||||||
|
}
|
||||||
|
|
||||||
|
# -d opt specification (separate).
|
||||||
|
OPT_DURATION = {
|
||||||
|
'dest': 'max_duration',
|
||||||
|
'type': float,
|
||||||
|
'default': 600.,
|
||||||
|
'help': (
|
||||||
|
'Set a maximum duration for processing audio '
|
||||||
|
'(only separate max_duration first seconds of '
|
||||||
|
'the input file)')
|
||||||
|
}
|
||||||
|
|
||||||
|
# -c opt specification.
|
||||||
|
OPT_CODEC = {
|
||||||
|
'dest': 'audio_codec',
|
||||||
|
'choices': ('wav', 'mp3', 'ogg', 'm4a', 'wma', 'flac'),
|
||||||
|
'default': 'wav',
|
||||||
|
'help': 'Audio codec to be used for the separated output'
|
||||||
|
}
|
||||||
|
|
||||||
|
# -m opt specification.
|
||||||
|
OPT_MWF = {
|
||||||
|
'dest': 'MWF',
|
||||||
|
'action': 'store_const',
|
||||||
|
'const': True,
|
||||||
|
'default': False,
|
||||||
|
'help': 'Whether to use multichannel Wiener filtering for separation',
|
||||||
|
}
|
||||||
|
|
||||||
|
# --mus_dir opt specification.
|
||||||
|
OPT_MUSDB = {
|
||||||
|
'dest': 'mus_dir',
|
||||||
|
'type': str,
|
||||||
|
'required': True,
|
||||||
|
'help': 'Path to folder with musDB'
|
||||||
|
}
|
||||||
|
|
||||||
|
# -d opt specification (train).
|
||||||
|
OPT_DATA = {
|
||||||
|
'dest': 'audio_path',
|
||||||
|
'type': str,
|
||||||
|
'required': True,
|
||||||
|
'help': 'Path of the folder containing audio data for training'
|
||||||
|
}
|
||||||
|
|
||||||
|
# -a opt specification.
|
||||||
|
OPT_ADAPTER = {
|
||||||
|
'dest': 'audio_adapter',
|
||||||
|
'type': str,
|
||||||
|
'help': 'Name of the audio adapter to use for audio I/O'
|
||||||
|
}
|
||||||
|
|
||||||
|
# -a opt specification.
|
||||||
|
OPT_VERBOSE = {
|
||||||
|
'action': 'store_true',
|
||||||
|
'help': 'Shows verbose logs'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _add_common_options(parser):
|
||||||
|
""" Add common option to the given parser.
|
||||||
|
|
||||||
|
:param parser: Parser to add common opt to.
|
||||||
|
"""
|
||||||
|
parser.add_argument('-a', '--adapter', **OPT_ADAPTER)
|
||||||
|
parser.add_argument('-p', '--params_filename', **OPT_PARAMS)
|
||||||
|
parser.add_argument('--verbose', **OPT_VERBOSE)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_train_parser(parser_factory):
|
||||||
|
""" Creates an argparser for training command
|
||||||
|
|
||||||
|
:param parser_factory: Factory to use to create parser instance.
|
||||||
|
:returns: Created and configured parser.
|
||||||
|
"""
|
||||||
|
parser = parser_factory('train', help='Train a source separation model')
|
||||||
|
_add_common_options(parser)
|
||||||
|
parser.add_argument('-d', '--data', **OPT_DATA)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def _create_evaluate_parser(parser_factory):
|
||||||
|
""" Creates an argparser for evaluation command
|
||||||
|
|
||||||
|
:param parser_factory: Factory to use to create parser instance.
|
||||||
|
:returns: Created and configured parser.
|
||||||
|
"""
|
||||||
|
parser = parser_factory(
|
||||||
|
'evaluate',
|
||||||
|
help='Evaluate a model on the musDB test dataset')
|
||||||
|
_add_common_options(parser)
|
||||||
|
parser.add_argument('-o', '--output_path', **OPT_OUTPUT)
|
||||||
|
parser.add_argument('--mus_dir', **OPT_MUSDB)
|
||||||
|
parser.add_argument('-m', '--mwf', **OPT_MWF)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def _create_separate_parser(parser_factory):
|
||||||
|
""" Creates an argparser for separation command
|
||||||
|
|
||||||
|
:param parser_factory: Factory to use to create parser instance.
|
||||||
|
:returns: Created and configured parser.
|
||||||
|
"""
|
||||||
|
parser = parser_factory('separate', help='Separate audio files')
|
||||||
|
_add_common_options(parser)
|
||||||
|
parser.add_argument('-i', '--audio_filenames', **OPT_INPUT)
|
||||||
|
parser.add_argument('-o', '--output_path', **OPT_OUTPUT)
|
||||||
|
parser.add_argument('-n', '--output_naming', **OPT_OUTPUT_NAMING)
|
||||||
|
parser.add_argument('-d', '--max_duration', **OPT_DURATION)
|
||||||
|
parser.add_argument('-c', '--audio_codec', **OPT_CODEC)
|
||||||
|
parser.add_argument('-m', '--mwf', **OPT_MWF)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def create_argument_parser():
|
||||||
|
""" Creates overall command line parser for Spleeter.
|
||||||
|
|
||||||
|
:returns: Created argument parser.
|
||||||
|
"""
|
||||||
|
parser = ArgumentParser(prog='python -m spleeter')
|
||||||
|
subparsers = parser.add_subparsers()
|
||||||
|
subparsers.dest = 'command'
|
||||||
|
subparsers.required = True
|
||||||
|
_create_separate_parser(subparsers.add_parser)
|
||||||
|
_create_train_parser(subparsers.add_parser)
|
||||||
|
_create_evaluate_parser(subparsers.add_parser)
|
||||||
|
return parser
|
||||||
154
spleeter/commands/evaluate.py
Normal file
154
spleeter/commands/evaluate.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
"""
|
||||||
|
Entrypoint provider for performing model evaluation.
|
||||||
|
|
||||||
|
Evaluation is performed against musDB dataset.
|
||||||
|
|
||||||
|
USAGE: python -m spleeter evaluate \
|
||||||
|
-p /path/to/params \
|
||||||
|
-o /path/to/output/dir \
|
||||||
|
[-m] \
|
||||||
|
--mus_dir /path/to/musdb dataset
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from argparse import Namespace
|
||||||
|
from itertools import product
|
||||||
|
from glob import glob
|
||||||
|
from os.path import join, exists
|
||||||
|
|
||||||
|
# pylint: disable=import-error
|
||||||
|
import musdb
|
||||||
|
import museval
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
# pylint: enable=import-error
|
||||||
|
|
||||||
|
from .separate import entrypoint as separate_entrypoint
|
||||||
|
from ..utils.logging import get_logger
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
_SPLIT = 'test'
|
||||||
|
_MIXTURE = 'mixture.wav'
|
||||||
|
_NAMING = 'directory'
|
||||||
|
_AUDIO_DIRECTORY = 'audio'
|
||||||
|
_METRICS_DIRECTORY = 'metrics'
|
||||||
|
_INSTRUMENTS = ('vocals', 'drums', 'bass', 'other')
|
||||||
|
_METRICS = ('SDR', 'SAR', 'SIR', 'ISR')
|
||||||
|
|
||||||
|
|
||||||
|
def _separate_evaluation_dataset(arguments, musdb_root_directory, params):
|
||||||
|
""" Performs audio separation on the musdb dataset from
|
||||||
|
the given directory and params.
|
||||||
|
|
||||||
|
:param arguments: Entrypoint arguments.
|
||||||
|
:param musdb_root_directory: Directory to retrieve dataset from.
|
||||||
|
:param params: Spleeter configuration to apply to separation.
|
||||||
|
:returns: Separation output directory path.
|
||||||
|
"""
|
||||||
|
songs = glob(join(musdb_root_directory, _SPLIT, '*/'))
|
||||||
|
mixtures = [join(song, _MIXTURE) for song in songs]
|
||||||
|
audio_output_directory = join(
|
||||||
|
arguments.output_path,
|
||||||
|
_AUDIO_DIRECTORY)
|
||||||
|
separate_entrypoint(
|
||||||
|
Namespace(
|
||||||
|
audio_adapter=arguments.audio_adapter,
|
||||||
|
audio_filenames=mixtures,
|
||||||
|
audio_codec='wav',
|
||||||
|
output_path=join(audio_output_directory, _SPLIT),
|
||||||
|
output_naming=_NAMING,
|
||||||
|
max_duration=600.,
|
||||||
|
MWF=arguments.MWF,
|
||||||
|
verbose=arguments.verbose),
|
||||||
|
params)
|
||||||
|
return audio_output_directory
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_musdb_metrics(
|
||||||
|
arguments,
|
||||||
|
musdb_root_directory,
|
||||||
|
audio_output_directory):
|
||||||
|
""" Generates musdb metrics fro previsouly computed audio estimation.
|
||||||
|
|
||||||
|
:param arguments: Entrypoint arguments.
|
||||||
|
:param audio_output_directory: Directory to get audio estimation from.
|
||||||
|
:returns: Path of generated metrics directory.
|
||||||
|
"""
|
||||||
|
metrics_output_directory = join(
|
||||||
|
arguments.output_path,
|
||||||
|
_METRICS_DIRECTORY)
|
||||||
|
get_logger().info('Starting musdb evaluation (this could be long) ...')
|
||||||
|
dataset = musdb.DB(
|
||||||
|
root=musdb_root_directory,
|
||||||
|
is_wav=True,
|
||||||
|
subsets=[_SPLIT])
|
||||||
|
museval.eval_mus_dir(
|
||||||
|
dataset=dataset,
|
||||||
|
estimates_dir=audio_output_directory,
|
||||||
|
output_dir=metrics_output_directory)
|
||||||
|
get_logger().info('musdb evaluation done')
|
||||||
|
return metrics_output_directory
|
||||||
|
|
||||||
|
|
||||||
|
def _compile_metrics(metrics_output_directory):
|
||||||
|
""" Compiles metrics from given directory and returns
|
||||||
|
results as dict.
|
||||||
|
|
||||||
|
:param metrics_output_directory: Directory to get metrics from.
|
||||||
|
:returns: Compiled metrics as dict.
|
||||||
|
"""
|
||||||
|
songs = glob(join(metrics_output_directory, 'test/*.json'))
|
||||||
|
index = pd.MultiIndex.from_tuples(
|
||||||
|
product(_INSTRUMENTS, _METRICS),
|
||||||
|
names=['instrument', 'metric'])
|
||||||
|
pd.DataFrame([], index=['config1', 'config2'], columns=index)
|
||||||
|
metrics = {
|
||||||
|
instrument: {k: [] for k in _METRICS}
|
||||||
|
for instrument in _INSTRUMENTS}
|
||||||
|
for song in songs:
|
||||||
|
with open(song, 'r') as stream:
|
||||||
|
data = json.load(stream)
|
||||||
|
for target in data['targets']:
|
||||||
|
instrument = target['name']
|
||||||
|
for metric in _METRICS:
|
||||||
|
sdr_med = np.median([
|
||||||
|
frame['metrics'][metric]
|
||||||
|
for frame in target['frames']
|
||||||
|
if not np.isnan(frame['metrics'][metric])])
|
||||||
|
metrics[instrument][metric].append(sdr_med)
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def entrypoint(arguments, params):
|
||||||
|
""" Command entrypoint.
|
||||||
|
|
||||||
|
:param arguments: Command line parsed argument as argparse.Namespace.
|
||||||
|
:param params: Deserialized JSON configuration file provided in CLI args.
|
||||||
|
"""
|
||||||
|
# Parse and check musdb directory.
|
||||||
|
musdb_root_directory = arguments.mus_dir
|
||||||
|
if not exists(musdb_root_directory):
|
||||||
|
raise IOError(f'musdb directory {musdb_root_directory} not found')
|
||||||
|
# Separate musdb sources.
|
||||||
|
audio_output_directory = _separate_evaluation_dataset(
|
||||||
|
arguments,
|
||||||
|
musdb_root_directory,
|
||||||
|
params)
|
||||||
|
# Compute metrics with musdb.
|
||||||
|
metrics_output_directory = _compute_musdb_metrics(
|
||||||
|
arguments,
|
||||||
|
musdb_root_directory,
|
||||||
|
audio_output_directory)
|
||||||
|
# Compute and pretty print median metrics.
|
||||||
|
metrics = _compile_metrics(metrics_output_directory)
|
||||||
|
for instrument, metric in metrics.items():
|
||||||
|
get_logger().info('%s:', instrument)
|
||||||
|
for metric, value in metric.items():
|
||||||
|
get_logger().info('%s: %s', metric, f'{np.median(value):.3f}')
|
||||||
180
spleeter/commands/separate.py
Normal file
180
spleeter/commands/separate.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
"""
|
||||||
|
Entrypoint provider for performing source separation.
|
||||||
|
|
||||||
|
USAGE: python -m spleeter separate \
|
||||||
|
-p /path/to/params \
|
||||||
|
-i inputfile1 inputfile2 ... inputfilen
|
||||||
|
-o /path/to/output/dir \
|
||||||
|
-i /path/to/audio1.wav /path/to/audio2.mp3
|
||||||
|
"""
|
||||||
|
|
||||||
|
from multiprocessing import Pool
|
||||||
|
from os.path import isabs, join, split, splitext
|
||||||
|
from tempfile import gettempdir
|
||||||
|
|
||||||
|
# pylint: disable=import-error
|
||||||
|
import tensorflow as tf
|
||||||
|
import numpy as np
|
||||||
|
# pylint: enable=import-error
|
||||||
|
|
||||||
|
from ..utils.audio.adapter import get_audio_adapter
|
||||||
|
from ..utils.audio.convertor import to_n_channels
|
||||||
|
from ..utils.estimator import create_estimator
|
||||||
|
from ..utils.tensor import set_tensor_shape
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset(audio_adapter, filenames_and_crops, sample_rate, n_channels):
|
||||||
|
""""
|
||||||
|
Build a tensorflow dataset of waveform from a filename list wit crop
|
||||||
|
information.
|
||||||
|
|
||||||
|
Params:
|
||||||
|
- audio_adapter: An AudioAdapter instance to load audio from.
|
||||||
|
- filenames_and_crops: list of (audio_filename, start, duration)
|
||||||
|
tuples separation is performed on each filaneme
|
||||||
|
from start (in seconds) to start + duration
|
||||||
|
(in seconds).
|
||||||
|
- sample_rate: audio sample_rate of the input and output audio
|
||||||
|
signals
|
||||||
|
- n_channels: int, number of channels of the input and output
|
||||||
|
audio signals
|
||||||
|
|
||||||
|
Returns
|
||||||
|
A tensorflow dataset of waveform to feed a tensorflow estimator in
|
||||||
|
predict mode.
|
||||||
|
"""
|
||||||
|
filenames, starts, ends = list(zip(*filenames_and_crops))
|
||||||
|
dataset = tf.data.Dataset.from_tensor_slices({
|
||||||
|
'audio_id': list(filenames),
|
||||||
|
'start': list(starts),
|
||||||
|
'end': list(ends)
|
||||||
|
})
|
||||||
|
# Load waveform.
|
||||||
|
dataset = dataset.map(
|
||||||
|
lambda sample: dict(
|
||||||
|
sample,
|
||||||
|
**audio_adapter.load_tf_waveform(
|
||||||
|
sample['audio_id'],
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
offset=sample['start'],
|
||||||
|
duration=sample['end'] - sample['start'])),
|
||||||
|
num_parallel_calls=2)
|
||||||
|
# Filter out error.
|
||||||
|
dataset = dataset.filter(
|
||||||
|
lambda sample: tf.logical_not(sample['waveform_error']))
|
||||||
|
# Convert waveform to the right number of channels.
|
||||||
|
dataset = dataset.map(
|
||||||
|
lambda sample: dict(
|
||||||
|
sample,
|
||||||
|
waveform=to_n_channels(sample['waveform'], n_channels)))
|
||||||
|
# Set number of channels (required for the model).
|
||||||
|
dataset = dataset.map(
|
||||||
|
lambda sample: dict(
|
||||||
|
sample,
|
||||||
|
waveform=set_tensor_shape(sample['waveform'], (None, n_channels))))
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def process_audio(
|
||||||
|
audio_adapter,
|
||||||
|
filenames_and_crops, estimator, output_path,
|
||||||
|
sample_rate, n_channels, codec, output_naming):
|
||||||
|
"""
|
||||||
|
Perform separation on a list of audio ids.
|
||||||
|
|
||||||
|
Params:
|
||||||
|
- audio_adapter: Audio adapter to use for audio I/O.
|
||||||
|
- filenames_and_crops: list of (audio_filename, start, duration)
|
||||||
|
tuples separation is performed on each filaneme
|
||||||
|
from start (in seconds) to start + duration
|
||||||
|
(in seconds).
|
||||||
|
- estimator: the tensorflow estimator that performs the
|
||||||
|
source separation.
|
||||||
|
- output_path: output_path where to export separated files.
|
||||||
|
- sample_rate: audio sample_rate of the input and output audio
|
||||||
|
signals
|
||||||
|
- n_channels: int, number of channels of the input and output
|
||||||
|
audio signals
|
||||||
|
- codec: string codec to be used for export (could be
|
||||||
|
"wav", "mp3", "ogg", "m4a") could be anything
|
||||||
|
supported by ffmpeg.
|
||||||
|
- output_naming: string (= "filename" of "directory")
|
||||||
|
naming convention for output.
|
||||||
|
for an input file /path/to/audio/input_file.wav:
|
||||||
|
* if output_naming is equal to "filename":
|
||||||
|
output files will be put in the directory <output_path>/input_file
|
||||||
|
(<output_path>/input_file/<instrument1>.<codec>,
|
||||||
|
<output_path>/input_file/<instrument2>.<codec>...).
|
||||||
|
* if output_naming is equal to "directory":
|
||||||
|
output files will be put in the directory <output_path>/audio/
|
||||||
|
(<output_path>/audio/<instrument1>.<codec>,
|
||||||
|
<output_path>/audio/<instrument2>.<codec>...)
|
||||||
|
Use "directory" when separating the MusDB dataset.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Get estimator
|
||||||
|
prediction = estimator.predict(
|
||||||
|
lambda: get_dataset(
|
||||||
|
audio_adapter,
|
||||||
|
filenames_and_crops,
|
||||||
|
sample_rate,
|
||||||
|
n_channels),
|
||||||
|
yield_single_examples=False)
|
||||||
|
# initialize pool for audio export
|
||||||
|
pool = Pool(16)
|
||||||
|
tasks = []
|
||||||
|
for sample in prediction:
|
||||||
|
sample_filename = sample.pop('audio_id', 'unknown_filename').decode()
|
||||||
|
input_directory, input_filename = split(sample_filename)
|
||||||
|
if output_naming == 'directory':
|
||||||
|
output_dirname = split(input_directory)[1]
|
||||||
|
elif output_naming == 'filename':
|
||||||
|
output_dirname = splitext(input_filename)[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown output naming {output_naming}')
|
||||||
|
for instrument, waveform in sample.items():
|
||||||
|
filename = join(
|
||||||
|
output_path,
|
||||||
|
output_dirname,
|
||||||
|
f'{instrument}.{codec}')
|
||||||
|
tasks.append(
|
||||||
|
pool.apply_async(
|
||||||
|
audio_adapter.save,
|
||||||
|
(filename, waveform, sample_rate, codec)))
|
||||||
|
# Wait for everything to be written
|
||||||
|
for task in tasks:
|
||||||
|
task.wait(timeout=20)
|
||||||
|
|
||||||
|
|
||||||
|
def entrypoint(arguments, params):
|
||||||
|
""" Command entrypoint.
|
||||||
|
|
||||||
|
:param arguments: Command line parsed argument as argparse.Namespace.
|
||||||
|
:param params: Deserialized JSON configuration file provided in CLI args.
|
||||||
|
"""
|
||||||
|
audio_adapter = get_audio_adapter(arguments.audio_adapter)
|
||||||
|
filenames = arguments.audio_filenames
|
||||||
|
output_path = arguments.output_path
|
||||||
|
max_duration = arguments.max_duration
|
||||||
|
audio_codec = arguments.audio_codec
|
||||||
|
output_naming = arguments.output_naming
|
||||||
|
estimator = create_estimator(params, arguments.MWF)
|
||||||
|
filenames_and_crops = [
|
||||||
|
(filename, 0., max_duration)
|
||||||
|
for filename in filenames]
|
||||||
|
process_audio(
|
||||||
|
audio_adapter,
|
||||||
|
filenames_and_crops,
|
||||||
|
estimator,
|
||||||
|
output_path,
|
||||||
|
params['sample_rate'],
|
||||||
|
params['n_channels'],
|
||||||
|
codec=audio_codec,
|
||||||
|
output_naming=output_naming)
|
||||||
98
spleeter/commands/train.py
Normal file
98
spleeter/commands/train.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
"""
|
||||||
|
Entrypoint provider for performing model training.
|
||||||
|
|
||||||
|
USAGE: python -m spleeter train -p /path/to/params
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
# pylint: disable=import-error
|
||||||
|
import tensorflow as tf
|
||||||
|
# pylint: enable=import-error
|
||||||
|
|
||||||
|
from ..dataset import get_training_dataset, get_validation_dataset
|
||||||
|
from ..model import model_fn
|
||||||
|
from ..utils.audio.adapter import get_audio_adapter
|
||||||
|
from ..utils.logging import get_logger
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
def _create_estimator(params):
|
||||||
|
""" Creates estimator.
|
||||||
|
|
||||||
|
:param params: TF params to build estimator from.
|
||||||
|
:returns: Built estimator.
|
||||||
|
"""
|
||||||
|
session_config = tf.compat.v1.ConfigProto()
|
||||||
|
session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
|
||||||
|
estimator = tf.estimator.Estimator(
|
||||||
|
model_fn=model_fn,
|
||||||
|
model_dir=params['model_dir'],
|
||||||
|
params=params,
|
||||||
|
config=tf.estimator.RunConfig(
|
||||||
|
save_checkpoints_steps=params['save_checkpoints_steps'],
|
||||||
|
tf_random_seed=params['random_seed'],
|
||||||
|
save_summary_steps=params['save_summary_steps'],
|
||||||
|
session_config=session_config,
|
||||||
|
log_step_count_steps=10,
|
||||||
|
keep_checkpoint_max=2))
|
||||||
|
return estimator
|
||||||
|
|
||||||
|
|
||||||
|
def _create_train_spec(params, audio_adapter, audio_path):
|
||||||
|
""" Creates train spec.
|
||||||
|
|
||||||
|
:param params: TF params to build spec from.
|
||||||
|
:returns: Built train spec.
|
||||||
|
"""
|
||||||
|
input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
|
||||||
|
train_spec = tf.estimator.TrainSpec(
|
||||||
|
input_fn=input_fn,
|
||||||
|
max_steps=params['train_max_steps'])
|
||||||
|
return train_spec
|
||||||
|
|
||||||
|
|
||||||
|
def _create_evaluation_spec(params, audio_adapter, audio_path):
|
||||||
|
""" Setup eval spec evaluating ever n seconds
|
||||||
|
|
||||||
|
:param params: TF params to build spec from.
|
||||||
|
:returns: Built evaluation spec.
|
||||||
|
"""
|
||||||
|
input_fn = partial(
|
||||||
|
get_validation_dataset,
|
||||||
|
params,
|
||||||
|
audio_adapter,
|
||||||
|
audio_path)
|
||||||
|
evaluation_spec = tf.estimator.EvalSpec(
|
||||||
|
input_fn=input_fn,
|
||||||
|
steps=None,
|
||||||
|
throttle_secs=params['throttle_secs'])
|
||||||
|
return evaluation_spec
|
||||||
|
|
||||||
|
|
||||||
|
def entrypoint(arguments, params):
|
||||||
|
""" Command entrypoint.
|
||||||
|
|
||||||
|
:param arguments: Command line parsed argument as argparse.Namespace.
|
||||||
|
:param params: Deserialized JSON configuration file provided in CLI args.
|
||||||
|
"""
|
||||||
|
audio_adapter = get_audio_adapter(arguments.audio_adapter)
|
||||||
|
audio_path = arguments.audio_path
|
||||||
|
estimator = _create_estimator(params)
|
||||||
|
train_spec = _create_train_spec(params, audio_adapter, audio_path)
|
||||||
|
evaluation_spec = _create_evaluation_spec(
|
||||||
|
params,
|
||||||
|
audio_adapter,
|
||||||
|
audio_path)
|
||||||
|
get_logger().info('Start model training')
|
||||||
|
tf.estimator.train_and_evaluate(
|
||||||
|
estimator,
|
||||||
|
train_spec,
|
||||||
|
evaluation_spec)
|
||||||
|
get_logger().info('Model training done')
|
||||||
464
spleeter/dataset.py
Normal file
464
spleeter/dataset.py
Normal file
@@ -0,0 +1,464 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
"""
|
||||||
|
Module for building data preprocessing pipeline using the tensorflow data
|
||||||
|
API.
|
||||||
|
Data preprocessing such as audio loading, spectrogram computation, cropping,
|
||||||
|
feature caching or data augmentation is done using a tensorflow dataset object
|
||||||
|
that output a tuple (input_, output) where:
|
||||||
|
- input_ is a dictionary with a single key that contains the (batched) mix
|
||||||
|
spectrogram of audio samples
|
||||||
|
- output is a dictionary of spectrogram of the isolated tracks (ground truth)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
from os.path import exists, join, sep as SEPARATOR
|
||||||
|
|
||||||
|
# pylint: disable=import-error
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
# pylint: enable=import-error
|
||||||
|
|
||||||
|
from .utils.audio.convertor import (
|
||||||
|
db_uint_spectrogram_to_gain,
|
||||||
|
spectrogram_to_db_uint)
|
||||||
|
from .utils.audio.spectrogram import (
|
||||||
|
compute_spectrogram_tf,
|
||||||
|
random_pitch_shift,
|
||||||
|
random_time_stretch)
|
||||||
|
from .utils.logging import get_logger
|
||||||
|
from .utils.tensor import (
|
||||||
|
check_tensor_shape,
|
||||||
|
dataset_from_csv,
|
||||||
|
set_tensor_shape,
|
||||||
|
sync_apply)
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
# Default datasets path parameter to use.
|
||||||
|
DEFAULT_DATASETS_PATH = join(
|
||||||
|
'audio_database',
|
||||||
|
'separated_sources',
|
||||||
|
'experiments',
|
||||||
|
'karaoke_vocal_extraction',
|
||||||
|
'tensorflow_experiment'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default audio parameters to use.
|
||||||
|
DEFAULT_AUDIO_PARAMS = {
|
||||||
|
'instrument_list': ('vocals', 'accompaniment'),
|
||||||
|
'mix_name': 'mix',
|
||||||
|
'sample_rate': 44100,
|
||||||
|
'frame_length': 4096,
|
||||||
|
'frame_step': 1024,
|
||||||
|
'T': 512,
|
||||||
|
'F': 1024
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_training_dataset(audio_params, audio_adapter, audio_path):
|
||||||
|
""" Builds training dataset.
|
||||||
|
|
||||||
|
:param audio_params: Audio parameters.
|
||||||
|
:param audio_adapter: Adapter to load audio from.
|
||||||
|
:param audio_path: Path of directory containing audio.
|
||||||
|
:returns: Built dataset.
|
||||||
|
"""
|
||||||
|
builder = DatasetBuilder(
|
||||||
|
audio_params,
|
||||||
|
audio_adapter,
|
||||||
|
audio_path,
|
||||||
|
chunk_duration=audio_params.get('chunk_duration', 20.0),
|
||||||
|
random_seed=audio_params.get('random_seed', 0))
|
||||||
|
return builder.build(
|
||||||
|
audio_params.get('train_csv'),
|
||||||
|
cache_directory=audio_params.get('training_cache'),
|
||||||
|
batch_size=audio_params.get('batch_size'),
|
||||||
|
n_chunks_per_song=audio_params.get('n_chunks_per_song', 2),
|
||||||
|
random_data_augmentation=False,
|
||||||
|
convert_to_uint=True,
|
||||||
|
wait_for_cache=False)
|
||||||
|
|
||||||
|
|
||||||
|
def get_validation_dataset(audio_params, audio_adapter, audio_path):
|
||||||
|
""" Builds validation dataset.
|
||||||
|
|
||||||
|
:param audio_params: Audio parameters.
|
||||||
|
:param audio_adapter: Adapter to load audio from.
|
||||||
|
:param audio_path: Path of directory containing audio.
|
||||||
|
:returns: Built dataset.
|
||||||
|
"""
|
||||||
|
builder = DatasetBuilder(
|
||||||
|
audio_params,
|
||||||
|
audio_adapter,
|
||||||
|
audio_path,
|
||||||
|
chunk_duration=12.0)
|
||||||
|
return builder.build(
|
||||||
|
audio_params.get('validation_csv'),
|
||||||
|
batch_size=audio_params.get('batch_size'),
|
||||||
|
cache_directory=audio_params.get('training_cache'),
|
||||||
|
convert_to_uint=True,
|
||||||
|
infinite_generator=False,
|
||||||
|
n_chunks_per_song=1,
|
||||||
|
# should not perform data augmentation for eval:
|
||||||
|
random_data_augmentation=False,
|
||||||
|
random_time_crop=False,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InstrumentDatasetBuilder(object):
|
||||||
|
""" Instrument based filter and mapper provider. """
|
||||||
|
|
||||||
|
def __init__(self, parent, instrument):
|
||||||
|
""" Default constructor.
|
||||||
|
|
||||||
|
:param parent: Parent dataset builder.
|
||||||
|
:param instrument: Target instrument.
|
||||||
|
"""
|
||||||
|
self._parent = parent
|
||||||
|
self._instrument = instrument
|
||||||
|
self._spectrogram_key = f'{instrument}_spectrogram'
|
||||||
|
self._min_spectrogram_key = f'min_{instrument}_spectrogram'
|
||||||
|
self._max_spectrogram_key = f'max_{instrument}_spectrogram'
|
||||||
|
|
||||||
|
def load_waveform(self, sample):
|
||||||
|
""" Load waveform for given sample. """
|
||||||
|
return dict(sample, **self._parent._audio_adapter.load_tf_waveform(
|
||||||
|
sample[f'{self._instrument}_path'],
|
||||||
|
offset=sample['start'],
|
||||||
|
duration=self._parent._chunk_duration,
|
||||||
|
sample_rate=self._parent._sample_rate,
|
||||||
|
waveform_name='waveform'))
|
||||||
|
|
||||||
|
def compute_spectrogram(self, sample):
|
||||||
|
""" Compute spectrogram of the given sample. """
|
||||||
|
return dict(sample, **{
|
||||||
|
self._spectrogram_key: compute_spectrogram_tf(
|
||||||
|
sample['waveform'],
|
||||||
|
frame_length=self._parent._frame_length,
|
||||||
|
frame_step=self._parent._frame_step,
|
||||||
|
spec_exponent=1.,
|
||||||
|
window_exponent=1.)})
|
||||||
|
|
||||||
|
def filter_frequencies(self, sample):
|
||||||
|
""" """
|
||||||
|
return dict(sample, **{
|
||||||
|
self._spectrogram_key:
|
||||||
|
sample[self._spectrogram_key][:, :self._parent._F, :]})
|
||||||
|
|
||||||
|
def convert_to_uint(self, sample):
|
||||||
|
""" Convert given sample from float to unit. """
|
||||||
|
return dict(sample, **spectrogram_to_db_uint(
|
||||||
|
sample[self._spectrogram_key],
|
||||||
|
tensor_key=self._spectrogram_key,
|
||||||
|
min_key=self._min_spectrogram_key,
|
||||||
|
max_key=self._max_spectrogram_key))
|
||||||
|
|
||||||
|
def filter_infinity(self, sample):
|
||||||
|
""" Filter infinity sample. """
|
||||||
|
return tf.logical_not(
|
||||||
|
tf.math.is_inf(
|
||||||
|
sample[self._min_spectrogram_key]))
|
||||||
|
|
||||||
|
def convert_to_float32(self, sample):
|
||||||
|
""" Convert given sample from unit to float. """
|
||||||
|
return dict(sample, **{
|
||||||
|
self._spectrogram_key: db_uint_spectrogram_to_gain(
|
||||||
|
sample[self._spectrogram_key],
|
||||||
|
sample[self._min_spectrogram_key],
|
||||||
|
sample[self._max_spectrogram_key])})
|
||||||
|
|
||||||
|
def time_crop(self, sample):
|
||||||
|
""" """
|
||||||
|
def start(sample):
|
||||||
|
""" mid_segment_start """
|
||||||
|
return tf.cast(
|
||||||
|
tf.maximum(
|
||||||
|
tf.shape(sample[self._spectrogram_key])[0]
|
||||||
|
/ 2 - self._parent._T / 2, 0),
|
||||||
|
tf.int32)
|
||||||
|
return dict(sample, **{
|
||||||
|
self._spectrogram_key: sample[self._spectrogram_key][
|
||||||
|
start(sample):start(sample) + self._parent._T, :, :]})
|
||||||
|
|
||||||
|
def filter_shape(self, sample):
|
||||||
|
""" Filter badly shaped sample. """
|
||||||
|
return check_tensor_shape(
|
||||||
|
sample[self._spectrogram_key], (
|
||||||
|
self._parent._T, self._parent._F, 2))
|
||||||
|
|
||||||
|
def reshape_spectrogram(self, sample):
|
||||||
|
""" """
|
||||||
|
return dict(sample, **{
|
||||||
|
self._spectrogram_key: set_tensor_shape(
|
||||||
|
sample[self._spectrogram_key],
|
||||||
|
(self._parent._T, self._parent._F, 2))})
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetBuilder(object):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Margin at beginning and end of songs in seconds.
|
||||||
|
MARGIN = 0.5
|
||||||
|
|
||||||
|
# Wait period for cache (in seconds).
|
||||||
|
WAIT_PERIOD = 60
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
audio_params, audio_adapter, audio_path,
|
||||||
|
random_seed=0, chunk_duration=20.0):
|
||||||
|
""" Default constructor.
|
||||||
|
|
||||||
|
NOTE: Probably need for AudioAdapter.
|
||||||
|
|
||||||
|
:param audio_params: Audio parameters to use.
|
||||||
|
:param audio_adapter: Audio adapter to use.
|
||||||
|
:param audio_path:
|
||||||
|
:param random_seed:
|
||||||
|
:param chunk_duration:
|
||||||
|
"""
|
||||||
|
# Length of segment in frames (if fs=22050 and
|
||||||
|
# frame_step=512, then T=512 corresponds to 11.89s)
|
||||||
|
self._T = audio_params['T']
|
||||||
|
# Number of frequency bins to be used (should
|
||||||
|
# be less than frame_length/2 + 1)
|
||||||
|
self._F = audio_params['F']
|
||||||
|
self._sample_rate = audio_params['sample_rate']
|
||||||
|
self._frame_length = audio_params['frame_length']
|
||||||
|
self._frame_step = audio_params['frame_step']
|
||||||
|
self._mix_name = audio_params['mix_name']
|
||||||
|
self._instruments = [self._mix_name] + audio_params['instrument_list']
|
||||||
|
self._instrument_builders = None
|
||||||
|
self._chunk_duration = chunk_duration
|
||||||
|
self._audio_adapter = audio_adapter
|
||||||
|
self._audio_params = audio_params
|
||||||
|
self._audio_path = audio_path
|
||||||
|
self._random_seed = random_seed
|
||||||
|
|
||||||
|
def expand_path(self, sample):
|
||||||
|
""" Expands audio paths for the given sample. """
|
||||||
|
return dict(sample, **{f'{instrument}_path': tf.string_join(
|
||||||
|
(self._audio_path, sample[f'{instrument}_path']), SEPARATOR)
|
||||||
|
for instrument in self._instruments})
|
||||||
|
|
||||||
|
def filter_error(self, sample):
|
||||||
|
""" Filter errored sample. """
|
||||||
|
return tf.logical_not(sample['waveform_error'])
|
||||||
|
|
||||||
|
def filter_waveform(self, sample):
|
||||||
|
""" Filter waveform from sample. """
|
||||||
|
return {k: v for k, v in sample.items() if not k == 'waveform'}
|
||||||
|
|
||||||
|
def harmonize_spectrogram(self, sample):
|
||||||
|
""" Ensure same size for vocals and mix spectrograms. """
|
||||||
|
def _reduce(sample):
|
||||||
|
return tf.reduce_min([
|
||||||
|
tf.shape(sample[f'{instrument}_spectrogram'])[0]
|
||||||
|
for instrument in self._instruments])
|
||||||
|
return dict(sample, **{
|
||||||
|
f'{instrument}_spectrogram':
|
||||||
|
sample[f'{instrument}_spectrogram'][:_reduce(sample), :, :]
|
||||||
|
for instrument in self._instruments})
|
||||||
|
|
||||||
|
def filter_short_segments(self, sample):
|
||||||
|
""" Filter out too short segment. """
|
||||||
|
return tf.reduce_any([
|
||||||
|
tf.shape(sample[f'{instrument}_spectrogram'])[0] >= self._T
|
||||||
|
for instrument in self._instruments])
|
||||||
|
|
||||||
|
def random_time_crop(self, sample):
|
||||||
|
""" Random time crop of 11.88s. """
|
||||||
|
return dict(sample, **sync_apply({
|
||||||
|
f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram']
|
||||||
|
for instrument in self._instruments},
|
||||||
|
lambda x: tf.image.random_crop(
|
||||||
|
x, (self._T, len(self._instruments) * self._F, 2),
|
||||||
|
seed=self._random_seed)))
|
||||||
|
|
||||||
|
def random_time_stretch(self, sample):
|
||||||
|
""" Randomly time stretch the given sample. """
|
||||||
|
return dict(sample, **sync_apply({
|
||||||
|
f'{instrument}_spectrogram':
|
||||||
|
sample[f'{instrument}_spectrogram']
|
||||||
|
for instrument in self._instruments},
|
||||||
|
lambda x: random_time_stretch(
|
||||||
|
x, factor_min=0.9, factor_max=1.1)))
|
||||||
|
|
||||||
|
def random_pitch_shift(self, sample):
|
||||||
|
""" Randomly pitch shift the given sample. """
|
||||||
|
return dict(sample, **sync_apply({
|
||||||
|
f'{instrument}_spectrogram':
|
||||||
|
sample[f'{instrument}_spectrogram']
|
||||||
|
for instrument in self._instruments},
|
||||||
|
lambda x: random_pitch_shift(
|
||||||
|
x, shift_min=-1.0, shift_max=1.0), concat_axis=0))
|
||||||
|
|
||||||
|
def map_features(self, sample):
|
||||||
|
""" Select features and annotation of the given sample. """
|
||||||
|
input_ = {
|
||||||
|
f'{self._mix_name}_spectrogram':
|
||||||
|
sample[f'{self._mix_name}_spectrogram']}
|
||||||
|
output = {
|
||||||
|
f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram']
|
||||||
|
for instrument in self._audio_params['instrument_list']}
|
||||||
|
return (input_, output)
|
||||||
|
|
||||||
|
def compute_segments(self, dataset, n_chunks_per_song):
|
||||||
|
""" Computes segments for each song of the dataset.
|
||||||
|
|
||||||
|
:param dataset: Dataset to compute segments for.
|
||||||
|
:param n_chunks_per_song: Number of segment per song to compute.
|
||||||
|
:returns: Segmented dataset.
|
||||||
|
"""
|
||||||
|
if n_chunks_per_song <= 0:
|
||||||
|
raise ValueError('n_chunks_per_song must be positif')
|
||||||
|
datasets = []
|
||||||
|
for k in range(n_chunks_per_song):
|
||||||
|
if n_chunks_per_song > 1:
|
||||||
|
datasets.append(
|
||||||
|
dataset.map(lambda sample: dict(sample, start=tf.maximum(
|
||||||
|
k * (
|
||||||
|
sample['duration'] - self._chunk_duration - 2
|
||||||
|
* self.MARGIN) / (n_chunks_per_song - 1)
|
||||||
|
+ self.MARGIN, 0))))
|
||||||
|
elif n_chunks_per_song == 1: # Take central segment.
|
||||||
|
datasets.append(
|
||||||
|
dataset.map(lambda sample: dict(sample, start=tf.maximum(
|
||||||
|
sample['duration'] / 2 - self._chunk_duration / 2,
|
||||||
|
0))))
|
||||||
|
dataset = datasets[-1]
|
||||||
|
for d in datasets[:-1]:
|
||||||
|
dataset = dataset.concatenate(d)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
@property
|
||||||
|
def instruments(self):
|
||||||
|
""" Instrument dataset builder generator.
|
||||||
|
|
||||||
|
:yield InstrumentBuilder instance.
|
||||||
|
"""
|
||||||
|
if self._instrument_builders is None:
|
||||||
|
self._instrument_builders = []
|
||||||
|
for instrument in self._instruments:
|
||||||
|
self._instrument_builders.append(
|
||||||
|
InstrumentDatasetBuilder(self, instrument))
|
||||||
|
for builder in self._instrument_builders:
|
||||||
|
yield builder
|
||||||
|
|
||||||
|
def cache(self, dataset, cache, wait):
|
||||||
|
""" Cache the given dataset if cache is enabled. Eventually waits for
|
||||||
|
cache to be available (useful if another process is already computing
|
||||||
|
cache) if provided wait flag is True.
|
||||||
|
|
||||||
|
:param dataset: Dataset to be cached if cache is required.
|
||||||
|
:param cache: Path of cache directory to be used, None if no cache.
|
||||||
|
:param wait: If caching is enabled, True is cache should be waited.
|
||||||
|
:returns: Cached dataset if needed, original dataset otherwise.
|
||||||
|
"""
|
||||||
|
if cache is not None:
|
||||||
|
if wait:
|
||||||
|
while not exists(f'{cache}.index'):
|
||||||
|
get_logger().info(
|
||||||
|
'Cache not available, wait %s',
|
||||||
|
self.WAIT_PERIOD)
|
||||||
|
time.sleep(self.WAIT_PERIOD)
|
||||||
|
cache_path = os.path.split(cache)[0]
|
||||||
|
os.makedirs(cache_path, exist_ok=True)
|
||||||
|
return dataset.cache(cache)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def build(
|
||||||
|
self, csv_path,
|
||||||
|
batch_size=8, shuffle=True, convert_to_uint=True,
|
||||||
|
random_data_augmentation=False, random_time_crop=True,
|
||||||
|
infinite_generator=True, cache_directory=None,
|
||||||
|
wait_for_cache=False, num_parallel_calls=4, n_chunks_per_song=2,):
|
||||||
|
"""
|
||||||
|
TO BE DOCUMENTED.
|
||||||
|
"""
|
||||||
|
dataset = dataset_from_csv(csv_path)
|
||||||
|
dataset = self.compute_segments(dataset, n_chunks_per_song)
|
||||||
|
# Shuffle data
|
||||||
|
if shuffle:
|
||||||
|
dataset = dataset.shuffle(
|
||||||
|
buffer_size=200000,
|
||||||
|
seed=self._random_seed,
|
||||||
|
# useless since it is cached :
|
||||||
|
reshuffle_each_iteration=True)
|
||||||
|
# Expand audio path.
|
||||||
|
dataset = dataset.map(self.expand_path)
|
||||||
|
# Load waveform, compute spectrogram, and filtering error,
|
||||||
|
# K bins frequencies, and waveform.
|
||||||
|
N = num_parallel_calls
|
||||||
|
for instrument in self.instruments:
|
||||||
|
dataset = (
|
||||||
|
dataset
|
||||||
|
.map(instrument.load_waveform, num_parallel_calls=N)
|
||||||
|
.filter(self.filter_error)
|
||||||
|
.map(instrument.compute_spectrogram, num_parallel_calls=N)
|
||||||
|
.map(instrument.filter_frequencies))
|
||||||
|
dataset = dataset.map(self.filter_waveform)
|
||||||
|
# Convert to uint before caching in order to save space.
|
||||||
|
if convert_to_uint:
|
||||||
|
for instrument in self.instruments:
|
||||||
|
dataset = dataset.map(instrument.convert_to_uint)
|
||||||
|
dataset = self.cache(dataset, cache_directory, wait_for_cache)
|
||||||
|
# Check for INFINITY (should not happen)
|
||||||
|
for instrument in self.instruments:
|
||||||
|
dataset = dataset.filter(instrument.filter_infinity)
|
||||||
|
# Repeat indefinitly
|
||||||
|
if infinite_generator:
|
||||||
|
dataset = dataset.repeat(count=-1)
|
||||||
|
# Ensure same size for vocals and mix spectrograms.
|
||||||
|
# NOTE: could be done before caching ?
|
||||||
|
dataset = dataset.map(self.harmonize_spectrogram)
|
||||||
|
# Filter out too short segment.
|
||||||
|
# NOTE: could be done before caching ?
|
||||||
|
dataset = dataset.filter(self.filter_short_segments)
|
||||||
|
# Random time crop of 11.88s
|
||||||
|
if random_time_crop:
|
||||||
|
dataset = dataset.map(self.random_time_crop, num_parallel_calls=N)
|
||||||
|
else:
|
||||||
|
# frame_duration = 11.88/T
|
||||||
|
# take central segment (for validation)
|
||||||
|
for instrument in self.instruments:
|
||||||
|
dataset = dataset.map(instrument.time_crop)
|
||||||
|
# Post cache shuffling. Done where the data are the lightest:
|
||||||
|
# after croping but before converting back to float.
|
||||||
|
if shuffle:
|
||||||
|
dataset = dataset.shuffle(
|
||||||
|
buffer_size=256, seed=self._random_seed,
|
||||||
|
reshuffle_each_iteration=True)
|
||||||
|
# Convert back to float32
|
||||||
|
if convert_to_uint:
|
||||||
|
for instrument in self.instruments:
|
||||||
|
dataset = dataset.map(
|
||||||
|
instrument.convert_to_float32, num_parallel_calls=N)
|
||||||
|
M = 8 # Parallel call post caching.
|
||||||
|
# Must be applied with the same factor on mix and vocals.
|
||||||
|
if random_data_augmentation:
|
||||||
|
dataset = (
|
||||||
|
dataset
|
||||||
|
.map(self.random_time_stretch, num_parallel_calls=M)
|
||||||
|
.map(self.random_pitch_shift, num_parallel_calls=M))
|
||||||
|
# Filter by shape (remove badly shaped tensors).
|
||||||
|
for instrument in self.instruments:
|
||||||
|
dataset = (
|
||||||
|
dataset
|
||||||
|
.filter(instrument.filter_shape)
|
||||||
|
.map(instrument.reshape_spectrogram))
|
||||||
|
# Select features and annotation.
|
||||||
|
dataset = dataset.map(self.map_features)
|
||||||
|
# Make batch (done after selection to avoid
|
||||||
|
# error due to unprocessed instrument spectrogram batching).
|
||||||
|
dataset = dataset.batch(batch_size)
|
||||||
|
return dataset
|
||||||
397
spleeter/model/__init__.py
Normal file
397
spleeter/model/__init__.py
Normal file
@@ -0,0 +1,397 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
""" This package provide an estimator builder as well as model functions. """
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
# pylint: disable=import-error
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.contrib.signal import stft, inverse_stft, hann_window
|
||||||
|
# pylint: enable=import-error
|
||||||
|
|
||||||
|
from ..utils.tensor import pad_and_partition, pad_and_reshape
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_function(model_type):
|
||||||
|
"""
|
||||||
|
Get tensorflow function of the model to be applied to the input tensor.
|
||||||
|
For instance "unet.softmax_unet" will return the softmax_unet function
|
||||||
|
in the "unet.py" submodule of the current module (spleeter.model).
|
||||||
|
|
||||||
|
Params:
|
||||||
|
- model_type: str
|
||||||
|
the relative module path to the model function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensorflow function to be applied to the input tensor to get the
|
||||||
|
multitrack output.
|
||||||
|
"""
|
||||||
|
relative_path_to_module = '.'.join(model_type.split('.')[:-1])
|
||||||
|
model_name = model_type.split('.')[-1]
|
||||||
|
main_module = '.'.join((__name__, 'functions'))
|
||||||
|
path_to_module = f'{main_module}.{relative_path_to_module}'
|
||||||
|
module = importlib.import_module(path_to_module)
|
||||||
|
model_function = getattr(module, model_name)
|
||||||
|
return model_function
|
||||||
|
|
||||||
|
|
||||||
|
class EstimatorSpecBuilder(object):
|
||||||
|
""" A builder class that allows to builds a multitrack unet model
|
||||||
|
estimator. The built model estimator has a different behaviour when
|
||||||
|
used in a train/eval mode and in predict mode.
|
||||||
|
|
||||||
|
* In train/eval mode: it takes as input and outputs magnitude spectrogram
|
||||||
|
* In predict mode: it takes as input and outputs waveform. The whole
|
||||||
|
separation process is then done in this function
|
||||||
|
for performance reason: it makes it possible to run
|
||||||
|
the whole spearation process (including STFT and
|
||||||
|
inverse STFT) on GPU.
|
||||||
|
|
||||||
|
:Example:
|
||||||
|
|
||||||
|
>>> from spleeter.model import EstimatorSpecBuilder
|
||||||
|
>>> builder = EstimatorSpecBuilder()
|
||||||
|
>>> builder.build_prediction_model()
|
||||||
|
>>> builder.build_evaluation_model()
|
||||||
|
>>> builder.build_training_model()
|
||||||
|
|
||||||
|
>>> from spleeter.model import model_fn
|
||||||
|
>>> estimator = tf.estimator.Estimator(model_fn=model_fn, ...)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Supported model functions.
|
||||||
|
DEFAULT_MODEL = 'unet.unet'
|
||||||
|
|
||||||
|
# Supported loss functions.
|
||||||
|
L1_MASK = 'L1_mask'
|
||||||
|
WEIGHTED_L1_MASK = 'weighted_L1_mask'
|
||||||
|
|
||||||
|
# Supported optimizers.
|
||||||
|
ADADELTA = 'Adadelta'
|
||||||
|
SGD = 'SGD'
|
||||||
|
|
||||||
|
# Math constants.
|
||||||
|
WINDOW_COMPENSATION_FACTOR = 2./3.
|
||||||
|
EPSILON = 1e-10
|
||||||
|
|
||||||
|
def __init__(self, features, params):
|
||||||
|
""" Default constructor. Depending on built model
|
||||||
|
usage, the provided features should be different:
|
||||||
|
|
||||||
|
* In train/eval mode: features is a dictionary with a
|
||||||
|
"mix_spectrogram" key, associated to the
|
||||||
|
mix magnitude spectrogram.
|
||||||
|
* In predict mode: features is a dictionary with a "waveform"
|
||||||
|
key, associated to the waveform of the sound
|
||||||
|
to be separated.
|
||||||
|
|
||||||
|
:param features: The input features for the estimator.
|
||||||
|
:param params: Some hyperparameters as a dictionary.
|
||||||
|
"""
|
||||||
|
self._features = features
|
||||||
|
self._params = params
|
||||||
|
# Get instrument name.
|
||||||
|
self._mix_name = params['mix_name']
|
||||||
|
self._instruments = params['instrument_list']
|
||||||
|
# Get STFT/signals parameters
|
||||||
|
self._n_channels = params['n_channels']
|
||||||
|
self._T = params['T']
|
||||||
|
self._F = params['F']
|
||||||
|
self._frame_length = params['frame_length']
|
||||||
|
self._frame_step = params['frame_step']
|
||||||
|
|
||||||
|
def _build_output_dict(self):
|
||||||
|
""" Created a batch_sizexTxFxn_channels input tensor containing
|
||||||
|
mix magnitude spectrogram, then an output dict from it according
|
||||||
|
to the selected model in internal parameters.
|
||||||
|
|
||||||
|
:returns: Build output dict.
|
||||||
|
:raise ValueError: If required model_type is not supported.
|
||||||
|
"""
|
||||||
|
input_tensor = self._features[f'{self._mix_name}_spectrogram']
|
||||||
|
model = self._params.get('model', None)
|
||||||
|
if model is not None:
|
||||||
|
model_type = model.get('type', self.DEFAULT_MODEL)
|
||||||
|
else:
|
||||||
|
model_type = self.DEFAULT_MODEL
|
||||||
|
try:
|
||||||
|
apply_model = get_model_function(model_type)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
raise ValueError(f'No model function {model_type} found')
|
||||||
|
return apply_model(
|
||||||
|
input_tensor,
|
||||||
|
self._instruments,
|
||||||
|
self._params['model']['params'])
|
||||||
|
|
||||||
|
def _build_loss(self, output_dict, labels):
|
||||||
|
""" Construct tensorflow loss and metrics
|
||||||
|
|
||||||
|
:param output_dict: dictionary of network outputs (key: instrument
|
||||||
|
name, value: estimated spectrogram of the instrument)
|
||||||
|
:param labels: dictionary of target outputs (key: instrument
|
||||||
|
name, value: ground truth spectrogram of the instrument)
|
||||||
|
:returns: tensorflow (loss, metrics) tuple.
|
||||||
|
"""
|
||||||
|
loss_type = self._params.get('loss_type', self.L1_MASK)
|
||||||
|
if loss_type == self.L1_MASK:
|
||||||
|
losses = {
|
||||||
|
name: tf.reduce_mean(tf.abs(output - labels[name]))
|
||||||
|
for name, output in output_dict.items()
|
||||||
|
}
|
||||||
|
elif loss_type == self.WEIGHTED_L1_MASK:
|
||||||
|
losses = {
|
||||||
|
name: tf.reduce_mean(
|
||||||
|
tf.reduce_mean(
|
||||||
|
labels[name],
|
||||||
|
axis=[1, 2, 3],
|
||||||
|
keep_dims=True) *
|
||||||
|
tf.abs(output - labels[name]))
|
||||||
|
for name, output in output_dict.items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unkwnown loss type: {loss_type}")
|
||||||
|
loss = tf.reduce_sum(list(losses.values()))
|
||||||
|
# Add metrics for monitoring each instrument.
|
||||||
|
metrics = {k: tf.compat.v1.metrics.mean(v) for k, v in losses.items()}
|
||||||
|
metrics['absolute_difference'] = tf.compat.v1.metrics.mean(loss)
|
||||||
|
return loss, metrics
|
||||||
|
|
||||||
|
def _build_optimizer(self):
|
||||||
|
""" Builds an optimizer instance from internal parameter values.
|
||||||
|
|
||||||
|
Default to AdamOptimizer if not specified.
|
||||||
|
|
||||||
|
:returns: Optimizer instance from internal configuration.
|
||||||
|
"""
|
||||||
|
name = self._params.get('optimizer')
|
||||||
|
if name == self.ADADELTA:
|
||||||
|
return tf.compat.v1.train.AdadeltaOptimizer()
|
||||||
|
rate = self._params['learning_rate']
|
||||||
|
if name == self.SGD:
|
||||||
|
return tf.compat.v1.train.GradientDescentOptimizer(rate)
|
||||||
|
return tf.compat.v1.train.AdamOptimizer(rate)
|
||||||
|
|
||||||
|
def _build_stft_feature(self):
|
||||||
|
""" Compute STFT of waveform and slice the STFT in segment
|
||||||
|
with the right length to feed the network.
|
||||||
|
"""
|
||||||
|
stft_feature = tf.transpose(
|
||||||
|
stft(
|
||||||
|
tf.transpose(self._features['waveform']),
|
||||||
|
self._frame_length,
|
||||||
|
self._frame_step,
|
||||||
|
window_fn=lambda frame_length, dtype: (
|
||||||
|
hann_window(frame_length, periodic=True, dtype=dtype)),
|
||||||
|
pad_end=True),
|
||||||
|
perm=[1, 2, 0])
|
||||||
|
self._features[f'{self._mix_name}_stft'] = stft_feature
|
||||||
|
self._features[f'{self._mix_name}_spectrogram'] = tf.abs(
|
||||||
|
pad_and_partition(stft_feature, self._T))[:, :, :self._F, :]
|
||||||
|
|
||||||
|
def _inverse_stft(self, stft):
|
||||||
|
""" Inverse and reshape the given STFT
|
||||||
|
|
||||||
|
:param stft: input STFT
|
||||||
|
:returns: inverse STFT (waveform)
|
||||||
|
"""
|
||||||
|
inversed = inverse_stft(
|
||||||
|
tf.transpose(stft, perm=[2, 0, 1]),
|
||||||
|
self._frame_length,
|
||||||
|
self._frame_step,
|
||||||
|
window_fn=lambda frame_length, dtype: (
|
||||||
|
hann_window(frame_length, periodic=True, dtype=dtype))
|
||||||
|
) * self.WINDOW_COMPENSATION_FACTOR
|
||||||
|
reshaped = tf.transpose(inversed)
|
||||||
|
return reshaped[:tf.shape(self._features['waveform'])[0], :]
|
||||||
|
|
||||||
|
def _build_mwf_output_waveform(self, output_dict):
|
||||||
|
""" Perform separation with multichannel Wiener Filtering using Norbert.
|
||||||
|
Note: multichannel Wiener Filtering is not coded in Tensorflow and thus
|
||||||
|
may be quite slow.
|
||||||
|
|
||||||
|
:param output_dict: dictionary of estimated spectrogram (key: instrument
|
||||||
|
name, value: estimated spectrogram of the instrument)
|
||||||
|
:returns: dictionary of separated waveforms (key: instrument name,
|
||||||
|
value: estimated waveform of the instrument)
|
||||||
|
"""
|
||||||
|
import norbert # pylint: disable=import-error
|
||||||
|
x = self._features[f'{self._mix_name}_stft']
|
||||||
|
v = tf.stack(
|
||||||
|
[
|
||||||
|
pad_and_reshape(
|
||||||
|
output_dict[f'{instrument}_spectrogram'],
|
||||||
|
self._frame_length,
|
||||||
|
self._F)[:tf.shape(x)[0], ...]
|
||||||
|
for instrument in self._instruments
|
||||||
|
],
|
||||||
|
axis=3)
|
||||||
|
input_args = [v, x]
|
||||||
|
stft_function = tf.py_function(
|
||||||
|
lambda v, x: norbert.wiener(v.numpy(), x.numpy()),
|
||||||
|
input_args,
|
||||||
|
tf.complex64),
|
||||||
|
return {
|
||||||
|
instrument: self._inverse_stft(stft_function[0][:, :, :, k])
|
||||||
|
for k, instrument in enumerate(self._instruments)
|
||||||
|
}
|
||||||
|
|
||||||
|
def _extend_mask(self, mask):
|
||||||
|
""" Extend mask, from reduced number of frequency bin to the number of
|
||||||
|
frequency bin in the STFT.
|
||||||
|
|
||||||
|
:param mask: restricted mask
|
||||||
|
:returns: extended mask
|
||||||
|
:raise ValueError: If invalid mask_extension parameter is set.
|
||||||
|
"""
|
||||||
|
extension = self._params['mask_extension']
|
||||||
|
# Extend with average
|
||||||
|
# (dispatch according to energy in the processed band)
|
||||||
|
if extension == "average":
|
||||||
|
extension_row = tf.reduce_mean(mask, axis=2, keepdims=True)
|
||||||
|
# Extend with 0
|
||||||
|
# (avoid extension artifacts but not conservative separation)
|
||||||
|
elif extension == "zeros":
|
||||||
|
mask_shape = tf.shape(mask)
|
||||||
|
extension_row = tf.zeros((
|
||||||
|
mask_shape[0],
|
||||||
|
mask_shape[1],
|
||||||
|
1,
|
||||||
|
mask_shape[-1]))
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Invalid mask_extension parameter {extension}')
|
||||||
|
n_extra_row = (self._frame_length) // 2 + 1 - self._F
|
||||||
|
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
|
||||||
|
return tf.concat([mask, extension], axis=2)
|
||||||
|
|
||||||
|
def _build_manual_output_waveform(self, output_dict):
|
||||||
|
""" Perform ratio mask separation
|
||||||
|
|
||||||
|
:param output_dict: dictionary of estimated spectrogram (key: instrument
|
||||||
|
name, value: estimated spectrogram of the instrument)
|
||||||
|
:returns: dictionary of separated waveforms (key: instrument name,
|
||||||
|
value: estimated waveform of the instrument)
|
||||||
|
"""
|
||||||
|
separation_exponent = self._params['separation_exponent']
|
||||||
|
output_sum = tf.reduce_sum(
|
||||||
|
[e ** separation_exponent for e in output_dict.values()],
|
||||||
|
axis=0
|
||||||
|
) + self.EPSILON
|
||||||
|
output_waveform = {}
|
||||||
|
for instrument in self._instruments:
|
||||||
|
output = output_dict[f'{instrument}_spectrogram']
|
||||||
|
# Compute mask with the model.
|
||||||
|
instrument_mask = (
|
||||||
|
output ** separation_exponent
|
||||||
|
+ (self.EPSILON / len(output_dict))) / output_sum
|
||||||
|
# Extend mask;
|
||||||
|
instrument_mask = self._extend_mask(instrument_mask)
|
||||||
|
# Stack back mask.
|
||||||
|
old_shape = tf.shape(instrument_mask)
|
||||||
|
new_shape = tf.concat(
|
||||||
|
[[old_shape[0] * old_shape[1]], old_shape[2:]],
|
||||||
|
axis=0)
|
||||||
|
instrument_mask = tf.reshape(instrument_mask, new_shape)
|
||||||
|
# Remove padded part (for mask having the same size as STFT);
|
||||||
|
stft_feature = self._features[f'{self._mix_name}_stft']
|
||||||
|
instrument_mask = instrument_mask[
|
||||||
|
:tf.shape(stft_feature)[0], ...]
|
||||||
|
# Compute masked STFT and normalize it.
|
||||||
|
output_waveform[instrument] = self._inverse_stft(
|
||||||
|
tf.cast(instrument_mask, dtype=tf.complex64) * stft_feature)
|
||||||
|
return output_waveform
|
||||||
|
|
||||||
|
def _build_output_waveform(self, output_dict):
|
||||||
|
""" Build output waveform from given output dict in order to be used in
|
||||||
|
prediction context. Regarding of the configuration building method will
|
||||||
|
be using MWF.
|
||||||
|
|
||||||
|
:param output_dict: Output dict to build output waveform from.
|
||||||
|
:returns: Built output waveform.
|
||||||
|
"""
|
||||||
|
if self._params.get('MWF', False):
|
||||||
|
output_waveform = self._build_mwf_output_waveform(output_dict)
|
||||||
|
else:
|
||||||
|
output_waveform = self._build_manual_output_waveform(output_dict)
|
||||||
|
if 'audio_id' in self._features:
|
||||||
|
output_waveform['audio_id'] = self._features['audio_id']
|
||||||
|
return output_waveform
|
||||||
|
|
||||||
|
def build_predict_model(self):
|
||||||
|
""" Builder interface for creating model instance that aims to perform
|
||||||
|
prediction / inference over given track. The output of such estimator
|
||||||
|
will be a dictionary with a "<instrument>" key per separated instrument
|
||||||
|
, associated to the estimated separated waveform of the instrument.
|
||||||
|
|
||||||
|
:returns: An estimator for performing prediction.
|
||||||
|
"""
|
||||||
|
self._build_stft_feature()
|
||||||
|
output_dict = self._build_output_dict()
|
||||||
|
output_waveform = self._build_output_waveform(output_dict)
|
||||||
|
return tf.estimator.EstimatorSpec(
|
||||||
|
tf.estimator.ModeKeys.PREDICT,
|
||||||
|
predictions=output_waveform)
|
||||||
|
|
||||||
|
def build_evaluation_model(self, labels):
|
||||||
|
""" Builder interface for creating model instance that aims to perform
|
||||||
|
model evaluation. The output of such estimator will be a dictionary
|
||||||
|
with a key "<instrument>_spectrogram" per separated instrument,
|
||||||
|
associated to the estimated separated instrument magnitude spectrogram.
|
||||||
|
|
||||||
|
:param labels: Model labels.
|
||||||
|
:returns: An estimator for performing model evaluation.
|
||||||
|
"""
|
||||||
|
output_dict = self._build_output_dict()
|
||||||
|
loss, metrics = self._build_loss(output_dict, labels)
|
||||||
|
return tf.estimator.EstimatorSpec(
|
||||||
|
tf.estimator.ModeKeys.EVAL,
|
||||||
|
loss=loss,
|
||||||
|
eval_metric_ops=metrics)
|
||||||
|
|
||||||
|
def build_train_model(self, labels):
|
||||||
|
""" Builder interface for creating model instance that aims to perform
|
||||||
|
model training. The output of such estimator will be a dictionary
|
||||||
|
with a key "<instrument>_spectrogram" per separated instrument,
|
||||||
|
associated to the estimated separated instrument magnitude spectrogram.
|
||||||
|
|
||||||
|
:param labels: Model labels.
|
||||||
|
:returns: An estimator for performing model training.
|
||||||
|
"""
|
||||||
|
output_dict = self._build_output_dict()
|
||||||
|
loss, metrics = self._build_loss(output_dict, labels)
|
||||||
|
optimizer = self._build_optimizer()
|
||||||
|
train_operation = optimizer.minimize(
|
||||||
|
loss=loss,
|
||||||
|
global_step=tf.compat.v1.train.get_global_step())
|
||||||
|
return tf.estimator.EstimatorSpec(
|
||||||
|
mode=tf.estimator.ModeKeys.TRAIN,
|
||||||
|
loss=loss,
|
||||||
|
train_op=train_operation,
|
||||||
|
eval_metric_ops=metrics,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn(features, labels, mode, params, config):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param features:
|
||||||
|
:param labels:
|
||||||
|
:param mode: Estimator mode.
|
||||||
|
:param params:
|
||||||
|
:param config: TF configuration (not used).
|
||||||
|
:returns: Built EstimatorSpec.
|
||||||
|
:raise ValueError: If estimator mode is not supported.
|
||||||
|
"""
|
||||||
|
builder = EstimatorSpecBuilder(features, params)
|
||||||
|
if mode == tf.estimator.ModeKeys.PREDICT:
|
||||||
|
return builder.build_predict_model()
|
||||||
|
elif mode == tf.estimator.ModeKeys.EVAL:
|
||||||
|
return builder.build_evaluation_model(labels)
|
||||||
|
elif mode == tf.estimator.ModeKeys.TRAIN:
|
||||||
|
return builder.build_train_model(labels)
|
||||||
|
raise ValueError(f'Unknown mode {mode}')
|
||||||
27
spleeter/model/functions/__init__.py
Normal file
27
spleeter/model/functions/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
""" This package provide model functions. """
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
def apply(function, input_tensor, instruments, params={}):
|
||||||
|
""" Apply given function to the input tensor.
|
||||||
|
|
||||||
|
:param function: Function to be applied to tensor.
|
||||||
|
:param input_tensor: Tensor to apply blstm to.
|
||||||
|
:param instruments: Iterable that provides a collection of instruments.
|
||||||
|
:param params: (Optional) dict of BLSTM parameters.
|
||||||
|
:returns: Created output tensor dict.
|
||||||
|
"""
|
||||||
|
output_dict = {}
|
||||||
|
for instrument in instruments:
|
||||||
|
out_name = f'{instrument}_spectrogram'
|
||||||
|
output_dict[out_name] = function(
|
||||||
|
input_tensor,
|
||||||
|
output_name=out_name,
|
||||||
|
params=params)
|
||||||
|
return output_dict
|
||||||
76
spleeter/model/functions/blstm.py
Normal file
76
spleeter/model/functions/blstm.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
"""
|
||||||
|
This system (UHL1) uses a bi-directional LSTM network as described in :
|
||||||
|
|
||||||
|
`S. Uhlich, M. Porcu, F. Giron, M. Enenkl, T. Kemp, N. Takahashi and
|
||||||
|
Y. Mitsufuji.
|
||||||
|
|
||||||
|
"Improving music source separation based on deep neural networks through
|
||||||
|
data augmentation and network blending", Proc. ICASSP, 2017.`
|
||||||
|
|
||||||
|
It has three BLSTM layers, each having 500 cells. For each instrument,
|
||||||
|
a network is trained which predicts the target instrument amplitude from
|
||||||
|
the mixture amplitude in the STFT domain (frame size: 4096, hop size:
|
||||||
|
1024). The raw output of each network is then combined by a multichannel
|
||||||
|
Wiener filter. The network is trained on musdb where we split train into
|
||||||
|
train_train and train_valid with 86 and 14 songs, respectively. The
|
||||||
|
validation set is used to perform early stopping and hyperparameter
|
||||||
|
selection (LSTM layer dropout rate, regularization strength).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=import-error
|
||||||
|
from tensorflow.compat.v1.keras.initializers import he_uniform
|
||||||
|
from tensorflow.compat.v1.keras.layers import CuDNNLSTM
|
||||||
|
from tensorflow.keras.layers import (
|
||||||
|
Bidirectional,
|
||||||
|
Dense,
|
||||||
|
Flatten,
|
||||||
|
Reshape,
|
||||||
|
TimeDistributed)
|
||||||
|
# pylint: enable=import-error
|
||||||
|
|
||||||
|
from . import apply
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
def apply_blstm(input_tensor, output_name='output', params={}):
|
||||||
|
""" Apply BLSTM to the given input_tensor.
|
||||||
|
|
||||||
|
:param input_tensor: Input of the model.
|
||||||
|
:param output_name: (Optional) name of the output, default to 'output'.
|
||||||
|
:param params: (Optional) dict of BLSTM parameters.
|
||||||
|
:returns: Output tensor.
|
||||||
|
"""
|
||||||
|
units = params.get('lstm_units', 250)
|
||||||
|
kernel_initializer = he_uniform(seed=50)
|
||||||
|
flatten_input = TimeDistributed(Flatten())((input_tensor))
|
||||||
|
|
||||||
|
def create_bidirectional():
|
||||||
|
return Bidirectional(
|
||||||
|
CuDNNLSTM(
|
||||||
|
units,
|
||||||
|
kernel_initializer=kernel_initializer,
|
||||||
|
return_sequences=True))
|
||||||
|
|
||||||
|
l1 = create_bidirectional()((flatten_input))
|
||||||
|
l2 = create_bidirectional()((l1))
|
||||||
|
l3 = create_bidirectional()((l2))
|
||||||
|
dense = TimeDistributed(
|
||||||
|
Dense(
|
||||||
|
int(flatten_input.shape[2]),
|
||||||
|
activation='relu',
|
||||||
|
kernel_initializer=kernel_initializer))((l3))
|
||||||
|
output = TimeDistributed(
|
||||||
|
Reshape(input_tensor.shape[2:]),
|
||||||
|
name=output_name)(dense)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def blstm(input_tensor, output_name='output', params={}):
|
||||||
|
""" Model function applier. """
|
||||||
|
return apply(apply_blstm, input_tensor, output_name, params)
|
||||||
201
spleeter/model/functions/unet.py
Normal file
201
spleeter/model/functions/unet.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
"""
|
||||||
|
This module contains building functions for U-net source separation source
|
||||||
|
separation models.
|
||||||
|
Each instrument is modeled by a single U-net convolutional/deconvolutional
|
||||||
|
network that take a mix spectrogram as input and the estimated sound spectrogram
|
||||||
|
as output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
# pylint: disable=import-error
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.keras.layers import (
|
||||||
|
BatchNormalization,
|
||||||
|
Concatenate,
|
||||||
|
Conv2D,
|
||||||
|
Conv2DTranspose,
|
||||||
|
Dropout,
|
||||||
|
ELU,
|
||||||
|
LeakyReLU,
|
||||||
|
Multiply,
|
||||||
|
ReLU,
|
||||||
|
Softmax)
|
||||||
|
from tensorflow.compat.v1 import logging
|
||||||
|
from tensorflow.compat.v1.keras.initializers import he_uniform
|
||||||
|
# pylint: enable=import-error
|
||||||
|
|
||||||
|
from . import apply
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
def _get_conv_activation_layer(params):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param params:
|
||||||
|
:returns: Required Activation function.
|
||||||
|
"""
|
||||||
|
conv_activation = params.get('conv_activation')
|
||||||
|
if conv_activation == 'ReLU':
|
||||||
|
return ReLU()
|
||||||
|
elif conv_activation == 'ELU':
|
||||||
|
return ELU()
|
||||||
|
return LeakyReLU(0.2)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_deconv_activation_layer(params):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param params:
|
||||||
|
:returns: Required Activation function.
|
||||||
|
"""
|
||||||
|
deconv_activation = params.get('deconv_activation')
|
||||||
|
if deconv_activation == 'LeakyReLU':
|
||||||
|
return LeakyReLU(0.2)
|
||||||
|
elif deconv_activation == 'ELU':
|
||||||
|
return ELU()
|
||||||
|
return ReLU()
|
||||||
|
|
||||||
|
|
||||||
|
def apply_unet(
|
||||||
|
input_tensor,
|
||||||
|
output_name='output',
|
||||||
|
params={},
|
||||||
|
output_mask_logit=False):
|
||||||
|
""" Apply a convolutionnal U-net to model a single instrument (one U-net
|
||||||
|
is used for each instrument).
|
||||||
|
|
||||||
|
:param input_tensor:
|
||||||
|
:param output_name: (Optional) , default to 'output'
|
||||||
|
:param params: (Optional) , default to empty dict.
|
||||||
|
:param output_mask_logit: (Optional) , default to False.
|
||||||
|
"""
|
||||||
|
logging.info(f'Apply unet for {output_name}')
|
||||||
|
conv_n_filters = params.get('conv_n_filters', [16, 32, 64, 128, 256, 512])
|
||||||
|
conv_activation_layer = _get_conv_activation_layer(params)
|
||||||
|
deconv_activation_layer = _get_deconv_activation_layer(params)
|
||||||
|
kernel_initializer = he_uniform(seed=50)
|
||||||
|
conv2d_factory = partial(
|
||||||
|
Conv2D,
|
||||||
|
strides=(2, 2),
|
||||||
|
padding='same',
|
||||||
|
kernel_initializer=kernel_initializer)
|
||||||
|
# First layer.
|
||||||
|
conv1 = conv2d_factory(conv_n_filters[0], (5, 5))(input_tensor)
|
||||||
|
batch1 = BatchNormalization(axis=-1)(conv1)
|
||||||
|
rel1 = conv_activation_layer(batch1)
|
||||||
|
# Second layer.
|
||||||
|
conv2 = conv2d_factory(conv_n_filters[1], (5, 5))(rel1)
|
||||||
|
batch2 = BatchNormalization(axis=-1)(conv2)
|
||||||
|
rel2 = conv_activation_layer(batch2)
|
||||||
|
# Third layer.
|
||||||
|
conv3 = conv2d_factory(conv_n_filters[2], (5, 5))(rel2)
|
||||||
|
batch3 = BatchNormalization(axis=-1)(conv3)
|
||||||
|
rel3 = conv_activation_layer(batch3)
|
||||||
|
# Fourth layer.
|
||||||
|
conv4 = conv2d_factory(conv_n_filters[3], (5, 5))(rel3)
|
||||||
|
batch4 = BatchNormalization(axis=-1)(conv4)
|
||||||
|
rel4 = conv_activation_layer(batch4)
|
||||||
|
# Fifth layer.
|
||||||
|
conv5 = conv2d_factory(conv_n_filters[4], (5, 5))(rel4)
|
||||||
|
batch5 = BatchNormalization(axis=-1)(conv5)
|
||||||
|
rel5 = conv_activation_layer(batch5)
|
||||||
|
# Sixth layer
|
||||||
|
conv6 = conv2d_factory(conv_n_filters[5], (5, 5))(rel5)
|
||||||
|
batch6 = BatchNormalization(axis=-1)(conv6)
|
||||||
|
_ = conv_activation_layer(batch6)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
conv2d_transpose_factory = partial(
|
||||||
|
Conv2DTranspose,
|
||||||
|
strides=(2, 2),
|
||||||
|
padding='same',
|
||||||
|
kernel_initializer=kernel_initializer)
|
||||||
|
#
|
||||||
|
up1 = conv2d_transpose_factory(conv_n_filters[4], (5, 5))((conv6))
|
||||||
|
up1 = deconv_activation_layer(up1)
|
||||||
|
batch7 = BatchNormalization(axis=-1)(up1)
|
||||||
|
drop1 = Dropout(0.5)(batch7)
|
||||||
|
merge1 = Concatenate(axis=-1)([conv5, drop1])
|
||||||
|
#
|
||||||
|
up2 = conv2d_transpose_factory(conv_n_filters[3], (5, 5))((merge1))
|
||||||
|
up2 = deconv_activation_layer(up2)
|
||||||
|
batch8 = BatchNormalization(axis=-1)(up2)
|
||||||
|
drop2 = Dropout(0.5)(batch8)
|
||||||
|
merge2 = Concatenate(axis=-1)([conv4, drop2])
|
||||||
|
#
|
||||||
|
up3 = conv2d_transpose_factory(conv_n_filters[2], (5, 5))((merge2))
|
||||||
|
up3 = deconv_activation_layer(up3)
|
||||||
|
batch9 = BatchNormalization(axis=-1)(up3)
|
||||||
|
drop3 = Dropout(0.5)(batch9)
|
||||||
|
merge3 = Concatenate(axis=-1)([conv3, drop3])
|
||||||
|
#
|
||||||
|
up4 = conv2d_transpose_factory(conv_n_filters[1], (5, 5))((merge3))
|
||||||
|
up4 = deconv_activation_layer(up4)
|
||||||
|
batch10 = BatchNormalization(axis=-1)(up4)
|
||||||
|
merge4 = Concatenate(axis=-1)([conv2, batch10])
|
||||||
|
#
|
||||||
|
up5 = conv2d_transpose_factory(conv_n_filters[0], (5, 5))((merge4))
|
||||||
|
up5 = deconv_activation_layer(up5)
|
||||||
|
batch11 = BatchNormalization(axis=-1)(up5)
|
||||||
|
merge5 = Concatenate(axis=-1)([conv1, batch11])
|
||||||
|
#
|
||||||
|
up6 = conv2d_transpose_factory(1, (5, 5), strides=(2, 2))((merge5))
|
||||||
|
up6 = deconv_activation_layer(up6)
|
||||||
|
batch12 = BatchNormalization(axis=-1)(up6)
|
||||||
|
# Last layer to ensure initial shape reconstruction.
|
||||||
|
if not output_mask_logit:
|
||||||
|
up7 = Conv2D(
|
||||||
|
2,
|
||||||
|
(4, 4),
|
||||||
|
dilation_rate=(2, 2),
|
||||||
|
activation='sigmoid',
|
||||||
|
padding='same',
|
||||||
|
kernel_initializer=kernel_initializer)((batch12))
|
||||||
|
output = Multiply(name=output_name)([up7, input_tensor])
|
||||||
|
return output
|
||||||
|
return Conv2D(
|
||||||
|
2,
|
||||||
|
(4, 4),
|
||||||
|
dilation_rate=(2, 2),
|
||||||
|
padding='same',
|
||||||
|
kernel_initializer=kernel_initializer)((batch12))
|
||||||
|
|
||||||
|
|
||||||
|
def unet(input_tensor, instruments, params={}):
|
||||||
|
""" Model function applier. """
|
||||||
|
return apply(apply_unet, input_tensor, instruments, params)
|
||||||
|
|
||||||
|
|
||||||
|
def softmax_unet(input_tensor, instruments, params={}):
|
||||||
|
""" Apply softmax to multitrack unet in order to have mask suming to one.
|
||||||
|
|
||||||
|
:param input_tensor: Tensor to apply blstm to.
|
||||||
|
:param instruments: Iterable that provides a collection of instruments.
|
||||||
|
:param params: (Optional) dict of BLSTM parameters.
|
||||||
|
:returns: Created output tensor dict.
|
||||||
|
"""
|
||||||
|
logit_mask_list = []
|
||||||
|
for instrument in instruments:
|
||||||
|
out_name = f'{instrument}_spectrogram'
|
||||||
|
logit_mask_list.append(
|
||||||
|
apply_unet(
|
||||||
|
input_tensor,
|
||||||
|
output_name=out_name,
|
||||||
|
params=params,
|
||||||
|
output_mask_logit=True))
|
||||||
|
masks = Softmax(axis=4)(tf.stack(logit_mask_list, axis=4))
|
||||||
|
output_dict = {}
|
||||||
|
for i, instrument in enumerate(instruments):
|
||||||
|
out_name = f'{instrument}_spectrogram'
|
||||||
|
output_dict[out_name] = Multiply(name=out_name)([
|
||||||
|
masks[..., i],
|
||||||
|
input_tensor])
|
||||||
|
return output_dict
|
||||||
79
spleeter/model/provider/__init__.py
Normal file
79
spleeter/model/provider/__init__.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
"""
|
||||||
|
This package provides tools for downloading model from network
|
||||||
|
using remote storage abstraction.
|
||||||
|
|
||||||
|
:Example:
|
||||||
|
|
||||||
|
>>> provider = MyProviderImplementation()
|
||||||
|
>>> provider.get('/path/to/local/storage', params)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from os import environ, makedirs
|
||||||
|
from os.path import exists, isabs, join, sep
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProvider(ABC):
|
||||||
|
"""
|
||||||
|
A ModelProvider manages model files on disk and
|
||||||
|
file download is not available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_MODEL_PATH = environ.get('MODEL_PATH', 'pretrained_models')
|
||||||
|
MODEL_PROBE_PATH = '.probe'
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def download(self, name, path):
|
||||||
|
""" Download model denoted by the given name to disk.
|
||||||
|
|
||||||
|
:param name: Name of the model to download.
|
||||||
|
:param path: Path of the directory to save model into.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def writeProbe(self, directory):
|
||||||
|
""" Write a model probe file into the given directory.
|
||||||
|
|
||||||
|
:param directory: Directory to write probe into.
|
||||||
|
"""
|
||||||
|
with open(join(directory, self.MODEL_PROBE_PATH), 'w') as stream:
|
||||||
|
stream.write('OK')
|
||||||
|
|
||||||
|
def get(self, model_directory):
|
||||||
|
""" Ensures required model is available at given location.
|
||||||
|
|
||||||
|
:param model_directory: Expected model_directory to be available.
|
||||||
|
:raise IOError: If model can not be retrieved.
|
||||||
|
"""
|
||||||
|
# Expend model directory if needed.
|
||||||
|
if not isabs(model_directory):
|
||||||
|
model_directory = join(self.DEFAULT_MODEL_PATH, model_directory)
|
||||||
|
# Download it if not exists.
|
||||||
|
model_probe = join(model_directory, self.MODEL_PROBE_PATH)
|
||||||
|
if not exists(model_probe):
|
||||||
|
if not exists(model_directory):
|
||||||
|
makedirs(model_directory)
|
||||||
|
self.download(
|
||||||
|
model_directory.split(sep)[-1],
|
||||||
|
model_directory)
|
||||||
|
self.writeProbe(model_directory)
|
||||||
|
return model_directory
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_model_provider():
|
||||||
|
""" Builds and returns a default model provider.
|
||||||
|
|
||||||
|
:returns: A default model provider instance to use.
|
||||||
|
"""
|
||||||
|
from .github import GithubModelProvider
|
||||||
|
host = environ.get('GITHUB_HOST', 'https://github.com')
|
||||||
|
repository = environ.get('GITHUB_REPOSITORY', 'deezer/spleeter')
|
||||||
|
release = environ.get('GITHUB_RELEASE', GithubModelProvider.LATEST_RELEASE)
|
||||||
|
return GithubModelProvider(host, repository, release)
|
||||||
73
spleeter/model/provider/github.py
Normal file
73
spleeter/model/provider/github.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
"""
|
||||||
|
A ModelProvider backed by Github Release feature.
|
||||||
|
|
||||||
|
:Example:
|
||||||
|
|
||||||
|
>>> from spleeter.model.provider import github
|
||||||
|
>>> provider = github.GithubModelProvider(
|
||||||
|
'github.com',
|
||||||
|
'Deezer/spleeter',
|
||||||
|
'latest')
|
||||||
|
>>> provider.download('2stems', '/path/to/local/storage')
|
||||||
|
"""
|
||||||
|
|
||||||
|
import tarfile
|
||||||
|
|
||||||
|
from os import environ
|
||||||
|
from tempfile import TemporaryFile
|
||||||
|
from shutil import copyfileobj
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from . import ModelProvider
|
||||||
|
from ...utils.logging import get_logger
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
class GithubModelProvider(ModelProvider):
|
||||||
|
""" A ModelProvider implementation backed on Github for remote storage. """
|
||||||
|
|
||||||
|
LATEST_RELEASE = 'v1.4.0'
|
||||||
|
RELEASE_PATH = 'releases/download'
|
||||||
|
|
||||||
|
def __init__(self, host, repository, release):
|
||||||
|
""" Default constructor.
|
||||||
|
|
||||||
|
:param host: Host to the Github instance to reach.
|
||||||
|
:param repository: Repository path within target Github.
|
||||||
|
:param release: Release name to get models from.
|
||||||
|
"""
|
||||||
|
self._host = host
|
||||||
|
self._repository = repository
|
||||||
|
self._release = release
|
||||||
|
|
||||||
|
def download(self, name, path):
|
||||||
|
""" Download model denoted by the given name to disk.
|
||||||
|
|
||||||
|
:param name: Name of the model to download.
|
||||||
|
:param path: Path of the directory to save model into.
|
||||||
|
"""
|
||||||
|
url = '{}/{}/{}/{}/{}.tar.gz'.format(
|
||||||
|
self._host,
|
||||||
|
self._repository,
|
||||||
|
self.RELEASE_PATH,
|
||||||
|
self._release,
|
||||||
|
name)
|
||||||
|
get_logger().info('Downloading model archive %s', url)
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise IOError(f'Resource {url} not found')
|
||||||
|
with TemporaryFile() as stream:
|
||||||
|
copyfileobj(response.raw, stream)
|
||||||
|
get_logger().debug('Extracting downloaded archive')
|
||||||
|
stream.seek(0)
|
||||||
|
tar = tarfile.open(fileobj=stream)
|
||||||
|
tar.extractall(path=path)
|
||||||
|
tar.close()
|
||||||
|
get_logger().debug('Model file extracted')
|
||||||
28
spleeter/resources/2stems.json
Normal file
28
spleeter/resources/2stems.json
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
{
|
||||||
|
"train_csv": "path/to/train.csv",
|
||||||
|
"validation_csv": "path/to/test.csv",
|
||||||
|
"model_dir": "2stems",
|
||||||
|
"mix_name": "mix",
|
||||||
|
"instrument_list": ["vocals", "accompaniment"],
|
||||||
|
"sample_rate":44100,
|
||||||
|
"frame_length":4096,
|
||||||
|
"frame_step":1024,
|
||||||
|
"T":512,
|
||||||
|
"F":1024,
|
||||||
|
"n_channels":2,
|
||||||
|
"separation_exponent":2,
|
||||||
|
"mask_extension":"zeros",
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"batch_size":4,
|
||||||
|
"training_cache":"training_cache",
|
||||||
|
"validation_cache":"validation_cache",
|
||||||
|
"train_max_steps": 1000000,
|
||||||
|
"throttle_secs":300,
|
||||||
|
"random_seed":0,
|
||||||
|
"save_checkpoints_steps":150,
|
||||||
|
"save_summary_steps":5,
|
||||||
|
"model":{
|
||||||
|
"type":"unet.unet",
|
||||||
|
"params":{}
|
||||||
|
}
|
||||||
|
}
|
||||||
31
spleeter/resources/4stems.json
Normal file
31
spleeter/resources/4stems.json
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
{
|
||||||
|
"train_csv": "path/to/train.csv",
|
||||||
|
"validation_csv": "path/to/val.csv",
|
||||||
|
"model_dir": "4stems",
|
||||||
|
"mix_name": "mix",
|
||||||
|
"instrument_list": ["vocals", "drums", "bass", "other"],
|
||||||
|
"sample_rate":44100,
|
||||||
|
"frame_length":4096,
|
||||||
|
"frame_step":1024,
|
||||||
|
"T":512,
|
||||||
|
"F":1024,
|
||||||
|
"n_channels":2,
|
||||||
|
"separation_exponent":2,
|
||||||
|
"mask_extension":"zeros",
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"batch_size":4,
|
||||||
|
"training_cache":"training_cache",
|
||||||
|
"validation_cache":"validation_cache",
|
||||||
|
"train_max_steps": 1500000,
|
||||||
|
"throttle_secs":600,
|
||||||
|
"random_seed":3,
|
||||||
|
"save_checkpoints_steps":300,
|
||||||
|
"save_summary_steps":5,
|
||||||
|
"model":{
|
||||||
|
"type":"unet.unet",
|
||||||
|
"params":{
|
||||||
|
"conv_activation":"ELU",
|
||||||
|
"deconv_activation":"ELU"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
31
spleeter/resources/5stems.json
Normal file
31
spleeter/resources/5stems.json
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
{
|
||||||
|
"train_csv": "path/to/train.csv",
|
||||||
|
"validation_csv": "path/to/test.csv",
|
||||||
|
"model_dir": "5stems",
|
||||||
|
"mix_name": "mix",
|
||||||
|
"instrument_list": ["vocals", "piano", "drums", "bass", "other"],
|
||||||
|
"sample_rate":44100,
|
||||||
|
"frame_length":4096,
|
||||||
|
"frame_step":1024,
|
||||||
|
"T":512,
|
||||||
|
"F":1024,
|
||||||
|
"n_channels":2,
|
||||||
|
"separation_exponent":2,
|
||||||
|
"mask_extension":"zeros",
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"batch_size":4,
|
||||||
|
"training_cache":"training_cache",
|
||||||
|
"validation_cache":"validation_cache",
|
||||||
|
"train_max_steps": 2500000,
|
||||||
|
"throttle_secs":600,
|
||||||
|
"random_seed":8,
|
||||||
|
"save_checkpoints_steps":300,
|
||||||
|
"save_summary_steps":5,
|
||||||
|
"model":{
|
||||||
|
"type":"unet.softmax_unet",
|
||||||
|
"params":{
|
||||||
|
"conv_activation":"ELU",
|
||||||
|
"deconv_activation":"ELU"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
8
spleeter/resources/__init__.py
Normal file
8
spleeter/resources/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
""" Packages that provides static resources file for the library. """
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
32
spleeter/resources/musdb.json
Normal file
32
spleeter/resources/musdb.json
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
{
|
||||||
|
"train_csv": "configs/musdb_train.csv",
|
||||||
|
"validation_csv": "configs/musdb_validation.csv",
|
||||||
|
"model_dir": "musdb_model",
|
||||||
|
"mix_name": "mix",
|
||||||
|
"instrument_list": ["vocals", "drums", "bass", "other"],
|
||||||
|
"sample_rate":44100,
|
||||||
|
"frame_length":4096,
|
||||||
|
"frame_step":1024,
|
||||||
|
"T":512,
|
||||||
|
"F":1024,
|
||||||
|
"n_channels":2,
|
||||||
|
"n_chunks_per_song":1,
|
||||||
|
"separation_exponent":2,
|
||||||
|
"mask_extension":"zeros",
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"batch_size":4,
|
||||||
|
"training_cache":"training_cache",
|
||||||
|
"validation_cache":"validation_cache",
|
||||||
|
"train_max_steps": 100000,
|
||||||
|
"throttle_secs":600,
|
||||||
|
"random_seed":3,
|
||||||
|
"save_checkpoints_steps":300,
|
||||||
|
"save_summary_steps":5,
|
||||||
|
"model":{
|
||||||
|
"type":"unet.unet",
|
||||||
|
"params":{
|
||||||
|
"conv_activation":"ELU",
|
||||||
|
"deconv_activation":"ELU"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
127
spleeter/separator.py
Normal file
127
spleeter/separator.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
"""
|
||||||
|
Module that provides a class wrapper for source separation.
|
||||||
|
|
||||||
|
:Example:
|
||||||
|
|
||||||
|
>>> from spleeter.separator import Separator
|
||||||
|
>>> separator = Separator('spleeter:2stems')
|
||||||
|
>>> separator.separate(waveform, lambda instrument, data: ...)
|
||||||
|
>>> separator.separate_to_file(...)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from multiprocessing import Pool
|
||||||
|
from pathlib import Path
|
||||||
|
from os.path import join
|
||||||
|
|
||||||
|
from .model import model_fn
|
||||||
|
from .utils.audio.adapter import get_default_audio_adapter
|
||||||
|
from .utils.audio.convertor import to_stereo
|
||||||
|
from .utils.configuration import load_configuration
|
||||||
|
from .utils.estimator import create_estimator, to_predictor
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
class Separator(object):
|
||||||
|
""" A wrapper class for performing separation. """
|
||||||
|
|
||||||
|
def __init__(self, params_descriptor, MWF=False):
|
||||||
|
""" Default constructor.
|
||||||
|
|
||||||
|
:param params_descriptor: Descriptor for TF params to be used.
|
||||||
|
:param MWF: (Optional) True if MWF should be used, False otherwise.
|
||||||
|
"""
|
||||||
|
self._params = load_configuration(params_descriptor)
|
||||||
|
self._sample_rate = self._params['sample_rate']
|
||||||
|
self._MWF = MWF
|
||||||
|
self._predictor = None
|
||||||
|
self._pool = Pool()
|
||||||
|
self._tasks = []
|
||||||
|
|
||||||
|
def _get_predictor(self):
|
||||||
|
""" Lazy loading access method for internal predictor instance.
|
||||||
|
|
||||||
|
:returns: Predictor to use for source separation.
|
||||||
|
"""
|
||||||
|
if self._predictor is None:
|
||||||
|
estimator = create_estimator(self._params, self._MWF)
|
||||||
|
self._predictor = to_predictor(estimator)
|
||||||
|
return self._predictor
|
||||||
|
|
||||||
|
def join(self, timeout=20):
|
||||||
|
""" Wait for all pending tasks to be finished.
|
||||||
|
|
||||||
|
:param timeout: (Optional) task waiting timeout.
|
||||||
|
"""
|
||||||
|
while len(self._tasks) > 0:
|
||||||
|
task = self._tasks.pop()
|
||||||
|
task.get()
|
||||||
|
task.wait(timeout=timeout)
|
||||||
|
|
||||||
|
def separate(self, waveform):
|
||||||
|
""" Performs source separation over the given waveform.
|
||||||
|
|
||||||
|
The separation is performed synchronously but the result
|
||||||
|
processing is done asynchronously, allowing for instance
|
||||||
|
to export audio in parallel (through multiprocessing).
|
||||||
|
|
||||||
|
Given result is passed by to the given consumer, which will
|
||||||
|
be waited for task finishing if synchronous flag is True.
|
||||||
|
|
||||||
|
:param waveform: Waveform to apply separation on.
|
||||||
|
:returns: Separated waveforms.
|
||||||
|
"""
|
||||||
|
if not waveform.shape[-1] == 2:
|
||||||
|
waveform = to_stereo(waveform)
|
||||||
|
predictor = self._get_predictor()
|
||||||
|
prediction = predictor({
|
||||||
|
'waveform': waveform,
|
||||||
|
'audio_id': ''})
|
||||||
|
prediction.pop('audio_id')
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
def separate_to_file(
|
||||||
|
self, audio_descriptor, destination,
|
||||||
|
audio_adapter=get_default_audio_adapter(),
|
||||||
|
offset=0, duration=600., codec='wav', bitrate='128k',
|
||||||
|
synchronous=True):
|
||||||
|
""" Performs source separation and export result to file using
|
||||||
|
given audio adapter.
|
||||||
|
|
||||||
|
:param audio_descriptor: Describe song to separate, used by audio
|
||||||
|
adapter to retrieve and load audio data,
|
||||||
|
in case of file based audio adapter, such
|
||||||
|
descriptor would be a file path.
|
||||||
|
:param destination: Target directory to write output to.
|
||||||
|
:param audio_adapter: (Optional) Audio adapter to use for I/O.
|
||||||
|
:param offset: (Optional) Offset of loaded song.
|
||||||
|
:param duration: (Optional) Duration of loaded song.
|
||||||
|
:param codec: (Optional) Export codec.
|
||||||
|
:param bitrate: (Optional) Export bitrate.
|
||||||
|
:param synchronous: (Optional) True is should by synchronous.
|
||||||
|
"""
|
||||||
|
waveform, _ = audio_adapter.load(
|
||||||
|
audio_descriptor,
|
||||||
|
offset=offset,
|
||||||
|
duration=duration,
|
||||||
|
sample_rate=self._sample_rate)
|
||||||
|
sources = self.separate(waveform)
|
||||||
|
for instrument, data in sources.items():
|
||||||
|
task = self._pool.apply_async(audio_adapter.save, (
|
||||||
|
join(destination, f'{instrument}.{codec}'),
|
||||||
|
data,
|
||||||
|
self._sample_rate,
|
||||||
|
codec,
|
||||||
|
bitrate))
|
||||||
|
self._tasks.append(task)
|
||||||
|
if synchronous:
|
||||||
|
self.join()
|
||||||
8
spleeter/utils/__init__.py
Normal file
8
spleeter/utils/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
""" This package provides utility function and classes. """
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
15
spleeter/utils/audio/__init__.py
Normal file
15
spleeter/utils/audio/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
"""
|
||||||
|
`spleeter.utils.audio` package provides various
|
||||||
|
tools for manipulating audio content such as :
|
||||||
|
|
||||||
|
- Audio adapter class for abstract interaction with audio file.
|
||||||
|
- FFMPEG implementation for audio adapter.
|
||||||
|
- Waveform convertion and transforming functions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
144
spleeter/utils/audio/adapter.py
Normal file
144
spleeter/utils/audio/adapter.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
""" AudioAdapter class defintion. """
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from importlib import import_module
|
||||||
|
from os.path import exists
|
||||||
|
|
||||||
|
# pylint: disable=import-error
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.contrib.signal import stft, hann_window
|
||||||
|
# pylint: enable=import-error
|
||||||
|
|
||||||
|
from ..logging import get_logger
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
class AudioAdapter(ABC):
|
||||||
|
""" An abstract class for manipulating audio signal. """
|
||||||
|
|
||||||
|
# Default audio adapter singleton instance.
|
||||||
|
DEFAULT = None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load(
|
||||||
|
self, audio_descriptor, offset, duration,
|
||||||
|
sample_rate, dtype=np.float32):
|
||||||
|
""" Loads the audio file denoted by the given audio descriptor
|
||||||
|
and returns it data as a waveform. Aims to be implemented
|
||||||
|
by client.
|
||||||
|
|
||||||
|
:param audio_descriptor: Describe song to load, in case of file
|
||||||
|
based audio adapter, such descriptor would
|
||||||
|
be a file path.
|
||||||
|
:param offset: Start offset to load from in seconds.
|
||||||
|
:param duration: Duration to load in seconds.
|
||||||
|
:param sample_rate: Sample rate to load audio with.
|
||||||
|
:param dtype: Numpy data type to use, default to float32.
|
||||||
|
:returns: Loaded data as (wf, sample_rate) tuple.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_tf_waveform(
|
||||||
|
self, audio_descriptor,
|
||||||
|
offset=0.0, duration=1800., sample_rate=44100,
|
||||||
|
dtype=b'float32', waveform_name='waveform'):
|
||||||
|
""" Load the audio and convert it to a tensorflow waveform.
|
||||||
|
|
||||||
|
:param audio_descriptor: Describe song to load, in case of file
|
||||||
|
based audio adapter, such descriptor would
|
||||||
|
be a file path.
|
||||||
|
:param offset: Start offset to load from in seconds.
|
||||||
|
:param duration: Duration to load in seconds.
|
||||||
|
:param sample_rate: Sample rate to load audio with.
|
||||||
|
:param dtype: Numpy data type to use, default to float32.
|
||||||
|
:param waveform_name: (Optional) Name of the key in output dict.
|
||||||
|
:returns: TF output dict with waveform as
|
||||||
|
(T x chan numpy array) and a boolean that
|
||||||
|
tells whether there were an error while
|
||||||
|
trying to load the waveform.
|
||||||
|
"""
|
||||||
|
# Cast parameters to TF format.
|
||||||
|
offset = tf.cast(offset, tf.float64)
|
||||||
|
duration = tf.cast(duration, tf.float64)
|
||||||
|
|
||||||
|
# Defined safe loading function.
|
||||||
|
def safe_load(path, offset, duration, sample_rate, dtype):
|
||||||
|
get_logger().info(
|
||||||
|
f'Loading audio {path} from {offset} to {offset + duration}')
|
||||||
|
try:
|
||||||
|
(data, _) = self.load(
|
||||||
|
path.numpy(),
|
||||||
|
offset.numpy(),
|
||||||
|
duration.numpy(),
|
||||||
|
sample_rate.numpy(),
|
||||||
|
dtype=dtype.numpy())
|
||||||
|
return (data, False)
|
||||||
|
except Exception as e:
|
||||||
|
get_logger().warning(e)
|
||||||
|
return (np.float32(-1.0), True)
|
||||||
|
|
||||||
|
# Execute function and format results.
|
||||||
|
results = tf.py_function(
|
||||||
|
safe_load,
|
||||||
|
[audio_descriptor, offset, duration, sample_rate, dtype],
|
||||||
|
(tf.float32, tf.bool)),
|
||||||
|
waveform, error = results[0]
|
||||||
|
return {
|
||||||
|
waveform_name: waveform,
|
||||||
|
f'{waveform_name}_error': error
|
||||||
|
}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(
|
||||||
|
self, path, data, sample_rate,
|
||||||
|
codec=None, bitrate=None):
|
||||||
|
""" Save the given audio data to the file denoted by
|
||||||
|
the given path.
|
||||||
|
|
||||||
|
:param path: Path of the audio file to save data in.
|
||||||
|
:param data: Waveform data to write.
|
||||||
|
:param sample_rate: Sample rate to write file in.
|
||||||
|
:param codec: (Optional) Writing codec to use.
|
||||||
|
:param bitrate: (Optional) Bitrate of the written audio file.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_audio_adapter():
|
||||||
|
""" Builds and returns a default audio adapter instance.
|
||||||
|
|
||||||
|
:returns: An audio adapter instance.
|
||||||
|
"""
|
||||||
|
if AudioAdapter.DEFAULT is None:
|
||||||
|
from .ffmpeg import FFMPEGProcessAudioAdapter
|
||||||
|
AudioAdapter.DEFAULT = FFMPEGProcessAudioAdapter()
|
||||||
|
return AudioAdapter.DEFAULT
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_adapter(descriptor):
|
||||||
|
""" Load dynamically an AudioAdapter from given class descriptor.
|
||||||
|
|
||||||
|
:param descriptor: Adapter class descriptor (module.Class)
|
||||||
|
:returns: Created adapter instance.
|
||||||
|
"""
|
||||||
|
if descriptor is None:
|
||||||
|
return get_default_audio_adapter()
|
||||||
|
module_path = descriptor.split('.')
|
||||||
|
adapter_class_name = module_path[-1]
|
||||||
|
module_path = '.'.join(module_path[:-1])
|
||||||
|
adapter_module = import_module(module_path)
|
||||||
|
adapter_class = getattr(adapter_module, adapter_class_name)
|
||||||
|
if not isinstance(adapter_class, AudioAdapter):
|
||||||
|
raise ValueError(
|
||||||
|
f'{adapter_class_name} is not a valid AudioAdapter class')
|
||||||
|
return adapter_class()
|
||||||
88
spleeter/utils/audio/convertor.py
Normal file
88
spleeter/utils/audio/convertor.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
""" This module provides audio data convertion functions. """
|
||||||
|
|
||||||
|
# pylint: disable=import-error
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
# pylint: enable=import-error
|
||||||
|
|
||||||
|
from ..tensor import from_float32_to_uint8, from_uint8_to_float32
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
def to_n_channels(waveform, n_channels):
|
||||||
|
""" Convert a waveform to n_channels by removing or
|
||||||
|
duplicating channels if needed (in tensorflow).
|
||||||
|
|
||||||
|
:param waveform: Waveform to transform.
|
||||||
|
:param n_channels: Number of channel to reshape waveform in.
|
||||||
|
:returns: Reshaped waveform.
|
||||||
|
"""
|
||||||
|
return tf.cond(
|
||||||
|
tf.shape(waveform)[1] >= n_channels,
|
||||||
|
true_fn=lambda: waveform[:, :n_channels],
|
||||||
|
false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def to_stereo(waveform):
|
||||||
|
""" Convert a waveform to stereo by duplicating if mono,
|
||||||
|
or truncating if too many channels.
|
||||||
|
|
||||||
|
:param waveform: a (N, d) numpy array.
|
||||||
|
:returns: A stereo waveform as a (N, 1) numpy array.
|
||||||
|
"""
|
||||||
|
if waveform.shape[1] == 1:
|
||||||
|
return np.repeat(waveform, 2, axis=-1)
|
||||||
|
if waveform.shape[1] > 2:
|
||||||
|
return waveform[:, :2]
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
|
||||||
|
def gain_to_db(tensor, espilon=10e-10):
|
||||||
|
""" Convert from gain to decibel in tensorflow.
|
||||||
|
|
||||||
|
:param tensor: Tensor to convert.
|
||||||
|
:param epsilon: Operation constant.
|
||||||
|
:returns: Converted tensor.
|
||||||
|
"""
|
||||||
|
return 20. / np.log(10) * tf.math.log(tf.maximum(tensor, espilon))
|
||||||
|
|
||||||
|
|
||||||
|
def db_to_gain(tensor):
|
||||||
|
""" Convert from decibel to gain in tensorflow.
|
||||||
|
|
||||||
|
:param tensor_db: Tensor to convert.
|
||||||
|
:returns: Converted tensor.
|
||||||
|
"""
|
||||||
|
return tf.pow(10., (tensor / 20.))
|
||||||
|
|
||||||
|
|
||||||
|
def spectrogram_to_db_uint(spectrogram, db_range=100., **kwargs):
|
||||||
|
""" Encodes given spectrogram into uint8 using decibel scale.
|
||||||
|
|
||||||
|
:param spectrogram: Spectrogram to be encoded as TF float tensor.
|
||||||
|
:param db_range: Range in decibel for encoding.
|
||||||
|
:returns: Encoded decibel spectrogram as uint8 tensor.
|
||||||
|
"""
|
||||||
|
db_spectrogram = gain_to_db(spectrogram)
|
||||||
|
max_db_spectrogram = tf.reduce_max(db_spectrogram)
|
||||||
|
db_spectrogram = tf.maximum(db_spectrogram, max_db_spectrogram - db_range)
|
||||||
|
return from_float32_to_uint8(db_spectrogram, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def db_uint_spectrogram_to_gain(db_uint_spectrogram, min_db, max_db):
|
||||||
|
""" Decode spectrogram from uint8 decibel scale.
|
||||||
|
|
||||||
|
:param db_uint_spectrogram: Decibel pectrogram to decode.
|
||||||
|
:param min_db: Lower bound limit for decoding.
|
||||||
|
:param max_db: Upper bound limit for decoding.
|
||||||
|
:returns: Decoded spectrogram as float2 tensor.
|
||||||
|
"""
|
||||||
|
db_spectrogram = from_uint8_to_float32(db_uint_spectrogram, min_db, max_db)
|
||||||
|
return db_to_gain(db_spectrogram)
|
||||||
263
spleeter/utils/audio/ffmpeg.py
Normal file
263
spleeter/utils/audio/ffmpeg.py
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
"""
|
||||||
|
This module provides an AudioAdapter implementation based on FFMPEG
|
||||||
|
process. Such implementation is POSIXish and depends on nothing except
|
||||||
|
standard Python libraries. Thus this implementation is the default one
|
||||||
|
used within this library.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import os.path
|
||||||
|
import platform
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
import numpy as np # pylint: disable=import-error
|
||||||
|
|
||||||
|
from .adapter import AudioAdapter
|
||||||
|
from ..logging import get_logger
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
# Default FFMPEG binary name.
|
||||||
|
_UNIX_BINARY = 'ffmpeg'
|
||||||
|
_WINDOWS_BINARY = 'ffmpeg.exe'
|
||||||
|
|
||||||
|
|
||||||
|
def _which(program):
|
||||||
|
""" A pure python implementation of `which`command
|
||||||
|
for retrieving absolute path from command name or path.
|
||||||
|
|
||||||
|
@see https://stackoverflow.com/a/377028/1211342
|
||||||
|
|
||||||
|
:param program: Program name or path to expend.
|
||||||
|
:returns: Absolute path of program if any, None otherwise.
|
||||||
|
"""
|
||||||
|
def is_exe(fpath):
|
||||||
|
return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
|
||||||
|
|
||||||
|
fpath, _ = os.path.split(program)
|
||||||
|
if fpath:
|
||||||
|
if is_exe(program):
|
||||||
|
return program
|
||||||
|
else:
|
||||||
|
for path in os.environ['PATH'].split(os.pathsep):
|
||||||
|
exe_file = os.path.join(path, program)
|
||||||
|
if is_exe(exe_file):
|
||||||
|
return exe_file
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_ffmpeg_path():
|
||||||
|
""" Retrieves FFMPEG binary path using ENVVAR if defined
|
||||||
|
or default binary name (Windows or UNIX style).
|
||||||
|
|
||||||
|
:returns: Absolute path of FFMPEG binary.
|
||||||
|
:raise IOError: If FFMPEG binary cannot be found.
|
||||||
|
"""
|
||||||
|
ffmpeg_path = os.environ.get('FFMPEG_PATH', None)
|
||||||
|
if ffmpeg_path is None:
|
||||||
|
# Note: try to infer standard binary name regarding of platform.
|
||||||
|
if platform.system() == 'Windows':
|
||||||
|
ffmpeg_path = _WINDOWS_BINARY
|
||||||
|
else:
|
||||||
|
ffmpeg_path = _UNIX_BINARY
|
||||||
|
expended = _which(ffmpeg_path)
|
||||||
|
if expended is None:
|
||||||
|
raise IOError(f'FFMPEG binary ({ffmpeg_path}) not found')
|
||||||
|
return expended
|
||||||
|
|
||||||
|
|
||||||
|
def _to_ffmpeg_time(n):
|
||||||
|
""" Format number of seconds to time expected by FFMPEG.
|
||||||
|
|
||||||
|
:param n: Time in seconds to format.
|
||||||
|
:returns: Formatted time in FFMPEG format.
|
||||||
|
"""
|
||||||
|
m, s = divmod(n, 60)
|
||||||
|
h, m = divmod(m, 60)
|
||||||
|
return '%d:%02d:%09.6f' % (h, m, s)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_ffmpg_results(stderr):
|
||||||
|
""" Extract number of channels and sample rate from
|
||||||
|
the given FFMPEG STDERR output line.
|
||||||
|
|
||||||
|
:param stderr: STDERR output line to parse.
|
||||||
|
:returns: Parsed n_channels and sample_rate values.
|
||||||
|
"""
|
||||||
|
# Setup default value.
|
||||||
|
n_channels = 0
|
||||||
|
sample_rate = 0
|
||||||
|
# Find samplerate
|
||||||
|
match = re.search(r'(\d+) hz', stderr)
|
||||||
|
if match:
|
||||||
|
sample_rate = int(match.group(1))
|
||||||
|
# Channel count.
|
||||||
|
match = re.search(r'hz, ([^,]+),', stderr)
|
||||||
|
if match:
|
||||||
|
mode = match.group(1)
|
||||||
|
if mode == 'stereo':
|
||||||
|
n_channels = 2
|
||||||
|
else:
|
||||||
|
match = re.match(r'(\d+) ', mode)
|
||||||
|
n_channels = match and int(match.group(1)) or 1
|
||||||
|
return n_channels, sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
class _CommandBuilder(object):
|
||||||
|
""" A simple builder pattern class for CLI string. """
|
||||||
|
|
||||||
|
def __init__(self, binary):
|
||||||
|
""" Default constructor. """
|
||||||
|
self._command = [binary]
|
||||||
|
|
||||||
|
def flag(self, flag):
|
||||||
|
""" Add flag or unlabelled opt. """
|
||||||
|
self._command.append(flag)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def opt(self, short, value, formatter=str):
|
||||||
|
""" Add option if value not None. """
|
||||||
|
if value is not None:
|
||||||
|
self._command.append(short)
|
||||||
|
self._command.append(formatter(value))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def command(self):
|
||||||
|
""" Build string command. """
|
||||||
|
return self._command
|
||||||
|
|
||||||
|
|
||||||
|
class FFMPEGProcessAudioAdapter(AudioAdapter):
|
||||||
|
""" An AudioAdapter implementation that use FFMPEG binary through
|
||||||
|
subprocess in order to perform I/O operation for audio processing.
|
||||||
|
|
||||||
|
When created, FFMPEG binary path will be checked and expended,
|
||||||
|
raising exception if not found. Such path could be infered using
|
||||||
|
FFMPEG_PATH environment variable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
""" Default constructor. """
|
||||||
|
self._ffmpeg_path = _get_ffmpeg_path()
|
||||||
|
|
||||||
|
def _get_command_builder(self):
|
||||||
|
""" Creates and returns a command builder using FFMPEG path.
|
||||||
|
|
||||||
|
:returns: Built command builder.
|
||||||
|
"""
|
||||||
|
return _CommandBuilder(self._ffmpeg_path)
|
||||||
|
|
||||||
|
def load(
|
||||||
|
self, path, offset=None, duration=None,
|
||||||
|
sample_rate=None, dtype=np.float32):
|
||||||
|
""" Loads the audio file denoted by the given path
|
||||||
|
and returns it data as a waveform.
|
||||||
|
|
||||||
|
:param path: Path of the audio file to load data from.
|
||||||
|
:param offset: (Optional) Start offset to load from in seconds.
|
||||||
|
:param duration: (Optional) Duration to load in seconds.
|
||||||
|
:param sample_rate: (Optional) Sample rate to load audio with.
|
||||||
|
:param dtype: (Optional) Numpy data type to use, default to float32.
|
||||||
|
:returns: Loaded data a (waveform, sample_rate) tuple.
|
||||||
|
"""
|
||||||
|
if not isinstance(path, str):
|
||||||
|
path = path.decode()
|
||||||
|
command = (
|
||||||
|
self._get_command_builder()
|
||||||
|
.opt('-ss', offset, formatter=_to_ffmpeg_time)
|
||||||
|
.opt('-t', duration, formatter=_to_ffmpeg_time)
|
||||||
|
.opt('-i', path)
|
||||||
|
.opt('-ar', sample_rate)
|
||||||
|
.opt('-f', 'f32le')
|
||||||
|
.flag('-')
|
||||||
|
.command())
|
||||||
|
process = subprocess.Popen(
|
||||||
|
command,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE)
|
||||||
|
buffer = process.stdout.read(-1)
|
||||||
|
# Read STDERR until end of the process detected.
|
||||||
|
while True:
|
||||||
|
status = process.stderr.readline()
|
||||||
|
if not status:
|
||||||
|
raise OSError('Stream info not found')
|
||||||
|
if isinstance(status, bytes): # Note: Python 3 compatibility.
|
||||||
|
status = status.decode('utf8', 'ignore')
|
||||||
|
status = status.strip().lower()
|
||||||
|
if 'no such file' in status:
|
||||||
|
raise IOError(f'File {path} not found')
|
||||||
|
elif 'invalid data found' in status:
|
||||||
|
raise IOError(f'FFMPEG error : {status}')
|
||||||
|
elif 'audio:' in status:
|
||||||
|
n_channels, ffmpeg_sample_rate = _parse_ffmpg_results(status)
|
||||||
|
if sample_rate is None:
|
||||||
|
sample_rate = ffmpeg_sample_rate
|
||||||
|
break
|
||||||
|
# Load waveform and clean process.
|
||||||
|
waveform = np.frombuffer(buffer, dtype='<f4').reshape(-1, n_channels)
|
||||||
|
if not waveform.dtype == np.dtype(dtype):
|
||||||
|
waveform = waveform.astype(dtype)
|
||||||
|
process.stdout.close()
|
||||||
|
process.stderr.close()
|
||||||
|
del process
|
||||||
|
return (waveform, sample_rate)
|
||||||
|
|
||||||
|
def save(
|
||||||
|
self, path, data, sample_rate,
|
||||||
|
codec=None, bitrate=None):
|
||||||
|
""" Write waveform data to the file denoted by the given path
|
||||||
|
using FFMPEG process.
|
||||||
|
|
||||||
|
:param path: Path of the audio file to save data in.
|
||||||
|
:param data: Waveform data to write.
|
||||||
|
:param sample_rate: Sample rate to write file in.
|
||||||
|
:param codec: (Optional) Writing codec to use.
|
||||||
|
:param bitrate: (Optional) Bitrate of the written audio file.
|
||||||
|
:raise IOError: If any error occurs while using FFMPEG to write data.
|
||||||
|
"""
|
||||||
|
directory = os.path.split(path)[0]
|
||||||
|
if not os.path.exists(directory):
|
||||||
|
os.makedirs(directory)
|
||||||
|
get_logger().debug('Writing file %s', path)
|
||||||
|
# NOTE: Tweak.
|
||||||
|
if codec == 'wav':
|
||||||
|
codec = None
|
||||||
|
command = (
|
||||||
|
self._get_command_builder()
|
||||||
|
.flag('-y')
|
||||||
|
.opt('-loglevel', 'error')
|
||||||
|
.opt('-f', 'f32le')
|
||||||
|
.opt('-ar', sample_rate)
|
||||||
|
.opt('-ac', data.shape[1])
|
||||||
|
.opt('-i', '-')
|
||||||
|
.flag('-vn')
|
||||||
|
.opt('-acodec', codec)
|
||||||
|
.opt('-ar', sample_rate) # Note: why twice ?
|
||||||
|
.opt('-strict', '-2') # Note: For 'aac' codec support.
|
||||||
|
.opt('-ab', bitrate)
|
||||||
|
.flag(path)
|
||||||
|
.command())
|
||||||
|
process = subprocess.Popen(
|
||||||
|
command,
|
||||||
|
stdout=open(os.devnull, 'wb'),
|
||||||
|
stdin=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE)
|
||||||
|
# Write data to STDIN.
|
||||||
|
try:
|
||||||
|
process.stdin.write(
|
||||||
|
data.astype('<f4').tostring())
|
||||||
|
except IOError:
|
||||||
|
raise IOError(f'FFMPEG error: {process.stderr.read()}')
|
||||||
|
# Clean process.
|
||||||
|
process.stdin.close()
|
||||||
|
if process.stderr is not None:
|
||||||
|
process.stderr.close()
|
||||||
|
process.wait()
|
||||||
|
del process
|
||||||
|
get_logger().info('File %s written', path)
|
||||||
128
spleeter/utils/audio/spectrogram.py
Normal file
128
spleeter/utils/audio/spectrogram.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
""" Spectrogram specific data augmentation """
|
||||||
|
|
||||||
|
# pylint: disable=import-error
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.contrib.signal import stft, hann_window
|
||||||
|
# pylint: enable=import-error
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
def compute_spectrogram_tf(
|
||||||
|
waveform,
|
||||||
|
frame_length=2048, frame_step=512,
|
||||||
|
spec_exponent=1., window_exponent=1.):
|
||||||
|
""" Compute magnitude / power spectrogram from waveform as
|
||||||
|
a n_samples x n_channels tensor.
|
||||||
|
|
||||||
|
:param waveform: Input waveform as (times x number of channels)
|
||||||
|
tensor.
|
||||||
|
:param frame_length: Length of a STFT frame to use.
|
||||||
|
:param frame_step: HOP between successive frames.
|
||||||
|
:param spec_exponent: Exponent of the spectrogram (usually 1 for
|
||||||
|
magnitude spectrogram, or 2 for power spectrogram).
|
||||||
|
:param window_exponent: Exponent applied to the Hann windowing function
|
||||||
|
(may be useful for making perfect STFT/iSTFT
|
||||||
|
reconstruction).
|
||||||
|
:returns: Computed magnitude / power spectrogram as a
|
||||||
|
(T x F x n_channels) tensor.
|
||||||
|
"""
|
||||||
|
stft_tensor = tf.transpose(
|
||||||
|
stft(
|
||||||
|
tf.transpose(waveform),
|
||||||
|
frame_length,
|
||||||
|
frame_step,
|
||||||
|
window_fn=lambda f, dtype: hann_window(
|
||||||
|
f,
|
||||||
|
periodic=True,
|
||||||
|
dtype=waveform.dtype) ** window_exponent),
|
||||||
|
perm=[1, 2, 0])
|
||||||
|
return np.abs(stft_tensor) ** spec_exponent
|
||||||
|
|
||||||
|
|
||||||
|
def time_stretch(
|
||||||
|
spectrogram,
|
||||||
|
factor=1.0,
|
||||||
|
method=tf.image.ResizeMethod.BILINEAR):
|
||||||
|
""" Time stretch a spectrogram preserving shape in tensorflow. Note that
|
||||||
|
this is an approximation in the frequency domain.
|
||||||
|
|
||||||
|
:param spectrogram: Input spectrogram to be time stretched as tensor.
|
||||||
|
:param factor: (Optional) Time stretch factor, must be >0, default to 1.
|
||||||
|
:param mehtod: (Optional) Interpolation method, default to BILINEAR.
|
||||||
|
:returns: Time stretched spectrogram as tensor with same shape.
|
||||||
|
"""
|
||||||
|
T = tf.shape(spectrogram)[0]
|
||||||
|
T_ts = tf.cast(tf.cast(T, tf.float32) * factor, tf.int32)[0]
|
||||||
|
F = tf.shape(spectrogram)[1]
|
||||||
|
ts_spec = tf.image.resize_images(
|
||||||
|
spectrogram,
|
||||||
|
[T_ts, F],
|
||||||
|
method=method,
|
||||||
|
align_corners=True)
|
||||||
|
return tf.image.resize_image_with_crop_or_pad(ts_spec, T, F)
|
||||||
|
|
||||||
|
|
||||||
|
def random_time_stretch(spectrogram, factor_min=0.9, factor_max=1.1, **kwargs):
|
||||||
|
""" Time stretch a spectrogram preserving shape with random ratio in
|
||||||
|
tensorflow. Applies time_stretch to spectrogram with a random ratio drawn
|
||||||
|
uniformly in [factor_min, factor_max].
|
||||||
|
|
||||||
|
:param spectrogram: Input spectrogram to be time stretched as tensor.
|
||||||
|
:param factor_min: (Optional) Min time stretch factor, default to 0.9.
|
||||||
|
:param factor_max: (Optional) Max time stretch factor, default to 1.1.
|
||||||
|
:returns: Randomly time stretched spectrogram as tensor with same shape.
|
||||||
|
"""
|
||||||
|
factor = tf.random_uniform(
|
||||||
|
shape=(1,),
|
||||||
|
seed=0) * (factor_max - factor_min) + factor_min
|
||||||
|
return time_stretch(spectrogram, factor=factor, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def pitch_shift(
|
||||||
|
spectrogram,
|
||||||
|
semitone_shift=0.0,
|
||||||
|
method=tf.image.ResizeMethod.BILINEAR):
|
||||||
|
""" Pitch shift a spectrogram preserving shape in tensorflow. Note that
|
||||||
|
this is an approximation in the frequency domain.
|
||||||
|
|
||||||
|
:param spectrogram: Input spectrogram to be pitch shifted as tensor.
|
||||||
|
:param semitone_shift: (Optional) Pitch shift in semitone, default to 0.0.
|
||||||
|
:param mehtod: (Optional) Interpolation method, default to BILINEAR.
|
||||||
|
:returns: Pitch shifted spectrogram (same shape as spectrogram).
|
||||||
|
"""
|
||||||
|
factor = 2 ** (semitone_shift / 12.)
|
||||||
|
T = tf.shape(spectrogram)[0]
|
||||||
|
F = tf.shape(spectrogram)[1]
|
||||||
|
F_ps = tf.cast(tf.cast(F, tf.float32) * factor, tf.int32)[0]
|
||||||
|
ps_spec = tf.image.resize_images(
|
||||||
|
spectrogram,
|
||||||
|
[T, F_ps],
|
||||||
|
method=method,
|
||||||
|
align_corners=True)
|
||||||
|
paddings = [[0, 0], [0, tf.maximum(0, F - F_ps)], [0, 0]]
|
||||||
|
return tf.pad(ps_spec[:, :F, :], paddings, 'CONSTANT')
|
||||||
|
|
||||||
|
|
||||||
|
def random_pitch_shift(spectrogram, shift_min=-1., shift_max=1., **kwargs):
|
||||||
|
""" Pitch shift a spectrogram preserving shape with random ratio in
|
||||||
|
tensorflow. Applies pitch_shift to spectrogram with a random shift
|
||||||
|
amount (expressed in semitones) drawn uniformly in [shift_min, shift_max].
|
||||||
|
|
||||||
|
:param spectrogram: Input spectrogram to be pitch shifted as tensor.
|
||||||
|
|
||||||
|
:param shift_min: (Optional) Min pitch shift in semitone, default to -1.
|
||||||
|
:param shift_max: (Optional) Max pitch shift in semitone, default to 1.
|
||||||
|
:returns: Randomly pitch shifted spectrogram (same shape as spectrogram).
|
||||||
|
"""
|
||||||
|
semitone_shift = tf.random_uniform(
|
||||||
|
shape=(1,),
|
||||||
|
seed=0) * (shift_max - shift_min) + shift_min
|
||||||
|
return pitch_shift(spectrogram, semitone_shift=semitone_shift, **kwargs)
|
||||||
47
spleeter/utils/configuration.py
Normal file
47
spleeter/utils/configuration.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
""" Module that provides configuration loading function. """
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
try:
|
||||||
|
import importlib.resources as loader
|
||||||
|
except ImportError:
|
||||||
|
# Try backported to PY<37 `importlib_resources`.
|
||||||
|
import importlib_resources as loader
|
||||||
|
|
||||||
|
from os.path import exists
|
||||||
|
|
||||||
|
from .. import resources
|
||||||
|
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
_EMBEDDED_CONFIGURATION_PREFIX = 'spleeter:'
|
||||||
|
|
||||||
|
|
||||||
|
def load_configuration(descriptor):
|
||||||
|
""" Load configuration from the given descriptor. Could be
|
||||||
|
either a `spleeter:` prefixed embedded configuration name
|
||||||
|
or a file system path to read configuration from.
|
||||||
|
|
||||||
|
:param descriptor: Configuration descriptor to use for lookup.
|
||||||
|
:returns: Loaded description as dict.
|
||||||
|
:raise ValueError: If required embedded configuration does not exists.
|
||||||
|
:raise IOError: If required configuration file does not exists.
|
||||||
|
"""
|
||||||
|
# Embedded configuration reading.
|
||||||
|
if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX):
|
||||||
|
name = descriptor[len(_EMBEDDED_CONFIGURATION_PREFIX):]
|
||||||
|
if not loader.is_resource(resources, f'{name}.json'):
|
||||||
|
raise ValueError(f'No embedded configuration {name} found')
|
||||||
|
with loader.open_text(resources, f'{name}.json') as stream:
|
||||||
|
return json.load(stream)
|
||||||
|
# Standard file reading.
|
||||||
|
if not exists(descriptor):
|
||||||
|
raise IOError(f'Configuration file {descriptor} not found')
|
||||||
|
with open(descriptor, 'r') as stream:
|
||||||
|
return json.load(stream)
|
||||||
69
spleeter/utils/estimator.py
Normal file
69
spleeter/utils/estimator.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
""" Utility functions for creating estimator. """
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# pylint: disable=import-error
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.contrib import predictor
|
||||||
|
# pylint: enable=import-error
|
||||||
|
|
||||||
|
from ..model import model_fn
|
||||||
|
from ..model.provider import get_default_model_provider
|
||||||
|
|
||||||
|
# Default exporting directory for predictor.
|
||||||
|
DEFAULT_EXPORT_DIRECTORY = '/tmp/serving'
|
||||||
|
|
||||||
|
|
||||||
|
def create_estimator(params, MWF):
|
||||||
|
"""
|
||||||
|
Initialize tensorflow estimator that will perform separation
|
||||||
|
|
||||||
|
Params:
|
||||||
|
- params: a dictionnary of parameters for building the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a tensorflow estimator
|
||||||
|
"""
|
||||||
|
# Load model.
|
||||||
|
model_directory = params['model_dir']
|
||||||
|
model_provider = get_default_model_provider()
|
||||||
|
params['model_dir'] = model_provider.get(model_directory)
|
||||||
|
params['MWF'] = MWF
|
||||||
|
# Setup config
|
||||||
|
session_config = tf.compat.v1.ConfigProto()
|
||||||
|
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
|
||||||
|
config = tf.estimator.RunConfig(session_config=session_config)
|
||||||
|
# Setup estimator
|
||||||
|
estimator = tf.estimator.Estimator(
|
||||||
|
model_fn=model_fn,
|
||||||
|
model_dir=params['model_dir'],
|
||||||
|
params=params,
|
||||||
|
config=config
|
||||||
|
)
|
||||||
|
return estimator
|
||||||
|
|
||||||
|
|
||||||
|
def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
|
||||||
|
""" Exports given estimator as predictor into the given directory
|
||||||
|
and returns associated tf.predictor instance.
|
||||||
|
|
||||||
|
:param estimator: Estimator to export.
|
||||||
|
:param directory: (Optional) path to write exported model into.
|
||||||
|
"""
|
||||||
|
def receiver():
|
||||||
|
shape = (None, estimator.params['n_channels'])
|
||||||
|
features = {
|
||||||
|
'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape),
|
||||||
|
'audio_id': tf.compat.v1.placeholder(tf.string)}
|
||||||
|
return tf.estimator.export.ServingInputReceiver(features, features)
|
||||||
|
|
||||||
|
estimator.export_saved_model(directory, receiver)
|
||||||
|
versions = [
|
||||||
|
model for model in Path(directory).iterdir()
|
||||||
|
if model.is_dir() and 'temp' not in str(model)]
|
||||||
|
latest = str(sorted(versions)[-1])
|
||||||
|
return predictor.from_saved_model(latest)
|
||||||
45
spleeter/utils/logging.py
Normal file
45
spleeter/utils/logging.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
""" Centralized logging facilities for Spleeter. """
|
||||||
|
|
||||||
|
from os import environ
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
class _LoggerHolder(object):
|
||||||
|
""" Logger singleton instance holder. """
|
||||||
|
|
||||||
|
INSTANCE = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger():
|
||||||
|
""" Returns library scoped logger.
|
||||||
|
|
||||||
|
:returns: Library logger.
|
||||||
|
"""
|
||||||
|
if _LoggerHolder.INSTANCE is None:
|
||||||
|
# pylint: disable=import-error
|
||||||
|
from tensorflow.compat.v1 import logging
|
||||||
|
# pylint: enable=import-error
|
||||||
|
_LoggerHolder.INSTANCE = logging
|
||||||
|
_LoggerHolder.INSTANCE.set_verbosity(_LoggerHolder.INSTANCE.ERROR)
|
||||||
|
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
|
return _LoggerHolder.INSTANCE
|
||||||
|
|
||||||
|
|
||||||
|
def enable_logging():
|
||||||
|
""" Enable INFO level logging. """
|
||||||
|
environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||||
|
logger = get_logger()
|
||||||
|
logger.set_verbosity(logger.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def enable_verbose_logging():
|
||||||
|
""" Enable DEBUG level logging. """
|
||||||
|
environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
|
||||||
|
logger = get_logger()
|
||||||
|
logger.set_verbosity(logger.DEBUG)
|
||||||
191
spleeter/utils/tensor.py
Normal file
191
spleeter/utils/tensor.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
""" Utility function for tensorflow. """
|
||||||
|
|
||||||
|
# pylint: disable=import-error
|
||||||
|
import tensorflow as tf
|
||||||
|
import pandas as pd
|
||||||
|
# pylint: enable=import-error
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
def sync_apply(tensor_dict, func, concat_axis=1):
|
||||||
|
""" Return a function that applies synchronously the provided func on the
|
||||||
|
provided dictionnary of tensor. This means that func is applied to the
|
||||||
|
concatenation of the tensors in tensor_dict. This is useful for performing
|
||||||
|
random operation that needs the same drawn value on multiple tensor, such
|
||||||
|
as a random time-crop on both input data and label (the same crop should be
|
||||||
|
applied to both input data and label, so random crop cannot be applied
|
||||||
|
separately on each of them).
|
||||||
|
|
||||||
|
IMPORTANT NOTE: all tensor are assumed to be the same shape.
|
||||||
|
|
||||||
|
Params:
|
||||||
|
- tensor_dict: dictionary (key: strings, values: tf.tensor)
|
||||||
|
a dictionary of tensor.
|
||||||
|
- func: function
|
||||||
|
function to be applied to the concatenation of the tensors in
|
||||||
|
tensor_dict
|
||||||
|
- concat_axis: int
|
||||||
|
The axis on which to perform the concatenation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
processed tensors dictionary with the same name (keys) as input
|
||||||
|
tensor_dict.
|
||||||
|
"""
|
||||||
|
if concat_axis not in {0, 1}:
|
||||||
|
raise NotImplementedError(
|
||||||
|
'Function only implemented for concat_axis equal to 0 or 1')
|
||||||
|
tensor_list = list(tensor_dict.values())
|
||||||
|
concat_tensor = tf.concat(tensor_list, concat_axis)
|
||||||
|
processed_concat_tensor = func(concat_tensor)
|
||||||
|
tensor_shape = tf.shape(list(tensor_dict.values())[0])
|
||||||
|
D = tensor_shape[concat_axis]
|
||||||
|
if concat_axis == 0:
|
||||||
|
return {
|
||||||
|
name: processed_concat_tensor[index * D:(index + 1) * D, :, :]
|
||||||
|
for index, name in enumerate(tensor_dict)
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
name: processed_concat_tensor[:, index * D:(index + 1) * D, :]
|
||||||
|
for index, name in enumerate(tensor_dict)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def from_float32_to_uint8(
|
||||||
|
tensor,
|
||||||
|
tensor_key='tensor',
|
||||||
|
min_key='min',
|
||||||
|
max_key='max'):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param tensor:
|
||||||
|
:param tensor_key:
|
||||||
|
:param min_key:
|
||||||
|
:param max_key:
|
||||||
|
:returns:
|
||||||
|
"""
|
||||||
|
tensor_min = tf.reduce_min(tensor)
|
||||||
|
tensor_max = tf.reduce_max(tensor)
|
||||||
|
return {
|
||||||
|
tensor_key: tf.cast(
|
||||||
|
(tensor - tensor_min) / (tensor_max - tensor_min + 1e-16)
|
||||||
|
* 255.9999, dtype=tf.uint8),
|
||||||
|
min_key: tensor_min,
|
||||||
|
max_key: tensor_max
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def from_uint8_to_float32(tensor, tensor_min, tensor_max):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param tensor:
|
||||||
|
:param tensor_min:
|
||||||
|
:param tensor_max:
|
||||||
|
:returns:
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
tf.cast(tensor, tf.float32)
|
||||||
|
* (tensor_max - tensor_min)
|
||||||
|
/ 255.9999 + tensor_min)
|
||||||
|
|
||||||
|
|
||||||
|
def pad_and_partition(tensor, segment_len):
|
||||||
|
""" Pad and partition a tensor into segment of len segment_len
|
||||||
|
along the first dimension. The tensor is padded with 0 in order
|
||||||
|
to ensure that the first dimension is a multiple of segment_len.
|
||||||
|
|
||||||
|
Tensor must be of known fixed rank
|
||||||
|
|
||||||
|
:Example:
|
||||||
|
|
||||||
|
>>> tensor = [[1, 2, 3], [4, 5, 6]]
|
||||||
|
>>> segment_len = 2
|
||||||
|
>>> pad_and_partition(tensor, segment_len)
|
||||||
|
[[[1, 2], [4, 5]], [[3, 0], [6, 0]]]
|
||||||
|
|
||||||
|
:param tensor:
|
||||||
|
:param segment_len:
|
||||||
|
:returns:
|
||||||
|
"""
|
||||||
|
tensor_size = tf.math.floormod(tf.shape(tensor)[0], segment_len)
|
||||||
|
pad_size = tf.math.floormod(segment_len - tensor_size, segment_len)
|
||||||
|
padded = tf.pad(
|
||||||
|
tensor,
|
||||||
|
[[0, pad_size]] + [[0, 0]] * (len(tensor.shape)-1))
|
||||||
|
split = (tf.shape(padded)[0] + segment_len - 1) // segment_len
|
||||||
|
return tf.reshape(
|
||||||
|
padded,
|
||||||
|
tf.concat(
|
||||||
|
[[split, segment_len], tf.shape(padded)[1:]],
|
||||||
|
axis=0))
|
||||||
|
|
||||||
|
|
||||||
|
def pad_and_reshape(instr_spec, frame_length, F):
|
||||||
|
"""
|
||||||
|
:param instr_spec:
|
||||||
|
:param frame_length:
|
||||||
|
:param F:
|
||||||
|
:returns:
|
||||||
|
"""
|
||||||
|
spec_shape = tf.shape(instr_spec)
|
||||||
|
extension_row = tf.zeros((spec_shape[0], spec_shape[1], 1, spec_shape[-1]))
|
||||||
|
n_extra_row = (frame_length) // 2 + 1 - F
|
||||||
|
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
|
||||||
|
extended_spec = tf.concat([instr_spec, extension], axis=2)
|
||||||
|
old_shape = tf.shape(extended_spec)
|
||||||
|
new_shape = tf.concat([
|
||||||
|
[old_shape[0] * old_shape[1]],
|
||||||
|
old_shape[2:]],
|
||||||
|
axis=0)
|
||||||
|
processed_instr_spec = tf.reshape(extended_spec, new_shape)
|
||||||
|
return processed_instr_spec
|
||||||
|
|
||||||
|
|
||||||
|
def dataset_from_csv(csv_path, **kwargs):
|
||||||
|
""" Load dataset from a CSV file using Pandas. kwargs if any are
|
||||||
|
forwarded to the `pandas.read_csv` function.
|
||||||
|
|
||||||
|
:param csv_path: Path of the CSV file to load dataset from.
|
||||||
|
:returns: Loaded dataset.
|
||||||
|
"""
|
||||||
|
df = pd.read_csv(csv_path, **kwargs)
|
||||||
|
dataset = (
|
||||||
|
tf.data.Dataset.from_tensor_slices(
|
||||||
|
{key: df[key].values for key in df})
|
||||||
|
)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def check_tensor_shape(tensor_tf, target_shape):
|
||||||
|
""" Return a Tensorflow boolean graph that indicates whether
|
||||||
|
sample[features_key] has the specified target shape. Only check
|
||||||
|
not None entries of target_shape.
|
||||||
|
|
||||||
|
:param tensor_tf: Tensor to check shape for.
|
||||||
|
:param target_shape: Target shape to compare tensor to.
|
||||||
|
:returns: True if shape is valid, False otherwise (as TF boolean).
|
||||||
|
"""
|
||||||
|
result = tf.constant(True)
|
||||||
|
for i, target_length in enumerate(target_shape):
|
||||||
|
if target_length:
|
||||||
|
result = tf.logical_and(
|
||||||
|
result,
|
||||||
|
tf.equal(tf.constant(target_length), tf.shape(tensor_tf)[i]))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def set_tensor_shape(tensor, tensor_shape):
|
||||||
|
""" Set shape for a tensor (not in place, as opposed to tf.set_shape)
|
||||||
|
|
||||||
|
:param tensor: Tensor to reshape.
|
||||||
|
:param tensor_shape: Shape to apply to the tensor.
|
||||||
|
:returns: A reshaped tensor.
|
||||||
|
"""
|
||||||
|
# NOTE: That SOUND LIKE IN PLACE HERE ?
|
||||||
|
tensor.set_shape(tensor_shape)
|
||||||
|
return tensor
|
||||||
Reference in New Issue
Block a user