Updating tests to test for librosa backend

This commit is contained in:
akhlif
2020-02-27 15:38:46 +01:00
parent 922fcd85bb
commit 3cba6985f4
3 changed files with 36 additions and 26 deletions

View File

@@ -105,7 +105,7 @@ class InputProviderFactory(object):
@staticmethod
def get(params):
stft_backend = params["stft_backend"]
assert stft_backend in ("tensorflow", "librosa")
assert stft_backend in ("tensorflow", "librosa"), "Unexpected backend {}".format(stft_backend)
if stft_backend == "tensorflow":
return WaveformInputProvider(params)
else:

View File

@@ -85,7 +85,7 @@ class Separator(object):
task.get()
task.wait(timeout=timeout)
def separate(self, waveform):
def separate_tensorflow(self, waveform, audio_descriptor):
""" Performs source separation over the given waveform.
The separation is performed synchronously but the result
@@ -103,11 +103,11 @@ class Separator(object):
predictor = self._get_predictor()
prediction = predictor({
'waveform': waveform,
'audio_id': ''})
'audio_id': audio_descriptor})
prediction.pop('audio_id')
return prediction
def stft(self, data, inverse=False):
def stft(self, data, inverse=False, length=None):
"""
Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two
channels are processed separately and are concatenated together in the result. The expected input formats are:
@@ -116,11 +116,13 @@ class Separator(object):
:param inverse: should a stft or an istft be computed.
:return: Stereo data as numpy array for the transform. The channels are stored in the last dimension
"""
assert not (inverse and length is None)
data = np.asfortranarray(data)
N = self._params["frame_length"]
H = self._params["frame_step"]
win = hann(N, sym=False)
fstft = istft if inverse else stft
win_len_arg = {"win_length": None} if inverse else {"n_fft": N}
win_len_arg = {"win_length": None, "length": length} if inverse else {"n_fft": N}
dl, dr = (data[:, :, 0].T, data[:, :, 1].T) if inverse else (data[:, 0], data[:, 1])
s1 = fstft(dl, hop_length=H, window=win, center=False, **win_len_arg)
s2 = fstft(dr, hop_length=H, window=win, center=False, **win_len_arg)
@@ -145,9 +147,15 @@ class Separator(object):
saver.restore(sess, latest_checkpoint)
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))
for inst in builder.instruments:
out[inst] = self.stft(outputs[inst], inverse=True)[:waveform.shape[0], :]
out[inst] = self.stft(outputs[inst], inverse=True, length=waveform.shape[0])
return out
def separate(self, waveform, audio_descriptor):
if self._params["stft_backend"] == "tensorflow":
return self.separate_tensorflow(waveform, audio_descriptor)
else:
return self.separate_librosa(waveform, audio_descriptor)
def separate_to_file(
self, audio_descriptor, destination,
audio_adapter=get_default_audio_adapter(),
@@ -180,10 +188,7 @@ class Separator(object):
offset=offset,
duration=duration,
sample_rate=self._sample_rate)
if self._params["stft_backend"] == "tensorflow":
sources = self.separate(waveform)
else:
sources = self.separate_librosa(waveform, audio_descriptor)
sources = self.separate(waveform, audio_descriptor)
self.save_to_file(sources, audio_descriptor, destination, filename_format, codec,
audio_adapter, bitrate, synchronous)

View File

@@ -13,6 +13,7 @@ from os.path import splitext, basename, exists, join
from tempfile import TemporaryDirectory
import pytest
import numpy as np
from spleeter import SpleeterError
from spleeter.audio.adapter import get_default_audio_adapter
@@ -21,34 +22,38 @@ from spleeter.separator import Separator
TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3'
TEST_AUDIO_BASENAME = splitext(basename(TEST_AUDIO_DESCRIPTOR))[0]
TEST_CONFIGURATIONS = [
('spleeter:2stems', ('vocals', 'accompaniment')),
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other')),
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'))
('spleeter:2stems', ('vocals', 'accompaniment'), 'tensorflow'),
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'tensorflow'),
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'tensorflow'),
('spleeter:2stems', ('vocals', 'accompaniment'), 'librosa'),
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'librosa'),
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'librosa')
]
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS)
def test_separate(configuration, instruments):
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
def test_separate(configuration, instruments, backend):
""" Test separation from raw data. """
adapter = get_default_audio_adapter()
waveform, _ = adapter.load(TEST_AUDIO_DESCRIPTOR)
separator = Separator(configuration)
prediction = separator.separate(waveform)
separator = Separator(configuration, stft_backend=backend)
prediction = separator.separate(waveform, TEST_AUDIO_DESCRIPTOR)
assert len(prediction) == len(instruments)
for instrument in instruments:
assert instrument in prediction
for instrument in instruments:
track = prediction[instrument]
assert not (waveform == track).all()
assert waveform.shape == track.shape
assert not np.allclose(waveform, track)
for compared in instruments:
if instrument != compared:
assert not (track == prediction[compared]).all()
assert not np.allclose(track, prediction[compared])
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS)
def test_separate_to_file(configuration, instruments):
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
def test_separate_to_file(configuration, instruments, backend):
""" Test file based separation. """
separator = Separator(configuration)
separator = Separator(configuration, stft_backend=backend)
with TemporaryDirectory() as directory:
separator.separate_to_file(
TEST_AUDIO_DESCRIPTOR,
@@ -59,10 +64,10 @@ def test_separate_to_file(configuration, instruments):
'{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument)))
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS)
def test_filename_format(configuration, instruments):
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
def test_filename_format(configuration, instruments, backend):
""" Test custom filename format. """
separator = Separator(configuration)
separator = Separator(configuration, stft_backend=backend)
with TemporaryDirectory() as directory:
separator.separate_to_file(
TEST_AUDIO_DESCRIPTOR,
@@ -74,7 +79,7 @@ def test_filename_format(configuration, instruments):
'export/{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument)))
def test_filename_confilct():
def test_filename_conflict():
""" Test error handling with static pattern. """
separator = Separator(TEST_CONFIGURATIONS[0][0])
with TemporaryDirectory() as directory: