Avoid multiple checkpoint restauration and add a specific tf graph for each Separator which makes it possible to instanciate several ones

This commit is contained in:
romi1502
2020-05-12 21:48:41 +02:00
parent 1c1ff80c0b
commit 0f89b712ed

View File

@@ -60,14 +60,20 @@ class Separator(object):
self._params = load_configuration(params_descriptor)
self._sample_rate = self._params['sample_rate']
self._MWF = MWF
self._tf_graph = tf.Graph()
self._predictor = None
self._input_provider = None
self._builder = None
self._features = None
self._session = None
self._pool = Pool() if multiprocess else None
self._tasks = []
self._params["stft_backend"] = get_backend(stft_backend)
def __del__(self):
if self._session:
self._session.close()
def _get_predictor(self):
""" Lazy loading access method for internal predictor instance.
@@ -147,15 +153,22 @@ class Separator(object):
self._builder = EstimatorSpecBuilder(self._get_features(), self._params)
return self._builder
def _get_session(self):
if self._session is None:
saver = tf.train.Saver()
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
self._session = tf.Session()
saver.restore(self._session, latest_checkpoint)
return self._session
def _separate_librosa(self, waveform, audio_id):
"""
Performs separation with librosa backend for STFT.
"""
with self._tf_graph.as_default():
out = {}
features = self._get_features()
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
# TODO: fix the logic, build sometimes return, sometimes set attribute
outputs = self._get_builder().outputs
stft = self._stft(waveform)
@@ -164,9 +177,7 @@ class Separator(object):
elif stft.shape[-1] > 2:
stft = stft[:, :2]
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, latest_checkpoint)
sess = self._get_session()
outputs = sess.run(outputs, feed_dict=self._get_input_provider().get_feed_dict(features, stft, audio_id))
for inst in self._get_builder().instruments:
out[inst] = self._stft(outputs[inst], inverse=True, length=waveform.shape[0])