Merge pull request #375 from deezer/avoid_multiple_restore_librosa

Fix some issues with librosa stft backend: avoid multiple checkpoint restoration + one tf.graph per Separator. Version bumped to 1.5.2.
This commit is contained in:
Romain Hennequin
2020-05-15 16:11:45 +02:00
committed by GitHub
2 changed files with 27 additions and 16 deletions

View File

@@ -14,7 +14,7 @@ __license__ = 'MIT License'
# Default project values.
project_name = 'spleeter'
project_version = '1.5.1'
project_version = '1.5.2'
tensorflow_dependency = 'tensorflow'
tensorflow_version = '1.15.2'
here = path.abspath(path.dirname(__file__))

View File

@@ -60,14 +60,20 @@ class Separator(object):
self._params = load_configuration(params_descriptor)
self._sample_rate = self._params['sample_rate']
self._MWF = MWF
self._tf_graph = tf.Graph()
self._predictor = None
self._input_provider = None
self._builder = None
self._features = None
self._session = None
self._pool = Pool() if multiprocess else None
self._tasks = []
self._params["stft_backend"] = get_backend(stft_backend)
def __del__(self):
if self._session:
self._session.close()
def _get_predictor(self):
""" Lazy loading access method for internal predictor instance.
@@ -147,30 +153,35 @@ class Separator(object):
self._builder = EstimatorSpecBuilder(self._get_features(), self._params)
return self._builder
def _get_session(self):
if self._session is None:
saver = tf.train.Saver()
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
self._session = tf.Session()
saver.restore(self._session, latest_checkpoint)
return self._session
def _separate_librosa(self, waveform, audio_id):
"""
Performs separation with librosa backend for STFT.
"""
out = {}
features = self._get_features()
with self._tf_graph.as_default():
out = {}
features = self._get_features()
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
# TODO: fix the logic, build sometimes return, sometimes set attribute
outputs = self._get_builder().outputs
stft = self._stft(waveform)
if stft.shape[-1] == 1:
stft = np.concatenate([stft, stft], axis=-1)
elif stft.shape[-1] > 2:
stft = stft[:, :2]
# TODO: fix the logic, build sometimes return, sometimes set attribute
outputs = self._get_builder().outputs
stft = self._stft(waveform)
if stft.shape[-1] == 1:
stft = np.concatenate([stft, stft], axis=-1)
elif stft.shape[-1] > 2:
stft = stft[:, :2]
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, latest_checkpoint)
sess = self._get_session()
outputs = sess.run(outputs, feed_dict=self._get_input_provider().get_feed_dict(features, stft, audio_id))
for inst in self._get_builder().instruments:
out[inst] = self._stft(outputs[inst], inverse=True, length=waveform.shape[0])
return out
return out
def separate(self, waveform, audio_descriptor=""):
""" Performs separation on a waveform.