Moving get_backend to separator

This commit is contained in:
akhlif
2020-02-27 14:13:59 +01:00
parent 6001ae12a9
commit d177525ea7
3 changed files with 25 additions and 14 deletions

View File

@@ -39,6 +39,14 @@ __license__ = 'MIT License'
logger = logging.getLogger("spleeter")
def get_backend(backend):
assert backend in ["auto", "tensorflow", "librosa"]
if backend == "auto":
return "tensorflow" if tf.test.is_gpu_available() else "librosa"
return backend
class Separator(object):
""" A wrapper class for performing separation. """
@@ -48,13 +56,14 @@ class Separator(object):
:param params_descriptor: Descriptor for TF params to be used.
:param MWF: (Optional) True if MWF should be used, False otherwise.
"""
self._params = load_configuration(params_descriptor)
self._sample_rate = self._params['sample_rate']
self._MWF = MWF
self._predictor = None
self._pool = Pool() if multiprocess else None
self._tasks = []
self._params["stft_backend"] = stft_backend
self._params["stft_backend"] = get_backend(stft_backend)
def _get_predictor(self):
""" Lazy loading access method for internal predictor instance.