From e2937f7898d473e0f966bf02621ec12025e55d7e Mon Sep 17 00:00:00 2001 From: romi1502 Date: Wed, 1 Jul 2020 15:49:32 +0200 Subject: [PATCH] Replace predictor by estimator --- spleeter/separator.py | 70 +++++++++++++++++++++++++++++++++---------- 1 file changed, 55 insertions(+), 15 deletions(-) diff --git a/spleeter/separator.py b/spleeter/separator.py index e769c28..12f01b7 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -27,7 +27,7 @@ from . import SpleeterError from .audio.adapter import get_default_audio_adapter from .audio.convertor import to_stereo from .utils.configuration import load_configuration -from .utils.estimator import create_estimator, to_predictor, get_default_model_dir +from .utils.estimator import create_estimator, get_default_model_dir from .model import EstimatorSpecBuilder, InputProviderFactory @@ -40,8 +40,34 @@ logger = logging.getLogger("spleeter") +class DataGenerator(): + """ + generator object that store a sample and generate it once while called. + Used to feed a tensorflow estimator without knowing the whole data at build time. + + """ + + def __init__(self): + self._current_data = None + + def update_data(self, data): + """ + replace data + """ + self._current_data = data + + def __call__(self): + res = self._current_data + while res is not None: + yield res + res = self._current_data + + + def get_backend(backend): assert backend in ["auto", "tensorflow", "librosa"] + # print("USING TENSORFLOW BACKEND !!!!!!") + # return "tensorflow" if backend == "auto": return "tensorflow" if tf.test.is_gpu_available() else "librosa" return backend @@ -61,7 +87,7 @@ class Separator(object): self._sample_rate = self._params['sample_rate'] self._MWF = MWF self._tf_graph = tf.Graph() - self._predictor = None + self._prediction_generator = None self._input_provider = None self._builder = None self._features = None @@ -69,20 +95,30 @@ class Separator(object): self._pool = Pool() if multiprocess else None self._tasks = [] self._params["stft_backend"] = get_backend(stft_backend) + self._data_generator = DataGenerator() + def __del__(self): + if self._session: self._session.close() - def _get_predictor(self): - """ Lazy loading access method for internal predictor instance. - - :returns: Predictor to use for source separation. + def _get_prediction_generator(self): """ - if self._predictor is None: + Lazy loading access method for internal prediction generator returned by the predict method of a tensorflow estimator. + + :returns: generator of prediction. + """ + + if self._prediction_generator is None: estimator = create_estimator(self._params, self._MWF) - self._predictor = to_predictor(estimator) - return self._predictor + def get_dataset(): + return tf.data.Dataset.from_generator(self._data_generator, output_types={"waveform":tf.float32, "audio_id":tf.string}, output_shapes={"waveform":(None,2),"audio_id":()}) + self._prediction_generator = estimator.predict(get_dataset, + yield_single_examples=False) + + + return self._prediction_generator def join(self, timeout=200): """ Wait for all pending tasks to be finished. @@ -103,10 +139,14 @@ class Separator(object): """ if not waveform.shape[-1] == 2: waveform = to_stereo(waveform) - predictor = self._get_predictor() - prediction = predictor({ - 'waveform': waveform, - 'audio_id': audio_descriptor}) + prediction_generator = self._get_prediction_generator() + + # update data in generator before performing separation + self._data_generator.update_data({"waveform": waveform, + 'audio_id': np.array(audio_descriptor)}) + + # perform separation + prediction = next(prediction_generator) prediction.pop('audio_id') return prediction @@ -155,9 +195,9 @@ class Separator(object): def _get_session(self): if self._session is None: - saver = tf.train.Saver() + saver = tf.compat.v1.train.Saver() latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir'])) - self._session = tf.Session() + self._session = tf.compat.v1.Session() saver.restore(self._session, latest_checkpoint) return self._session