diff --git a/spleeter/audio/__init__.py b/spleeter/audio/__init__.py index 18efade..06ecceb 100644 --- a/spleeter/audio/__init__.py +++ b/spleeter/audio/__init__.py @@ -12,6 +12,11 @@ from enum import Enum +# pyright: reportMissingImports=false +# pylint: disable=import-error +import tensorflow as tf +# pylint: enable=import-error + __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' @@ -34,3 +39,12 @@ class STFTBackend(str, Enum): AUTO: str = 'auto' TENSORFLOW: str = 'tensorflow' LIBROSA: str = 'librosa' + + def resolve(cls: type, backend: str) -> str: + if backend not in cls.__members__.items(): + raise ValueError(f'Unsupported backend {backend}') + if backend == cls.AUTO: + if len(tf.config.list_physical_devices('GPU')): + return cls.TENSORFLOW + return STFTBackend.LIBROSA + return backend diff --git a/spleeter/separator.py b/spleeter/separator.py index bf7e1dc..04d96c9 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -129,8 +129,7 @@ class Separator(object): else: self._pool = None self._tasks = [] - # NOTE: provide type check here ? - self._params['stft_backend'] = stft_backend + self._params['stft_backend'] = STFTBackend.resolve(stft_backend) self._data_generator = DataGenerator() def __del__(self) -> None: @@ -333,11 +332,6 @@ class Separator(object): (Optional) string describing the waveform (e.g. filename). """ backend: str = self._params['stft_backend'] - if backend == STFTBackend.AUTO: - if len(tf.config.list_physical_devices('GPU')): - backend = STFTBackend.TENSORFLOW - else: - backend = STFTBackend.LIBROSA if backend == STFTBackend.TENSORFLOW: return self._separate_tensorflow(waveform, audio_descriptor) elif backend == STFTBackend.LIBROSA: