From c0e45dee720899ed1a9776aa6292e273e05b665e Mon Sep 17 00:00:00 2001 From: Faylixe Date: Fri, 2 Oct 2020 17:29:59 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20=20add=20pool=20close=20callback?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- spleeter/separator.py | 229 ++++++++++++++++++++++++++---------------- 1 file changed, 141 insertions(+), 88 deletions(-) diff --git a/spleeter/separator.py b/spleeter/separator.py index 5c57aa8..131bdac 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -12,14 +12,18 @@ >>> separator.separate_to_file(...) """ +import atexit import os import logging -from time import time from multiprocessing import Pool from os.path import basename, join, splitext, dirname +from time import time +from typing import Container, NoReturn + import numpy as np import tensorflow as tf + from librosa.core import stft, istft from scipy.signal.windows import hann @@ -30,59 +34,63 @@ from .utils.configuration import load_configuration from .utils.estimator import create_estimator, get_default_model_dir from .model import EstimatorSpecBuilder, InputProviderFactory - __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' - -logger = logging.getLogger("spleeter") - +SUPPORTED_BACKEND: Container[str] = ('auto', 'tensorflow', 'librosa') +""" """ class DataGenerator(): """ - generator object that store a sample and generate it once while called. - Used to feed a tensorflow estimator without knowing the whole data at build time. - + Generator object that store a sample and generate it once while called. + Used to feed a tensorflow estimator without knowing the whole data at + build time. """ def __init__(self): + """ Default constructor. """ self._current_data = None def update_data(self, data): - """ - replace data - """ + """ Replace internal data. """ self._current_data = data def __call__(self): - res = self._current_data - while res is not None: - yield res - res = self._current_data + """ Generation process. """ + buffer = self._current_data + while buffer: + yield buffer + buffer = self._current_data - -def get_backend(backend): - assert backend in ["auto", "tensorflow", "librosa"] - # print("USING TENSORFLOW BACKEND !!!!!!") - # return "tensorflow" - if backend == "auto": - return "tensorflow" if len(tf.config.list_physical_devices('GPU')) else "librosa" +def get_backend(backend: str) -> str: + """ + """ + if backend not in SUPPORTED_BACKEND: + raise ValueError(f'Unsupported backend {backend}') + if backend == 'auto': + if len(tf.config.list_physical_devices('GPU')): + return 'tensorflow' + return 'librosa' return backend class Separator(object): """ A wrapper class for performing separation. """ - def __init__(self, params_descriptor, MWF=False, stft_backend="auto", multiprocess=True): + def __init__( + self, + params_descriptor, + MWF: bool = False, + stft_backend: str = 'auto', + multiprocess: bool = True): """ Default constructor. :param params_descriptor: Descriptor for TF params to be used. :param MWF: (Optional) True if MWF should be used, False otherwise. """ - self._params = load_configuration(params_descriptor) self._sample_rate = self._params['sample_rate'] self._MWF = MWF @@ -92,35 +100,45 @@ class Separator(object): self._builder = None self._features = None self._session = None - self._pool = Pool() if multiprocess else None + if multiprocess: + self._pool = Pool() + atexit.register(self._pool.close) + else: + self._pool = None self._tasks = [] - self._params["stft_backend"] = get_backend(stft_backend) + self._params['stft_backend'] = get_backend(stft_backend) self._data_generator = DataGenerator() - def __del__(self): - + """ """ if self._session: self._session.close() def _get_prediction_generator(self): - """ - Lazy loading access method for internal prediction generator returned by the predict method of a tensorflow estimator. + """ Lazy loading access method for internal prediction generator + returned by the predict method of a tensorflow estimator. - :returns: generator of prediction. + :returns: generator of prediction. """ - if self._prediction_generator is None: estimator = create_estimator(self._params, self._MWF) + def get_dataset(): - return tf.data.Dataset.from_generator(self._data_generator, output_types={"waveform":tf.float32, "audio_id":tf.string}, output_shapes={"waveform":(None,2),"audio_id":()}) - self._prediction_generator = estimator.predict(get_dataset, - yield_single_examples=False) - + return tf.data.Dataset.from_generator( + self._data_generator, + output_types={ + 'waveform': tf.float32, + 'audio_id': tf.string}, + output_shapes={ + 'waveform': (None, 2), + 'audio_id': ()}) + self._prediction_generator = estimator.predict( + get_dataset, + yield_single_examples=False) return self._prediction_generator - def join(self, timeout=200): + def join(self, timeout: int = 200) -> NoReturn: """ Wait for all pending tasks to be finished. :param timeout: (Optional) task waiting timeout. @@ -130,9 +148,9 @@ class Separator(object): task.get() task.wait(timeout=timeout) - def _separate_tensorflow(self, waveform, audio_descriptor): - """ - Performs source separation over the given waveform with tensorflow backend. + def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor): + """ Performs source separation over the given waveform with tensorflow + backend. :param waveform: Waveform to apply separation on. :returns: Separated waveforms. @@ -140,38 +158,42 @@ class Separator(object): if not waveform.shape[-1] == 2: waveform = to_stereo(waveform) prediction_generator = self._get_prediction_generator() - - # update data in generator before performing separation - self._data_generator.update_data({"waveform": waveform, - 'audio_id': np.array(audio_descriptor)}) - - # perform separation + # NOTE: update data in generator before performing separation. + self._data_generator.update_data({ + 'waveform': waveform, + 'audio_id': np.array(audio_descriptor)}) + # NOTE: perform separation. prediction = next(prediction_generator) prediction.pop('audio_id') return prediction - def _stft(self, data, inverse=False, length=None): - """ - Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two - channels are processed separately and are concatenated together in the result. The expected input formats are: - (n_samples, 2) for stft and (T, F, 2) for istft. - :param data: np.array with either the waveform or the complex spectrogram depending on the parameter inverse + def _stft(self, data, inverse: bool = False, length=None): + """ Single entrypoint for both stft and istft. This computes stft and + istft with librosa on stereo data. The two channels are processed + separately and are concatenated together in the result. The expected + input formats are: (n_samples, 2) for stft and (T, F, 2) for istft. + + :param data: np.array with either the waveform or the complex + spectrogram depending on the parameter inverse :param inverse: should a stft or an istft be computed. - :return: Stereo data as numpy array for the transform. The channels are stored in the last dimension + :returns: Stereo data as numpy array for the transform. + The channels are stored in the last dimension. """ assert not (inverse and length is None) data = np.asfortranarray(data) - N = self._params["frame_length"] - H = self._params["frame_step"] - + N = self._params['frame_length'] + H = self._params['frame_step'] win = hann(N, sym=False) fstft = istft if inverse else stft - win_len_arg = {"win_length": None, - "length": None} if inverse else {"n_fft": N} + win_len_arg = { + 'win_length': None, + 'length': None} if inverse else {'n_fft': N} n_channels = data.shape[-1] out = [] for c in range(n_channels): - d = np.concatenate((np.zeros((N, )), data[:, c], np.zeros((N, )))) if not inverse else data[:, :, c].T + d = np.concatenate( + (np.zeros((N, )), data[:, c], np.zeros((N, ))) + ) if not inverse else data[:, :, c].T s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg) if inverse: s = s[N:N+length] @@ -181,7 +203,6 @@ class Separator(object): return out[0] return np.concatenate(out, axis=2-inverse) - def _get_input_provider(self): if self._input_provider is None: self._input_provider = InputProviderFactory.get(self._params) @@ -189,66 +210,83 @@ class Separator(object): def _get_features(self): if self._features is None: - self._features = self._get_input_provider().get_input_dict_placeholders() + provider = self._get_input_provider() + self._features = provider.get_input_dict_placeholders() return self._features def _get_builder(self): if self._builder is None: - self._builder = EstimatorSpecBuilder(self._get_features(), self._params) + self._builder = EstimatorSpecBuilder( + self._get_features(), + self._params) return self._builder def _get_session(self): if self._session is None: saver = tf.compat.v1.train.Saver() - latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir'])) + latest_checkpoint = tf.train.latest_checkpoint( + get_default_model_dir(self._params['model_dir'])) self._session = tf.compat.v1.Session() saver.restore(self._session, latest_checkpoint) return self._session - def _separate_librosa(self, waveform, audio_id): - """ - Performs separation with librosa backend for STFT. + def _separate_librosa(self, waveform: np.ndarray, audio_id): + """ Performs separation with librosa backend for STFT. """ with self._tf_graph.as_default(): out = {} features = self._get_features() - - # TODO: fix the logic, build sometimes return, sometimes set attribute + # TODO: fix the logic, build sometimes return, + # sometimes set attribute. outputs = self._get_builder().outputs stft = self._stft(waveform) if stft.shape[-1] == 1: stft = np.concatenate([stft, stft], axis=-1) elif stft.shape[-1] > 2: stft = stft[:, :2] - sess = self._get_session() - outputs = sess.run(outputs, feed_dict=self._get_input_provider().get_feed_dict(features, stft, audio_id)) + outputs = sess.run( + outputs, + feed_dict=self._get_input_provider().get_feed_dict( + features, + stft, + audio_id)) for inst in self._get_builder().instruments: - out[inst] = self._stft(outputs[inst], inverse=True, length=waveform.shape[0]) + out[inst] = self._stft( + outputs[inst], + inverse=True, + length=waveform.shape[0]) return out - def separate(self, waveform, audio_descriptor=""): + def separate(self, waveform: np.ndarray, audio_descriptor=''): """ Performs separation on a waveform. :param waveform: Waveform to be separated (as a numpy array) - :param audio_descriptor: (Optional) string describing the waveform (e.g. filename). + :param audio_descriptor: (Optional) string describing the waveform + (e.g. filename). """ - if self._params["stft_backend"] == "tensorflow": + if self._params['stft_backend'] == 'tensorflow': return self._separate_tensorflow(waveform, audio_descriptor) else: return self._separate_librosa(waveform, audio_descriptor) def separate_to_file( - self, audio_descriptor, destination, + self, + audio_descriptor, + destination, audio_adapter=get_default_audio_adapter(), - offset=0, duration=600., codec='wav', bitrate='128k', + offset=0, + duration=600., + codec='wav', + bitrate='128k', filename_format='{filename}/{instrument}.{codec}', synchronous=True): """ Performs source separation and export result to file using given audio adapter. Filename format should be a Python formattable string that could use - following parameters : {instrument}, {filename}, {foldername} and {codec}. + following parameters : {instrument}, {filename}, {foldername} and + {codec}. :param audio_descriptor: Describe song to separate, used by audio adapter to retrieve and load audio data, @@ -257,8 +295,8 @@ class Separator(object): :param destination: Target directory to write output to. :param audio_adapter: (Optional) Audio adapter to use for I/O. :param offset: (Optional) Offset of loaded song. - :param duration: (Optional) Duration of loaded song (default: - 600s). + :param duration: (Optional) Duration of loaded song + (default: 600s). :param codec: (Optional) Export codec. :param bitrate: (Optional) Export bitrate. :param filename_format: (Optional) Filename format. @@ -270,16 +308,27 @@ class Separator(object): duration=duration, sample_rate=self._sample_rate) sources = self.separate(waveform, audio_descriptor) - self.save_to_file( sources, audio_descriptor, destination, - filename_format, codec, audio_adapter, - bitrate, synchronous) + self.save_to_file( + sources, + audio_descriptor, + destination, + filename_format, + codec, + audio_adapter, + bitrate, + synchronous) def save_to_file( - self, sources, audio_descriptor, destination, + self, + sources, + audio_descriptor, + destination, filename_format='{filename}/{instrument}.{codec}', - codec='wav', audio_adapter=get_default_audio_adapter(), - bitrate='128k', synchronous=True): - """ export dictionary of sources to files. + codec='wav', + audio_adapter=get_default_audio_adapter(), + bitrate='128k', + synchronous=True): + """ Export dictionary of sources to files. :param sources: Dictionary of sources to be exported. The keys are the name of the instruments, and @@ -298,7 +347,6 @@ class Separator(object): :param synchronous: (Optional) True is should by synchronous. """ - foldername = basename(dirname(audio_descriptor)) filename = splitext(basename(audio_descriptor))[0] generated = [] @@ -326,6 +374,11 @@ class Separator(object): bitrate)) self._tasks.append(task) else: - audio_adapter.save(path, data, self._sample_rate, codec, bitrate) + audio_adapter.save( + path, + data, + self._sample_rate, + codec, + bitrate) if synchronous and self._pool: self.join()