🐛 fix backend resolution

This commit is contained in:
Faylixe
2020-12-08 13:02:35 +01:00
parent 232bf0d3b6
commit c96897df2c
2 changed files with 15 additions and 7 deletions

View File

@@ -129,8 +129,7 @@ class Separator(object):
else:
self._pool = None
self._tasks = []
# NOTE: provide type check here ?
self._params['stft_backend'] = stft_backend
self._params['stft_backend'] = STFTBackend.resolve(stft_backend)
self._data_generator = DataGenerator()
def __del__(self) -> None:
@@ -333,11 +332,6 @@ class Separator(object):
(Optional) string describing the waveform (e.g. filename).
"""
backend: str = self._params['stft_backend']
if backend == STFTBackend.AUTO:
if len(tf.config.list_physical_devices('GPU')):
backend = STFTBackend.TENSORFLOW
else:
backend = STFTBackend.LIBROSA
if backend == STFTBackend.TENSORFLOW:
return self._separate_tensorflow(waveform, audio_descriptor)
elif backend == STFTBackend.LIBROSA: