diff --git a/spleeter/separator.py b/spleeter/separator.py index 9b77111..3bc6289 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -127,7 +127,7 @@ class Separator(object): out = [] for c in range(n_channels): d = data[:, :, c].T if inverse else data[:, c] - s = fstft(dl, hop_length=H, window=win, center=False, **win_len_arg) + s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg) s = np.expand_dims(s.T, 2-inverse) out.append(s) if len(out) == 1: @@ -144,9 +144,11 @@ class Separator(object): # TODO: fix the logic, build sometimes return, sometimes set attribute outputs = builder.outputs + stft = self.stft(waveform) + if stft.shape[-1] != 2: + stft = np.concatenate([stft, stft], axis=-1) saver = tf.train.Saver() - stft = self.stft(waveform) with tf.Session() as sess: saver.restore(sess, latest_checkpoint) outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id)) diff --git a/tests/test_separator.py b/tests/test_separator.py index 2c244ad..245571d 100644 --- a/tests/test_separator.py +++ b/tests/test_separator.py @@ -15,6 +15,8 @@ from tempfile import TemporaryDirectory import pytest import numpy as np +import tensorflow as tf + from spleeter import SpleeterError from spleeter.audio.adapter import get_default_audio_adapter from spleeter.separator import Separator @@ -22,6 +24,7 @@ from spleeter.separator import Separator TEST_AUDIO_DESCRIPTORS = ['audio_example.mp3', 'audio_example_mono.mp3'] BACKENDS = ["tensorflow", "librosa"] MODELS = ['spleeter:2stems', 'spleeter:4stems', 'spleeter:5stems'] + MODEL_TO_INST = { 'spleeter:2stems': ('vocals', 'accompaniment'), 'spleeter:4stems': ('vocals', 'drums', 'bass', 'other'), @@ -29,12 +32,17 @@ MODEL_TO_INST = { } +MODELS_AND_TEST_FILES = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS)) TEST_CONFIGURATIONS = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS, BACKENDS)) +print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__)) + + @pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) def test_separate(test_file, configuration, backend): """ Test separation from raw data. """ + tf.reset_default_graph() instruments = MODEL_TO_INST[configuration] adapter = get_default_audio_adapter() waveform, _ = adapter.load(test_file) @@ -45,7 +53,7 @@ def test_separate(test_file, configuration, backend): assert instrument in prediction for instrument in instruments: track = prediction[instrument] - assert waveform.shape == track.shape + assert waveform.shape[:-1] == track.shape[:-1] assert not np.allclose(waveform, track) for compared in instruments: if instrument != compared: @@ -55,9 +63,10 @@ def test_separate(test_file, configuration, backend): @pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) def test_separate_to_file(test_file, configuration, backend): """ Test file based separation. """ + tf.reset_default_graph() instruments = MODEL_TO_INST[configuration] separator = Separator(configuration, stft_backend=backend) - basename = splitext(basename(test_file)) + name = splitext(basename(test_file))[0] with TemporaryDirectory() as directory: separator.separate_to_file( test_file, @@ -65,15 +74,16 @@ def test_separate_to_file(test_file, configuration, backend): for instrument in instruments: assert exists(join( directory, - '{}/{}.wav'.format(basename, instrument))) + '{}/{}.wav'.format(name, instrument))) @pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) def test_filename_format(test_file, configuration, backend): """ Test custom filename format. """ + tf.reset_default_graph() instruments = MODEL_TO_INST[configuration] separator = Separator(configuration, stft_backend=backend) - basename = splitext(basename(test_file)) + name = splitext(basename(test_file))[0] with TemporaryDirectory() as directory: separator.separate_to_file( test_file, @@ -82,13 +92,14 @@ def test_filename_format(test_file, configuration, backend): for instrument in instruments: assert exists(join( directory, - 'export/{}/{}.wav'.format(basename, instrument))) + 'export/{}/{}.wav'.format(name, instrument))) -@pytest.mark.parametrize('test_file', TEST_AUDIO_DESCRIPTORS) -def test_filename_conflict(test_file): +@pytest.mark.parametrize('test_file, configuration', MODELS_AND_TEST_FILES) +def test_filename_conflict(test_file, configuration): """ Test error handling with static pattern. """ - separator = Separator(TEST_CONFIGURATIONS[0][0]) + tf.reset_default_graph() + separator = Separator(configuration) with TemporaryDirectory() as directory: with pytest.raises(SpleeterError): separator.separate_to_file(