From 1886dde38b9b703439c9f35a699623e5389a58ef Mon Sep 17 00:00:00 2001 From: Faylixe Date: Tue, 8 Dec 2020 12:38:18 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20=20fix=20separator=20test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- spleeter/__main__.py | 2 +- tests/test_separator.py | 22 +++++++++++++++------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/spleeter/__main__.py b/spleeter/__main__.py index 52b90f1..9048097 100644 --- a/spleeter/__main__.py +++ b/spleeter/__main__.py @@ -85,7 +85,7 @@ def train( logger.info('Model training done') -@spleeter.commmand() +@spleeter.command() def separate( files: List[Path] = AudioInputArgument, adapter: str = AudioAdapterOption, diff --git a/tests/test_separator.py b/tests/test_separator.py index e757abf..947b037 100644 --- a/tests/test_separator.py +++ b/tests/test_separator.py @@ -17,7 +17,7 @@ import numpy as np import tensorflow as tf from spleeter import SpleeterError -from spleeter.audio.adapter import get_default_audio_adapter +from spleeter.audio.adapter import AudioAdapter from spleeter.separator import Separator TEST_AUDIO_DESCRIPTORS = ['audio_example.mp3', 'audio_example_mono.mp3'] @@ -41,7 +41,7 @@ print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__)) @pytest.mark.parametrize('test_file', TEST_AUDIO_DESCRIPTORS) def test_separator_backends(test_file): - adapter = get_default_audio_adapter() + adapter = AudioAdapter.default() waveform, _ = adapter.load(test_file) separator_lib = Separator( @@ -64,11 +64,13 @@ def test_separator_backends(test_file): assert np.allclose(out_tf[instrument], out_lib[instrument], atol=1e-5) -@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) +@pytest.mark.parametrize( + 'test_file, configuration, backend', + TEST_CONFIGURATIONS) def test_separate(test_file, configuration, backend): """ Test separation from raw data. """ instruments = MODEL_TO_INST[configuration] - adapter = get_default_audio_adapter() + adapter = AudioAdapter.default() waveform, _ = adapter.load(test_file) separator = Separator( configuration, stft_backend=backend, multiprocess=False) @@ -85,7 +87,9 @@ def test_separate(test_file, configuration, backend): assert not np.allclose(track, prediction[compared]) -@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) +@pytest.mark.parametrize( + 'test_file, configuration, backend', + TEST_CONFIGURATIONS) def test_separate_to_file(test_file, configuration, backend): """ Test file based separation. """ instruments = MODEL_TO_INST[configuration] @@ -102,7 +106,9 @@ def test_separate_to_file(test_file, configuration, backend): '{}/{}.wav'.format(name, instrument))) -@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) +@pytest.mark.parametrize( + 'test_file, configuration, backend', + TEST_CONFIGURATIONS) def test_filename_format(test_file, configuration, backend): """ Test custom filename format. """ instruments = MODEL_TO_INST[configuration] @@ -120,7 +126,9 @@ def test_filename_format(test_file, configuration, backend): 'export/{}/{}.wav'.format(name, instrument))) -@pytest.mark.parametrize('test_file, configuration', MODELS_AND_TEST_FILES) +@pytest.mark.parametrize( + 'test_file, configuration', + MODELS_AND_TEST_FILES) def test_filename_conflict(test_file, configuration): """ Test error handling with static pattern. """ separator = Separator(configuration, multiprocess=False)