mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-06 04:32:43 +00:00
Replace predictor by estimator
This commit is contained in:
@@ -27,7 +27,7 @@ from . import SpleeterError
|
||||
from .audio.adapter import get_default_audio_adapter
|
||||
from .audio.convertor import to_stereo
|
||||
from .utils.configuration import load_configuration
|
||||
from .utils.estimator import create_estimator, to_predictor, get_default_model_dir
|
||||
from .utils.estimator import create_estimator, get_default_model_dir
|
||||
from .model import EstimatorSpecBuilder, InputProviderFactory
|
||||
|
||||
|
||||
@@ -40,8 +40,34 @@ logger = logging.getLogger("spleeter")
|
||||
|
||||
|
||||
|
||||
class DataGenerator():
|
||||
"""
|
||||
generator object that store a sample and generate it once while called.
|
||||
Used to feed a tensorflow estimator without knowing the whole data at build time.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._current_data = None
|
||||
|
||||
def update_data(self, data):
|
||||
"""
|
||||
replace data
|
||||
"""
|
||||
self._current_data = data
|
||||
|
||||
def __call__(self):
|
||||
res = self._current_data
|
||||
while res is not None:
|
||||
yield res
|
||||
res = self._current_data
|
||||
|
||||
|
||||
|
||||
def get_backend(backend):
|
||||
assert backend in ["auto", "tensorflow", "librosa"]
|
||||
# print("USING TENSORFLOW BACKEND !!!!!!")
|
||||
# return "tensorflow"
|
||||
if backend == "auto":
|
||||
return "tensorflow" if tf.test.is_gpu_available() else "librosa"
|
||||
return backend
|
||||
@@ -61,7 +87,7 @@ class Separator(object):
|
||||
self._sample_rate = self._params['sample_rate']
|
||||
self._MWF = MWF
|
||||
self._tf_graph = tf.Graph()
|
||||
self._predictor = None
|
||||
self._prediction_generator = None
|
||||
self._input_provider = None
|
||||
self._builder = None
|
||||
self._features = None
|
||||
@@ -69,20 +95,30 @@ class Separator(object):
|
||||
self._pool = Pool() if multiprocess else None
|
||||
self._tasks = []
|
||||
self._params["stft_backend"] = get_backend(stft_backend)
|
||||
self._data_generator = DataGenerator()
|
||||
|
||||
|
||||
def __del__(self):
|
||||
|
||||
if self._session:
|
||||
self._session.close()
|
||||
|
||||
def _get_predictor(self):
|
||||
""" Lazy loading access method for internal predictor instance.
|
||||
|
||||
:returns: Predictor to use for source separation.
|
||||
def _get_prediction_generator(self):
|
||||
"""
|
||||
if self._predictor is None:
|
||||
Lazy loading access method for internal prediction generator returned by the predict method of a tensorflow estimator.
|
||||
|
||||
:returns: generator of prediction.
|
||||
"""
|
||||
|
||||
if self._prediction_generator is None:
|
||||
estimator = create_estimator(self._params, self._MWF)
|
||||
self._predictor = to_predictor(estimator)
|
||||
return self._predictor
|
||||
def get_dataset():
|
||||
return tf.data.Dataset.from_generator(self._data_generator, output_types={"waveform":tf.float32, "audio_id":tf.string}, output_shapes={"waveform":(None,2),"audio_id":()})
|
||||
self._prediction_generator = estimator.predict(get_dataset,
|
||||
yield_single_examples=False)
|
||||
|
||||
|
||||
return self._prediction_generator
|
||||
|
||||
def join(self, timeout=200):
|
||||
""" Wait for all pending tasks to be finished.
|
||||
@@ -103,10 +139,14 @@ class Separator(object):
|
||||
"""
|
||||
if not waveform.shape[-1] == 2:
|
||||
waveform = to_stereo(waveform)
|
||||
predictor = self._get_predictor()
|
||||
prediction = predictor({
|
||||
'waveform': waveform,
|
||||
'audio_id': audio_descriptor})
|
||||
prediction_generator = self._get_prediction_generator()
|
||||
|
||||
# update data in generator before performing separation
|
||||
self._data_generator.update_data({"waveform": waveform,
|
||||
'audio_id': np.array(audio_descriptor)})
|
||||
|
||||
# perform separation
|
||||
prediction = next(prediction_generator)
|
||||
prediction.pop('audio_id')
|
||||
return prediction
|
||||
|
||||
@@ -155,9 +195,9 @@ class Separator(object):
|
||||
|
||||
def _get_session(self):
|
||||
if self._session is None:
|
||||
saver = tf.train.Saver()
|
||||
saver = tf.compat.v1.train.Saver()
|
||||
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
|
||||
self._session = tf.Session()
|
||||
self._session = tf.compat.v1.Session()
|
||||
saver.restore(self._session, latest_checkpoint)
|
||||
return self._session
|
||||
|
||||
|
||||
Reference in New Issue
Block a user