Replace predictor by estimator

This commit is contained in:
romi1502
2020-07-01 15:49:32 +02:00
parent dd7ce237ed
commit e2937f7898

View File

@@ -27,7 +27,7 @@ from . import SpleeterError
from .audio.adapter import get_default_audio_adapter
from .audio.convertor import to_stereo
from .utils.configuration import load_configuration
from .utils.estimator import create_estimator, to_predictor, get_default_model_dir
from .utils.estimator import create_estimator, get_default_model_dir
from .model import EstimatorSpecBuilder, InputProviderFactory
@@ -40,8 +40,34 @@ logger = logging.getLogger("spleeter")
class DataGenerator():
"""
generator object that store a sample and generate it once while called.
Used to feed a tensorflow estimator without knowing the whole data at build time.
"""
def __init__(self):
self._current_data = None
def update_data(self, data):
"""
replace data
"""
self._current_data = data
def __call__(self):
res = self._current_data
while res is not None:
yield res
res = self._current_data
def get_backend(backend):
assert backend in ["auto", "tensorflow", "librosa"]
# print("USING TENSORFLOW BACKEND !!!!!!")
# return "tensorflow"
if backend == "auto":
return "tensorflow" if tf.test.is_gpu_available() else "librosa"
return backend
@@ -61,7 +87,7 @@ class Separator(object):
self._sample_rate = self._params['sample_rate']
self._MWF = MWF
self._tf_graph = tf.Graph()
self._predictor = None
self._prediction_generator = None
self._input_provider = None
self._builder = None
self._features = None
@@ -69,20 +95,30 @@ class Separator(object):
self._pool = Pool() if multiprocess else None
self._tasks = []
self._params["stft_backend"] = get_backend(stft_backend)
self._data_generator = DataGenerator()
def __del__(self):
if self._session:
self._session.close()
def _get_predictor(self):
""" Lazy loading access method for internal predictor instance.
:returns: Predictor to use for source separation.
def _get_prediction_generator(self):
"""
if self._predictor is None:
Lazy loading access method for internal prediction generator returned by the predict method of a tensorflow estimator.
:returns: generator of prediction.
"""
if self._prediction_generator is None:
estimator = create_estimator(self._params, self._MWF)
self._predictor = to_predictor(estimator)
return self._predictor
def get_dataset():
return tf.data.Dataset.from_generator(self._data_generator, output_types={"waveform":tf.float32, "audio_id":tf.string}, output_shapes={"waveform":(None,2),"audio_id":()})
self._prediction_generator = estimator.predict(get_dataset,
yield_single_examples=False)
return self._prediction_generator
def join(self, timeout=200):
""" Wait for all pending tasks to be finished.
@@ -103,10 +139,14 @@ class Separator(object):
"""
if not waveform.shape[-1] == 2:
waveform = to_stereo(waveform)
predictor = self._get_predictor()
prediction = predictor({
'waveform': waveform,
'audio_id': audio_descriptor})
prediction_generator = self._get_prediction_generator()
# update data in generator before performing separation
self._data_generator.update_data({"waveform": waveform,
'audio_id': np.array(audio_descriptor)})
# perform separation
prediction = next(prediction_generator)
prediction.pop('audio_id')
return prediction
@@ -155,9 +195,9 @@ class Separator(object):
def _get_session(self):
if self._session is None:
saver = tf.train.Saver()
saver = tf.compat.v1.train.Saver()
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
self._session = tf.Session()
self._session = tf.compat.v1.Session()
saver.restore(self._session, latest_checkpoint)
return self._session