mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
🎨 finalizes model provider and functions
This commit is contained in:
@@ -16,28 +16,22 @@
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
from os.path import exists, join, sep as SEPARATOR
|
|
||||||
|
|
||||||
|
from os.path import exists, sep as SEPARATOR
|
||||||
|
|
||||||
|
from .audio.convertor import db_uint_spectrogram_to_gain
|
||||||
|
from .audio.convertor import spectrogram_to_db_uint
|
||||||
|
from .audio.spectrogram import compute_spectrogram_tf
|
||||||
|
from .audio.spectrogram import random_pitch_shift, random_time_stretch
|
||||||
|
from .utils.logging import get_logger
|
||||||
|
from .utils.tensor import check_tensor_shape, dataset_from_csv
|
||||||
|
from .utils.tensor import set_tensor_shape, sync_apply
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
# pylint: enable=import-error
|
# pylint: enable=import-error
|
||||||
|
|
||||||
from .audio.convertor import (
|
|
||||||
db_uint_spectrogram_to_gain,
|
|
||||||
spectrogram_to_db_uint)
|
|
||||||
from .audio.spectrogram import (
|
|
||||||
compute_spectrogram_tf,
|
|
||||||
random_pitch_shift,
|
|
||||||
random_time_stretch)
|
|
||||||
from .utils.logging import get_logger
|
|
||||||
from .utils.tensor import (
|
|
||||||
check_tensor_shape,
|
|
||||||
dataset_from_csv,
|
|
||||||
set_tensor_shape,
|
|
||||||
sync_apply)
|
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
__email__ = 'spleeter@deezer.com'
|
||||||
__author__ = 'Deezer Research'
|
__author__ = 'Deezer Research'
|
||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|||||||
@@ -3,25 +3,44 @@
|
|||||||
|
|
||||||
""" This package provide model functions. """
|
""" This package provide model functions. """
|
||||||
|
|
||||||
|
from typing import Callable, Dict, Iterable, Optional
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
|
# pylint: disable=import-error
|
||||||
|
import tensorflow as tf
|
||||||
|
# pylint: enable=import-error
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
__email__ = 'spleeter@deezer.com'
|
||||||
__author__ = 'Deezer Research'
|
__author__ = 'Deezer Research'
|
||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
def apply(function, input_tensor, instruments, params={}):
|
def apply(
|
||||||
""" Apply given function to the input tensor.
|
function: Callable,
|
||||||
|
input_tensor: tf.Tensor,
|
||||||
:param function: Function to be applied to tensor.
|
instruments: Iterable[str],
|
||||||
:param input_tensor: Tensor to apply blstm to.
|
params: Optional[Dict] = None) -> Dict:
|
||||||
:param instruments: Iterable that provides a collection of instruments.
|
|
||||||
:param params: (Optional) dict of BLSTM parameters.
|
|
||||||
:returns: Created output tensor dict.
|
|
||||||
"""
|
"""
|
||||||
output_dict = {}
|
Apply given function to the input tensor.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
function:
|
||||||
|
Function to be applied to tensor.
|
||||||
|
input_tensor (tensorflow.Tensor):
|
||||||
|
Tensor to apply blstm to.
|
||||||
|
instruments (Iterable[str]):
|
||||||
|
Iterable that provides a collection of instruments.
|
||||||
|
params:
|
||||||
|
(Optional) dict of BLSTM parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created output tensor dict.
|
||||||
|
"""
|
||||||
|
output_dict: Dict = {}
|
||||||
for instrument in instruments:
|
for instrument in instruments:
|
||||||
out_name = f'{instrument}_spectrogram'
|
out_name = f'{instrument}_spectrogram'
|
||||||
output_dict[out_name] = function(
|
output_dict[out_name] = function(
|
||||||
input_tensor,
|
input_tensor,
|
||||||
output_name=out_name,
|
output_name=out_name,
|
||||||
params=params)
|
params=params or {})
|
||||||
return output_dict
|
return output_dict
|
||||||
|
|||||||
@@ -20,7 +20,14 @@
|
|||||||
selection (LSTM layer dropout rate, regularization strength).
|
selection (LSTM layer dropout rate, regularization strength).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from . import apply
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.compat.v1.keras.initializers import he_uniform
|
from tensorflow.compat.v1.keras.initializers import he_uniform
|
||||||
from tensorflow.compat.v1.keras.layers import CuDNNLSTM
|
from tensorflow.compat.v1.keras.layers import CuDNNLSTM
|
||||||
from tensorflow.keras.layers import (
|
from tensorflow.keras.layers import (
|
||||||
@@ -31,22 +38,33 @@ from tensorflow.keras.layers import (
|
|||||||
TimeDistributed)
|
TimeDistributed)
|
||||||
# pylint: enable=import-error
|
# pylint: enable=import-error
|
||||||
|
|
||||||
from . import apply
|
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
__email__ = 'spleeter@deezer.com'
|
||||||
__author__ = 'Deezer Research'
|
__author__ = 'Deezer Research'
|
||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
def apply_blstm(input_tensor, output_name='output', params={}):
|
def apply_blstm(
|
||||||
""" Apply BLSTM to the given input_tensor.
|
input_tensor: tf.Tensor,
|
||||||
|
output_name: str = 'output',
|
||||||
:param input_tensor: Input of the model.
|
params: Optional[Dict] = None) -> tf.Tensor:
|
||||||
:param output_name: (Optional) name of the output, default to 'output'.
|
|
||||||
:param params: (Optional) dict of BLSTM parameters.
|
|
||||||
:returns: Output tensor.
|
|
||||||
"""
|
"""
|
||||||
units = params.get('lstm_units', 250)
|
Apply BLSTM to the given input_tensor.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
input_tensor (tensorflow.Tensor):
|
||||||
|
Input of the model.
|
||||||
|
output_name (str):
|
||||||
|
(Optional) name of the output, default to 'output'.
|
||||||
|
params (Optional[Dict]):
|
||||||
|
(Optional) dict of BLSTM parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensorflow.Tensor:
|
||||||
|
Output tensor.
|
||||||
|
"""
|
||||||
|
if params is None:
|
||||||
|
params = {}
|
||||||
|
units: int = params.get('lstm_units', 250)
|
||||||
kernel_initializer = he_uniform(seed=50)
|
kernel_initializer = he_uniform(seed=50)
|
||||||
flatten_input = TimeDistributed(Flatten())((input_tensor))
|
flatten_input = TimeDistributed(Flatten())((input_tensor))
|
||||||
|
|
||||||
@@ -65,12 +83,15 @@ def apply_blstm(input_tensor, output_name='output', params={}):
|
|||||||
int(flatten_input.shape[2]),
|
int(flatten_input.shape[2]),
|
||||||
activation='relu',
|
activation='relu',
|
||||||
kernel_initializer=kernel_initializer))((l3))
|
kernel_initializer=kernel_initializer))((l3))
|
||||||
output = TimeDistributed(
|
output: tf.Tensor = TimeDistributed(
|
||||||
Reshape(input_tensor.shape[2:]),
|
Reshape(input_tensor.shape[2:]),
|
||||||
name=output_name)(dense)
|
name=output_name)(dense)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def blstm(input_tensor, output_name='output', params={}):
|
def blstm(
|
||||||
|
input_tensor: tf.Tensor,
|
||||||
|
output_name: str = 'output',
|
||||||
|
params: Optional[Dict] = None) -> tf.Tensor:
|
||||||
""" Model function applier. """
|
""" Model function applier. """
|
||||||
return apply(apply_blstm, input_tensor, output_name, params)
|
return apply(apply_blstm, input_tensor, output_name, params)
|
||||||
|
|||||||
@@ -2,16 +2,23 @@
|
|||||||
# coding: utf8
|
# coding: utf8
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This module contains building functions for U-net source
|
This module contains building functions for U-net source
|
||||||
separation models in a similar way as in A. Jansson et al. "Singing
|
separation models in a similar way as in A. Jansson et al. :
|
||||||
voice separation with deep u-net convolutional networks", ISMIR 2017.
|
|
||||||
Each instrument is modeled by a single U-net convolutional
|
"Singing voice separation with deep u-net convolutional networks",
|
||||||
/ deconvolutional network that take a mix spectrogram as input and the
|
ISMIR 2017
|
||||||
estimated sound spectrogram as output.
|
|
||||||
|
Each instrument is modeled by a single U-net
|
||||||
|
convolutional / deconvolutional network that take a mix spectrogram
|
||||||
|
as input and the estimated sound spectrogram as output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import Any, Dict, Iterable, Optional
|
||||||
|
|
||||||
|
from . import apply
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
@@ -30,20 +37,23 @@ from tensorflow.compat.v1 import logging
|
|||||||
from tensorflow.compat.v1.keras.initializers import he_uniform
|
from tensorflow.compat.v1.keras.initializers import he_uniform
|
||||||
# pylint: enable=import-error
|
# pylint: enable=import-error
|
||||||
|
|
||||||
from . import apply
|
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
__email__ = 'spleeter@deezer.com'
|
||||||
__author__ = 'Deezer Research'
|
__author__ = 'Deezer Research'
|
||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
def _get_conv_activation_layer(params):
|
def _get_conv_activation_layer(params: Dict) -> Any:
|
||||||
"""
|
"""
|
||||||
|
> To be documented.
|
||||||
|
|
||||||
:param params:
|
Parameters:
|
||||||
:returns: Required Activation function.
|
params (Dict):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any:
|
||||||
|
Required Activation function.
|
||||||
"""
|
"""
|
||||||
conv_activation = params.get('conv_activation')
|
conv_activation: str = params.get('conv_activation')
|
||||||
if conv_activation == 'ReLU':
|
if conv_activation == 'ReLU':
|
||||||
return ReLU()
|
return ReLU()
|
||||||
elif conv_activation == 'ELU':
|
elif conv_activation == 'ELU':
|
||||||
@@ -51,13 +61,18 @@ def _get_conv_activation_layer(params):
|
|||||||
return LeakyReLU(0.2)
|
return LeakyReLU(0.2)
|
||||||
|
|
||||||
|
|
||||||
def _get_deconv_activation_layer(params):
|
def _get_deconv_activation_layer(params: Dict) -> Any:
|
||||||
"""
|
"""
|
||||||
|
> To be documented.
|
||||||
|
|
||||||
:param params:
|
Parameters:
|
||||||
:returns: Required Activation function.
|
params (Dict):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any:
|
||||||
|
Required Activation function.
|
||||||
"""
|
"""
|
||||||
deconv_activation = params.get('deconv_activation')
|
deconv_activation: str = params.get('deconv_activation')
|
||||||
if deconv_activation == 'LeakyReLU':
|
if deconv_activation == 'LeakyReLU':
|
||||||
return LeakyReLU(0.2)
|
return LeakyReLU(0.2)
|
||||||
elif deconv_activation == 'ELU':
|
elif deconv_activation == 'ELU':
|
||||||
@@ -66,17 +81,19 @@ def _get_deconv_activation_layer(params):
|
|||||||
|
|
||||||
|
|
||||||
def apply_unet(
|
def apply_unet(
|
||||||
input_tensor,
|
input_tensor: tf.Tensor,
|
||||||
output_name='output',
|
output_name: str = 'output',
|
||||||
params={},
|
params: Optional[Dict] = None,
|
||||||
output_mask_logit=False):
|
output_mask_logit: bool = False) -> Any:
|
||||||
""" Apply a convolutionnal U-net to model a single instrument (one U-net
|
"""
|
||||||
is used for each instrument).
|
Apply a convolutionnal U-net to model a single instrument (one U-net
|
||||||
|
is used for each instrument).
|
||||||
|
|
||||||
:param input_tensor:
|
Parameters:
|
||||||
:param output_name: (Optional) , default to 'output'
|
input_tensor (tensorflow.Tensor):
|
||||||
:param params: (Optional) , default to empty dict.
|
output_name (str):
|
||||||
:param output_mask_logit: (Optional) , default to False.
|
params (Optional[Dict]):
|
||||||
|
output_mask_logit (bool):
|
||||||
"""
|
"""
|
||||||
logging.info(f'Apply unet for {output_name}')
|
logging.info(f'Apply unet for {output_name}')
|
||||||
conv_n_filters = params.get('conv_n_filters', [16, 32, 64, 128, 256, 512])
|
conv_n_filters = params.get('conv_n_filters', [16, 32, 64, 128, 256, 512])
|
||||||
@@ -170,18 +187,32 @@ def apply_unet(
|
|||||||
kernel_initializer=kernel_initializer)((batch12))
|
kernel_initializer=kernel_initializer)((batch12))
|
||||||
|
|
||||||
|
|
||||||
def unet(input_tensor, instruments, params={}):
|
def unet(
|
||||||
|
input_tensor: tf.Tensor,
|
||||||
|
instruments: Iterable[str],
|
||||||
|
params: Optional[Dict] = None) -> Dict:
|
||||||
""" Model function applier. """
|
""" Model function applier. """
|
||||||
return apply(apply_unet, input_tensor, instruments, params)
|
return apply(apply_unet, input_tensor, instruments, params)
|
||||||
|
|
||||||
|
|
||||||
def softmax_unet(input_tensor, instruments, params={}):
|
def softmax_unet(
|
||||||
""" Apply softmax to multitrack unet in order to have mask suming to one.
|
input_tensor: tf.Tensor,
|
||||||
|
instruments: Iterable[str],
|
||||||
|
params: Optional[Dict] = None) -> Dict:
|
||||||
|
"""
|
||||||
|
Apply softmax to multitrack unet in order to have mask suming to one.
|
||||||
|
|
||||||
:param input_tensor: Tensor to apply blstm to.
|
Parameters:
|
||||||
:param instruments: Iterable that provides a collection of instruments.
|
input_tensor (tensorflow.Tensor):
|
||||||
:param params: (Optional) dict of BLSTM parameters.
|
Tensor to apply blstm to.
|
||||||
:returns: Created output tensor dict.
|
instruments (Iterable[str]):
|
||||||
|
Iterable that provides a collection of instruments.
|
||||||
|
params (Optional[Dict]):
|
||||||
|
(Optional) dict of BLSTM parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict:
|
||||||
|
Created output tensor dict.
|
||||||
"""
|
"""
|
||||||
logit_mask_list = []
|
logit_mask_list = []
|
||||||
for instrument in instruments:
|
for instrument in instruments:
|
||||||
|
|||||||
@@ -5,10 +5,12 @@
|
|||||||
This package provides tools for downloading model from network
|
This package provides tools for downloading model from network
|
||||||
using remote storage abstraction.
|
using remote storage abstraction.
|
||||||
|
|
||||||
:Example:
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
>>> provider = MyProviderImplementation()
|
>>> provider = MyProviderImplementation()
|
||||||
>>> provider.get('/path/to/local/storage', params)
|
>>> provider.get('/path/to/local/storage', params)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@@ -26,39 +28,52 @@ class ModelProvider(ABC):
|
|||||||
file download is not available.
|
file download is not available.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_MODEL_PATH = environ.get('MODEL_PATH', 'pretrained_models')
|
DEFAULT_MODEL_PATH: str = environ.get('MODEL_PATH', 'pretrained_models')
|
||||||
MODEL_PROBE_PATH = '.probe'
|
MODEL_PROBE_PATH: str = '.probe'
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def download(self, name, path):
|
def download(_, name: str, path: str) -> None:
|
||||||
""" Download model denoted by the given name to disk.
|
"""
|
||||||
|
Download model denoted by the given name to disk.
|
||||||
|
|
||||||
:param name: Name of the model to download.
|
Parameters:
|
||||||
:param path: Path of the directory to save model into.
|
name (str):
|
||||||
|
Name of the model to download.
|
||||||
|
path (str):
|
||||||
|
Path of the directory to save model into.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def writeProbe(directory):
|
def writeProbe(directory: str) -> None:
|
||||||
""" Write a model probe file into the given directory.
|
|
||||||
|
|
||||||
:param directory: Directory to write probe into.
|
|
||||||
"""
|
"""
|
||||||
probe = join(directory, ModelProvider.MODEL_PROBE_PATH)
|
Write a model probe file into the given directory.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
directory (str):
|
||||||
|
Directory to write probe into.
|
||||||
|
"""
|
||||||
|
probe: str = join(directory, ModelProvider.MODEL_PROBE_PATH)
|
||||||
with open(probe, 'w') as stream:
|
with open(probe, 'w') as stream:
|
||||||
stream.write('OK')
|
stream.write('OK')
|
||||||
|
|
||||||
def get(self, model_directory):
|
def get(self, model_directory: str) -> str:
|
||||||
""" Ensures required model is available at given location.
|
"""
|
||||||
|
Ensures required model is available at given location.
|
||||||
|
|
||||||
:param model_directory: Expected model_directory to be available.
|
Parameters:
|
||||||
:raise IOError: If model can not be retrieved.
|
model_directory (str):
|
||||||
|
Expected model_directory to be available.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
IOError:
|
||||||
|
If model can not be retrieved.
|
||||||
"""
|
"""
|
||||||
# Expend model directory if needed.
|
# Expend model directory if needed.
|
||||||
if not isabs(model_directory):
|
if not isabs(model_directory):
|
||||||
model_directory = join(self.DEFAULT_MODEL_PATH, model_directory)
|
model_directory = join(self.DEFAULT_MODEL_PATH, model_directory)
|
||||||
# Download it if not exists.
|
# Download it if not exists.
|
||||||
model_probe = join(model_directory, self.MODEL_PROBE_PATH)
|
model_probe: str = join(model_directory, self.MODEL_PROBE_PATH)
|
||||||
if not exists(model_probe):
|
if not exists(model_probe):
|
||||||
if not exists(model_directory):
|
if not exists(model_directory):
|
||||||
makedirs(model_directory)
|
makedirs(model_directory)
|
||||||
@@ -68,14 +83,14 @@ class ModelProvider(ABC):
|
|||||||
self.writeProbe(model_directory)
|
self.writeProbe(model_directory)
|
||||||
return model_directory
|
return model_directory
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default(_: type) -> 'ModelProvider':
|
||||||
|
"""
|
||||||
|
Builds and returns a default model provider.
|
||||||
|
|
||||||
def get_default_model_provider():
|
Returns:
|
||||||
""" Builds and returns a default model provider.
|
ModelProvider:
|
||||||
|
A default model provider instance to use.
|
||||||
:returns: A default model provider instance to use.
|
"""
|
||||||
"""
|
from .github import GithubModelProvider
|
||||||
from .github import GithubModelProvider
|
return GithubModelProvider.from_environ()
|
||||||
host = environ.get('GITHUB_HOST', 'https://github.com')
|
|
||||||
repository = environ.get('GITHUB_REPOSITORY', 'deezer/spleeter')
|
|
||||||
release = environ.get('GITHUB_RELEASE', GithubModelProvider.LATEST_RELEASE)
|
|
||||||
return GithubModelProvider(host, repository, release)
|
|
||||||
|
|||||||
@@ -4,27 +4,34 @@
|
|||||||
"""
|
"""
|
||||||
A ModelProvider backed by Github Release feature.
|
A ModelProvider backed by Github Release feature.
|
||||||
|
|
||||||
:Example:
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
>>> from spleeter.model.provider import github
|
>>> from spleeter.model.provider import github
|
||||||
>>> provider = github.GithubModelProvider(
|
>>> provider = github.GithubModelProvider(
|
||||||
'github.com',
|
'github.com',
|
||||||
'Deezer/spleeter',
|
'Deezer/spleeter',
|
||||||
'latest')
|
'latest')
|
||||||
>>> provider.download('2stems', '/path/to/local/storage')
|
>>> provider.download('2stems', '/path/to/local/storage')
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import tarfile
|
import tarfile
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from os import environ
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
|
from typing import Dict
|
||||||
import requests
|
|
||||||
|
|
||||||
from . import ModelProvider
|
from . import ModelProvider
|
||||||
from ...utils.logging import get_logger
|
from ...utils.logging import get_logger
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
|
# pylint: disable=import-error
|
||||||
|
import httpx
|
||||||
|
# pylint: enable=import-error
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
__email__ = 'spleeter@deezer.com'
|
||||||
__author__ = 'Deezer Research'
|
__author__ = 'Deezer Research'
|
||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
@@ -46,69 +53,108 @@ def compute_file_checksum(path):
|
|||||||
class GithubModelProvider(ModelProvider):
|
class GithubModelProvider(ModelProvider):
|
||||||
""" A ModelProvider implementation backed on Github for remote storage. """
|
""" A ModelProvider implementation backed on Github for remote storage. """
|
||||||
|
|
||||||
LATEST_RELEASE = 'v1.4.0'
|
DEFAULT_HOST: str = 'https://github.com'
|
||||||
RELEASE_PATH = 'releases/download'
|
DEFAULT_REPOSITORY: str = 'deezer/spleeter'
|
||||||
CHECKSUM_INDEX = 'checksum.json'
|
|
||||||
|
|
||||||
def __init__(self, host, repository, release):
|
CHECKSUM_INDEX: str = 'checksum.json'
|
||||||
|
LATEST_RELEASE: str = 'v1.4.0'
|
||||||
|
RELEASE_PATH: str = 'releases/download'
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
repository: str,
|
||||||
|
release: str) -> None:
|
||||||
""" Default constructor.
|
""" Default constructor.
|
||||||
|
|
||||||
:param host: Host to the Github instance to reach.
|
Parameters:
|
||||||
:param repository: Repository path within target Github.
|
host (str):
|
||||||
:param release: Release name to get models from.
|
Host to the Github instance to reach.
|
||||||
|
repository (str):
|
||||||
|
Repository path within target Github.
|
||||||
|
release (str):
|
||||||
|
Release name to get models from.
|
||||||
"""
|
"""
|
||||||
self._host = host
|
self._host: str = host
|
||||||
self._repository = repository
|
self._repository: str = repository
|
||||||
self._release = release
|
self._release: str = release
|
||||||
|
|
||||||
def checksum(self, name):
|
@classmethod
|
||||||
""" Downloads and returns reference checksum for the given model name.
|
def from_environ(cls: type) -> 'GithubModelProvider':
|
||||||
|
|
||||||
:param name: Name of the model to get checksum for.
|
|
||||||
:returns: Checksum of the required model.
|
|
||||||
:raise ValueError: If the given model name is not indexed.
|
|
||||||
"""
|
"""
|
||||||
url = '{}/{}/{}/{}/{}'.format(
|
Factory method that creates provider from envvars.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GithubModelProvider:
|
||||||
|
Created instance.
|
||||||
|
"""
|
||||||
|
return cls(
|
||||||
|
environ.get('GITHUB_HOST', cls.DEFAULT_HOST),
|
||||||
|
environ.get('GITHUB_REPOSITORY', cls.DEFAULT_REPOSITORY),
|
||||||
|
environ.get('GITHUB_RELEASE', cls.LATEST_RELEASE))
|
||||||
|
|
||||||
|
def checksum(self, name: str) -> str:
|
||||||
|
"""
|
||||||
|
Downloads and returns reference checksum for the given model name.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
name (str):
|
||||||
|
Name of the model to get checksum for.
|
||||||
|
Returns:
|
||||||
|
str:
|
||||||
|
Checksum of the required model.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError:
|
||||||
|
If the given model name is not indexed.
|
||||||
|
"""
|
||||||
|
url: str = '/'.join((
|
||||||
self._host,
|
self._host,
|
||||||
self._repository,
|
self._repository,
|
||||||
self.RELEASE_PATH,
|
self.RELEASE_PATH,
|
||||||
self._release,
|
self._release,
|
||||||
self.CHECKSUM_INDEX)
|
self.CHECKSUM_INDEX))
|
||||||
response = requests.get(url)
|
response: httpx.Response = httpx.get(url)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
index = response.json()
|
index: Dict = response.json()
|
||||||
if name not in index:
|
if name not in index:
|
||||||
raise ValueError('No checksum for model {}'.format(name))
|
raise ValueError(f'No checksum for model {name}')
|
||||||
return index[name]
|
return index[name]
|
||||||
|
|
||||||
def download(self, name, path):
|
def download(self, name: str, path: str) -> None:
|
||||||
""" Download model denoted by the given name to disk.
|
|
||||||
|
|
||||||
:param name: Name of the model to download.
|
|
||||||
:param path: Path of the directory to save model into.
|
|
||||||
"""
|
"""
|
||||||
url = '{}/{}/{}/{}/{}.tar.gz'.format(
|
Download model denoted by the given name to disk.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
name (str):
|
||||||
|
Name of the model to download.
|
||||||
|
path (str):
|
||||||
|
Path of the directory to save model into.
|
||||||
|
"""
|
||||||
|
url: str = '/'.join((
|
||||||
self._host,
|
self._host,
|
||||||
self._repository,
|
self._repository,
|
||||||
self.RELEASE_PATH,
|
self.RELEASE_PATH,
|
||||||
self._release,
|
self._release,
|
||||||
name)
|
name))
|
||||||
get_logger().info('Downloading model archive %s', url)
|
url = f'{url}.tar.gz'
|
||||||
with requests.get(url, stream=True) as response:
|
get_logger().info(f'Downloading model archive {url}')
|
||||||
response.raise_for_status()
|
with httpx.Client(http2=True) as client:
|
||||||
archive = NamedTemporaryFile(delete=False)
|
with client.strema('GET', url) as response:
|
||||||
try:
|
response.raise_for_status()
|
||||||
with archive as stream:
|
archive = NamedTemporaryFile(delete=False)
|
||||||
# Note: check for chunk size parameters ?
|
try:
|
||||||
for chunk in response.iter_content(chunk_size=8192):
|
with archive as stream:
|
||||||
if chunk:
|
for chunk in response.iter_raw():
|
||||||
stream.write(chunk)
|
stream.write(chunk)
|
||||||
get_logger().info('Validating archive checksum')
|
get_logger().info('Validating archive checksum')
|
||||||
if compute_file_checksum(archive.name) != self.checksum(name):
|
checksum: str = compute_file_checksum(archive.name)
|
||||||
raise IOError('Downloaded file is corrupted, please retry')
|
if checksum != self.checksum(name):
|
||||||
get_logger().info('Extracting downloaded %s archive', name)
|
raise IOError(
|
||||||
with tarfile.open(name=archive.name) as tar:
|
'Downloaded file is corrupted, please retry')
|
||||||
tar.extractall(path=path)
|
get_logger().info(f'Extracting downloaded {name} archive')
|
||||||
finally:
|
with tarfile.open(name=archive.name) as tar:
|
||||||
os.unlink(archive.name)
|
tar.extractall(path=path)
|
||||||
get_logger().info('%s model file(s) extracted', name)
|
finally:
|
||||||
|
os.unlink(archive.name)
|
||||||
|
get_logger().info(f'{name} model file(s) extracted')
|
||||||
|
|||||||
@@ -4,62 +4,60 @@
|
|||||||
"""
|
"""
|
||||||
Module that provides a class wrapper for source separation.
|
Module that provides a class wrapper for source separation.
|
||||||
|
|
||||||
:Example:
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
>>> from spleeter.separator import Separator
|
>>> from spleeter.separator import Separator
|
||||||
>>> separator = Separator('spleeter:2stems')
|
>>> separator = Separator('spleeter:2stems')
|
||||||
>>> separator.separate(waveform, lambda instrument, data: ...)
|
>>> separator.separate(waveform, lambda instrument, data: ...)
|
||||||
>>> separator.separate_to_file(...)
|
>>> separator.separate_to_file(...)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import atexit
|
import atexit
|
||||||
import os
|
import os
|
||||||
import logging
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from os.path import basename, join, splitext, dirname
|
from os.path import basename, join, splitext, dirname
|
||||||
from time import time
|
from typing import Generator, Optional
|
||||||
from typing import Container, NoReturn
|
|
||||||
|
|
||||||
|
from . import SpleeterError
|
||||||
|
from .audio import STFTBackend
|
||||||
|
from .audio.adapter import get_default_audio_adapter
|
||||||
|
from .audio.convertor import to_stereo
|
||||||
|
from .model import EstimatorSpecBuilder, InputProviderFactory
|
||||||
|
from .utils.configuration import load_configuration
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
|
# pylint: disable=import-error
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from librosa.core import stft, istft
|
from librosa.core import stft, istft
|
||||||
from scipy.signal.windows import hann
|
from scipy.signal.windows import hann
|
||||||
|
# pylint: enable=import-error
|
||||||
from . import SpleeterError
|
|
||||||
from .audio.adapter import get_default_audio_adapter
|
|
||||||
from .audio.convertor import to_stereo
|
|
||||||
from .utils.configuration import load_configuration
|
|
||||||
from .utils.estimator import create_estimator, get_default_model_dir
|
|
||||||
from .model import EstimatorSpecBuilder, InputProviderFactory
|
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
__email__ = 'spleeter@deezer.com'
|
||||||
__author__ = 'Deezer Research'
|
__author__ = 'Deezer Research'
|
||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
SUPPORTED_BACKEND: Container[str] = ('auto', 'tensorflow', 'librosa')
|
|
||||||
""" """
|
|
||||||
|
|
||||||
|
class DataGenerator(object):
|
||||||
class DataGenerator():
|
|
||||||
"""
|
"""
|
||||||
Generator object that store a sample and generate it once while called.
|
Generator object that store a sample and generate it once while called.
|
||||||
Used to feed a tensorflow estimator without knowing the whole data at
|
Used to feed a tensorflow estimator without knowing the whole data at
|
||||||
build time.
|
build time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
""" Default constructor. """
|
""" Default constructor. """
|
||||||
self._current_data = None
|
self._current_data = None
|
||||||
|
|
||||||
def update_data(self, data):
|
def update_data(self, data) -> None:
|
||||||
""" Replace internal data. """
|
""" Replace internal data. """
|
||||||
self._current_data = data
|
self._current_data = data
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self) -> Generator:
|
||||||
""" Generation process. """
|
""" Generation process. """
|
||||||
buffer = self._current_data
|
buffer = self._current_data
|
||||||
while buffer:
|
while buffer:
|
||||||
@@ -79,19 +77,50 @@ def get_backend(backend: str) -> str:
|
|||||||
return backend
|
return backend
|
||||||
|
|
||||||
|
|
||||||
|
def create_estimator(params, MWF):
|
||||||
|
"""
|
||||||
|
Initialize tensorflow estimator that will perform separation
|
||||||
|
|
||||||
|
Params:
|
||||||
|
- params: a dictionary of parameters for building the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a tensorflow estimator
|
||||||
|
"""
|
||||||
|
# Load model.
|
||||||
|
provider: ModelProvider = ModelProvider.default()
|
||||||
|
params['model_dir'] = provider.get(params['model_dir'])
|
||||||
|
params['MWF'] = MWF
|
||||||
|
# Setup config
|
||||||
|
session_config = tf.compat.v1.ConfigProto()
|
||||||
|
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
|
||||||
|
config = tf.estimator.RunConfig(session_config=session_config)
|
||||||
|
# Setup estimator
|
||||||
|
estimator = tf.estimator.Estimator(
|
||||||
|
model_fn=model_fn,
|
||||||
|
model_dir=params['model_dir'],
|
||||||
|
params=params,
|
||||||
|
config=config)
|
||||||
|
return estimator
|
||||||
|
|
||||||
|
|
||||||
class Separator(object):
|
class Separator(object):
|
||||||
""" A wrapper class for performing separation. """
|
""" A wrapper class for performing separation. """
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params_descriptor,
|
params_descriptor: str,
|
||||||
MWF: bool = False,
|
MWF: bool = False,
|
||||||
stft_backend: str = 'auto',
|
stft_backend: STFTBackend = STFTBackend.AUTO,
|
||||||
multiprocess: bool = True):
|
multiprocess: bool = True) -> None:
|
||||||
""" Default constructor.
|
"""
|
||||||
|
Default constructor.
|
||||||
|
|
||||||
:param params_descriptor: Descriptor for TF params to be used.
|
Parameters:
|
||||||
:param MWF: (Optional) True if MWF should be used, False otherwise.
|
params_descriptor (str):
|
||||||
|
Descriptor for TF params to be used.
|
||||||
|
MWF (bool):
|
||||||
|
(Optional) `True` if MWF should be used, `False` otherwise.
|
||||||
"""
|
"""
|
||||||
self._params = load_configuration(params_descriptor)
|
self._params = load_configuration(params_descriptor)
|
||||||
self._sample_rate = self._params['sample_rate']
|
self._sample_rate = self._params['sample_rate']
|
||||||
@@ -111,8 +140,7 @@ class Separator(object):
|
|||||||
self._params['stft_backend'] = get_backend(stft_backend)
|
self._params['stft_backend'] = get_backend(stft_backend)
|
||||||
self._data_generator = DataGenerator()
|
self._data_generator = DataGenerator()
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self) -> None:
|
||||||
""" """
|
|
||||||
if self._session:
|
if self._session:
|
||||||
self._session.close()
|
self._session.close()
|
||||||
|
|
||||||
@@ -140,35 +168,19 @@ class Separator(object):
|
|||||||
yield_single_examples=False)
|
yield_single_examples=False)
|
||||||
return self._prediction_generator
|
return self._prediction_generator
|
||||||
|
|
||||||
def join(self, timeout: int = 200) -> NoReturn:
|
def join(self, timeout: int = 200) -> None:
|
||||||
""" Wait for all pending tasks to be finished.
|
"""
|
||||||
|
Wait for all pending tasks to be finished.
|
||||||
|
|
||||||
:param timeout: (Optional) task waiting timeout.
|
Parameters:
|
||||||
|
timeout (int):
|
||||||
|
(Optional) task waiting timeout.
|
||||||
"""
|
"""
|
||||||
while len(self._tasks) > 0:
|
while len(self._tasks) > 0:
|
||||||
task = self._tasks.pop()
|
task = self._tasks.pop()
|
||||||
task.get()
|
task.get()
|
||||||
task.wait(timeout=timeout)
|
task.wait(timeout=timeout)
|
||||||
|
|
||||||
def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
|
|
||||||
""" Performs source separation over the given waveform with tensorflow
|
|
||||||
backend.
|
|
||||||
|
|
||||||
:param waveform: Waveform to apply separation on.
|
|
||||||
:returns: Separated waveforms.
|
|
||||||
"""
|
|
||||||
if not waveform.shape[-1] == 2:
|
|
||||||
waveform = to_stereo(waveform)
|
|
||||||
prediction_generator = self._get_prediction_generator()
|
|
||||||
# NOTE: update data in generator before performing separation.
|
|
||||||
self._data_generator.update_data({
|
|
||||||
'waveform': waveform,
|
|
||||||
'audio_id': np.array(audio_descriptor)})
|
|
||||||
# NOTE: perform separation.
|
|
||||||
prediction = next(prediction_generator)
|
|
||||||
prediction.pop('audio_id')
|
|
||||||
return prediction
|
|
||||||
|
|
||||||
def _stft(self, data, inverse: bool = False, length=None):
|
def _stft(self, data, inverse: bool = False, length=None):
|
||||||
""" Single entrypoint for both stft and istft. This computes stft and
|
""" Single entrypoint for both stft and istft. This computes stft and
|
||||||
istft with librosa on stereo data. The two channels are processed
|
istft with librosa on stereo data. The two channels are processed
|
||||||
@@ -233,7 +245,12 @@ class Separator(object):
|
|||||||
return self._session
|
return self._session
|
||||||
|
|
||||||
def _separate_librosa(self, waveform: np.ndarray, audio_id):
|
def _separate_librosa(self, waveform: np.ndarray, audio_id):
|
||||||
""" Performs separation with librosa backend for STFT.
|
"""
|
||||||
|
Performs separation with librosa backend for STFT.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
waveform (numpy.ndarray):
|
||||||
|
Waveform to be separated (as a numpy array)
|
||||||
"""
|
"""
|
||||||
with self._tf_graph.as_default():
|
with self._tf_graph.as_default():
|
||||||
out = {}
|
out = {}
|
||||||
@@ -260,12 +277,42 @@ class Separator(object):
|
|||||||
length=waveform.shape[0])
|
length=waveform.shape[0])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def separate(self, waveform: np.ndarray, audio_descriptor=''):
|
def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
|
||||||
""" Performs separation on a waveform.
|
"""
|
||||||
|
Performs source separation over the given waveform with tensorflow
|
||||||
|
backend.
|
||||||
|
|
||||||
:param waveform: Waveform to be separated (as a numpy array)
|
Parameters:
|
||||||
:param audio_descriptor: (Optional) string describing the waveform
|
waveform (numpy.ndarray):
|
||||||
(e.g. filename).
|
Waveform to be separated (as a numpy array)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Separated waveforms.
|
||||||
|
"""
|
||||||
|
if not waveform.shape[-1] == 2:
|
||||||
|
waveform = to_stereo(waveform)
|
||||||
|
prediction_generator = self._get_prediction_generator()
|
||||||
|
# NOTE: update data in generator before performing separation.
|
||||||
|
self._data_generator.update_data({
|
||||||
|
'waveform': waveform,
|
||||||
|
'audio_id': np.array(audio_descriptor)})
|
||||||
|
# NOTE: perform separation.
|
||||||
|
prediction = next(prediction_generator)
|
||||||
|
prediction.pop('audio_id')
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
def separate(
|
||||||
|
self,
|
||||||
|
waveform: np.ndarray,
|
||||||
|
audio_descriptor: Optional[str] = None) -> None:
|
||||||
|
"""
|
||||||
|
Performs separation on a waveform.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
waveform (numpy.ndarray):
|
||||||
|
Waveform to be separated (as a numpy array)
|
||||||
|
audio_descriptor (str):
|
||||||
|
(Optional) string describing the waveform (e.g. filename).
|
||||||
"""
|
"""
|
||||||
if self._params['stft_backend'] == 'tensorflow':
|
if self._params['stft_backend'] == 'tensorflow':
|
||||||
return self._separate_tensorflow(waveform, audio_descriptor)
|
return self._separate_tensorflow(waveform, audio_descriptor)
|
||||||
|
|||||||
@@ -4,14 +4,10 @@
|
|||||||
""" Module that provides configuration loading function. """
|
""" Module that provides configuration loading function. """
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import importlib.resources as loader
|
||||||
try:
|
|
||||||
import importlib.resources as loader
|
|
||||||
except ImportError:
|
|
||||||
# Try backported to PY<37 `importlib_resources`.
|
|
||||||
import importlib_resources as loader
|
|
||||||
|
|
||||||
from os.path import exists
|
from os.path import exists
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
from .. import resources, SpleeterError
|
from .. import resources, SpleeterError
|
||||||
|
|
||||||
@@ -20,18 +16,28 @@ __email__ = 'spleeter@deezer.com'
|
|||||||
__author__ = 'Deezer Research'
|
__author__ = 'Deezer Research'
|
||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
_EMBEDDED_CONFIGURATION_PREFIX = 'spleeter:'
|
_EMBEDDED_CONFIGURATION_PREFIX: str = 'spleeter:'
|
||||||
|
|
||||||
|
|
||||||
def load_configuration(descriptor):
|
def load_configuration(descriptor: str) -> Dict:
|
||||||
""" Load configuration from the given descriptor. Could be
|
"""
|
||||||
either a `spleeter:` prefixed embedded configuration name
|
Load configuration from the given descriptor. Could be either a
|
||||||
or a file system path to read configuration from.
|
`spleeter:` prefixed embedded configuration name or a file system path
|
||||||
|
to read configuration from.
|
||||||
|
|
||||||
:param descriptor: Configuration descriptor to use for lookup.
|
Parameters:
|
||||||
:returns: Loaded description as dict.
|
descriptor (str):
|
||||||
:raise ValueError: If required embedded configuration does not exists.
|
Configuration descriptor to use for lookup.
|
||||||
:raise SpleeterError: If required configuration file does not exists.
|
|
||||||
|
Returns:
|
||||||
|
Dict:
|
||||||
|
Loaded description as dict.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError:
|
||||||
|
If required embedded configuration does not exists.
|
||||||
|
SpleeterError:
|
||||||
|
If required configuration file does not exists.
|
||||||
"""
|
"""
|
||||||
# Embedded configuration reading.
|
# Embedded configuration reading.
|
||||||
if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX):
|
if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX):
|
||||||
|
|||||||
@@ -1,46 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# coding: utf8
|
|
||||||
|
|
||||||
""" Utility functions for creating estimator. """
|
|
||||||
|
|
||||||
import tensorflow as tf # pylint: disable=import-error
|
|
||||||
|
|
||||||
from ..model import model_fn
|
|
||||||
from ..model.provider import get_default_model_provider
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_model_dir(model_dir):
|
|
||||||
"""
|
|
||||||
Transforms a string like 'spleeter:2stems' into an actual path.
|
|
||||||
:param model_dir:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
model_provider = get_default_model_provider()
|
|
||||||
return model_provider.get(model_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def create_estimator(params, MWF):
|
|
||||||
"""
|
|
||||||
Initialize tensorflow estimator that will perform separation
|
|
||||||
|
|
||||||
Params:
|
|
||||||
- params: a dictionary of parameters for building the model
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
a tensorflow estimator
|
|
||||||
"""
|
|
||||||
# Load model.
|
|
||||||
params['model_dir'] = get_default_model_dir(params['model_dir'])
|
|
||||||
params['MWF'] = MWF
|
|
||||||
# Setup config
|
|
||||||
session_config = tf.compat.v1.ConfigProto()
|
|
||||||
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
|
|
||||||
config = tf.estimator.RunConfig(session_config=session_config)
|
|
||||||
# Setup estimator
|
|
||||||
estimator = tf.estimator.Estimator(
|
|
||||||
model_fn=model_fn,
|
|
||||||
model_dir=params['model_dir'],
|
|
||||||
params=params,
|
|
||||||
config=config
|
|
||||||
)
|
|
||||||
return estimator
|
|
||||||
Reference in New Issue
Block a user