🎨 finalizes model provider and functions

This commit is contained in:
Faylixe
2020-12-07 19:19:19 +01:00
parent f02bcbd9c7
commit ae9269525d
9 changed files with 398 additions and 265 deletions

View File

@@ -16,28 +16,22 @@
import time import time
import os import os
from os.path import exists, join, sep as SEPARATOR
from os.path import exists, sep as SEPARATOR
from .audio.convertor import db_uint_spectrogram_to_gain
from .audio.convertor import spectrogram_to_db_uint
from .audio.spectrogram import compute_spectrogram_tf
from .audio.spectrogram import random_pitch_shift, random_time_stretch
from .utils.logging import get_logger
from .utils.tensor import check_tensor_shape, dataset_from_csv
from .utils.tensor import set_tensor_shape, sync_apply
# pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import pandas as pd
import numpy as np
import tensorflow as tf import tensorflow as tf
# pylint: enable=import-error # pylint: enable=import-error
from .audio.convertor import (
db_uint_spectrogram_to_gain,
spectrogram_to_db_uint)
from .audio.spectrogram import (
compute_spectrogram_tf,
random_pitch_shift,
random_time_stretch)
from .utils.logging import get_logger
from .utils.tensor import (
check_tensor_shape,
dataset_from_csv,
set_tensor_shape,
sync_apply)
__email__ = 'spleeter@deezer.com' __email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research' __author__ = 'Deezer Research'
__license__ = 'MIT License' __license__ = 'MIT License'

View File

@@ -3,25 +3,44 @@
""" This package provide model functions. """ """ This package provide model functions. """
from typing import Callable, Dict, Iterable, Optional
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com' __email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research' __author__ = 'Deezer Research'
__license__ = 'MIT License' __license__ = 'MIT License'
def apply(function, input_tensor, instruments, params={}): def apply(
""" Apply given function to the input tensor. function: Callable,
input_tensor: tf.Tensor,
:param function: Function to be applied to tensor. instruments: Iterable[str],
:param input_tensor: Tensor to apply blstm to. params: Optional[Dict] = None) -> Dict:
:param instruments: Iterable that provides a collection of instruments.
:param params: (Optional) dict of BLSTM parameters.
:returns: Created output tensor dict.
""" """
output_dict = {} Apply given function to the input tensor.
Parameters:
function:
Function to be applied to tensor.
input_tensor (tensorflow.Tensor):
Tensor to apply blstm to.
instruments (Iterable[str]):
Iterable that provides a collection of instruments.
params:
(Optional) dict of BLSTM parameters.
Returns:
Created output tensor dict.
"""
output_dict: Dict = {}
for instrument in instruments: for instrument in instruments:
out_name = f'{instrument}_spectrogram' out_name = f'{instrument}_spectrogram'
output_dict[out_name] = function( output_dict[out_name] = function(
input_tensor, input_tensor,
output_name=out_name, output_name=out_name,
params=params) params=params or {})
return output_dict return output_dict

View File

@@ -20,7 +20,14 @@
selection (LSTM layer dropout rate, regularization strength). selection (LSTM layer dropout rate, regularization strength).
""" """
from typing import Dict, Optional
from . import apply
# pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import tensorflow as tf
from tensorflow.compat.v1.keras.initializers import he_uniform from tensorflow.compat.v1.keras.initializers import he_uniform
from tensorflow.compat.v1.keras.layers import CuDNNLSTM from tensorflow.compat.v1.keras.layers import CuDNNLSTM
from tensorflow.keras.layers import ( from tensorflow.keras.layers import (
@@ -31,22 +38,33 @@ from tensorflow.keras.layers import (
TimeDistributed) TimeDistributed)
# pylint: enable=import-error # pylint: enable=import-error
from . import apply
__email__ = 'spleeter@deezer.com' __email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research' __author__ = 'Deezer Research'
__license__ = 'MIT License' __license__ = 'MIT License'
def apply_blstm(input_tensor, output_name='output', params={}): def apply_blstm(
""" Apply BLSTM to the given input_tensor. input_tensor: tf.Tensor,
output_name: str = 'output',
:param input_tensor: Input of the model. params: Optional[Dict] = None) -> tf.Tensor:
:param output_name: (Optional) name of the output, default to 'output'.
:param params: (Optional) dict of BLSTM parameters.
:returns: Output tensor.
""" """
units = params.get('lstm_units', 250) Apply BLSTM to the given input_tensor.
Parameters:
input_tensor (tensorflow.Tensor):
Input of the model.
output_name (str):
(Optional) name of the output, default to 'output'.
params (Optional[Dict]):
(Optional) dict of BLSTM parameters.
Returns:
tensorflow.Tensor:
Output tensor.
"""
if params is None:
params = {}
units: int = params.get('lstm_units', 250)
kernel_initializer = he_uniform(seed=50) kernel_initializer = he_uniform(seed=50)
flatten_input = TimeDistributed(Flatten())((input_tensor)) flatten_input = TimeDistributed(Flatten())((input_tensor))
@@ -65,12 +83,15 @@ def apply_blstm(input_tensor, output_name='output', params={}):
int(flatten_input.shape[2]), int(flatten_input.shape[2]),
activation='relu', activation='relu',
kernel_initializer=kernel_initializer))((l3)) kernel_initializer=kernel_initializer))((l3))
output = TimeDistributed( output: tf.Tensor = TimeDistributed(
Reshape(input_tensor.shape[2:]), Reshape(input_tensor.shape[2:]),
name=output_name)(dense) name=output_name)(dense)
return output return output
def blstm(input_tensor, output_name='output', params={}): def blstm(
input_tensor: tf.Tensor,
output_name: str = 'output',
params: Optional[Dict] = None) -> tf.Tensor:
""" Model function applier. """ """ Model function applier. """
return apply(apply_blstm, input_tensor, output_name, params) return apply(apply_blstm, input_tensor, output_name, params)

View File

@@ -2,16 +2,23 @@
# coding: utf8 # coding: utf8
""" """
This module contains building functions for U-net source This module contains building functions for U-net source
separation models in a similar way as in A. Jansson et al. "Singing separation models in a similar way as in A. Jansson et al. :
voice separation with deep u-net convolutional networks", ISMIR 2017.
Each instrument is modeled by a single U-net convolutional "Singing voice separation with deep u-net convolutional networks",
/ deconvolutional network that take a mix spectrogram as input and the ISMIR 2017
estimated sound spectrogram as output.
Each instrument is modeled by a single U-net
convolutional / deconvolutional network that take a mix spectrogram
as input and the estimated sound spectrogram as output.
""" """
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, Optional
from . import apply
# pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import tensorflow as tf import tensorflow as tf
@@ -30,20 +37,23 @@ from tensorflow.compat.v1 import logging
from tensorflow.compat.v1.keras.initializers import he_uniform from tensorflow.compat.v1.keras.initializers import he_uniform
# pylint: enable=import-error # pylint: enable=import-error
from . import apply
__email__ = 'spleeter@deezer.com' __email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research' __author__ = 'Deezer Research'
__license__ = 'MIT License' __license__ = 'MIT License'
def _get_conv_activation_layer(params): def _get_conv_activation_layer(params: Dict) -> Any:
""" """
> To be documented.
:param params: Parameters:
:returns: Required Activation function. params (Dict):
Returns:
Any:
Required Activation function.
""" """
conv_activation = params.get('conv_activation') conv_activation: str = params.get('conv_activation')
if conv_activation == 'ReLU': if conv_activation == 'ReLU':
return ReLU() return ReLU()
elif conv_activation == 'ELU': elif conv_activation == 'ELU':
@@ -51,13 +61,18 @@ def _get_conv_activation_layer(params):
return LeakyReLU(0.2) return LeakyReLU(0.2)
def _get_deconv_activation_layer(params): def _get_deconv_activation_layer(params: Dict) -> Any:
""" """
> To be documented.
:param params: Parameters:
:returns: Required Activation function. params (Dict):
Returns:
Any:
Required Activation function.
""" """
deconv_activation = params.get('deconv_activation') deconv_activation: str = params.get('deconv_activation')
if deconv_activation == 'LeakyReLU': if deconv_activation == 'LeakyReLU':
return LeakyReLU(0.2) return LeakyReLU(0.2)
elif deconv_activation == 'ELU': elif deconv_activation == 'ELU':
@@ -66,17 +81,19 @@ def _get_deconv_activation_layer(params):
def apply_unet( def apply_unet(
input_tensor, input_tensor: tf.Tensor,
output_name='output', output_name: str = 'output',
params={}, params: Optional[Dict] = None,
output_mask_logit=False): output_mask_logit: bool = False) -> Any:
""" Apply a convolutionnal U-net to model a single instrument (one U-net """
is used for each instrument). Apply a convolutionnal U-net to model a single instrument (one U-net
is used for each instrument).
:param input_tensor: Parameters:
:param output_name: (Optional) , default to 'output' input_tensor (tensorflow.Tensor):
:param params: (Optional) , default to empty dict. output_name (str):
:param output_mask_logit: (Optional) , default to False. params (Optional[Dict]):
output_mask_logit (bool):
""" """
logging.info(f'Apply unet for {output_name}') logging.info(f'Apply unet for {output_name}')
conv_n_filters = params.get('conv_n_filters', [16, 32, 64, 128, 256, 512]) conv_n_filters = params.get('conv_n_filters', [16, 32, 64, 128, 256, 512])
@@ -170,18 +187,32 @@ def apply_unet(
kernel_initializer=kernel_initializer)((batch12)) kernel_initializer=kernel_initializer)((batch12))
def unet(input_tensor, instruments, params={}): def unet(
input_tensor: tf.Tensor,
instruments: Iterable[str],
params: Optional[Dict] = None) -> Dict:
""" Model function applier. """ """ Model function applier. """
return apply(apply_unet, input_tensor, instruments, params) return apply(apply_unet, input_tensor, instruments, params)
def softmax_unet(input_tensor, instruments, params={}): def softmax_unet(
""" Apply softmax to multitrack unet in order to have mask suming to one. input_tensor: tf.Tensor,
instruments: Iterable[str],
params: Optional[Dict] = None) -> Dict:
"""
Apply softmax to multitrack unet in order to have mask suming to one.
:param input_tensor: Tensor to apply blstm to. Parameters:
:param instruments: Iterable that provides a collection of instruments. input_tensor (tensorflow.Tensor):
:param params: (Optional) dict of BLSTM parameters. Tensor to apply blstm to.
:returns: Created output tensor dict. instruments (Iterable[str]):
Iterable that provides a collection of instruments.
params (Optional[Dict]):
(Optional) dict of BLSTM parameters.
Returns:
Dict:
Created output tensor dict.
""" """
logit_mask_list = [] logit_mask_list = []
for instrument in instruments: for instrument in instruments:

View File

@@ -5,10 +5,12 @@
This package provides tools for downloading model from network This package provides tools for downloading model from network
using remote storage abstraction. using remote storage abstraction.
:Example: Examples:
```python
>>> provider = MyProviderImplementation() >>> provider = MyProviderImplementation()
>>> provider.get('/path/to/local/storage', params) >>> provider.get('/path/to/local/storage', params)
```
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@@ -26,39 +28,52 @@ class ModelProvider(ABC):
file download is not available. file download is not available.
""" """
DEFAULT_MODEL_PATH = environ.get('MODEL_PATH', 'pretrained_models') DEFAULT_MODEL_PATH: str = environ.get('MODEL_PATH', 'pretrained_models')
MODEL_PROBE_PATH = '.probe' MODEL_PROBE_PATH: str = '.probe'
@abstractmethod @abstractmethod
def download(self, name, path): def download(_, name: str, path: str) -> None:
""" Download model denoted by the given name to disk. """
Download model denoted by the given name to disk.
:param name: Name of the model to download. Parameters:
:param path: Path of the directory to save model into. name (str):
Name of the model to download.
path (str):
Path of the directory to save model into.
""" """
pass pass
@staticmethod @staticmethod
def writeProbe(directory): def writeProbe(directory: str) -> None:
""" Write a model probe file into the given directory.
:param directory: Directory to write probe into.
""" """
probe = join(directory, ModelProvider.MODEL_PROBE_PATH) Write a model probe file into the given directory.
Parameters:
directory (str):
Directory to write probe into.
"""
probe: str = join(directory, ModelProvider.MODEL_PROBE_PATH)
with open(probe, 'w') as stream: with open(probe, 'w') as stream:
stream.write('OK') stream.write('OK')
def get(self, model_directory): def get(self, model_directory: str) -> str:
""" Ensures required model is available at given location. """
Ensures required model is available at given location.
:param model_directory: Expected model_directory to be available. Parameters:
:raise IOError: If model can not be retrieved. model_directory (str):
Expected model_directory to be available.
Raises:
IOError:
If model can not be retrieved.
""" """
# Expend model directory if needed. # Expend model directory if needed.
if not isabs(model_directory): if not isabs(model_directory):
model_directory = join(self.DEFAULT_MODEL_PATH, model_directory) model_directory = join(self.DEFAULT_MODEL_PATH, model_directory)
# Download it if not exists. # Download it if not exists.
model_probe = join(model_directory, self.MODEL_PROBE_PATH) model_probe: str = join(model_directory, self.MODEL_PROBE_PATH)
if not exists(model_probe): if not exists(model_probe):
if not exists(model_directory): if not exists(model_directory):
makedirs(model_directory) makedirs(model_directory)
@@ -68,14 +83,14 @@ class ModelProvider(ABC):
self.writeProbe(model_directory) self.writeProbe(model_directory)
return model_directory return model_directory
@classmethod
def default(_: type) -> 'ModelProvider':
"""
Builds and returns a default model provider.
def get_default_model_provider(): Returns:
""" Builds and returns a default model provider. ModelProvider:
A default model provider instance to use.
:returns: A default model provider instance to use. """
""" from .github import GithubModelProvider
from .github import GithubModelProvider return GithubModelProvider.from_environ()
host = environ.get('GITHUB_HOST', 'https://github.com')
repository = environ.get('GITHUB_REPOSITORY', 'deezer/spleeter')
release = environ.get('GITHUB_RELEASE', GithubModelProvider.LATEST_RELEASE)
return GithubModelProvider(host, repository, release)

View File

@@ -4,27 +4,34 @@
""" """
A ModelProvider backed by Github Release feature. A ModelProvider backed by Github Release feature.
:Example: Examples:
```python
>>> from spleeter.model.provider import github >>> from spleeter.model.provider import github
>>> provider = github.GithubModelProvider( >>> provider = github.GithubModelProvider(
'github.com', 'github.com',
'Deezer/spleeter', 'Deezer/spleeter',
'latest') 'latest')
>>> provider.download('2stems', '/path/to/local/storage') >>> provider.download('2stems', '/path/to/local/storage')
```
""" """
import hashlib import hashlib
import tarfile import tarfile
import os import os
from os import environ
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Dict
import requests
from . import ModelProvider from . import ModelProvider
from ...utils.logging import get_logger from ...utils.logging import get_logger
# pyright: reportMissingImports=false
# pylint: disable=import-error
import httpx
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com' __email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research' __author__ = 'Deezer Research'
__license__ = 'MIT License' __license__ = 'MIT License'
@@ -46,69 +53,108 @@ def compute_file_checksum(path):
class GithubModelProvider(ModelProvider): class GithubModelProvider(ModelProvider):
""" A ModelProvider implementation backed on Github for remote storage. """ """ A ModelProvider implementation backed on Github for remote storage. """
LATEST_RELEASE = 'v1.4.0' DEFAULT_HOST: str = 'https://github.com'
RELEASE_PATH = 'releases/download' DEFAULT_REPOSITORY: str = 'deezer/spleeter'
CHECKSUM_INDEX = 'checksum.json'
def __init__(self, host, repository, release): CHECKSUM_INDEX: str = 'checksum.json'
LATEST_RELEASE: str = 'v1.4.0'
RELEASE_PATH: str = 'releases/download'
def __init__(
self,
host: str,
repository: str,
release: str) -> None:
""" Default constructor. """ Default constructor.
:param host: Host to the Github instance to reach. Parameters:
:param repository: Repository path within target Github. host (str):
:param release: Release name to get models from. Host to the Github instance to reach.
repository (str):
Repository path within target Github.
release (str):
Release name to get models from.
""" """
self._host = host self._host: str = host
self._repository = repository self._repository: str = repository
self._release = release self._release: str = release
def checksum(self, name): @classmethod
""" Downloads and returns reference checksum for the given model name. def from_environ(cls: type) -> 'GithubModelProvider':
:param name: Name of the model to get checksum for.
:returns: Checksum of the required model.
:raise ValueError: If the given model name is not indexed.
""" """
url = '{}/{}/{}/{}/{}'.format( Factory method that creates provider from envvars.
Returns:
GithubModelProvider:
Created instance.
"""
return cls(
environ.get('GITHUB_HOST', cls.DEFAULT_HOST),
environ.get('GITHUB_REPOSITORY', cls.DEFAULT_REPOSITORY),
environ.get('GITHUB_RELEASE', cls.LATEST_RELEASE))
def checksum(self, name: str) -> str:
"""
Downloads and returns reference checksum for the given model name.
Parameters:
name (str):
Name of the model to get checksum for.
Returns:
str:
Checksum of the required model.
Raises:
ValueError:
If the given model name is not indexed.
"""
url: str = '/'.join((
self._host, self._host,
self._repository, self._repository,
self.RELEASE_PATH, self.RELEASE_PATH,
self._release, self._release,
self.CHECKSUM_INDEX) self.CHECKSUM_INDEX))
response = requests.get(url) response: httpx.Response = httpx.get(url)
response.raise_for_status() response.raise_for_status()
index = response.json() index: Dict = response.json()
if name not in index: if name not in index:
raise ValueError('No checksum for model {}'.format(name)) raise ValueError(f'No checksum for model {name}')
return index[name] return index[name]
def download(self, name, path): def download(self, name: str, path: str) -> None:
""" Download model denoted by the given name to disk.
:param name: Name of the model to download.
:param path: Path of the directory to save model into.
""" """
url = '{}/{}/{}/{}/{}.tar.gz'.format( Download model denoted by the given name to disk.
Parameters:
name (str):
Name of the model to download.
path (str):
Path of the directory to save model into.
"""
url: str = '/'.join((
self._host, self._host,
self._repository, self._repository,
self.RELEASE_PATH, self.RELEASE_PATH,
self._release, self._release,
name) name))
get_logger().info('Downloading model archive %s', url) url = f'{url}.tar.gz'
with requests.get(url, stream=True) as response: get_logger().info(f'Downloading model archive {url}')
response.raise_for_status() with httpx.Client(http2=True) as client:
archive = NamedTemporaryFile(delete=False) with client.strema('GET', url) as response:
try: response.raise_for_status()
with archive as stream: archive = NamedTemporaryFile(delete=False)
# Note: check for chunk size parameters ? try:
for chunk in response.iter_content(chunk_size=8192): with archive as stream:
if chunk: for chunk in response.iter_raw():
stream.write(chunk) stream.write(chunk)
get_logger().info('Validating archive checksum') get_logger().info('Validating archive checksum')
if compute_file_checksum(archive.name) != self.checksum(name): checksum: str = compute_file_checksum(archive.name)
raise IOError('Downloaded file is corrupted, please retry') if checksum != self.checksum(name):
get_logger().info('Extracting downloaded %s archive', name) raise IOError(
with tarfile.open(name=archive.name) as tar: 'Downloaded file is corrupted, please retry')
tar.extractall(path=path) get_logger().info(f'Extracting downloaded {name} archive')
finally: with tarfile.open(name=archive.name) as tar:
os.unlink(archive.name) tar.extractall(path=path)
get_logger().info('%s model file(s) extracted', name) finally:
os.unlink(archive.name)
get_logger().info(f'{name} model file(s) extracted')

View File

@@ -4,62 +4,60 @@
""" """
Module that provides a class wrapper for source separation. Module that provides a class wrapper for source separation.
:Example: Examples:
```python
>>> from spleeter.separator import Separator >>> from spleeter.separator import Separator
>>> separator = Separator('spleeter:2stems') >>> separator = Separator('spleeter:2stems')
>>> separator.separate(waveform, lambda instrument, data: ...) >>> separator.separate(waveform, lambda instrument, data: ...)
>>> separator.separate_to_file(...) >>> separator.separate_to_file(...)
```
""" """
import atexit import atexit
import os import os
import logging
from enum import Enum
from multiprocessing import Pool from multiprocessing import Pool
from os.path import basename, join, splitext, dirname from os.path import basename, join, splitext, dirname
from time import time from typing import Generator, Optional
from typing import Container, NoReturn
from . import SpleeterError
from .audio import STFTBackend
from .audio.adapter import get_default_audio_adapter
from .audio.convertor import to_stereo
from .model import EstimatorSpecBuilder, InputProviderFactory
from .utils.configuration import load_configuration
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from librosa.core import stft, istft from librosa.core import stft, istft
from scipy.signal.windows import hann from scipy.signal.windows import hann
# pylint: enable=import-error
from . import SpleeterError
from .audio.adapter import get_default_audio_adapter
from .audio.convertor import to_stereo
from .utils.configuration import load_configuration
from .utils.estimator import create_estimator, get_default_model_dir
from .model import EstimatorSpecBuilder, InputProviderFactory
__email__ = 'spleeter@deezer.com' __email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research' __author__ = 'Deezer Research'
__license__ = 'MIT License' __license__ = 'MIT License'
SUPPORTED_BACKEND: Container[str] = ('auto', 'tensorflow', 'librosa')
""" """
class DataGenerator(object):
class DataGenerator():
""" """
Generator object that store a sample and generate it once while called. Generator object that store a sample and generate it once while called.
Used to feed a tensorflow estimator without knowing the whole data at Used to feed a tensorflow estimator without knowing the whole data at
build time. build time.
""" """
def __init__(self): def __init__(self) -> None:
""" Default constructor. """ """ Default constructor. """
self._current_data = None self._current_data = None
def update_data(self, data): def update_data(self, data) -> None:
""" Replace internal data. """ """ Replace internal data. """
self._current_data = data self._current_data = data
def __call__(self): def __call__(self) -> Generator:
""" Generation process. """ """ Generation process. """
buffer = self._current_data buffer = self._current_data
while buffer: while buffer:
@@ -79,19 +77,50 @@ def get_backend(backend: str) -> str:
return backend return backend
def create_estimator(params, MWF):
"""
Initialize tensorflow estimator that will perform separation
Params:
- params: a dictionary of parameters for building the model
Returns:
a tensorflow estimator
"""
# Load model.
provider: ModelProvider = ModelProvider.default()
params['model_dir'] = provider.get(params['model_dir'])
params['MWF'] = MWF
# Setup config
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
config = tf.estimator.RunConfig(session_config=session_config)
# Setup estimator
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params['model_dir'],
params=params,
config=config)
return estimator
class Separator(object): class Separator(object):
""" A wrapper class for performing separation. """ """ A wrapper class for performing separation. """
def __init__( def __init__(
self, self,
params_descriptor, params_descriptor: str,
MWF: bool = False, MWF: bool = False,
stft_backend: str = 'auto', stft_backend: STFTBackend = STFTBackend.AUTO,
multiprocess: bool = True): multiprocess: bool = True) -> None:
""" Default constructor. """
Default constructor.
:param params_descriptor: Descriptor for TF params to be used. Parameters:
:param MWF: (Optional) True if MWF should be used, False otherwise. params_descriptor (str):
Descriptor for TF params to be used.
MWF (bool):
(Optional) `True` if MWF should be used, `False` otherwise.
""" """
self._params = load_configuration(params_descriptor) self._params = load_configuration(params_descriptor)
self._sample_rate = self._params['sample_rate'] self._sample_rate = self._params['sample_rate']
@@ -111,8 +140,7 @@ class Separator(object):
self._params['stft_backend'] = get_backend(stft_backend) self._params['stft_backend'] = get_backend(stft_backend)
self._data_generator = DataGenerator() self._data_generator = DataGenerator()
def __del__(self): def __del__(self) -> None:
""" """
if self._session: if self._session:
self._session.close() self._session.close()
@@ -140,35 +168,19 @@ class Separator(object):
yield_single_examples=False) yield_single_examples=False)
return self._prediction_generator return self._prediction_generator
def join(self, timeout: int = 200) -> NoReturn: def join(self, timeout: int = 200) -> None:
""" Wait for all pending tasks to be finished. """
Wait for all pending tasks to be finished.
:param timeout: (Optional) task waiting timeout. Parameters:
timeout (int):
(Optional) task waiting timeout.
""" """
while len(self._tasks) > 0: while len(self._tasks) > 0:
task = self._tasks.pop() task = self._tasks.pop()
task.get() task.get()
task.wait(timeout=timeout) task.wait(timeout=timeout)
def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
""" Performs source separation over the given waveform with tensorflow
backend.
:param waveform: Waveform to apply separation on.
:returns: Separated waveforms.
"""
if not waveform.shape[-1] == 2:
waveform = to_stereo(waveform)
prediction_generator = self._get_prediction_generator()
# NOTE: update data in generator before performing separation.
self._data_generator.update_data({
'waveform': waveform,
'audio_id': np.array(audio_descriptor)})
# NOTE: perform separation.
prediction = next(prediction_generator)
prediction.pop('audio_id')
return prediction
def _stft(self, data, inverse: bool = False, length=None): def _stft(self, data, inverse: bool = False, length=None):
""" Single entrypoint for both stft and istft. This computes stft and """ Single entrypoint for both stft and istft. This computes stft and
istft with librosa on stereo data. The two channels are processed istft with librosa on stereo data. The two channels are processed
@@ -233,7 +245,12 @@ class Separator(object):
return self._session return self._session
def _separate_librosa(self, waveform: np.ndarray, audio_id): def _separate_librosa(self, waveform: np.ndarray, audio_id):
""" Performs separation with librosa backend for STFT. """
Performs separation with librosa backend for STFT.
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
""" """
with self._tf_graph.as_default(): with self._tf_graph.as_default():
out = {} out = {}
@@ -260,12 +277,42 @@ class Separator(object):
length=waveform.shape[0]) length=waveform.shape[0])
return out return out
def separate(self, waveform: np.ndarray, audio_descriptor=''): def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
""" Performs separation on a waveform. """
Performs source separation over the given waveform with tensorflow
backend.
:param waveform: Waveform to be separated (as a numpy array) Parameters:
:param audio_descriptor: (Optional) string describing the waveform waveform (numpy.ndarray):
(e.g. filename). Waveform to be separated (as a numpy array)
Returns:
Separated waveforms.
"""
if not waveform.shape[-1] == 2:
waveform = to_stereo(waveform)
prediction_generator = self._get_prediction_generator()
# NOTE: update data in generator before performing separation.
self._data_generator.update_data({
'waveform': waveform,
'audio_id': np.array(audio_descriptor)})
# NOTE: perform separation.
prediction = next(prediction_generator)
prediction.pop('audio_id')
return prediction
def separate(
self,
waveform: np.ndarray,
audio_descriptor: Optional[str] = None) -> None:
"""
Performs separation on a waveform.
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
audio_descriptor (str):
(Optional) string describing the waveform (e.g. filename).
""" """
if self._params['stft_backend'] == 'tensorflow': if self._params['stft_backend'] == 'tensorflow':
return self._separate_tensorflow(waveform, audio_descriptor) return self._separate_tensorflow(waveform, audio_descriptor)

View File

@@ -4,14 +4,10 @@
""" Module that provides configuration loading function. """ """ Module that provides configuration loading function. """
import json import json
import importlib.resources as loader
try:
import importlib.resources as loader
except ImportError:
# Try backported to PY<37 `importlib_resources`.
import importlib_resources as loader
from os.path import exists from os.path import exists
from typing import Dict
from .. import resources, SpleeterError from .. import resources, SpleeterError
@@ -20,18 +16,28 @@ __email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research' __author__ = 'Deezer Research'
__license__ = 'MIT License' __license__ = 'MIT License'
_EMBEDDED_CONFIGURATION_PREFIX = 'spleeter:' _EMBEDDED_CONFIGURATION_PREFIX: str = 'spleeter:'
def load_configuration(descriptor): def load_configuration(descriptor: str) -> Dict:
""" Load configuration from the given descriptor. Could be """
either a `spleeter:` prefixed embedded configuration name Load configuration from the given descriptor. Could be either a
or a file system path to read configuration from. `spleeter:` prefixed embedded configuration name or a file system path
to read configuration from.
:param descriptor: Configuration descriptor to use for lookup. Parameters:
:returns: Loaded description as dict. descriptor (str):
:raise ValueError: If required embedded configuration does not exists. Configuration descriptor to use for lookup.
:raise SpleeterError: If required configuration file does not exists.
Returns:
Dict:
Loaded description as dict.
Raises:
ValueError:
If required embedded configuration does not exists.
SpleeterError:
If required configuration file does not exists.
""" """
# Embedded configuration reading. # Embedded configuration reading.
if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX): if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX):

View File

@@ -1,46 +0,0 @@
#!/usr/bin/env python
# coding: utf8
""" Utility functions for creating estimator. """
import tensorflow as tf # pylint: disable=import-error
from ..model import model_fn
from ..model.provider import get_default_model_provider
def get_default_model_dir(model_dir):
"""
Transforms a string like 'spleeter:2stems' into an actual path.
:param model_dir:
:return:
"""
model_provider = get_default_model_provider()
return model_provider.get(model_dir)
def create_estimator(params, MWF):
"""
Initialize tensorflow estimator that will perform separation
Params:
- params: a dictionary of parameters for building the model
Returns:
a tensorflow estimator
"""
# Load model.
params['model_dir'] = get_default_model_dir(params['model_dir'])
params['MWF'] = MWF
# Setup config
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
config = tf.estimator.RunConfig(session_config=session_config)
# Setup estimator
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params['model_dir'],
params=params,
config=config
)
return estimator