diff --git a/tests/test_separator.py b/tests/test_separator.py index 76011af..976cf6c 100644 --- a/tests/test_separator.py +++ b/tests/test_separator.py @@ -10,37 +10,39 @@ __license__ = 'MIT License' from os.path import exists, join from tempfile import TemporaryDirectory +import pytest + from spleeter.audio.adapter import get_default_audio_adapter from spleeter.separator import Separator TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3' -TEST_CONFIGURATIONS = { - 'spleeter:2stems': ('vocals', 'accompaniment'), - 'spleeter:4stems': ('vocals', 'drums', 'bass', 'other'), - 'spleeter:5stems': ('vocals', 'drums', 'bass', 'piano', 'other') -} +TEST_CONFIGURATIONS = [ + ('spleeter:2stems', ('vocals', 'accompaniment')), + ('spleeter:4stems', ('vocals', 'drums', 'bass', 'other')), + ('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other')) +] +@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS) def test_separate(): """ Test separation from raw data. """ adapter = get_default_audio_adapter() waveform, _ = adapter.load(TEST_AUDIO_DESCRIPTOR) - for configuration, instruments in TEST_CONFIGURATIONS.items(): - separator = Separator(configuration) - prediction = separator.separate(waveform) - assert len(prediction) == 2 - for instrument in instruments: - assert instrument in prediction + separator = Separator(configuration) + prediction = separator.separate(waveform) + assert len(prediction) == 2 + for instrument in instruments: + assert instrument in prediction -def test_separate_to_file(): +@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS) +def test_separate_to_file(configuration, instruments): """ Test file based separation. """ - for configuration, instruments in TEST_CONFIGURATIONS.items(): - separator = Separator(configuration) - with TemporaryDirectory() as directory: - separator.separate_to_file( - TEST_AUDIO_DESCRIPTOR, - directory) - for instrument in instruments: - assert exists(join(directory, '{}.wav'.format(instrument))) - # TODO: Consider testing generated file as well. + separator = Separator(configuration) + with TemporaryDirectory() as directory: + separator.separate_to_file( + TEST_AUDIO_DESCRIPTOR, + directory) + for instrument in instruments: + assert exists(join(directory, '{}.wav'.format(instrument))) + # TODO: Consider testing generated file as well.