🐛 add pool close callback

This commit is contained in:
Faylixe
2020-10-02 17:29:59 +02:00
parent 4a629c6bac
commit c0e45dee72

View File

@@ -12,14 +12,18 @@
>>> separator.separate_to_file(...)
"""
import atexit
import os
import logging
from time import time
from multiprocessing import Pool
from os.path import basename, join, splitext, dirname
from time import time
from typing import Container, NoReturn
import numpy as np
import tensorflow as tf
from librosa.core import stft, istft
from scipy.signal.windows import hann
@@ -30,59 +34,63 @@ from .utils.configuration import load_configuration
from .utils.estimator import create_estimator, get_default_model_dir
from .model import EstimatorSpecBuilder, InputProviderFactory
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
logger = logging.getLogger("spleeter")
SUPPORTED_BACKEND: Container[str] = ('auto', 'tensorflow', 'librosa')
""" """
class DataGenerator():
"""
generator object that store a sample and generate it once while called.
Used to feed a tensorflow estimator without knowing the whole data at build time.
Generator object that store a sample and generate it once while called.
Used to feed a tensorflow estimator without knowing the whole data at
build time.
"""
def __init__(self):
""" Default constructor. """
self._current_data = None
def update_data(self, data):
"""
replace data
"""
""" Replace internal data. """
self._current_data = data
def __call__(self):
res = self._current_data
while res is not None:
yield res
res = self._current_data
""" Generation process. """
buffer = self._current_data
while buffer:
yield buffer
buffer = self._current_data
def get_backend(backend):
assert backend in ["auto", "tensorflow", "librosa"]
# print("USING TENSORFLOW BACKEND !!!!!!")
# return "tensorflow"
if backend == "auto":
return "tensorflow" if len(tf.config.list_physical_devices('GPU')) else "librosa"
def get_backend(backend: str) -> str:
"""
"""
if backend not in SUPPORTED_BACKEND:
raise ValueError(f'Unsupported backend {backend}')
if backend == 'auto':
if len(tf.config.list_physical_devices('GPU')):
return 'tensorflow'
return 'librosa'
return backend
class Separator(object):
""" A wrapper class for performing separation. """
def __init__(self, params_descriptor, MWF=False, stft_backend="auto", multiprocess=True):
def __init__(
self,
params_descriptor,
MWF: bool = False,
stft_backend: str = 'auto',
multiprocess: bool = True):
""" Default constructor.
:param params_descriptor: Descriptor for TF params to be used.
:param MWF: (Optional) True if MWF should be used, False otherwise.
"""
self._params = load_configuration(params_descriptor)
self._sample_rate = self._params['sample_rate']
self._MWF = MWF
@@ -92,35 +100,45 @@ class Separator(object):
self._builder = None
self._features = None
self._session = None
self._pool = Pool() if multiprocess else None
if multiprocess:
self._pool = Pool()
atexit.register(self._pool.close)
else:
self._pool = None
self._tasks = []
self._params["stft_backend"] = get_backend(stft_backend)
self._params['stft_backend'] = get_backend(stft_backend)
self._data_generator = DataGenerator()
def __del__(self):
""" """
if self._session:
self._session.close()
def _get_prediction_generator(self):
"""
Lazy loading access method for internal prediction generator returned by the predict method of a tensorflow estimator.
""" Lazy loading access method for internal prediction generator
returned by the predict method of a tensorflow estimator.
:returns: generator of prediction.
:returns: generator of prediction.
"""
if self._prediction_generator is None:
estimator = create_estimator(self._params, self._MWF)
def get_dataset():
return tf.data.Dataset.from_generator(self._data_generator, output_types={"waveform":tf.float32, "audio_id":tf.string}, output_shapes={"waveform":(None,2),"audio_id":()})
self._prediction_generator = estimator.predict(get_dataset,
yield_single_examples=False)
return tf.data.Dataset.from_generator(
self._data_generator,
output_types={
'waveform': tf.float32,
'audio_id': tf.string},
output_shapes={
'waveform': (None, 2),
'audio_id': ()})
self._prediction_generator = estimator.predict(
get_dataset,
yield_single_examples=False)
return self._prediction_generator
def join(self, timeout=200):
def join(self, timeout: int = 200) -> NoReturn:
""" Wait for all pending tasks to be finished.
:param timeout: (Optional) task waiting timeout.
@@ -130,9 +148,9 @@ class Separator(object):
task.get()
task.wait(timeout=timeout)
def _separate_tensorflow(self, waveform, audio_descriptor):
"""
Performs source separation over the given waveform with tensorflow backend.
def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
""" Performs source separation over the given waveform with tensorflow
backend.
:param waveform: Waveform to apply separation on.
:returns: Separated waveforms.
@@ -140,38 +158,42 @@ class Separator(object):
if not waveform.shape[-1] == 2:
waveform = to_stereo(waveform)
prediction_generator = self._get_prediction_generator()
# update data in generator before performing separation
self._data_generator.update_data({"waveform": waveform,
'audio_id': np.array(audio_descriptor)})
# perform separation
# NOTE: update data in generator before performing separation.
self._data_generator.update_data({
'waveform': waveform,
'audio_id': np.array(audio_descriptor)})
# NOTE: perform separation.
prediction = next(prediction_generator)
prediction.pop('audio_id')
return prediction
def _stft(self, data, inverse=False, length=None):
"""
Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two
channels are processed separately and are concatenated together in the result. The expected input formats are:
(n_samples, 2) for stft and (T, F, 2) for istft.
:param data: np.array with either the waveform or the complex spectrogram depending on the parameter inverse
def _stft(self, data, inverse: bool = False, length=None):
""" Single entrypoint for both stft and istft. This computes stft and
istft with librosa on stereo data. The two channels are processed
separately and are concatenated together in the result. The expected
input formats are: (n_samples, 2) for stft and (T, F, 2) for istft.
:param data: np.array with either the waveform or the complex
spectrogram depending on the parameter inverse
:param inverse: should a stft or an istft be computed.
:return: Stereo data as numpy array for the transform. The channels are stored in the last dimension
:returns: Stereo data as numpy array for the transform.
The channels are stored in the last dimension.
"""
assert not (inverse and length is None)
data = np.asfortranarray(data)
N = self._params["frame_length"]
H = self._params["frame_step"]
N = self._params['frame_length']
H = self._params['frame_step']
win = hann(N, sym=False)
fstft = istft if inverse else stft
win_len_arg = {"win_length": None,
"length": None} if inverse else {"n_fft": N}
win_len_arg = {
'win_length': None,
'length': None} if inverse else {'n_fft': N}
n_channels = data.shape[-1]
out = []
for c in range(n_channels):
d = np.concatenate((np.zeros((N, )), data[:, c], np.zeros((N, )))) if not inverse else data[:, :, c].T
d = np.concatenate(
(np.zeros((N, )), data[:, c], np.zeros((N, )))
) if not inverse else data[:, :, c].T
s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg)
if inverse:
s = s[N:N+length]
@@ -181,7 +203,6 @@ class Separator(object):
return out[0]
return np.concatenate(out, axis=2-inverse)
def _get_input_provider(self):
if self._input_provider is None:
self._input_provider = InputProviderFactory.get(self._params)
@@ -189,66 +210,83 @@ class Separator(object):
def _get_features(self):
if self._features is None:
self._features = self._get_input_provider().get_input_dict_placeholders()
provider = self._get_input_provider()
self._features = provider.get_input_dict_placeholders()
return self._features
def _get_builder(self):
if self._builder is None:
self._builder = EstimatorSpecBuilder(self._get_features(), self._params)
self._builder = EstimatorSpecBuilder(
self._get_features(),
self._params)
return self._builder
def _get_session(self):
if self._session is None:
saver = tf.compat.v1.train.Saver()
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
latest_checkpoint = tf.train.latest_checkpoint(
get_default_model_dir(self._params['model_dir']))
self._session = tf.compat.v1.Session()
saver.restore(self._session, latest_checkpoint)
return self._session
def _separate_librosa(self, waveform, audio_id):
"""
Performs separation with librosa backend for STFT.
def _separate_librosa(self, waveform: np.ndarray, audio_id):
""" Performs separation with librosa backend for STFT.
"""
with self._tf_graph.as_default():
out = {}
features = self._get_features()
# TODO: fix the logic, build sometimes return, sometimes set attribute
# TODO: fix the logic, build sometimes return,
# sometimes set attribute.
outputs = self._get_builder().outputs
stft = self._stft(waveform)
if stft.shape[-1] == 1:
stft = np.concatenate([stft, stft], axis=-1)
elif stft.shape[-1] > 2:
stft = stft[:, :2]
sess = self._get_session()
outputs = sess.run(outputs, feed_dict=self._get_input_provider().get_feed_dict(features, stft, audio_id))
outputs = sess.run(
outputs,
feed_dict=self._get_input_provider().get_feed_dict(
features,
stft,
audio_id))
for inst in self._get_builder().instruments:
out[inst] = self._stft(outputs[inst], inverse=True, length=waveform.shape[0])
out[inst] = self._stft(
outputs[inst],
inverse=True,
length=waveform.shape[0])
return out
def separate(self, waveform, audio_descriptor=""):
def separate(self, waveform: np.ndarray, audio_descriptor=''):
""" Performs separation on a waveform.
:param waveform: Waveform to be separated (as a numpy array)
:param audio_descriptor: (Optional) string describing the waveform (e.g. filename).
:param audio_descriptor: (Optional) string describing the waveform
(e.g. filename).
"""
if self._params["stft_backend"] == "tensorflow":
if self._params['stft_backend'] == 'tensorflow':
return self._separate_tensorflow(waveform, audio_descriptor)
else:
return self._separate_librosa(waveform, audio_descriptor)
def separate_to_file(
self, audio_descriptor, destination,
self,
audio_descriptor,
destination,
audio_adapter=get_default_audio_adapter(),
offset=0, duration=600., codec='wav', bitrate='128k',
offset=0,
duration=600.,
codec='wav',
bitrate='128k',
filename_format='{filename}/{instrument}.{codec}',
synchronous=True):
""" Performs source separation and export result to file using
given audio adapter.
Filename format should be a Python formattable string that could use
following parameters : {instrument}, {filename}, {foldername} and {codec}.
following parameters : {instrument}, {filename}, {foldername} and
{codec}.
:param audio_descriptor: Describe song to separate, used by audio
adapter to retrieve and load audio data,
@@ -257,8 +295,8 @@ class Separator(object):
:param destination: Target directory to write output to.
:param audio_adapter: (Optional) Audio adapter to use for I/O.
:param offset: (Optional) Offset of loaded song.
:param duration: (Optional) Duration of loaded song (default:
600s).
:param duration: (Optional) Duration of loaded song
(default: 600s).
:param codec: (Optional) Export codec.
:param bitrate: (Optional) Export bitrate.
:param filename_format: (Optional) Filename format.
@@ -270,16 +308,27 @@ class Separator(object):
duration=duration,
sample_rate=self._sample_rate)
sources = self.separate(waveform, audio_descriptor)
self.save_to_file( sources, audio_descriptor, destination,
filename_format, codec, audio_adapter,
bitrate, synchronous)
self.save_to_file(
sources,
audio_descriptor,
destination,
filename_format,
codec,
audio_adapter,
bitrate,
synchronous)
def save_to_file(
self, sources, audio_descriptor, destination,
self,
sources,
audio_descriptor,
destination,
filename_format='{filename}/{instrument}.{codec}',
codec='wav', audio_adapter=get_default_audio_adapter(),
bitrate='128k', synchronous=True):
""" export dictionary of sources to files.
codec='wav',
audio_adapter=get_default_audio_adapter(),
bitrate='128k',
synchronous=True):
""" Export dictionary of sources to files.
:param sources: Dictionary of sources to be exported. The
keys are the name of the instruments, and
@@ -298,7 +347,6 @@ class Separator(object):
:param synchronous: (Optional) True is should by synchronous.
"""
foldername = basename(dirname(audio_descriptor))
filename = splitext(basename(audio_descriptor))[0]
generated = []
@@ -326,6 +374,11 @@ class Separator(object):
bitrate))
self._tasks.append(task)
else:
audio_adapter.save(path, data, self._sample_rate, codec, bitrate)
audio_adapter.save(
path,
data,
self._sample_rate,
codec,
bitrate)
if synchronous and self._pool:
self.join()