Merge pull request #319 from deezer/fix_multiple_call_to_separate

Fix multiple call to Separator.separate by instantiating the tensorflow graph only once in the methods.
This commit is contained in:
Romain Hennequin
2020-04-12 21:20:23 +02:00
committed by GitHub

View File

@@ -61,6 +61,9 @@ class Separator(object):
self._sample_rate = self._params['sample_rate']
self._MWF = MWF
self._predictor = None
self._input_provider = None
self._builder = None
self._features = None
self._pool = Pool() if multiprocess else None
self._tasks = []
self._params["stft_backend"] = get_backend(stft_backend)
@@ -128,19 +131,33 @@ class Separator(object):
return out[0]
return np.concatenate(out, axis=2-inverse)
def _get_input_provider(self):
if self._input_provider is None:
self._input_provider = InputProviderFactory.get(self._params)
return self._input_provider
def _get_features(self):
if self._features is None:
self._features = self._get_input_provider().get_input_dict_placeholders()
return self._features
def _get_builder(self):
if self._builder is None:
self._builder = EstimatorSpecBuilder(self._get_features(), self._params)
return self._builder
def _separate_librosa(self, waveform, audio_id):
"""
Performs separation with librosa backend for STFT.
"""
out = {}
input_provider = InputProviderFactory.get(self._params)
features = input_provider.get_input_dict_placeholders()
features = self._get_features()
builder = EstimatorSpecBuilder(features, self._params)
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
# TODO: fix the logic, build sometimes return, sometimes set attribute
outputs = builder.outputs
outputs = self._get_builder().outputs
stft = self._stft(waveform)
if stft.shape[-1] == 1:
stft = np.concatenate([stft, stft], axis=-1)
@@ -150,8 +167,8 @@ class Separator(object):
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, latest_checkpoint)
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))
for inst in builder.instruments:
outputs = sess.run(outputs, feed_dict=self._get_input_provider().get_feed_dict(features, stft, audio_id))
for inst in self._get_builder().instruments:
out[inst] = self._stft(outputs[inst], inverse=True, length=waveform.shape[0])
return out