🐛 fix backend resolution

This commit is contained in:
Faylixe
2020-12-08 13:02:35 +01:00
parent 232bf0d3b6
commit c96897df2c
2 changed files with 15 additions and 7 deletions

View File

@@ -12,6 +12,11 @@
from enum import Enum
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
@@ -34,3 +39,12 @@ class STFTBackend(str, Enum):
AUTO: str = 'auto'
TENSORFLOW: str = 'tensorflow'
LIBROSA: str = 'librosa'
def resolve(cls: type, backend: str) -> str:
if backend not in cls.__members__.items():
raise ValueError(f'Unsupported backend {backend}')
if backend == cls.AUTO:
if len(tf.config.list_physical_devices('GPU')):
return cls.TENSORFLOW
return STFTBackend.LIBROSA
return backend

View File

@@ -129,8 +129,7 @@ class Separator(object):
else:
self._pool = None
self._tasks = []
# NOTE: provide type check here ?
self._params['stft_backend'] = stft_backend
self._params['stft_backend'] = STFTBackend.resolve(stft_backend)
self._data_generator = DataGenerator()
def __del__(self) -> None:
@@ -333,11 +332,6 @@ class Separator(object):
(Optional) string describing the waveform (e.g. filename).
"""
backend: str = self._params['stft_backend']
if backend == STFTBackend.AUTO:
if len(tf.config.list_physical_devices('GPU')):
backend = STFTBackend.TENSORFLOW
else:
backend = STFTBackend.LIBROSA
if backend == STFTBackend.TENSORFLOW:
return self._separate_tensorflow(waveform, audio_descriptor)
elif backend == STFTBackend.LIBROSA: