mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-30 20:24:31 +00:00
Moving get_backend to separator
This commit is contained in:
@@ -39,6 +39,14 @@ __license__ = 'MIT License'
|
||||
logger = logging.getLogger("spleeter")
|
||||
|
||||
|
||||
|
||||
def get_backend(backend):
|
||||
assert backend in ["auto", "tensorflow", "librosa"]
|
||||
if backend == "auto":
|
||||
return "tensorflow" if tf.test.is_gpu_available() else "librosa"
|
||||
return backend
|
||||
|
||||
|
||||
class Separator(object):
|
||||
""" A wrapper class for performing separation. """
|
||||
|
||||
@@ -48,13 +56,14 @@ class Separator(object):
|
||||
:param params_descriptor: Descriptor for TF params to be used.
|
||||
:param MWF: (Optional) True if MWF should be used, False otherwise.
|
||||
"""
|
||||
|
||||
self._params = load_configuration(params_descriptor)
|
||||
self._sample_rate = self._params['sample_rate']
|
||||
self._MWF = MWF
|
||||
self._predictor = None
|
||||
self._pool = Pool() if multiprocess else None
|
||||
self._tasks = []
|
||||
self._params["stft_backend"] = stft_backend
|
||||
self._params["stft_backend"] = get_backend(stft_backend)
|
||||
|
||||
def _get_predictor(self):
|
||||
""" Lazy loading access method for internal predictor instance.
|
||||
|
||||
Reference in New Issue
Block a user