Updating tests to test for librosa backend

This commit is contained in:
akhlif
2020-02-27 15:38:46 +01:00
parent 922fcd85bb
commit 3cba6985f4
3 changed files with 36 additions and 26 deletions

View File

@@ -105,7 +105,7 @@ class InputProviderFactory(object):
@staticmethod @staticmethod
def get(params): def get(params):
stft_backend = params["stft_backend"] stft_backend = params["stft_backend"]
assert stft_backend in ("tensorflow", "librosa") assert stft_backend in ("tensorflow", "librosa"), "Unexpected backend {}".format(stft_backend)
if stft_backend == "tensorflow": if stft_backend == "tensorflow":
return WaveformInputProvider(params) return WaveformInputProvider(params)
else: else:

View File

@@ -85,7 +85,7 @@ class Separator(object):
task.get() task.get()
task.wait(timeout=timeout) task.wait(timeout=timeout)
def separate(self, waveform): def separate_tensorflow(self, waveform, audio_descriptor):
""" Performs source separation over the given waveform. """ Performs source separation over the given waveform.
The separation is performed synchronously but the result The separation is performed synchronously but the result
@@ -103,11 +103,11 @@ class Separator(object):
predictor = self._get_predictor() predictor = self._get_predictor()
prediction = predictor({ prediction = predictor({
'waveform': waveform, 'waveform': waveform,
'audio_id': ''}) 'audio_id': audio_descriptor})
prediction.pop('audio_id') prediction.pop('audio_id')
return prediction return prediction
def stft(self, data, inverse=False): def stft(self, data, inverse=False, length=None):
""" """
Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two
channels are processed separately and are concatenated together in the result. The expected input formats are: channels are processed separately and are concatenated together in the result. The expected input formats are:
@@ -116,11 +116,13 @@ class Separator(object):
:param inverse: should a stft or an istft be computed. :param inverse: should a stft or an istft be computed.
:return: Stereo data as numpy array for the transform. The channels are stored in the last dimension :return: Stereo data as numpy array for the transform. The channels are stored in the last dimension
""" """
assert not (inverse and length is None)
data = np.asfortranarray(data)
N = self._params["frame_length"] N = self._params["frame_length"]
H = self._params["frame_step"] H = self._params["frame_step"]
win = hann(N, sym=False) win = hann(N, sym=False)
fstft = istft if inverse else stft fstft = istft if inverse else stft
win_len_arg = {"win_length": None} if inverse else {"n_fft": N} win_len_arg = {"win_length": None, "length": length} if inverse else {"n_fft": N}
dl, dr = (data[:, :, 0].T, data[:, :, 1].T) if inverse else (data[:, 0], data[:, 1]) dl, dr = (data[:, :, 0].T, data[:, :, 1].T) if inverse else (data[:, 0], data[:, 1])
s1 = fstft(dl, hop_length=H, window=win, center=False, **win_len_arg) s1 = fstft(dl, hop_length=H, window=win, center=False, **win_len_arg)
s2 = fstft(dr, hop_length=H, window=win, center=False, **win_len_arg) s2 = fstft(dr, hop_length=H, window=win, center=False, **win_len_arg)
@@ -145,9 +147,15 @@ class Separator(object):
saver.restore(sess, latest_checkpoint) saver.restore(sess, latest_checkpoint)
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id)) outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))
for inst in builder.instruments: for inst in builder.instruments:
out[inst] = self.stft(outputs[inst], inverse=True)[:waveform.shape[0], :] out[inst] = self.stft(outputs[inst], inverse=True, length=waveform.shape[0])
return out return out
def separate(self, waveform, audio_descriptor):
if self._params["stft_backend"] == "tensorflow":
return self.separate_tensorflow(waveform, audio_descriptor)
else:
return self.separate_librosa(waveform, audio_descriptor)
def separate_to_file( def separate_to_file(
self, audio_descriptor, destination, self, audio_descriptor, destination,
audio_adapter=get_default_audio_adapter(), audio_adapter=get_default_audio_adapter(),
@@ -180,10 +188,7 @@ class Separator(object):
offset=offset, offset=offset,
duration=duration, duration=duration,
sample_rate=self._sample_rate) sample_rate=self._sample_rate)
if self._params["stft_backend"] == "tensorflow": sources = self.separate(waveform, audio_descriptor)
sources = self.separate(waveform)
else:
sources = self.separate_librosa(waveform, audio_descriptor)
self.save_to_file(sources, audio_descriptor, destination, filename_format, codec, self.save_to_file(sources, audio_descriptor, destination, filename_format, codec,
audio_adapter, bitrate, synchronous) audio_adapter, bitrate, synchronous)

View File

@@ -13,6 +13,7 @@ from os.path import splitext, basename, exists, join
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import pytest import pytest
import numpy as np
from spleeter import SpleeterError from spleeter import SpleeterError
from spleeter.audio.adapter import get_default_audio_adapter from spleeter.audio.adapter import get_default_audio_adapter
@@ -21,34 +22,38 @@ from spleeter.separator import Separator
TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3' TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3'
TEST_AUDIO_BASENAME = splitext(basename(TEST_AUDIO_DESCRIPTOR))[0] TEST_AUDIO_BASENAME = splitext(basename(TEST_AUDIO_DESCRIPTOR))[0]
TEST_CONFIGURATIONS = [ TEST_CONFIGURATIONS = [
('spleeter:2stems', ('vocals', 'accompaniment')), ('spleeter:2stems', ('vocals', 'accompaniment'), 'tensorflow'),
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other')), ('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'tensorflow'),
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other')) ('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'tensorflow'),
('spleeter:2stems', ('vocals', 'accompaniment'), 'librosa'),
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'librosa'),
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'librosa')
] ]
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS) @pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
def test_separate(configuration, instruments): def test_separate(configuration, instruments, backend):
""" Test separation from raw data. """ """ Test separation from raw data. """
adapter = get_default_audio_adapter() adapter = get_default_audio_adapter()
waveform, _ = adapter.load(TEST_AUDIO_DESCRIPTOR) waveform, _ = adapter.load(TEST_AUDIO_DESCRIPTOR)
separator = Separator(configuration) separator = Separator(configuration, stft_backend=backend)
prediction = separator.separate(waveform) prediction = separator.separate(waveform, TEST_AUDIO_DESCRIPTOR)
assert len(prediction) == len(instruments) assert len(prediction) == len(instruments)
for instrument in instruments: for instrument in instruments:
assert instrument in prediction assert instrument in prediction
for instrument in instruments: for instrument in instruments:
track = prediction[instrument] track = prediction[instrument]
assert not (waveform == track).all() assert waveform.shape == track.shape
assert not np.allclose(waveform, track)
for compared in instruments: for compared in instruments:
if instrument != compared: if instrument != compared:
assert not (track == prediction[compared]).all() assert not np.allclose(track, prediction[compared])
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS) @pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
def test_separate_to_file(configuration, instruments): def test_separate_to_file(configuration, instruments, backend):
""" Test file based separation. """ """ Test file based separation. """
separator = Separator(configuration) separator = Separator(configuration, stft_backend=backend)
with TemporaryDirectory() as directory: with TemporaryDirectory() as directory:
separator.separate_to_file( separator.separate_to_file(
TEST_AUDIO_DESCRIPTOR, TEST_AUDIO_DESCRIPTOR,
@@ -59,10 +64,10 @@ def test_separate_to_file(configuration, instruments):
'{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument))) '{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument)))
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS) @pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
def test_filename_format(configuration, instruments): def test_filename_format(configuration, instruments, backend):
""" Test custom filename format. """ """ Test custom filename format. """
separator = Separator(configuration) separator = Separator(configuration, stft_backend=backend)
with TemporaryDirectory() as directory: with TemporaryDirectory() as directory:
separator.separate_to_file( separator.separate_to_file(
TEST_AUDIO_DESCRIPTOR, TEST_AUDIO_DESCRIPTOR,
@@ -74,7 +79,7 @@ def test_filename_format(configuration, instruments):
'export/{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument))) 'export/{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument)))
def test_filename_confilct(): def test_filename_conflict():
""" Test error handling with static pattern. """ """ Test error handling with static pattern. """
separator = Separator(TEST_CONFIGURATIONS[0][0]) separator = Separator(TEST_CONFIGURATIONS[0][0])
with TemporaryDirectory() as directory: with TemporaryDirectory() as directory: