mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
Avoid multiple checkpoint restauration and add a specific tf graph for each Separator which makes it possible to instanciate several ones
This commit is contained in:
@@ -60,14 +60,20 @@ class Separator(object):
|
|||||||
self._params = load_configuration(params_descriptor)
|
self._params = load_configuration(params_descriptor)
|
||||||
self._sample_rate = self._params['sample_rate']
|
self._sample_rate = self._params['sample_rate']
|
||||||
self._MWF = MWF
|
self._MWF = MWF
|
||||||
|
self._tf_graph = tf.Graph()
|
||||||
self._predictor = None
|
self._predictor = None
|
||||||
self._input_provider = None
|
self._input_provider = None
|
||||||
self._builder = None
|
self._builder = None
|
||||||
self._features = None
|
self._features = None
|
||||||
|
self._session = None
|
||||||
self._pool = Pool() if multiprocess else None
|
self._pool = Pool() if multiprocess else None
|
||||||
self._tasks = []
|
self._tasks = []
|
||||||
self._params["stft_backend"] = get_backend(stft_backend)
|
self._params["stft_backend"] = get_backend(stft_backend)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self._session:
|
||||||
|
self._session.close()
|
||||||
|
|
||||||
def _get_predictor(self):
|
def _get_predictor(self):
|
||||||
""" Lazy loading access method for internal predictor instance.
|
""" Lazy loading access method for internal predictor instance.
|
||||||
|
|
||||||
@@ -147,30 +153,35 @@ class Separator(object):
|
|||||||
self._builder = EstimatorSpecBuilder(self._get_features(), self._params)
|
self._builder = EstimatorSpecBuilder(self._get_features(), self._params)
|
||||||
return self._builder
|
return self._builder
|
||||||
|
|
||||||
|
def _get_session(self):
|
||||||
|
if self._session is None:
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
|
||||||
|
self._session = tf.Session()
|
||||||
|
saver.restore(self._session, latest_checkpoint)
|
||||||
|
return self._session
|
||||||
|
|
||||||
def _separate_librosa(self, waveform, audio_id):
|
def _separate_librosa(self, waveform, audio_id):
|
||||||
"""
|
"""
|
||||||
Performs separation with librosa backend for STFT.
|
Performs separation with librosa backend for STFT.
|
||||||
"""
|
"""
|
||||||
out = {}
|
with self._tf_graph.as_default():
|
||||||
features = self._get_features()
|
out = {}
|
||||||
|
features = self._get_features()
|
||||||
|
|
||||||
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
|
# TODO: fix the logic, build sometimes return, sometimes set attribute
|
||||||
|
outputs = self._get_builder().outputs
|
||||||
|
stft = self._stft(waveform)
|
||||||
|
if stft.shape[-1] == 1:
|
||||||
|
stft = np.concatenate([stft, stft], axis=-1)
|
||||||
|
elif stft.shape[-1] > 2:
|
||||||
|
stft = stft[:, :2]
|
||||||
|
|
||||||
# TODO: fix the logic, build sometimes return, sometimes set attribute
|
sess = self._get_session()
|
||||||
outputs = self._get_builder().outputs
|
|
||||||
stft = self._stft(waveform)
|
|
||||||
if stft.shape[-1] == 1:
|
|
||||||
stft = np.concatenate([stft, stft], axis=-1)
|
|
||||||
elif stft.shape[-1] > 2:
|
|
||||||
stft = stft[:, :2]
|
|
||||||
|
|
||||||
saver = tf.train.Saver()
|
|
||||||
with tf.Session() as sess:
|
|
||||||
saver.restore(sess, latest_checkpoint)
|
|
||||||
outputs = sess.run(outputs, feed_dict=self._get_input_provider().get_feed_dict(features, stft, audio_id))
|
outputs = sess.run(outputs, feed_dict=self._get_input_provider().get_feed_dict(features, stft, audio_id))
|
||||||
for inst in self._get_builder().instruments:
|
for inst in self._get_builder().instruments:
|
||||||
out[inst] = self._stft(outputs[inst], inverse=True, length=waveform.shape[0])
|
out[inst] = self._stft(outputs[inst], inverse=True, length=waveform.shape[0])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def separate(self, waveform, audio_descriptor=""):
|
def separate(self, waveform, audio_descriptor=""):
|
||||||
""" Performs separation on a waveform.
|
""" Performs separation on a waveform.
|
||||||
|
|||||||
Reference in New Issue
Block a user