diff --git a/Makefile b/Makefile index 06ceeaf..e8f710e 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ FEEDSTOCK = spleeter-feedstock FEEDSTOCK_REPOSITORY = https://github.com/deezer/$(FEEDSTOCK) FEEDSTOCK_RECIPE = $(FEEDSTOCK)/recipe/spleeter/meta.yaml -PYTEST_CMD = pytest -W ignore::FutureWarning -W ignore::DeprecationWarning -vv +PYTEST_CMD = pytest -W ignore::FutureWarning -W ignore::DeprecationWarning -vv --forked all: clean build test deploy @@ -27,7 +27,7 @@ build-gpu: clean python3 setup.py sdist test: - $(foreach file, $(wildcard tests/test_*.py), $(PYTEST_CMD) $(file);) + $(PYTEST_CMD) deploy: diff --git a/tests/test_separator.py b/tests/test_separator.py index 245571d..c688f0f 100644 --- a/tests/test_separator.py +++ b/tests/test_separator.py @@ -42,67 +42,67 @@ print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__)) @pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) def test_separate(test_file, configuration, backend): """ Test separation from raw data. """ - tf.reset_default_graph() - instruments = MODEL_TO_INST[configuration] - adapter = get_default_audio_adapter() - waveform, _ = adapter.load(test_file) - separator = Separator(configuration, stft_backend=backend) - prediction = separator.separate(waveform, test_file) - assert len(prediction) == len(instruments) - for instrument in instruments: - assert instrument in prediction - for instrument in instruments: - track = prediction[instrument] - assert waveform.shape[:-1] == track.shape[:-1] - assert not np.allclose(waveform, track) - for compared in instruments: - if instrument != compared: - assert not np.allclose(track, prediction[compared]) + with tf.Session() as sess: + instruments = MODEL_TO_INST[configuration] + adapter = get_default_audio_adapter() + waveform, _ = adapter.load(test_file) + separator = Separator(configuration, stft_backend=backend) + prediction = separator.separate(waveform, test_file) + assert len(prediction) == len(instruments) + for instrument in instruments: + assert instrument in prediction + for instrument in instruments: + track = prediction[instrument] + assert waveform.shape[:-1] == track.shape[:-1] + assert not np.allclose(waveform, track) + for compared in instruments: + if instrument != compared: + assert not np.allclose(track, prediction[compared]) @pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) def test_separate_to_file(test_file, configuration, backend): """ Test file based separation. """ - tf.reset_default_graph() - instruments = MODEL_TO_INST[configuration] - separator = Separator(configuration, stft_backend=backend) - name = splitext(basename(test_file))[0] - with TemporaryDirectory() as directory: - separator.separate_to_file( - test_file, - directory) - for instrument in instruments: - assert exists(join( - directory, - '{}/{}.wav'.format(name, instrument))) + with tf.Session() as sess: + instruments = MODEL_TO_INST[configuration] + separator = Separator(configuration, stft_backend=backend) + name = splitext(basename(test_file))[0] + with TemporaryDirectory() as directory: + separator.separate_to_file( + test_file, + directory) + for instrument in instruments: + assert exists(join( + directory, + '{}/{}.wav'.format(name, instrument))) @pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) def test_filename_format(test_file, configuration, backend): """ Test custom filename format. """ - tf.reset_default_graph() - instruments = MODEL_TO_INST[configuration] - separator = Separator(configuration, stft_backend=backend) - name = splitext(basename(test_file))[0] - with TemporaryDirectory() as directory: - separator.separate_to_file( - test_file, - directory, - filename_format='export/{filename}/{instrument}.{codec}') - for instrument in instruments: - assert exists(join( + with tf.Session() as sess: + instruments = MODEL_TO_INST[configuration] + separator = Separator(configuration, stft_backend=backend) + name = splitext(basename(test_file))[0] + with TemporaryDirectory() as directory: + separator.separate_to_file( + test_file, directory, - 'export/{}/{}.wav'.format(name, instrument))) + filename_format='export/{filename}/{instrument}.{codec}') + for instrument in instruments: + assert exists(join( + directory, + 'export/{}/{}.wav'.format(name, instrument))) @pytest.mark.parametrize('test_file, configuration', MODELS_AND_TEST_FILES) def test_filename_conflict(test_file, configuration): """ Test error handling with static pattern. """ - tf.reset_default_graph() - separator = Separator(configuration) - with TemporaryDirectory() as directory: - with pytest.raises(SpleeterError): - separator.separate_to_file( - test_file, - directory, - filename_format='I wanna be your lover') + with tf.Session() as sess: + separator = Separator(configuration) + with TemporaryDirectory() as directory: + with pytest.raises(SpleeterError): + separator.separate_to_file( + test_file, + directory, + filename_format='I wanna be your lover')