diff --git a/tests/test_separator.py b/tests/test_separator.py index 85643c1..2c244ad 100644 --- a/tests/test_separator.py +++ b/tests/test_separator.py @@ -32,9 +32,10 @@ MODEL_TO_INST = { TEST_CONFIGURATIONS = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS, BACKENDS)) -@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS) -def test_separate(test_file, configuration, instruments, backend): +@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) +def test_separate(test_file, configuration, backend): """ Test separation from raw data. """ + instruments = MODEL_TO_INST[configuration] adapter = get_default_audio_adapter() waveform, _ = adapter.load(test_file) separator = Separator(configuration, stft_backend=backend) @@ -51,9 +52,10 @@ def test_separate(test_file, configuration, instruments, backend): assert not np.allclose(track, prediction[compared]) -@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS) -def test_separate_to_file(test_file, configuration, instruments, backend): +@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) +def test_separate_to_file(test_file, configuration, backend): """ Test file based separation. """ + instruments = MODEL_TO_INST[configuration] separator = Separator(configuration, stft_backend=backend) basename = splitext(basename(test_file)) with TemporaryDirectory() as directory: @@ -66,9 +68,10 @@ def test_separate_to_file(test_file, configuration, instruments, backend): '{}/{}.wav'.format(basename, instrument))) -@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS) -def test_filename_format(test_file, configuration, instruments, backend): +@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) +def test_filename_format(test_file, configuration, backend): """ Test custom filename format. """ + instruments = MODEL_TO_INST[configuration] separator = Separator(configuration, stft_backend=backend) basename = splitext(basename(test_file)) with TemporaryDirectory() as directory: @@ -82,6 +85,7 @@ def test_filename_format(test_file, configuration, instruments, backend): 'export/{}/{}.wav'.format(basename, instrument))) +@pytest.mark.parametrize('test_file', TEST_AUDIO_DESCRIPTORS) def test_filename_conflict(test_file): """ Test error handling with static pattern. """ separator = Separator(TEST_CONFIGURATIONS[0][0])