mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
Replace predictor by estimator
This commit is contained in:
@@ -27,7 +27,7 @@ from . import SpleeterError
|
|||||||
from .audio.adapter import get_default_audio_adapter
|
from .audio.adapter import get_default_audio_adapter
|
||||||
from .audio.convertor import to_stereo
|
from .audio.convertor import to_stereo
|
||||||
from .utils.configuration import load_configuration
|
from .utils.configuration import load_configuration
|
||||||
from .utils.estimator import create_estimator, to_predictor, get_default_model_dir
|
from .utils.estimator import create_estimator, get_default_model_dir
|
||||||
from .model import EstimatorSpecBuilder, InputProviderFactory
|
from .model import EstimatorSpecBuilder, InputProviderFactory
|
||||||
|
|
||||||
|
|
||||||
@@ -40,8 +40,34 @@ logger = logging.getLogger("spleeter")
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DataGenerator():
|
||||||
|
"""
|
||||||
|
generator object that store a sample and generate it once while called.
|
||||||
|
Used to feed a tensorflow estimator without knowing the whole data at build time.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._current_data = None
|
||||||
|
|
||||||
|
def update_data(self, data):
|
||||||
|
"""
|
||||||
|
replace data
|
||||||
|
"""
|
||||||
|
self._current_data = data
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
res = self._current_data
|
||||||
|
while res is not None:
|
||||||
|
yield res
|
||||||
|
res = self._current_data
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_backend(backend):
|
def get_backend(backend):
|
||||||
assert backend in ["auto", "tensorflow", "librosa"]
|
assert backend in ["auto", "tensorflow", "librosa"]
|
||||||
|
# print("USING TENSORFLOW BACKEND !!!!!!")
|
||||||
|
# return "tensorflow"
|
||||||
if backend == "auto":
|
if backend == "auto":
|
||||||
return "tensorflow" if tf.test.is_gpu_available() else "librosa"
|
return "tensorflow" if tf.test.is_gpu_available() else "librosa"
|
||||||
return backend
|
return backend
|
||||||
@@ -61,7 +87,7 @@ class Separator(object):
|
|||||||
self._sample_rate = self._params['sample_rate']
|
self._sample_rate = self._params['sample_rate']
|
||||||
self._MWF = MWF
|
self._MWF = MWF
|
||||||
self._tf_graph = tf.Graph()
|
self._tf_graph = tf.Graph()
|
||||||
self._predictor = None
|
self._prediction_generator = None
|
||||||
self._input_provider = None
|
self._input_provider = None
|
||||||
self._builder = None
|
self._builder = None
|
||||||
self._features = None
|
self._features = None
|
||||||
@@ -69,20 +95,30 @@ class Separator(object):
|
|||||||
self._pool = Pool() if multiprocess else None
|
self._pool = Pool() if multiprocess else None
|
||||||
self._tasks = []
|
self._tasks = []
|
||||||
self._params["stft_backend"] = get_backend(stft_backend)
|
self._params["stft_backend"] = get_backend(stft_backend)
|
||||||
|
self._data_generator = DataGenerator()
|
||||||
|
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
|
||||||
if self._session:
|
if self._session:
|
||||||
self._session.close()
|
self._session.close()
|
||||||
|
|
||||||
def _get_predictor(self):
|
def _get_prediction_generator(self):
|
||||||
""" Lazy loading access method for internal predictor instance.
|
|
||||||
|
|
||||||
:returns: Predictor to use for source separation.
|
|
||||||
"""
|
"""
|
||||||
if self._predictor is None:
|
Lazy loading access method for internal prediction generator returned by the predict method of a tensorflow estimator.
|
||||||
|
|
||||||
|
:returns: generator of prediction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._prediction_generator is None:
|
||||||
estimator = create_estimator(self._params, self._MWF)
|
estimator = create_estimator(self._params, self._MWF)
|
||||||
self._predictor = to_predictor(estimator)
|
def get_dataset():
|
||||||
return self._predictor
|
return tf.data.Dataset.from_generator(self._data_generator, output_types={"waveform":tf.float32, "audio_id":tf.string}, output_shapes={"waveform":(None,2),"audio_id":()})
|
||||||
|
self._prediction_generator = estimator.predict(get_dataset,
|
||||||
|
yield_single_examples=False)
|
||||||
|
|
||||||
|
|
||||||
|
return self._prediction_generator
|
||||||
|
|
||||||
def join(self, timeout=200):
|
def join(self, timeout=200):
|
||||||
""" Wait for all pending tasks to be finished.
|
""" Wait for all pending tasks to be finished.
|
||||||
@@ -103,10 +139,14 @@ class Separator(object):
|
|||||||
"""
|
"""
|
||||||
if not waveform.shape[-1] == 2:
|
if not waveform.shape[-1] == 2:
|
||||||
waveform = to_stereo(waveform)
|
waveform = to_stereo(waveform)
|
||||||
predictor = self._get_predictor()
|
prediction_generator = self._get_prediction_generator()
|
||||||
prediction = predictor({
|
|
||||||
'waveform': waveform,
|
# update data in generator before performing separation
|
||||||
'audio_id': audio_descriptor})
|
self._data_generator.update_data({"waveform": waveform,
|
||||||
|
'audio_id': np.array(audio_descriptor)})
|
||||||
|
|
||||||
|
# perform separation
|
||||||
|
prediction = next(prediction_generator)
|
||||||
prediction.pop('audio_id')
|
prediction.pop('audio_id')
|
||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
@@ -155,9 +195,9 @@ class Separator(object):
|
|||||||
|
|
||||||
def _get_session(self):
|
def _get_session(self):
|
||||||
if self._session is None:
|
if self._session is None:
|
||||||
saver = tf.train.Saver()
|
saver = tf.compat.v1.train.Saver()
|
||||||
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
|
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
|
||||||
self._session = tf.Session()
|
self._session = tf.compat.v1.Session()
|
||||||
saver.restore(self._session, latest_checkpoint)
|
saver.restore(self._session, latest_checkpoint)
|
||||||
return self._session
|
return self._session
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user