From 232bf0d3b674a523dbb3ffb9d2bd6aba32adafd1 Mon Sep 17 00:00:00 2001 From: Faylixe Date: Tue, 8 Dec 2020 12:48:15 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20=20fix=20AUTO=20backend=20suppor?= =?UTF-8?q?t?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- spleeter/__main__.py | 2 +- spleeter/separator.py | 5 +++++ tests/test_train.py | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/spleeter/__main__.py b/spleeter/__main__.py index 9048097..757b3da 100644 --- a/spleeter/__main__.py +++ b/spleeter/__main__.py @@ -37,7 +37,7 @@ from typer import Exit, Typer # pylint: enable=import-error spleeter: Typer = Typer() -""" """ +""" CLI application. """ @spleeter.command() diff --git a/spleeter/separator.py b/spleeter/separator.py index 7381e5e..bf7e1dc 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -333,6 +333,11 @@ class Separator(object): (Optional) string describing the waveform (e.g. filename). """ backend: str = self._params['stft_backend'] + if backend == STFTBackend.AUTO: + if len(tf.config.list_physical_devices('GPU')): + backend = STFTBackend.TENSORFLOW + else: + backend = STFTBackend.LIBROSA if backend == STFTBackend.TENSORFLOW: return self._separate_tensorflow(waveform, audio_descriptor) elif backend == STFTBackend.LIBROSA: diff --git a/tests/test_train.py b/tests/test_train.py index 3c63b40..7c4bc48 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -93,7 +93,7 @@ def test_train(): TRAIN_CONFIG['model_dir'] = join(path, 'model') TRAIN_CONFIG['training_cache'] = join(path, 'cache', 'training') TRAIN_CONFIG['validation_cache'] = join(path, 'cache', 'validation') - with open('useless_config.json') as stream: + with open('useless_config.json', 'w') as stream: json.dump(TRAIN_CONFIG, stream) # execute training result = runner.invoke(spleeter, [