From 3cba6985f410750f9704303668e4093dde29f7a3 Mon Sep 17 00:00:00 2001 From: akhlif Date: Thu, 27 Feb 2020 15:38:46 +0100 Subject: [PATCH] Updating tests to test for librosa backend --- spleeter/model/__init__.py | 2 +- spleeter/separator.py | 23 ++++++++++++++--------- tests/test_separator.py | 37 +++++++++++++++++++++---------------- 3 files changed, 36 insertions(+), 26 deletions(-) diff --git a/spleeter/model/__init__.py b/spleeter/model/__init__.py index 6c2fe3a..531f84c 100644 --- a/spleeter/model/__init__.py +++ b/spleeter/model/__init__.py @@ -105,7 +105,7 @@ class InputProviderFactory(object): @staticmethod def get(params): stft_backend = params["stft_backend"] - assert stft_backend in ("tensorflow", "librosa") + assert stft_backend in ("tensorflow", "librosa"), "Unexpected backend {}".format(stft_backend) if stft_backend == "tensorflow": return WaveformInputProvider(params) else: diff --git a/spleeter/separator.py b/spleeter/separator.py index 668b59d..174c73f 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -85,7 +85,7 @@ class Separator(object): task.get() task.wait(timeout=timeout) - def separate(self, waveform): + def separate_tensorflow(self, waveform, audio_descriptor): """ Performs source separation over the given waveform. The separation is performed synchronously but the result @@ -103,11 +103,11 @@ class Separator(object): predictor = self._get_predictor() prediction = predictor({ 'waveform': waveform, - 'audio_id': ''}) + 'audio_id': audio_descriptor}) prediction.pop('audio_id') return prediction - def stft(self, data, inverse=False): + def stft(self, data, inverse=False, length=None): """ Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two channels are processed separately and are concatenated together in the result. The expected input formats are: @@ -116,11 +116,13 @@ class Separator(object): :param inverse: should a stft or an istft be computed. :return: Stereo data as numpy array for the transform. The channels are stored in the last dimension """ + assert not (inverse and length is None) + data = np.asfortranarray(data) N = self._params["frame_length"] H = self._params["frame_step"] win = hann(N, sym=False) fstft = istft if inverse else stft - win_len_arg = {"win_length": None} if inverse else {"n_fft": N} + win_len_arg = {"win_length": None, "length": length} if inverse else {"n_fft": N} dl, dr = (data[:, :, 0].T, data[:, :, 1].T) if inverse else (data[:, 0], data[:, 1]) s1 = fstft(dl, hop_length=H, window=win, center=False, **win_len_arg) s2 = fstft(dr, hop_length=H, window=win, center=False, **win_len_arg) @@ -145,9 +147,15 @@ class Separator(object): saver.restore(sess, latest_checkpoint) outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id)) for inst in builder.instruments: - out[inst] = self.stft(outputs[inst], inverse=True)[:waveform.shape[0], :] + out[inst] = self.stft(outputs[inst], inverse=True, length=waveform.shape[0]) return out + def separate(self, waveform, audio_descriptor): + if self._params["stft_backend"] == "tensorflow": + return self.separate_tensorflow(waveform, audio_descriptor) + else: + return self.separate_librosa(waveform, audio_descriptor) + def separate_to_file( self, audio_descriptor, destination, audio_adapter=get_default_audio_adapter(), @@ -180,10 +188,7 @@ class Separator(object): offset=offset, duration=duration, sample_rate=self._sample_rate) - if self._params["stft_backend"] == "tensorflow": - sources = self.separate(waveform) - else: - sources = self.separate_librosa(waveform, audio_descriptor) + sources = self.separate(waveform, audio_descriptor) self.save_to_file(sources, audio_descriptor, destination, filename_format, codec, audio_adapter, bitrate, synchronous) diff --git a/tests/test_separator.py b/tests/test_separator.py index 9235731..271fdfb 100644 --- a/tests/test_separator.py +++ b/tests/test_separator.py @@ -13,6 +13,7 @@ from os.path import splitext, basename, exists, join from tempfile import TemporaryDirectory import pytest +import numpy as np from spleeter import SpleeterError from spleeter.audio.adapter import get_default_audio_adapter @@ -21,34 +22,38 @@ from spleeter.separator import Separator TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3' TEST_AUDIO_BASENAME = splitext(basename(TEST_AUDIO_DESCRIPTOR))[0] TEST_CONFIGURATIONS = [ - ('spleeter:2stems', ('vocals', 'accompaniment')), - ('spleeter:4stems', ('vocals', 'drums', 'bass', 'other')), - ('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other')) + ('spleeter:2stems', ('vocals', 'accompaniment'), 'tensorflow'), + ('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'tensorflow'), + ('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'tensorflow'), + ('spleeter:2stems', ('vocals', 'accompaniment'), 'librosa'), + ('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'librosa'), + ('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'librosa') ] -@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS) -def test_separate(configuration, instruments): +@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS) +def test_separate(configuration, instruments, backend): """ Test separation from raw data. """ adapter = get_default_audio_adapter() waveform, _ = adapter.load(TEST_AUDIO_DESCRIPTOR) - separator = Separator(configuration) - prediction = separator.separate(waveform) + separator = Separator(configuration, stft_backend=backend) + prediction = separator.separate(waveform, TEST_AUDIO_DESCRIPTOR) assert len(prediction) == len(instruments) for instrument in instruments: assert instrument in prediction for instrument in instruments: track = prediction[instrument] - assert not (waveform == track).all() + assert waveform.shape == track.shape + assert not np.allclose(waveform, track) for compared in instruments: if instrument != compared: - assert not (track == prediction[compared]).all() + assert not np.allclose(track, prediction[compared]) -@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS) -def test_separate_to_file(configuration, instruments): +@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS) +def test_separate_to_file(configuration, instruments, backend): """ Test file based separation. """ - separator = Separator(configuration) + separator = Separator(configuration, stft_backend=backend) with TemporaryDirectory() as directory: separator.separate_to_file( TEST_AUDIO_DESCRIPTOR, @@ -59,10 +64,10 @@ def test_separate_to_file(configuration, instruments): '{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument))) -@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS) -def test_filename_format(configuration, instruments): +@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS) +def test_filename_format(configuration, instruments, backend): """ Test custom filename format. """ - separator = Separator(configuration) + separator = Separator(configuration, stft_backend=backend) with TemporaryDirectory() as directory: separator.separate_to_file( TEST_AUDIO_DESCRIPTOR, @@ -74,7 +79,7 @@ def test_filename_format(configuration, instruments): 'export/{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument))) -def test_filename_confilct(): +def test_filename_conflict(): """ Test error handling with static pattern. """ separator = Separator(TEST_CONFIGURATIONS[0][0]) with TemporaryDirectory() as directory: