From 0f89b712ed32504a81f8212ed6a811bdae79cc8e Mon Sep 17 00:00:00 2001 From: romi1502 Date: Tue, 12 May 2020 21:48:41 +0200 Subject: [PATCH 1/2] Avoid multiple checkpoint restauration and add a specific tf graph for each Separator which makes it possible to instanciate several ones --- spleeter/separator.py | 41 ++++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/spleeter/separator.py b/spleeter/separator.py index dff0988..50cb287 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -60,14 +60,20 @@ class Separator(object): self._params = load_configuration(params_descriptor) self._sample_rate = self._params['sample_rate'] self._MWF = MWF + self._tf_graph = tf.Graph() self._predictor = None self._input_provider = None self._builder = None self._features = None + self._session = None self._pool = Pool() if multiprocess else None self._tasks = [] self._params["stft_backend"] = get_backend(stft_backend) + def __del__(self): + if self._session: + self._session.close() + def _get_predictor(self): """ Lazy loading access method for internal predictor instance. @@ -147,30 +153,35 @@ class Separator(object): self._builder = EstimatorSpecBuilder(self._get_features(), self._params) return self._builder + def _get_session(self): + if self._session is None: + saver = tf.train.Saver() + latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir'])) + self._session = tf.Session() + saver.restore(self._session, latest_checkpoint) + return self._session + def _separate_librosa(self, waveform, audio_id): """ Performs separation with librosa backend for STFT. """ - out = {} - features = self._get_features() + with self._tf_graph.as_default(): + out = {} + features = self._get_features() - latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir'])) + # TODO: fix the logic, build sometimes return, sometimes set attribute + outputs = self._get_builder().outputs + stft = self._stft(waveform) + if stft.shape[-1] == 1: + stft = np.concatenate([stft, stft], axis=-1) + elif stft.shape[-1] > 2: + stft = stft[:, :2] - # TODO: fix the logic, build sometimes return, sometimes set attribute - outputs = self._get_builder().outputs - stft = self._stft(waveform) - if stft.shape[-1] == 1: - stft = np.concatenate([stft, stft], axis=-1) - elif stft.shape[-1] > 2: - stft = stft[:, :2] - - saver = tf.train.Saver() - with tf.Session() as sess: - saver.restore(sess, latest_checkpoint) + sess = self._get_session() outputs = sess.run(outputs, feed_dict=self._get_input_provider().get_feed_dict(features, stft, audio_id)) for inst in self._get_builder().instruments: out[inst] = self._stft(outputs[inst], inverse=True, length=waveform.shape[0]) - return out + return out def separate(self, waveform, audio_descriptor=""): """ Performs separation on a waveform. From bf3baf69c0081efadfbbfa76a40aab70398b28e1 Mon Sep 17 00:00:00 2001 From: romi1502 Date: Fri, 15 May 2020 15:53:20 +0200 Subject: [PATCH 2/2] Increased version number --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 7c36ffb..3a9e54e 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ __license__ = 'MIT License' # Default project values. project_name = 'spleeter' -project_version = '1.5.1' +project_version = '1.5.2' tensorflow_dependency = 'tensorflow' tensorflow_version = '1.15.2' here = path.abspath(path.dirname(__file__))