diff --git a/Makefile b/Makefile index e8f710e..8661db9 100644 --- a/Makefile +++ b/Makefile @@ -27,7 +27,7 @@ build-gpu: clean python3 setup.py sdist test: - $(PYTEST_CMD) + $(foreach file, $(wildcard tests/test_*.py), $(PYTEST_CMD) $(file);) deploy: diff --git a/tests/test_separator.py b/tests/test_separator.py index c688f0f..2d2686a 100644 --- a/tests/test_separator.py +++ b/tests/test_separator.py @@ -42,67 +42,68 @@ print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__)) @pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) def test_separate(test_file, configuration, backend): """ Test separation from raw data. """ - with tf.Session() as sess: - instruments = MODEL_TO_INST[configuration] - adapter = get_default_audio_adapter() - waveform, _ = adapter.load(test_file) - separator = Separator(configuration, stft_backend=backend) - prediction = separator.separate(waveform, test_file) - assert len(prediction) == len(instruments) - for instrument in instruments: - assert instrument in prediction - for instrument in instruments: - track = prediction[instrument] - assert waveform.shape[:-1] == track.shape[:-1] - assert not np.allclose(waveform, track) - for compared in instruments: - if instrument != compared: - assert not np.allclose(track, prediction[compared]) + tf.reset_default_graph() + instruments = MODEL_TO_INST[configuration] + adapter = get_default_audio_adapter() + waveform, _ = adapter.load(test_file) + separator = Separator(configuration, stft_backend=backend, multiprocess=False) + prediction = separator.separate(waveform, test_file) + assert len(prediction) == len(instruments) + for instrument in instruments: + assert instrument in prediction + for instrument in instruments: + track = prediction[instrument] + assert waveform.shape[:-1] == track.shape[:-1] + assert not np.allclose(waveform, track) + for compared in instruments: + if instrument != compared: + assert not np.allclose(track, prediction[compared]) + @pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) def test_separate_to_file(test_file, configuration, backend): """ Test file based separation. """ - with tf.Session() as sess: - instruments = MODEL_TO_INST[configuration] - separator = Separator(configuration, stft_backend=backend) - name = splitext(basename(test_file))[0] - with TemporaryDirectory() as directory: - separator.separate_to_file( - test_file, - directory) - for instrument in instruments: - assert exists(join( - directory, - '{}/{}.wav'.format(name, instrument))) + tf.reset_default_graph() + instruments = MODEL_TO_INST[configuration] + separator = Separator(configuration, stft_backend=backend, multiprocess=False) + name = splitext(basename(test_file))[0] + with TemporaryDirectory() as directory: + separator.separate_to_file( + test_file, + directory) + for instrument in instruments: + assert exists(join( + directory, + '{}/{}.wav'.format(name, instrument))) @pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) def test_filename_format(test_file, configuration, backend): """ Test custom filename format. """ - with tf.Session() as sess: - instruments = MODEL_TO_INST[configuration] - separator = Separator(configuration, stft_backend=backend) - name = splitext(basename(test_file))[0] - with TemporaryDirectory() as directory: - separator.separate_to_file( - test_file, + tf.reset_default_graph() + instruments = MODEL_TO_INST[configuration] + separator = Separator(configuration, stft_backend=backend, multiprocess=False) + name = splitext(basename(test_file))[0] + with TemporaryDirectory() as directory: + separator.separate_to_file( + test_file, + directory, + filename_format='export/{filename}/{instrument}.{codec}') + for instrument in instruments: + assert exists(join( directory, - filename_format='export/{filename}/{instrument}.{codec}') - for instrument in instruments: - assert exists(join( - directory, - 'export/{}/{}.wav'.format(name, instrument))) + 'export/{}/{}.wav'.format(name, instrument))) @pytest.mark.parametrize('test_file, configuration', MODELS_AND_TEST_FILES) def test_filename_conflict(test_file, configuration): """ Test error handling with static pattern. """ - with tf.Session() as sess: - separator = Separator(configuration) - with TemporaryDirectory() as directory: - with pytest.raises(SpleeterError): - separator.separate_to_file( - test_file, - directory, - filename_format='I wanna be your lover') + tf.reset_default_graph() + separator = Separator(configuration, multiprocess=False) + with TemporaryDirectory() as directory: + with pytest.raises(SpleeterError): + separator.separate_to_file( + test_file, + directory, + filename_format='I wanna be your lover')