🎨 finalizes model provider and functions

This commit is contained in:
Faylixe
2020-12-07 19:19:19 +01:00
parent f02bcbd9c7
commit ae9269525d
9 changed files with 398 additions and 265 deletions

View File

@@ -4,62 +4,60 @@
"""
Module that provides a class wrapper for source separation.
:Example:
Examples:
```python
>>> from spleeter.separator import Separator
>>> separator = Separator('spleeter:2stems')
>>> separator.separate(waveform, lambda instrument, data: ...)
>>> separator.separate_to_file(...)
```
"""
import atexit
import os
import logging
from enum import Enum
from multiprocessing import Pool
from os.path import basename, join, splitext, dirname
from time import time
from typing import Container, NoReturn
from typing import Generator, Optional
from . import SpleeterError
from .audio import STFTBackend
from .audio.adapter import get_default_audio_adapter
from .audio.convertor import to_stereo
from .model import EstimatorSpecBuilder, InputProviderFactory
from .utils.configuration import load_configuration
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
from librosa.core import stft, istft
from scipy.signal.windows import hann
from . import SpleeterError
from .audio.adapter import get_default_audio_adapter
from .audio.convertor import to_stereo
from .utils.configuration import load_configuration
from .utils.estimator import create_estimator, get_default_model_dir
from .model import EstimatorSpecBuilder, InputProviderFactory
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
SUPPORTED_BACKEND: Container[str] = ('auto', 'tensorflow', 'librosa')
""" """
class DataGenerator():
class DataGenerator(object):
"""
Generator object that store a sample and generate it once while called.
Used to feed a tensorflow estimator without knowing the whole data at
build time.
"""
def __init__(self):
def __init__(self) -> None:
""" Default constructor. """
self._current_data = None
def update_data(self, data):
def update_data(self, data) -> None:
""" Replace internal data. """
self._current_data = data
def __call__(self):
def __call__(self) -> Generator:
""" Generation process. """
buffer = self._current_data
while buffer:
@@ -79,19 +77,50 @@ def get_backend(backend: str) -> str:
return backend
def create_estimator(params, MWF):
"""
Initialize tensorflow estimator that will perform separation
Params:
- params: a dictionary of parameters for building the model
Returns:
a tensorflow estimator
"""
# Load model.
provider: ModelProvider = ModelProvider.default()
params['model_dir'] = provider.get(params['model_dir'])
params['MWF'] = MWF
# Setup config
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
config = tf.estimator.RunConfig(session_config=session_config)
# Setup estimator
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params['model_dir'],
params=params,
config=config)
return estimator
class Separator(object):
""" A wrapper class for performing separation. """
def __init__(
self,
params_descriptor,
params_descriptor: str,
MWF: bool = False,
stft_backend: str = 'auto',
multiprocess: bool = True):
""" Default constructor.
stft_backend: STFTBackend = STFTBackend.AUTO,
multiprocess: bool = True) -> None:
"""
Default constructor.
:param params_descriptor: Descriptor for TF params to be used.
:param MWF: (Optional) True if MWF should be used, False otherwise.
Parameters:
params_descriptor (str):
Descriptor for TF params to be used.
MWF (bool):
(Optional) `True` if MWF should be used, `False` otherwise.
"""
self._params = load_configuration(params_descriptor)
self._sample_rate = self._params['sample_rate']
@@ -111,8 +140,7 @@ class Separator(object):
self._params['stft_backend'] = get_backend(stft_backend)
self._data_generator = DataGenerator()
def __del__(self):
""" """
def __del__(self) -> None:
if self._session:
self._session.close()
@@ -140,35 +168,19 @@ class Separator(object):
yield_single_examples=False)
return self._prediction_generator
def join(self, timeout: int = 200) -> NoReturn:
""" Wait for all pending tasks to be finished.
def join(self, timeout: int = 200) -> None:
"""
Wait for all pending tasks to be finished.
:param timeout: (Optional) task waiting timeout.
Parameters:
timeout (int):
(Optional) task waiting timeout.
"""
while len(self._tasks) > 0:
task = self._tasks.pop()
task.get()
task.wait(timeout=timeout)
def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
""" Performs source separation over the given waveform with tensorflow
backend.
:param waveform: Waveform to apply separation on.
:returns: Separated waveforms.
"""
if not waveform.shape[-1] == 2:
waveform = to_stereo(waveform)
prediction_generator = self._get_prediction_generator()
# NOTE: update data in generator before performing separation.
self._data_generator.update_data({
'waveform': waveform,
'audio_id': np.array(audio_descriptor)})
# NOTE: perform separation.
prediction = next(prediction_generator)
prediction.pop('audio_id')
return prediction
def _stft(self, data, inverse: bool = False, length=None):
""" Single entrypoint for both stft and istft. This computes stft and
istft with librosa on stereo data. The two channels are processed
@@ -233,7 +245,12 @@ class Separator(object):
return self._session
def _separate_librosa(self, waveform: np.ndarray, audio_id):
""" Performs separation with librosa backend for STFT.
"""
Performs separation with librosa backend for STFT.
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
"""
with self._tf_graph.as_default():
out = {}
@@ -260,12 +277,42 @@ class Separator(object):
length=waveform.shape[0])
return out
def separate(self, waveform: np.ndarray, audio_descriptor=''):
""" Performs separation on a waveform.
def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
"""
Performs source separation over the given waveform with tensorflow
backend.
:param waveform: Waveform to be separated (as a numpy array)
:param audio_descriptor: (Optional) string describing the waveform
(e.g. filename).
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
Returns:
Separated waveforms.
"""
if not waveform.shape[-1] == 2:
waveform = to_stereo(waveform)
prediction_generator = self._get_prediction_generator()
# NOTE: update data in generator before performing separation.
self._data_generator.update_data({
'waveform': waveform,
'audio_id': np.array(audio_descriptor)})
# NOTE: perform separation.
prediction = next(prediction_generator)
prediction.pop('audio_id')
return prediction
def separate(
self,
waveform: np.ndarray,
audio_descriptor: Optional[str] = None) -> None:
"""
Performs separation on a waveform.
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
audio_descriptor (str):
(Optional) string describing the waveform (e.g. filename).
"""
if self._params['stft_backend'] == 'tensorflow':
return self._separate_tensorflow(waveform, audio_descriptor)