Fixing tests when running in single process

This commit is contained in:
akhlif
2020-03-27 11:12:05 +01:00
parent 4e16e0fe23
commit 38bfff833e
2 changed files with 23 additions and 10 deletions

View File

@@ -127,7 +127,7 @@ class Separator(object):
out = [] out = []
for c in range(n_channels): for c in range(n_channels):
d = data[:, :, c].T if inverse else data[:, c] d = data[:, :, c].T if inverse else data[:, c]
s = fstft(dl, hop_length=H, window=win, center=False, **win_len_arg) s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg)
s = np.expand_dims(s.T, 2-inverse) s = np.expand_dims(s.T, 2-inverse)
out.append(s) out.append(s)
if len(out) == 1: if len(out) == 1:
@@ -144,9 +144,11 @@ class Separator(object):
# TODO: fix the logic, build sometimes return, sometimes set attribute # TODO: fix the logic, build sometimes return, sometimes set attribute
outputs = builder.outputs outputs = builder.outputs
stft = self.stft(waveform)
if stft.shape[-1] != 2:
stft = np.concatenate([stft, stft], axis=-1)
saver = tf.train.Saver() saver = tf.train.Saver()
stft = self.stft(waveform)
with tf.Session() as sess: with tf.Session() as sess:
saver.restore(sess, latest_checkpoint) saver.restore(sess, latest_checkpoint)
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id)) outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))

View File

@@ -15,6 +15,8 @@ from tempfile import TemporaryDirectory
import pytest import pytest
import numpy as np import numpy as np
import tensorflow as tf
from spleeter import SpleeterError from spleeter import SpleeterError
from spleeter.audio.adapter import get_default_audio_adapter from spleeter.audio.adapter import get_default_audio_adapter
from spleeter.separator import Separator from spleeter.separator import Separator
@@ -22,6 +24,7 @@ from spleeter.separator import Separator
TEST_AUDIO_DESCRIPTORS = ['audio_example.mp3', 'audio_example_mono.mp3'] TEST_AUDIO_DESCRIPTORS = ['audio_example.mp3', 'audio_example_mono.mp3']
BACKENDS = ["tensorflow", "librosa"] BACKENDS = ["tensorflow", "librosa"]
MODELS = ['spleeter:2stems', 'spleeter:4stems', 'spleeter:5stems'] MODELS = ['spleeter:2stems', 'spleeter:4stems', 'spleeter:5stems']
MODEL_TO_INST = { MODEL_TO_INST = {
'spleeter:2stems': ('vocals', 'accompaniment'), 'spleeter:2stems': ('vocals', 'accompaniment'),
'spleeter:4stems': ('vocals', 'drums', 'bass', 'other'), 'spleeter:4stems': ('vocals', 'drums', 'bass', 'other'),
@@ -29,12 +32,17 @@ MODEL_TO_INST = {
} }
MODELS_AND_TEST_FILES = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS))
TEST_CONFIGURATIONS = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS, BACKENDS)) TEST_CONFIGURATIONS = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS, BACKENDS))
print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__))
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) @pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
def test_separate(test_file, configuration, backend): def test_separate(test_file, configuration, backend):
""" Test separation from raw data. """ """ Test separation from raw data. """
tf.reset_default_graph()
instruments = MODEL_TO_INST[configuration] instruments = MODEL_TO_INST[configuration]
adapter = get_default_audio_adapter() adapter = get_default_audio_adapter()
waveform, _ = adapter.load(test_file) waveform, _ = adapter.load(test_file)
@@ -45,7 +53,7 @@ def test_separate(test_file, configuration, backend):
assert instrument in prediction assert instrument in prediction
for instrument in instruments: for instrument in instruments:
track = prediction[instrument] track = prediction[instrument]
assert waveform.shape == track.shape assert waveform.shape[:-1] == track.shape[:-1]
assert not np.allclose(waveform, track) assert not np.allclose(waveform, track)
for compared in instruments: for compared in instruments:
if instrument != compared: if instrument != compared:
@@ -55,9 +63,10 @@ def test_separate(test_file, configuration, backend):
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) @pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
def test_separate_to_file(test_file, configuration, backend): def test_separate_to_file(test_file, configuration, backend):
""" Test file based separation. """ """ Test file based separation. """
tf.reset_default_graph()
instruments = MODEL_TO_INST[configuration] instruments = MODEL_TO_INST[configuration]
separator = Separator(configuration, stft_backend=backend) separator = Separator(configuration, stft_backend=backend)
basename = splitext(basename(test_file)) name = splitext(basename(test_file))[0]
with TemporaryDirectory() as directory: with TemporaryDirectory() as directory:
separator.separate_to_file( separator.separate_to_file(
test_file, test_file,
@@ -65,15 +74,16 @@ def test_separate_to_file(test_file, configuration, backend):
for instrument in instruments: for instrument in instruments:
assert exists(join( assert exists(join(
directory, directory,
'{}/{}.wav'.format(basename, instrument))) '{}/{}.wav'.format(name, instrument)))
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) @pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
def test_filename_format(test_file, configuration, backend): def test_filename_format(test_file, configuration, backend):
""" Test custom filename format. """ """ Test custom filename format. """
tf.reset_default_graph()
instruments = MODEL_TO_INST[configuration] instruments = MODEL_TO_INST[configuration]
separator = Separator(configuration, stft_backend=backend) separator = Separator(configuration, stft_backend=backend)
basename = splitext(basename(test_file)) name = splitext(basename(test_file))[0]
with TemporaryDirectory() as directory: with TemporaryDirectory() as directory:
separator.separate_to_file( separator.separate_to_file(
test_file, test_file,
@@ -82,13 +92,14 @@ def test_filename_format(test_file, configuration, backend):
for instrument in instruments: for instrument in instruments:
assert exists(join( assert exists(join(
directory, directory,
'export/{}/{}.wav'.format(basename, instrument))) 'export/{}/{}.wav'.format(name, instrument)))
@pytest.mark.parametrize('test_file', TEST_AUDIO_DESCRIPTORS) @pytest.mark.parametrize('test_file, configuration', MODELS_AND_TEST_FILES)
def test_filename_conflict(test_file): def test_filename_conflict(test_file, configuration):
""" Test error handling with static pattern. """ """ Test error handling with static pattern. """
separator = Separator(TEST_CONFIGURATIONS[0][0]) tf.reset_default_graph()
separator = Separator(configuration)
with TemporaryDirectory() as directory: with TemporaryDirectory() as directory:
with pytest.raises(SpleeterError): with pytest.raises(SpleeterError):
separator.separate_to_file( separator.separate_to_file(