🎨 finalizes model provider and functions

This commit is contained in:
Faylixe
2020-12-07 19:19:19 +01:00
parent f02bcbd9c7
commit ae9269525d
9 changed files with 398 additions and 265 deletions

View File

@@ -16,28 +16,22 @@
import time
import os
from os.path import exists, join, sep as SEPARATOR
from os.path import exists, sep as SEPARATOR
from .audio.convertor import db_uint_spectrogram_to_gain
from .audio.convertor import spectrogram_to_db_uint
from .audio.spectrogram import compute_spectrogram_tf
from .audio.spectrogram import random_pitch_shift, random_time_stretch
from .utils.logging import get_logger
from .utils.tensor import check_tensor_shape, dataset_from_csv
from .utils.tensor import set_tensor_shape, sync_apply
# pyright: reportMissingImports=false
# pylint: disable=import-error
import pandas as pd
import numpy as np
import tensorflow as tf
# pylint: enable=import-error
from .audio.convertor import (
db_uint_spectrogram_to_gain,
spectrogram_to_db_uint)
from .audio.spectrogram import (
compute_spectrogram_tf,
random_pitch_shift,
random_time_stretch)
from .utils.logging import get_logger
from .utils.tensor import (
check_tensor_shape,
dataset_from_csv,
set_tensor_shape,
sync_apply)
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'

View File

@@ -3,25 +3,44 @@
""" This package provide model functions. """
from typing import Callable, Dict, Iterable, Optional
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def apply(function, input_tensor, instruments, params={}):
""" Apply given function to the input tensor.
:param function: Function to be applied to tensor.
:param input_tensor: Tensor to apply blstm to.
:param instruments: Iterable that provides a collection of instruments.
:param params: (Optional) dict of BLSTM parameters.
:returns: Created output tensor dict.
def apply(
function: Callable,
input_tensor: tf.Tensor,
instruments: Iterable[str],
params: Optional[Dict] = None) -> Dict:
"""
output_dict = {}
Apply given function to the input tensor.
Parameters:
function:
Function to be applied to tensor.
input_tensor (tensorflow.Tensor):
Tensor to apply blstm to.
instruments (Iterable[str]):
Iterable that provides a collection of instruments.
params:
(Optional) dict of BLSTM parameters.
Returns:
Created output tensor dict.
"""
output_dict: Dict = {}
for instrument in instruments:
out_name = f'{instrument}_spectrogram'
output_dict[out_name] = function(
input_tensor,
output_name=out_name,
params=params)
params=params or {})
return output_dict

View File

@@ -20,7 +20,14 @@
selection (LSTM layer dropout rate, regularization strength).
"""
from typing import Dict, Optional
from . import apply
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
from tensorflow.compat.v1.keras.initializers import he_uniform
from tensorflow.compat.v1.keras.layers import CuDNNLSTM
from tensorflow.keras.layers import (
@@ -31,22 +38,33 @@ from tensorflow.keras.layers import (
TimeDistributed)
# pylint: enable=import-error
from . import apply
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def apply_blstm(input_tensor, output_name='output', params={}):
""" Apply BLSTM to the given input_tensor.
:param input_tensor: Input of the model.
:param output_name: (Optional) name of the output, default to 'output'.
:param params: (Optional) dict of BLSTM parameters.
:returns: Output tensor.
def apply_blstm(
input_tensor: tf.Tensor,
output_name: str = 'output',
params: Optional[Dict] = None) -> tf.Tensor:
"""
units = params.get('lstm_units', 250)
Apply BLSTM to the given input_tensor.
Parameters:
input_tensor (tensorflow.Tensor):
Input of the model.
output_name (str):
(Optional) name of the output, default to 'output'.
params (Optional[Dict]):
(Optional) dict of BLSTM parameters.
Returns:
tensorflow.Tensor:
Output tensor.
"""
if params is None:
params = {}
units: int = params.get('lstm_units', 250)
kernel_initializer = he_uniform(seed=50)
flatten_input = TimeDistributed(Flatten())((input_tensor))
@@ -65,12 +83,15 @@ def apply_blstm(input_tensor, output_name='output', params={}):
int(flatten_input.shape[2]),
activation='relu',
kernel_initializer=kernel_initializer))((l3))
output = TimeDistributed(
output: tf.Tensor = TimeDistributed(
Reshape(input_tensor.shape[2:]),
name=output_name)(dense)
return output
def blstm(input_tensor, output_name='output', params={}):
def blstm(
input_tensor: tf.Tensor,
output_name: str = 'output',
params: Optional[Dict] = None) -> tf.Tensor:
""" Model function applier. """
return apply(apply_blstm, input_tensor, output_name, params)

View File

@@ -2,16 +2,23 @@
# coding: utf8
"""
This module contains building functions for U-net source
separation models in a similar way as in A. Jansson et al. "Singing
voice separation with deep u-net convolutional networks", ISMIR 2017.
Each instrument is modeled by a single U-net convolutional
/ deconvolutional network that take a mix spectrogram as input and the
estimated sound spectrogram as output.
This module contains building functions for U-net source
separation models in a similar way as in A. Jansson et al. :
"Singing voice separation with deep u-net convolutional networks",
ISMIR 2017
Each instrument is modeled by a single U-net
convolutional / deconvolutional network that take a mix spectrogram
as input and the estimated sound spectrogram as output.
"""
from functools import partial
from typing import Any, Dict, Iterable, Optional
from . import apply
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
@@ -30,20 +37,23 @@ from tensorflow.compat.v1 import logging
from tensorflow.compat.v1.keras.initializers import he_uniform
# pylint: enable=import-error
from . import apply
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def _get_conv_activation_layer(params):
def _get_conv_activation_layer(params: Dict) -> Any:
"""
> To be documented.
:param params:
:returns: Required Activation function.
Parameters:
params (Dict):
Returns:
Any:
Required Activation function.
"""
conv_activation = params.get('conv_activation')
conv_activation: str = params.get('conv_activation')
if conv_activation == 'ReLU':
return ReLU()
elif conv_activation == 'ELU':
@@ -51,13 +61,18 @@ def _get_conv_activation_layer(params):
return LeakyReLU(0.2)
def _get_deconv_activation_layer(params):
def _get_deconv_activation_layer(params: Dict) -> Any:
"""
> To be documented.
:param params:
:returns: Required Activation function.
Parameters:
params (Dict):
Returns:
Any:
Required Activation function.
"""
deconv_activation = params.get('deconv_activation')
deconv_activation: str = params.get('deconv_activation')
if deconv_activation == 'LeakyReLU':
return LeakyReLU(0.2)
elif deconv_activation == 'ELU':
@@ -66,17 +81,19 @@ def _get_deconv_activation_layer(params):
def apply_unet(
input_tensor,
output_name='output',
params={},
output_mask_logit=False):
""" Apply a convolutionnal U-net to model a single instrument (one U-net
is used for each instrument).
input_tensor: tf.Tensor,
output_name: str = 'output',
params: Optional[Dict] = None,
output_mask_logit: bool = False) -> Any:
"""
Apply a convolutionnal U-net to model a single instrument (one U-net
is used for each instrument).
:param input_tensor:
:param output_name: (Optional) , default to 'output'
:param params: (Optional) , default to empty dict.
:param output_mask_logit: (Optional) , default to False.
Parameters:
input_tensor (tensorflow.Tensor):
output_name (str):
params (Optional[Dict]):
output_mask_logit (bool):
"""
logging.info(f'Apply unet for {output_name}')
conv_n_filters = params.get('conv_n_filters', [16, 32, 64, 128, 256, 512])
@@ -170,18 +187,32 @@ def apply_unet(
kernel_initializer=kernel_initializer)((batch12))
def unet(input_tensor, instruments, params={}):
def unet(
input_tensor: tf.Tensor,
instruments: Iterable[str],
params: Optional[Dict] = None) -> Dict:
""" Model function applier. """
return apply(apply_unet, input_tensor, instruments, params)
def softmax_unet(input_tensor, instruments, params={}):
""" Apply softmax to multitrack unet in order to have mask suming to one.
def softmax_unet(
input_tensor: tf.Tensor,
instruments: Iterable[str],
params: Optional[Dict] = None) -> Dict:
"""
Apply softmax to multitrack unet in order to have mask suming to one.
:param input_tensor: Tensor to apply blstm to.
:param instruments: Iterable that provides a collection of instruments.
:param params: (Optional) dict of BLSTM parameters.
:returns: Created output tensor dict.
Parameters:
input_tensor (tensorflow.Tensor):
Tensor to apply blstm to.
instruments (Iterable[str]):
Iterable that provides a collection of instruments.
params (Optional[Dict]):
(Optional) dict of BLSTM parameters.
Returns:
Dict:
Created output tensor dict.
"""
logit_mask_list = []
for instrument in instruments:

View File

@@ -5,10 +5,12 @@
This package provides tools for downloading model from network
using remote storage abstraction.
:Example:
Examples:
```python
>>> provider = MyProviderImplementation()
>>> provider.get('/path/to/local/storage', params)
```
"""
from abc import ABC, abstractmethod
@@ -26,39 +28,52 @@ class ModelProvider(ABC):
file download is not available.
"""
DEFAULT_MODEL_PATH = environ.get('MODEL_PATH', 'pretrained_models')
MODEL_PROBE_PATH = '.probe'
DEFAULT_MODEL_PATH: str = environ.get('MODEL_PATH', 'pretrained_models')
MODEL_PROBE_PATH: str = '.probe'
@abstractmethod
def download(self, name, path):
""" Download model denoted by the given name to disk.
def download(_, name: str, path: str) -> None:
"""
Download model denoted by the given name to disk.
:param name: Name of the model to download.
:param path: Path of the directory to save model into.
Parameters:
name (str):
Name of the model to download.
path (str):
Path of the directory to save model into.
"""
pass
@staticmethod
def writeProbe(directory):
""" Write a model probe file into the given directory.
:param directory: Directory to write probe into.
def writeProbe(directory: str) -> None:
"""
probe = join(directory, ModelProvider.MODEL_PROBE_PATH)
Write a model probe file into the given directory.
Parameters:
directory (str):
Directory to write probe into.
"""
probe: str = join(directory, ModelProvider.MODEL_PROBE_PATH)
with open(probe, 'w') as stream:
stream.write('OK')
def get(self, model_directory):
""" Ensures required model is available at given location.
def get(self, model_directory: str) -> str:
"""
Ensures required model is available at given location.
:param model_directory: Expected model_directory to be available.
:raise IOError: If model can not be retrieved.
Parameters:
model_directory (str):
Expected model_directory to be available.
Raises:
IOError:
If model can not be retrieved.
"""
# Expend model directory if needed.
if not isabs(model_directory):
model_directory = join(self.DEFAULT_MODEL_PATH, model_directory)
# Download it if not exists.
model_probe = join(model_directory, self.MODEL_PROBE_PATH)
model_probe: str = join(model_directory, self.MODEL_PROBE_PATH)
if not exists(model_probe):
if not exists(model_directory):
makedirs(model_directory)
@@ -68,14 +83,14 @@ class ModelProvider(ABC):
self.writeProbe(model_directory)
return model_directory
@classmethod
def default(_: type) -> 'ModelProvider':
"""
Builds and returns a default model provider.
def get_default_model_provider():
""" Builds and returns a default model provider.
:returns: A default model provider instance to use.
"""
from .github import GithubModelProvider
host = environ.get('GITHUB_HOST', 'https://github.com')
repository = environ.get('GITHUB_REPOSITORY', 'deezer/spleeter')
release = environ.get('GITHUB_RELEASE', GithubModelProvider.LATEST_RELEASE)
return GithubModelProvider(host, repository, release)
Returns:
ModelProvider:
A default model provider instance to use.
"""
from .github import GithubModelProvider
return GithubModelProvider.from_environ()

View File

@@ -4,27 +4,34 @@
"""
A ModelProvider backed by Github Release feature.
:Example:
Examples:
```python
>>> from spleeter.model.provider import github
>>> provider = github.GithubModelProvider(
'github.com',
'Deezer/spleeter',
'latest')
>>> provider.download('2stems', '/path/to/local/storage')
```
"""
import hashlib
import tarfile
import os
from os import environ
from tempfile import NamedTemporaryFile
import requests
from typing import Dict
from . import ModelProvider
from ...utils.logging import get_logger
# pyright: reportMissingImports=false
# pylint: disable=import-error
import httpx
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
@@ -46,69 +53,108 @@ def compute_file_checksum(path):
class GithubModelProvider(ModelProvider):
""" A ModelProvider implementation backed on Github for remote storage. """
LATEST_RELEASE = 'v1.4.0'
RELEASE_PATH = 'releases/download'
CHECKSUM_INDEX = 'checksum.json'
DEFAULT_HOST: str = 'https://github.com'
DEFAULT_REPOSITORY: str = 'deezer/spleeter'
def __init__(self, host, repository, release):
CHECKSUM_INDEX: str = 'checksum.json'
LATEST_RELEASE: str = 'v1.4.0'
RELEASE_PATH: str = 'releases/download'
def __init__(
self,
host: str,
repository: str,
release: str) -> None:
""" Default constructor.
:param host: Host to the Github instance to reach.
:param repository: Repository path within target Github.
:param release: Release name to get models from.
Parameters:
host (str):
Host to the Github instance to reach.
repository (str):
Repository path within target Github.
release (str):
Release name to get models from.
"""
self._host = host
self._repository = repository
self._release = release
self._host: str = host
self._repository: str = repository
self._release: str = release
def checksum(self, name):
""" Downloads and returns reference checksum for the given model name.
:param name: Name of the model to get checksum for.
:returns: Checksum of the required model.
:raise ValueError: If the given model name is not indexed.
@classmethod
def from_environ(cls: type) -> 'GithubModelProvider':
"""
url = '{}/{}/{}/{}/{}'.format(
Factory method that creates provider from envvars.
Returns:
GithubModelProvider:
Created instance.
"""
return cls(
environ.get('GITHUB_HOST', cls.DEFAULT_HOST),
environ.get('GITHUB_REPOSITORY', cls.DEFAULT_REPOSITORY),
environ.get('GITHUB_RELEASE', cls.LATEST_RELEASE))
def checksum(self, name: str) -> str:
"""
Downloads and returns reference checksum for the given model name.
Parameters:
name (str):
Name of the model to get checksum for.
Returns:
str:
Checksum of the required model.
Raises:
ValueError:
If the given model name is not indexed.
"""
url: str = '/'.join((
self._host,
self._repository,
self.RELEASE_PATH,
self._release,
self.CHECKSUM_INDEX)
response = requests.get(url)
self.CHECKSUM_INDEX))
response: httpx.Response = httpx.get(url)
response.raise_for_status()
index = response.json()
index: Dict = response.json()
if name not in index:
raise ValueError('No checksum for model {}'.format(name))
raise ValueError(f'No checksum for model {name}')
return index[name]
def download(self, name, path):
""" Download model denoted by the given name to disk.
:param name: Name of the model to download.
:param path: Path of the directory to save model into.
def download(self, name: str, path: str) -> None:
"""
url = '{}/{}/{}/{}/{}.tar.gz'.format(
Download model denoted by the given name to disk.
Parameters:
name (str):
Name of the model to download.
path (str):
Path of the directory to save model into.
"""
url: str = '/'.join((
self._host,
self._repository,
self.RELEASE_PATH,
self._release,
name)
get_logger().info('Downloading model archive %s', url)
with requests.get(url, stream=True) as response:
response.raise_for_status()
archive = NamedTemporaryFile(delete=False)
try:
with archive as stream:
# Note: check for chunk size parameters ?
for chunk in response.iter_content(chunk_size=8192):
if chunk:
name))
url = f'{url}.tar.gz'
get_logger().info(f'Downloading model archive {url}')
with httpx.Client(http2=True) as client:
with client.strema('GET', url) as response:
response.raise_for_status()
archive = NamedTemporaryFile(delete=False)
try:
with archive as stream:
for chunk in response.iter_raw():
stream.write(chunk)
get_logger().info('Validating archive checksum')
if compute_file_checksum(archive.name) != self.checksum(name):
raise IOError('Downloaded file is corrupted, please retry')
get_logger().info('Extracting downloaded %s archive', name)
with tarfile.open(name=archive.name) as tar:
tar.extractall(path=path)
finally:
os.unlink(archive.name)
get_logger().info('%s model file(s) extracted', name)
get_logger().info('Validating archive checksum')
checksum: str = compute_file_checksum(archive.name)
if checksum != self.checksum(name):
raise IOError(
'Downloaded file is corrupted, please retry')
get_logger().info(f'Extracting downloaded {name} archive')
with tarfile.open(name=archive.name) as tar:
tar.extractall(path=path)
finally:
os.unlink(archive.name)
get_logger().info(f'{name} model file(s) extracted')

View File

@@ -4,62 +4,60 @@
"""
Module that provides a class wrapper for source separation.
:Example:
Examples:
```python
>>> from spleeter.separator import Separator
>>> separator = Separator('spleeter:2stems')
>>> separator.separate(waveform, lambda instrument, data: ...)
>>> separator.separate_to_file(...)
```
"""
import atexit
import os
import logging
from enum import Enum
from multiprocessing import Pool
from os.path import basename, join, splitext, dirname
from time import time
from typing import Container, NoReturn
from typing import Generator, Optional
from . import SpleeterError
from .audio import STFTBackend
from .audio.adapter import get_default_audio_adapter
from .audio.convertor import to_stereo
from .model import EstimatorSpecBuilder, InputProviderFactory
from .utils.configuration import load_configuration
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
from librosa.core import stft, istft
from scipy.signal.windows import hann
from . import SpleeterError
from .audio.adapter import get_default_audio_adapter
from .audio.convertor import to_stereo
from .utils.configuration import load_configuration
from .utils.estimator import create_estimator, get_default_model_dir
from .model import EstimatorSpecBuilder, InputProviderFactory
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
SUPPORTED_BACKEND: Container[str] = ('auto', 'tensorflow', 'librosa')
""" """
class DataGenerator():
class DataGenerator(object):
"""
Generator object that store a sample and generate it once while called.
Used to feed a tensorflow estimator without knowing the whole data at
build time.
"""
def __init__(self):
def __init__(self) -> None:
""" Default constructor. """
self._current_data = None
def update_data(self, data):
def update_data(self, data) -> None:
""" Replace internal data. """
self._current_data = data
def __call__(self):
def __call__(self) -> Generator:
""" Generation process. """
buffer = self._current_data
while buffer:
@@ -79,19 +77,50 @@ def get_backend(backend: str) -> str:
return backend
def create_estimator(params, MWF):
"""
Initialize tensorflow estimator that will perform separation
Params:
- params: a dictionary of parameters for building the model
Returns:
a tensorflow estimator
"""
# Load model.
provider: ModelProvider = ModelProvider.default()
params['model_dir'] = provider.get(params['model_dir'])
params['MWF'] = MWF
# Setup config
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
config = tf.estimator.RunConfig(session_config=session_config)
# Setup estimator
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params['model_dir'],
params=params,
config=config)
return estimator
class Separator(object):
""" A wrapper class for performing separation. """
def __init__(
self,
params_descriptor,
params_descriptor: str,
MWF: bool = False,
stft_backend: str = 'auto',
multiprocess: bool = True):
""" Default constructor.
stft_backend: STFTBackend = STFTBackend.AUTO,
multiprocess: bool = True) -> None:
"""
Default constructor.
:param params_descriptor: Descriptor for TF params to be used.
:param MWF: (Optional) True if MWF should be used, False otherwise.
Parameters:
params_descriptor (str):
Descriptor for TF params to be used.
MWF (bool):
(Optional) `True` if MWF should be used, `False` otherwise.
"""
self._params = load_configuration(params_descriptor)
self._sample_rate = self._params['sample_rate']
@@ -111,8 +140,7 @@ class Separator(object):
self._params['stft_backend'] = get_backend(stft_backend)
self._data_generator = DataGenerator()
def __del__(self):
""" """
def __del__(self) -> None:
if self._session:
self._session.close()
@@ -140,35 +168,19 @@ class Separator(object):
yield_single_examples=False)
return self._prediction_generator
def join(self, timeout: int = 200) -> NoReturn:
""" Wait for all pending tasks to be finished.
def join(self, timeout: int = 200) -> None:
"""
Wait for all pending tasks to be finished.
:param timeout: (Optional) task waiting timeout.
Parameters:
timeout (int):
(Optional) task waiting timeout.
"""
while len(self._tasks) > 0:
task = self._tasks.pop()
task.get()
task.wait(timeout=timeout)
def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
""" Performs source separation over the given waveform with tensorflow
backend.
:param waveform: Waveform to apply separation on.
:returns: Separated waveforms.
"""
if not waveform.shape[-1] == 2:
waveform = to_stereo(waveform)
prediction_generator = self._get_prediction_generator()
# NOTE: update data in generator before performing separation.
self._data_generator.update_data({
'waveform': waveform,
'audio_id': np.array(audio_descriptor)})
# NOTE: perform separation.
prediction = next(prediction_generator)
prediction.pop('audio_id')
return prediction
def _stft(self, data, inverse: bool = False, length=None):
""" Single entrypoint for both stft and istft. This computes stft and
istft with librosa on stereo data. The two channels are processed
@@ -233,7 +245,12 @@ class Separator(object):
return self._session
def _separate_librosa(self, waveform: np.ndarray, audio_id):
""" Performs separation with librosa backend for STFT.
"""
Performs separation with librosa backend for STFT.
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
"""
with self._tf_graph.as_default():
out = {}
@@ -260,12 +277,42 @@ class Separator(object):
length=waveform.shape[0])
return out
def separate(self, waveform: np.ndarray, audio_descriptor=''):
""" Performs separation on a waveform.
def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
"""
Performs source separation over the given waveform with tensorflow
backend.
:param waveform: Waveform to be separated (as a numpy array)
:param audio_descriptor: (Optional) string describing the waveform
(e.g. filename).
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
Returns:
Separated waveforms.
"""
if not waveform.shape[-1] == 2:
waveform = to_stereo(waveform)
prediction_generator = self._get_prediction_generator()
# NOTE: update data in generator before performing separation.
self._data_generator.update_data({
'waveform': waveform,
'audio_id': np.array(audio_descriptor)})
# NOTE: perform separation.
prediction = next(prediction_generator)
prediction.pop('audio_id')
return prediction
def separate(
self,
waveform: np.ndarray,
audio_descriptor: Optional[str] = None) -> None:
"""
Performs separation on a waveform.
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
audio_descriptor (str):
(Optional) string describing the waveform (e.g. filename).
"""
if self._params['stft_backend'] == 'tensorflow':
return self._separate_tensorflow(waveform, audio_descriptor)

View File

@@ -4,14 +4,10 @@
""" Module that provides configuration loading function. """
import json
try:
import importlib.resources as loader
except ImportError:
# Try backported to PY<37 `importlib_resources`.
import importlib_resources as loader
import importlib.resources as loader
from os.path import exists
from typing import Dict
from .. import resources, SpleeterError
@@ -20,18 +16,28 @@ __email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
_EMBEDDED_CONFIGURATION_PREFIX = 'spleeter:'
_EMBEDDED_CONFIGURATION_PREFIX: str = 'spleeter:'
def load_configuration(descriptor):
""" Load configuration from the given descriptor. Could be
either a `spleeter:` prefixed embedded configuration name
or a file system path to read configuration from.
def load_configuration(descriptor: str) -> Dict:
"""
Load configuration from the given descriptor. Could be either a
`spleeter:` prefixed embedded configuration name or a file system path
to read configuration from.
:param descriptor: Configuration descriptor to use for lookup.
:returns: Loaded description as dict.
:raise ValueError: If required embedded configuration does not exists.
:raise SpleeterError: If required configuration file does not exists.
Parameters:
descriptor (str):
Configuration descriptor to use for lookup.
Returns:
Dict:
Loaded description as dict.
Raises:
ValueError:
If required embedded configuration does not exists.
SpleeterError:
If required configuration file does not exists.
"""
# Embedded configuration reading.
if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX):

View File

@@ -1,46 +0,0 @@
#!/usr/bin/env python
# coding: utf8
""" Utility functions for creating estimator. """
import tensorflow as tf # pylint: disable=import-error
from ..model import model_fn
from ..model.provider import get_default_model_provider
def get_default_model_dir(model_dir):
"""
Transforms a string like 'spleeter:2stems' into an actual path.
:param model_dir:
:return:
"""
model_provider = get_default_model_provider()
return model_provider.get(model_dir)
def create_estimator(params, MWF):
"""
Initialize tensorflow estimator that will perform separation
Params:
- params: a dictionary of parameters for building the model
Returns:
a tensorflow estimator
"""
# Load model.
params['model_dir'] = get_default_model_dir(params['model_dir'])
params['MWF'] = MWF
# Setup config
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
config = tf.estimator.RunConfig(session_config=session_config)
# Setup estimator
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params['model_dir'],
params=params,
config=config
)
return estimator