mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
🎨 finalizes model provider and functions
This commit is contained in:
@@ -4,62 +4,60 @@
|
||||
"""
|
||||
Module that provides a class wrapper for source separation.
|
||||
|
||||
:Example:
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from spleeter.separator import Separator
|
||||
>>> separator = Separator('spleeter:2stems')
|
||||
>>> separator.separate(waveform, lambda instrument, data: ...)
|
||||
>>> separator.separate_to_file(...)
|
||||
```
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import os
|
||||
import logging
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from multiprocessing import Pool
|
||||
from os.path import basename, join, splitext, dirname
|
||||
from time import time
|
||||
from typing import Container, NoReturn
|
||||
from typing import Generator, Optional
|
||||
|
||||
from . import SpleeterError
|
||||
from .audio import STFTBackend
|
||||
from .audio.adapter import get_default_audio_adapter
|
||||
from .audio.convertor import to_stereo
|
||||
from .model import EstimatorSpecBuilder, InputProviderFactory
|
||||
from .utils.configuration import load_configuration
|
||||
|
||||
# pyright: reportMissingImports=false
|
||||
# pylint: disable=import-error
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from librosa.core import stft, istft
|
||||
from scipy.signal.windows import hann
|
||||
|
||||
from . import SpleeterError
|
||||
from .audio.adapter import get_default_audio_adapter
|
||||
from .audio.convertor import to_stereo
|
||||
from .utils.configuration import load_configuration
|
||||
from .utils.estimator import create_estimator, get_default_model_dir
|
||||
from .model import EstimatorSpecBuilder, InputProviderFactory
|
||||
# pylint: enable=import-error
|
||||
|
||||
__email__ = 'spleeter@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
SUPPORTED_BACKEND: Container[str] = ('auto', 'tensorflow', 'librosa')
|
||||
""" """
|
||||
|
||||
|
||||
class DataGenerator():
|
||||
class DataGenerator(object):
|
||||
"""
|
||||
Generator object that store a sample and generate it once while called.
|
||||
Used to feed a tensorflow estimator without knowing the whole data at
|
||||
build time.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
""" Default constructor. """
|
||||
self._current_data = None
|
||||
|
||||
def update_data(self, data):
|
||||
def update_data(self, data) -> None:
|
||||
""" Replace internal data. """
|
||||
self._current_data = data
|
||||
|
||||
def __call__(self):
|
||||
def __call__(self) -> Generator:
|
||||
""" Generation process. """
|
||||
buffer = self._current_data
|
||||
while buffer:
|
||||
@@ -79,19 +77,50 @@ def get_backend(backend: str) -> str:
|
||||
return backend
|
||||
|
||||
|
||||
def create_estimator(params, MWF):
|
||||
"""
|
||||
Initialize tensorflow estimator that will perform separation
|
||||
|
||||
Params:
|
||||
- params: a dictionary of parameters for building the model
|
||||
|
||||
Returns:
|
||||
a tensorflow estimator
|
||||
"""
|
||||
# Load model.
|
||||
provider: ModelProvider = ModelProvider.default()
|
||||
params['model_dir'] = provider.get(params['model_dir'])
|
||||
params['MWF'] = MWF
|
||||
# Setup config
|
||||
session_config = tf.compat.v1.ConfigProto()
|
||||
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
|
||||
config = tf.estimator.RunConfig(session_config=session_config)
|
||||
# Setup estimator
|
||||
estimator = tf.estimator.Estimator(
|
||||
model_fn=model_fn,
|
||||
model_dir=params['model_dir'],
|
||||
params=params,
|
||||
config=config)
|
||||
return estimator
|
||||
|
||||
|
||||
class Separator(object):
|
||||
""" A wrapper class for performing separation. """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params_descriptor,
|
||||
params_descriptor: str,
|
||||
MWF: bool = False,
|
||||
stft_backend: str = 'auto',
|
||||
multiprocess: bool = True):
|
||||
""" Default constructor.
|
||||
stft_backend: STFTBackend = STFTBackend.AUTO,
|
||||
multiprocess: bool = True) -> None:
|
||||
"""
|
||||
Default constructor.
|
||||
|
||||
:param params_descriptor: Descriptor for TF params to be used.
|
||||
:param MWF: (Optional) True if MWF should be used, False otherwise.
|
||||
Parameters:
|
||||
params_descriptor (str):
|
||||
Descriptor for TF params to be used.
|
||||
MWF (bool):
|
||||
(Optional) `True` if MWF should be used, `False` otherwise.
|
||||
"""
|
||||
self._params = load_configuration(params_descriptor)
|
||||
self._sample_rate = self._params['sample_rate']
|
||||
@@ -111,8 +140,7 @@ class Separator(object):
|
||||
self._params['stft_backend'] = get_backend(stft_backend)
|
||||
self._data_generator = DataGenerator()
|
||||
|
||||
def __del__(self):
|
||||
""" """
|
||||
def __del__(self) -> None:
|
||||
if self._session:
|
||||
self._session.close()
|
||||
|
||||
@@ -140,35 +168,19 @@ class Separator(object):
|
||||
yield_single_examples=False)
|
||||
return self._prediction_generator
|
||||
|
||||
def join(self, timeout: int = 200) -> NoReturn:
|
||||
""" Wait for all pending tasks to be finished.
|
||||
def join(self, timeout: int = 200) -> None:
|
||||
"""
|
||||
Wait for all pending tasks to be finished.
|
||||
|
||||
:param timeout: (Optional) task waiting timeout.
|
||||
Parameters:
|
||||
timeout (int):
|
||||
(Optional) task waiting timeout.
|
||||
"""
|
||||
while len(self._tasks) > 0:
|
||||
task = self._tasks.pop()
|
||||
task.get()
|
||||
task.wait(timeout=timeout)
|
||||
|
||||
def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
|
||||
""" Performs source separation over the given waveform with tensorflow
|
||||
backend.
|
||||
|
||||
:param waveform: Waveform to apply separation on.
|
||||
:returns: Separated waveforms.
|
||||
"""
|
||||
if not waveform.shape[-1] == 2:
|
||||
waveform = to_stereo(waveform)
|
||||
prediction_generator = self._get_prediction_generator()
|
||||
# NOTE: update data in generator before performing separation.
|
||||
self._data_generator.update_data({
|
||||
'waveform': waveform,
|
||||
'audio_id': np.array(audio_descriptor)})
|
||||
# NOTE: perform separation.
|
||||
prediction = next(prediction_generator)
|
||||
prediction.pop('audio_id')
|
||||
return prediction
|
||||
|
||||
def _stft(self, data, inverse: bool = False, length=None):
|
||||
""" Single entrypoint for both stft and istft. This computes stft and
|
||||
istft with librosa on stereo data. The two channels are processed
|
||||
@@ -233,7 +245,12 @@ class Separator(object):
|
||||
return self._session
|
||||
|
||||
def _separate_librosa(self, waveform: np.ndarray, audio_id):
|
||||
""" Performs separation with librosa backend for STFT.
|
||||
"""
|
||||
Performs separation with librosa backend for STFT.
|
||||
|
||||
Parameters:
|
||||
waveform (numpy.ndarray):
|
||||
Waveform to be separated (as a numpy array)
|
||||
"""
|
||||
with self._tf_graph.as_default():
|
||||
out = {}
|
||||
@@ -260,12 +277,42 @@ class Separator(object):
|
||||
length=waveform.shape[0])
|
||||
return out
|
||||
|
||||
def separate(self, waveform: np.ndarray, audio_descriptor=''):
|
||||
""" Performs separation on a waveform.
|
||||
def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
|
||||
"""
|
||||
Performs source separation over the given waveform with tensorflow
|
||||
backend.
|
||||
|
||||
:param waveform: Waveform to be separated (as a numpy array)
|
||||
:param audio_descriptor: (Optional) string describing the waveform
|
||||
(e.g. filename).
|
||||
Parameters:
|
||||
waveform (numpy.ndarray):
|
||||
Waveform to be separated (as a numpy array)
|
||||
|
||||
Returns:
|
||||
Separated waveforms.
|
||||
"""
|
||||
if not waveform.shape[-1] == 2:
|
||||
waveform = to_stereo(waveform)
|
||||
prediction_generator = self._get_prediction_generator()
|
||||
# NOTE: update data in generator before performing separation.
|
||||
self._data_generator.update_data({
|
||||
'waveform': waveform,
|
||||
'audio_id': np.array(audio_descriptor)})
|
||||
# NOTE: perform separation.
|
||||
prediction = next(prediction_generator)
|
||||
prediction.pop('audio_id')
|
||||
return prediction
|
||||
|
||||
def separate(
|
||||
self,
|
||||
waveform: np.ndarray,
|
||||
audio_descriptor: Optional[str] = None) -> None:
|
||||
"""
|
||||
Performs separation on a waveform.
|
||||
|
||||
Parameters:
|
||||
waveform (numpy.ndarray):
|
||||
Waveform to be separated (as a numpy array)
|
||||
audio_descriptor (str):
|
||||
(Optional) string describing the waveform (e.g. filename).
|
||||
"""
|
||||
if self._params['stft_backend'] == 'tensorflow':
|
||||
return self._separate_tensorflow(waveform, audio_descriptor)
|
||||
|
||||
Reference in New Issue
Block a user