diff --git a/spleeter/dataset.py b/spleeter/dataset.py index 5b11969..84a55d0 100644 --- a/spleeter/dataset.py +++ b/spleeter/dataset.py @@ -16,28 +16,22 @@ import time import os -from os.path import exists, join, sep as SEPARATOR +from os.path import exists, sep as SEPARATOR + +from .audio.convertor import db_uint_spectrogram_to_gain +from .audio.convertor import spectrogram_to_db_uint +from .audio.spectrogram import compute_spectrogram_tf +from .audio.spectrogram import random_pitch_shift, random_time_stretch +from .utils.logging import get_logger +from .utils.tensor import check_tensor_shape, dataset_from_csv +from .utils.tensor import set_tensor_shape, sync_apply + +# pyright: reportMissingImports=false # pylint: disable=import-error -import pandas as pd -import numpy as np import tensorflow as tf # pylint: enable=import-error -from .audio.convertor import ( - db_uint_spectrogram_to_gain, - spectrogram_to_db_uint) -from .audio.spectrogram import ( - compute_spectrogram_tf, - random_pitch_shift, - random_time_stretch) -from .utils.logging import get_logger -from .utils.tensor import ( - check_tensor_shape, - dataset_from_csv, - set_tensor_shape, - sync_apply) - __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' diff --git a/spleeter/model/functions/__init__.py b/spleeter/model/functions/__init__.py index 684f923..5f8c102 100644 --- a/spleeter/model/functions/__init__.py +++ b/spleeter/model/functions/__init__.py @@ -3,25 +3,44 @@ """ This package provide model functions. """ +from typing import Callable, Dict, Iterable, Optional + +# pyright: reportMissingImports=false +# pylint: disable=import-error +import tensorflow as tf +# pylint: enable=import-error + __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' -def apply(function, input_tensor, instruments, params={}): - """ Apply given function to the input tensor. - - :param function: Function to be applied to tensor. - :param input_tensor: Tensor to apply blstm to. - :param instruments: Iterable that provides a collection of instruments. - :param params: (Optional) dict of BLSTM parameters. - :returns: Created output tensor dict. +def apply( + function: Callable, + input_tensor: tf.Tensor, + instruments: Iterable[str], + params: Optional[Dict] = None) -> Dict: """ - output_dict = {} + Apply given function to the input tensor. + + Parameters: + function: + Function to be applied to tensor. + input_tensor (tensorflow.Tensor): + Tensor to apply blstm to. + instruments (Iterable[str]): + Iterable that provides a collection of instruments. + params: + (Optional) dict of BLSTM parameters. + + Returns: + Created output tensor dict. + """ + output_dict: Dict = {} for instrument in instruments: out_name = f'{instrument}_spectrogram' output_dict[out_name] = function( input_tensor, output_name=out_name, - params=params) + params=params or {}) return output_dict diff --git a/spleeter/model/functions/blstm.py b/spleeter/model/functions/blstm.py index b81122b..2bb0cc0 100644 --- a/spleeter/model/functions/blstm.py +++ b/spleeter/model/functions/blstm.py @@ -20,7 +20,14 @@ selection (LSTM layer dropout rate, regularization strength). """ +from typing import Dict, Optional + +from . import apply + +# pyright: reportMissingImports=false # pylint: disable=import-error +import tensorflow as tf + from tensorflow.compat.v1.keras.initializers import he_uniform from tensorflow.compat.v1.keras.layers import CuDNNLSTM from tensorflow.keras.layers import ( @@ -31,22 +38,33 @@ from tensorflow.keras.layers import ( TimeDistributed) # pylint: enable=import-error -from . import apply - __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' -def apply_blstm(input_tensor, output_name='output', params={}): - """ Apply BLSTM to the given input_tensor. - - :param input_tensor: Input of the model. - :param output_name: (Optional) name of the output, default to 'output'. - :param params: (Optional) dict of BLSTM parameters. - :returns: Output tensor. +def apply_blstm( + input_tensor: tf.Tensor, + output_name: str = 'output', + params: Optional[Dict] = None) -> tf.Tensor: """ - units = params.get('lstm_units', 250) + Apply BLSTM to the given input_tensor. + + Parameters: + input_tensor (tensorflow.Tensor): + Input of the model. + output_name (str): + (Optional) name of the output, default to 'output'. + params (Optional[Dict]): + (Optional) dict of BLSTM parameters. + + Returns: + tensorflow.Tensor: + Output tensor. + """ + if params is None: + params = {} + units: int = params.get('lstm_units', 250) kernel_initializer = he_uniform(seed=50) flatten_input = TimeDistributed(Flatten())((input_tensor)) @@ -65,12 +83,15 @@ def apply_blstm(input_tensor, output_name='output', params={}): int(flatten_input.shape[2]), activation='relu', kernel_initializer=kernel_initializer))((l3)) - output = TimeDistributed( + output: tf.Tensor = TimeDistributed( Reshape(input_tensor.shape[2:]), name=output_name)(dense) return output -def blstm(input_tensor, output_name='output', params={}): +def blstm( + input_tensor: tf.Tensor, + output_name: str = 'output', + params: Optional[Dict] = None) -> tf.Tensor: """ Model function applier. """ return apply(apply_blstm, input_tensor, output_name, params) diff --git a/spleeter/model/functions/unet.py b/spleeter/model/functions/unet.py index 7f9dbea..ccb7225 100644 --- a/spleeter/model/functions/unet.py +++ b/spleeter/model/functions/unet.py @@ -2,16 +2,23 @@ # coding: utf8 """ -This module contains building functions for U-net source -separation models in a similar way as in A. Jansson et al. "Singing -voice separation with deep u-net convolutional networks", ISMIR 2017. -Each instrument is modeled by a single U-net convolutional -/ deconvolutional network that take a mix spectrogram as input and the -estimated sound spectrogram as output. + This module contains building functions for U-net source + separation models in a similar way as in A. Jansson et al. : + + "Singing voice separation with deep u-net convolutional networks", + ISMIR 2017 + + Each instrument is modeled by a single U-net + convolutional / deconvolutional network that take a mix spectrogram + as input and the estimated sound spectrogram as output. """ from functools import partial +from typing import Any, Dict, Iterable, Optional +from . import apply + +# pyright: reportMissingImports=false # pylint: disable=import-error import tensorflow as tf @@ -30,20 +37,23 @@ from tensorflow.compat.v1 import logging from tensorflow.compat.v1.keras.initializers import he_uniform # pylint: enable=import-error -from . import apply - __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' -def _get_conv_activation_layer(params): +def _get_conv_activation_layer(params: Dict) -> Any: """ + > To be documented. - :param params: - :returns: Required Activation function. + Parameters: + params (Dict): + + Returns: + Any: + Required Activation function. """ - conv_activation = params.get('conv_activation') + conv_activation: str = params.get('conv_activation') if conv_activation == 'ReLU': return ReLU() elif conv_activation == 'ELU': @@ -51,13 +61,18 @@ def _get_conv_activation_layer(params): return LeakyReLU(0.2) -def _get_deconv_activation_layer(params): +def _get_deconv_activation_layer(params: Dict) -> Any: """ + > To be documented. - :param params: - :returns: Required Activation function. + Parameters: + params (Dict): + + Returns: + Any: + Required Activation function. """ - deconv_activation = params.get('deconv_activation') + deconv_activation: str = params.get('deconv_activation') if deconv_activation == 'LeakyReLU': return LeakyReLU(0.2) elif deconv_activation == 'ELU': @@ -66,17 +81,19 @@ def _get_deconv_activation_layer(params): def apply_unet( - input_tensor, - output_name='output', - params={}, - output_mask_logit=False): - """ Apply a convolutionnal U-net to model a single instrument (one U-net - is used for each instrument). + input_tensor: tf.Tensor, + output_name: str = 'output', + params: Optional[Dict] = None, + output_mask_logit: bool = False) -> Any: + """ + Apply a convolutionnal U-net to model a single instrument (one U-net + is used for each instrument). - :param input_tensor: - :param output_name: (Optional) , default to 'output' - :param params: (Optional) , default to empty dict. - :param output_mask_logit: (Optional) , default to False. + Parameters: + input_tensor (tensorflow.Tensor): + output_name (str): + params (Optional[Dict]): + output_mask_logit (bool): """ logging.info(f'Apply unet for {output_name}') conv_n_filters = params.get('conv_n_filters', [16, 32, 64, 128, 256, 512]) @@ -170,18 +187,32 @@ def apply_unet( kernel_initializer=kernel_initializer)((batch12)) -def unet(input_tensor, instruments, params={}): +def unet( + input_tensor: tf.Tensor, + instruments: Iterable[str], + params: Optional[Dict] = None) -> Dict: """ Model function applier. """ return apply(apply_unet, input_tensor, instruments, params) -def softmax_unet(input_tensor, instruments, params={}): - """ Apply softmax to multitrack unet in order to have mask suming to one. +def softmax_unet( + input_tensor: tf.Tensor, + instruments: Iterable[str], + params: Optional[Dict] = None) -> Dict: + """ + Apply softmax to multitrack unet in order to have mask suming to one. - :param input_tensor: Tensor to apply blstm to. - :param instruments: Iterable that provides a collection of instruments. - :param params: (Optional) dict of BLSTM parameters. - :returns: Created output tensor dict. + Parameters: + input_tensor (tensorflow.Tensor): + Tensor to apply blstm to. + instruments (Iterable[str]): + Iterable that provides a collection of instruments. + params (Optional[Dict]): + (Optional) dict of BLSTM parameters. + + Returns: + Dict: + Created output tensor dict. """ logit_mask_list = [] for instrument in instruments: diff --git a/spleeter/model/provider/__init__.py b/spleeter/model/provider/__init__.py index 3921907..7a5430e 100644 --- a/spleeter/model/provider/__init__.py +++ b/spleeter/model/provider/__init__.py @@ -5,10 +5,12 @@ This package provides tools for downloading model from network using remote storage abstraction. - :Example: + Examples: + ```python >>> provider = MyProviderImplementation() >>> provider.get('/path/to/local/storage', params) + ``` """ from abc import ABC, abstractmethod @@ -26,39 +28,52 @@ class ModelProvider(ABC): file download is not available. """ - DEFAULT_MODEL_PATH = environ.get('MODEL_PATH', 'pretrained_models') - MODEL_PROBE_PATH = '.probe' + DEFAULT_MODEL_PATH: str = environ.get('MODEL_PATH', 'pretrained_models') + MODEL_PROBE_PATH: str = '.probe' @abstractmethod - def download(self, name, path): - """ Download model denoted by the given name to disk. + def download(_, name: str, path: str) -> None: + """ + Download model denoted by the given name to disk. - :param name: Name of the model to download. - :param path: Path of the directory to save model into. + Parameters: + name (str): + Name of the model to download. + path (str): + Path of the directory to save model into. """ pass @staticmethod - def writeProbe(directory): - """ Write a model probe file into the given directory. - - :param directory: Directory to write probe into. + def writeProbe(directory: str) -> None: """ - probe = join(directory, ModelProvider.MODEL_PROBE_PATH) + Write a model probe file into the given directory. + + Parameters: + directory (str): + Directory to write probe into. + """ + probe: str = join(directory, ModelProvider.MODEL_PROBE_PATH) with open(probe, 'w') as stream: stream.write('OK') - def get(self, model_directory): - """ Ensures required model is available at given location. + def get(self, model_directory: str) -> str: + """ + Ensures required model is available at given location. - :param model_directory: Expected model_directory to be available. - :raise IOError: If model can not be retrieved. + Parameters: + model_directory (str): + Expected model_directory to be available. + + Raises: + IOError: + If model can not be retrieved. """ # Expend model directory if needed. if not isabs(model_directory): model_directory = join(self.DEFAULT_MODEL_PATH, model_directory) # Download it if not exists. - model_probe = join(model_directory, self.MODEL_PROBE_PATH) + model_probe: str = join(model_directory, self.MODEL_PROBE_PATH) if not exists(model_probe): if not exists(model_directory): makedirs(model_directory) @@ -68,14 +83,14 @@ class ModelProvider(ABC): self.writeProbe(model_directory) return model_directory + @classmethod + def default(_: type) -> 'ModelProvider': + """ + Builds and returns a default model provider. -def get_default_model_provider(): - """ Builds and returns a default model provider. - - :returns: A default model provider instance to use. - """ - from .github import GithubModelProvider - host = environ.get('GITHUB_HOST', 'https://github.com') - repository = environ.get('GITHUB_REPOSITORY', 'deezer/spleeter') - release = environ.get('GITHUB_RELEASE', GithubModelProvider.LATEST_RELEASE) - return GithubModelProvider(host, repository, release) + Returns: + ModelProvider: + A default model provider instance to use. + """ + from .github import GithubModelProvider + return GithubModelProvider.from_environ() diff --git a/spleeter/model/provider/github.py b/spleeter/model/provider/github.py index 65a10b4..f423d9c 100644 --- a/spleeter/model/provider/github.py +++ b/spleeter/model/provider/github.py @@ -4,27 +4,34 @@ """ A ModelProvider backed by Github Release feature. - :Example: + Examples: + ```python >>> from spleeter.model.provider import github >>> provider = github.GithubModelProvider( 'github.com', 'Deezer/spleeter', 'latest') >>> provider.download('2stems', '/path/to/local/storage') + ``` """ import hashlib import tarfile import os +from os import environ from tempfile import NamedTemporaryFile - -import requests +from typing import Dict from . import ModelProvider from ...utils.logging import get_logger +# pyright: reportMissingImports=false +# pylint: disable=import-error +import httpx +# pylint: enable=import-error + __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' @@ -46,69 +53,108 @@ def compute_file_checksum(path): class GithubModelProvider(ModelProvider): """ A ModelProvider implementation backed on Github for remote storage. """ - LATEST_RELEASE = 'v1.4.0' - RELEASE_PATH = 'releases/download' - CHECKSUM_INDEX = 'checksum.json' + DEFAULT_HOST: str = 'https://github.com' + DEFAULT_REPOSITORY: str = 'deezer/spleeter' - def __init__(self, host, repository, release): + CHECKSUM_INDEX: str = 'checksum.json' + LATEST_RELEASE: str = 'v1.4.0' + RELEASE_PATH: str = 'releases/download' + + def __init__( + self, + host: str, + repository: str, + release: str) -> None: """ Default constructor. - :param host: Host to the Github instance to reach. - :param repository: Repository path within target Github. - :param release: Release name to get models from. + Parameters: + host (str): + Host to the Github instance to reach. + repository (str): + Repository path within target Github. + release (str): + Release name to get models from. """ - self._host = host - self._repository = repository - self._release = release + self._host: str = host + self._repository: str = repository + self._release: str = release - def checksum(self, name): - """ Downloads and returns reference checksum for the given model name. - - :param name: Name of the model to get checksum for. - :returns: Checksum of the required model. - :raise ValueError: If the given model name is not indexed. + @classmethod + def from_environ(cls: type) -> 'GithubModelProvider': """ - url = '{}/{}/{}/{}/{}'.format( + Factory method that creates provider from envvars. + + Returns: + GithubModelProvider: + Created instance. + """ + return cls( + environ.get('GITHUB_HOST', cls.DEFAULT_HOST), + environ.get('GITHUB_REPOSITORY', cls.DEFAULT_REPOSITORY), + environ.get('GITHUB_RELEASE', cls.LATEST_RELEASE)) + + def checksum(self, name: str) -> str: + """ + Downloads and returns reference checksum for the given model name. + + Parameters: + name (str): + Name of the model to get checksum for. + Returns: + str: + Checksum of the required model. + + Raises: + ValueError: + If the given model name is not indexed. + """ + url: str = '/'.join(( self._host, self._repository, self.RELEASE_PATH, self._release, - self.CHECKSUM_INDEX) - response = requests.get(url) + self.CHECKSUM_INDEX)) + response: httpx.Response = httpx.get(url) response.raise_for_status() - index = response.json() + index: Dict = response.json() if name not in index: - raise ValueError('No checksum for model {}'.format(name)) + raise ValueError(f'No checksum for model {name}') return index[name] - def download(self, name, path): - """ Download model denoted by the given name to disk. - - :param name: Name of the model to download. - :param path: Path of the directory to save model into. + def download(self, name: str, path: str) -> None: """ - url = '{}/{}/{}/{}/{}.tar.gz'.format( + Download model denoted by the given name to disk. + + Parameters: + name (str): + Name of the model to download. + path (str): + Path of the directory to save model into. + """ + url: str = '/'.join(( self._host, self._repository, self.RELEASE_PATH, self._release, - name) - get_logger().info('Downloading model archive %s', url) - with requests.get(url, stream=True) as response: - response.raise_for_status() - archive = NamedTemporaryFile(delete=False) - try: - with archive as stream: - # Note: check for chunk size parameters ? - for chunk in response.iter_content(chunk_size=8192): - if chunk: + name)) + url = f'{url}.tar.gz' + get_logger().info(f'Downloading model archive {url}') + with httpx.Client(http2=True) as client: + with client.strema('GET', url) as response: + response.raise_for_status() + archive = NamedTemporaryFile(delete=False) + try: + with archive as stream: + for chunk in response.iter_raw(): stream.write(chunk) - get_logger().info('Validating archive checksum') - if compute_file_checksum(archive.name) != self.checksum(name): - raise IOError('Downloaded file is corrupted, please retry') - get_logger().info('Extracting downloaded %s archive', name) - with tarfile.open(name=archive.name) as tar: - tar.extractall(path=path) - finally: - os.unlink(archive.name) - get_logger().info('%s model file(s) extracted', name) + get_logger().info('Validating archive checksum') + checksum: str = compute_file_checksum(archive.name) + if checksum != self.checksum(name): + raise IOError( + 'Downloaded file is corrupted, please retry') + get_logger().info(f'Extracting downloaded {name} archive') + with tarfile.open(name=archive.name) as tar: + tar.extractall(path=path) + finally: + os.unlink(archive.name) + get_logger().info(f'{name} model file(s) extracted') diff --git a/spleeter/separator.py b/spleeter/separator.py index cc87b68..ad7dc3f 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -4,62 +4,60 @@ """ Module that provides a class wrapper for source separation. - :Example: + Examples: + ```python >>> from spleeter.separator import Separator >>> separator = Separator('spleeter:2stems') >>> separator.separate(waveform, lambda instrument, data: ...) >>> separator.separate_to_file(...) + ``` """ import atexit import os -import logging - -from enum import Enum from multiprocessing import Pool from os.path import basename, join, splitext, dirname -from time import time -from typing import Container, NoReturn +from typing import Generator, Optional +from . import SpleeterError +from .audio import STFTBackend +from .audio.adapter import get_default_audio_adapter +from .audio.convertor import to_stereo +from .model import EstimatorSpecBuilder, InputProviderFactory +from .utils.configuration import load_configuration + +# pyright: reportMissingImports=false +# pylint: disable=import-error import numpy as np import tensorflow as tf from librosa.core import stft, istft from scipy.signal.windows import hann - -from . import SpleeterError -from .audio.adapter import get_default_audio_adapter -from .audio.convertor import to_stereo -from .utils.configuration import load_configuration -from .utils.estimator import create_estimator, get_default_model_dir -from .model import EstimatorSpecBuilder, InputProviderFactory +# pylint: enable=import-error __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' -SUPPORTED_BACKEND: Container[str] = ('auto', 'tensorflow', 'librosa') -""" """ - -class DataGenerator(): +class DataGenerator(object): """ Generator object that store a sample and generate it once while called. Used to feed a tensorflow estimator without knowing the whole data at build time. """ - def __init__(self): + def __init__(self) -> None: """ Default constructor. """ self._current_data = None - def update_data(self, data): + def update_data(self, data) -> None: """ Replace internal data. """ self._current_data = data - def __call__(self): + def __call__(self) -> Generator: """ Generation process. """ buffer = self._current_data while buffer: @@ -79,19 +77,50 @@ def get_backend(backend: str) -> str: return backend +def create_estimator(params, MWF): + """ + Initialize tensorflow estimator that will perform separation + + Params: + - params: a dictionary of parameters for building the model + + Returns: + a tensorflow estimator + """ + # Load model. + provider: ModelProvider = ModelProvider.default() + params['model_dir'] = provider.get(params['model_dir']) + params['MWF'] = MWF + # Setup config + session_config = tf.compat.v1.ConfigProto() + session_config.gpu_options.per_process_gpu_memory_fraction = 0.7 + config = tf.estimator.RunConfig(session_config=session_config) + # Setup estimator + estimator = tf.estimator.Estimator( + model_fn=model_fn, + model_dir=params['model_dir'], + params=params, + config=config) + return estimator + + class Separator(object): """ A wrapper class for performing separation. """ def __init__( self, - params_descriptor, + params_descriptor: str, MWF: bool = False, - stft_backend: str = 'auto', - multiprocess: bool = True): - """ Default constructor. + stft_backend: STFTBackend = STFTBackend.AUTO, + multiprocess: bool = True) -> None: + """ + Default constructor. - :param params_descriptor: Descriptor for TF params to be used. - :param MWF: (Optional) True if MWF should be used, False otherwise. + Parameters: + params_descriptor (str): + Descriptor for TF params to be used. + MWF (bool): + (Optional) `True` if MWF should be used, `False` otherwise. """ self._params = load_configuration(params_descriptor) self._sample_rate = self._params['sample_rate'] @@ -111,8 +140,7 @@ class Separator(object): self._params['stft_backend'] = get_backend(stft_backend) self._data_generator = DataGenerator() - def __del__(self): - """ """ + def __del__(self) -> None: if self._session: self._session.close() @@ -140,35 +168,19 @@ class Separator(object): yield_single_examples=False) return self._prediction_generator - def join(self, timeout: int = 200) -> NoReturn: - """ Wait for all pending tasks to be finished. + def join(self, timeout: int = 200) -> None: + """ + Wait for all pending tasks to be finished. - :param timeout: (Optional) task waiting timeout. + Parameters: + timeout (int): + (Optional) task waiting timeout. """ while len(self._tasks) > 0: task = self._tasks.pop() task.get() task.wait(timeout=timeout) - def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor): - """ Performs source separation over the given waveform with tensorflow - backend. - - :param waveform: Waveform to apply separation on. - :returns: Separated waveforms. - """ - if not waveform.shape[-1] == 2: - waveform = to_stereo(waveform) - prediction_generator = self._get_prediction_generator() - # NOTE: update data in generator before performing separation. - self._data_generator.update_data({ - 'waveform': waveform, - 'audio_id': np.array(audio_descriptor)}) - # NOTE: perform separation. - prediction = next(prediction_generator) - prediction.pop('audio_id') - return prediction - def _stft(self, data, inverse: bool = False, length=None): """ Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two channels are processed @@ -233,7 +245,12 @@ class Separator(object): return self._session def _separate_librosa(self, waveform: np.ndarray, audio_id): - """ Performs separation with librosa backend for STFT. + """ + Performs separation with librosa backend for STFT. + + Parameters: + waveform (numpy.ndarray): + Waveform to be separated (as a numpy array) """ with self._tf_graph.as_default(): out = {} @@ -260,12 +277,42 @@ class Separator(object): length=waveform.shape[0]) return out - def separate(self, waveform: np.ndarray, audio_descriptor=''): - """ Performs separation on a waveform. + def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor): + """ + Performs source separation over the given waveform with tensorflow + backend. - :param waveform: Waveform to be separated (as a numpy array) - :param audio_descriptor: (Optional) string describing the waveform - (e.g. filename). + Parameters: + waveform (numpy.ndarray): + Waveform to be separated (as a numpy array) + + Returns: + Separated waveforms. + """ + if not waveform.shape[-1] == 2: + waveform = to_stereo(waveform) + prediction_generator = self._get_prediction_generator() + # NOTE: update data in generator before performing separation. + self._data_generator.update_data({ + 'waveform': waveform, + 'audio_id': np.array(audio_descriptor)}) + # NOTE: perform separation. + prediction = next(prediction_generator) + prediction.pop('audio_id') + return prediction + + def separate( + self, + waveform: np.ndarray, + audio_descriptor: Optional[str] = None) -> None: + """ + Performs separation on a waveform. + + Parameters: + waveform (numpy.ndarray): + Waveform to be separated (as a numpy array) + audio_descriptor (str): + (Optional) string describing the waveform (e.g. filename). """ if self._params['stft_backend'] == 'tensorflow': return self._separate_tensorflow(waveform, audio_descriptor) diff --git a/spleeter/utils/configuration.py b/spleeter/utils/configuration.py index 36f1043..90250a6 100644 --- a/spleeter/utils/configuration.py +++ b/spleeter/utils/configuration.py @@ -4,14 +4,10 @@ """ Module that provides configuration loading function. """ import json - -try: - import importlib.resources as loader -except ImportError: - # Try backported to PY<37 `importlib_resources`. - import importlib_resources as loader +import importlib.resources as loader from os.path import exists +from typing import Dict from .. import resources, SpleeterError @@ -20,18 +16,28 @@ __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' -_EMBEDDED_CONFIGURATION_PREFIX = 'spleeter:' +_EMBEDDED_CONFIGURATION_PREFIX: str = 'spleeter:' -def load_configuration(descriptor): - """ Load configuration from the given descriptor. Could be - either a `spleeter:` prefixed embedded configuration name - or a file system path to read configuration from. +def load_configuration(descriptor: str) -> Dict: + """ + Load configuration from the given descriptor. Could be either a + `spleeter:` prefixed embedded configuration name or a file system path + to read configuration from. - :param descriptor: Configuration descriptor to use for lookup. - :returns: Loaded description as dict. - :raise ValueError: If required embedded configuration does not exists. - :raise SpleeterError: If required configuration file does not exists. + Parameters: + descriptor (str): + Configuration descriptor to use for lookup. + + Returns: + Dict: + Loaded description as dict. + + Raises: + ValueError: + If required embedded configuration does not exists. + SpleeterError: + If required configuration file does not exists. """ # Embedded configuration reading. if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX): diff --git a/spleeter/utils/estimator.py b/spleeter/utils/estimator.py deleted file mode 100644 index aefc355..0000000 --- a/spleeter/utils/estimator.py +++ /dev/null @@ -1,46 +0,0 @@ -#!/usr/bin/env python -# coding: utf8 - -""" Utility functions for creating estimator. """ - -import tensorflow as tf # pylint: disable=import-error - -from ..model import model_fn -from ..model.provider import get_default_model_provider - - -def get_default_model_dir(model_dir): - """ - Transforms a string like 'spleeter:2stems' into an actual path. - :param model_dir: - :return: - """ - model_provider = get_default_model_provider() - return model_provider.get(model_dir) - - -def create_estimator(params, MWF): - """ - Initialize tensorflow estimator that will perform separation - - Params: - - params: a dictionary of parameters for building the model - - Returns: - a tensorflow estimator - """ - # Load model. - params['model_dir'] = get_default_model_dir(params['model_dir']) - params['MWF'] = MWF - # Setup config - session_config = tf.compat.v1.ConfigProto() - session_config.gpu_options.per_process_gpu_memory_fraction = 0.7 - config = tf.estimator.RunConfig(session_config=session_config) - # Setup estimator - estimator = tf.estimator.Estimator( - model_fn=model_fn, - model_dir=params['model_dir'], - params=params, - config=config - ) - return estimator