mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-06 04:32:43 +00:00
Updating tests to test for librosa backend
This commit is contained in:
@@ -105,7 +105,7 @@ class InputProviderFactory(object):
|
||||
@staticmethod
|
||||
def get(params):
|
||||
stft_backend = params["stft_backend"]
|
||||
assert stft_backend in ("tensorflow", "librosa")
|
||||
assert stft_backend in ("tensorflow", "librosa"), "Unexpected backend {}".format(stft_backend)
|
||||
if stft_backend == "tensorflow":
|
||||
return WaveformInputProvider(params)
|
||||
else:
|
||||
|
||||
@@ -85,7 +85,7 @@ class Separator(object):
|
||||
task.get()
|
||||
task.wait(timeout=timeout)
|
||||
|
||||
def separate(self, waveform):
|
||||
def separate_tensorflow(self, waveform, audio_descriptor):
|
||||
""" Performs source separation over the given waveform.
|
||||
|
||||
The separation is performed synchronously but the result
|
||||
@@ -103,11 +103,11 @@ class Separator(object):
|
||||
predictor = self._get_predictor()
|
||||
prediction = predictor({
|
||||
'waveform': waveform,
|
||||
'audio_id': ''})
|
||||
'audio_id': audio_descriptor})
|
||||
prediction.pop('audio_id')
|
||||
return prediction
|
||||
|
||||
def stft(self, data, inverse=False):
|
||||
def stft(self, data, inverse=False, length=None):
|
||||
"""
|
||||
Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two
|
||||
channels are processed separately and are concatenated together in the result. The expected input formats are:
|
||||
@@ -116,11 +116,13 @@ class Separator(object):
|
||||
:param inverse: should a stft or an istft be computed.
|
||||
:return: Stereo data as numpy array for the transform. The channels are stored in the last dimension
|
||||
"""
|
||||
assert not (inverse and length is None)
|
||||
data = np.asfortranarray(data)
|
||||
N = self._params["frame_length"]
|
||||
H = self._params["frame_step"]
|
||||
win = hann(N, sym=False)
|
||||
fstft = istft if inverse else stft
|
||||
win_len_arg = {"win_length": None} if inverse else {"n_fft": N}
|
||||
win_len_arg = {"win_length": None, "length": length} if inverse else {"n_fft": N}
|
||||
dl, dr = (data[:, :, 0].T, data[:, :, 1].T) if inverse else (data[:, 0], data[:, 1])
|
||||
s1 = fstft(dl, hop_length=H, window=win, center=False, **win_len_arg)
|
||||
s2 = fstft(dr, hop_length=H, window=win, center=False, **win_len_arg)
|
||||
@@ -145,9 +147,15 @@ class Separator(object):
|
||||
saver.restore(sess, latest_checkpoint)
|
||||
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))
|
||||
for inst in builder.instruments:
|
||||
out[inst] = self.stft(outputs[inst], inverse=True)[:waveform.shape[0], :]
|
||||
out[inst] = self.stft(outputs[inst], inverse=True, length=waveform.shape[0])
|
||||
return out
|
||||
|
||||
def separate(self, waveform, audio_descriptor):
|
||||
if self._params["stft_backend"] == "tensorflow":
|
||||
return self.separate_tensorflow(waveform, audio_descriptor)
|
||||
else:
|
||||
return self.separate_librosa(waveform, audio_descriptor)
|
||||
|
||||
def separate_to_file(
|
||||
self, audio_descriptor, destination,
|
||||
audio_adapter=get_default_audio_adapter(),
|
||||
@@ -180,10 +188,7 @@ class Separator(object):
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
sample_rate=self._sample_rate)
|
||||
if self._params["stft_backend"] == "tensorflow":
|
||||
sources = self.separate(waveform)
|
||||
else:
|
||||
sources = self.separate_librosa(waveform, audio_descriptor)
|
||||
sources = self.separate(waveform, audio_descriptor)
|
||||
self.save_to_file(sources, audio_descriptor, destination, filename_format, codec,
|
||||
audio_adapter, bitrate, synchronous)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from os.path import splitext, basename, exists, join
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from spleeter import SpleeterError
|
||||
from spleeter.audio.adapter import get_default_audio_adapter
|
||||
@@ -21,34 +22,38 @@ from spleeter.separator import Separator
|
||||
TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3'
|
||||
TEST_AUDIO_BASENAME = splitext(basename(TEST_AUDIO_DESCRIPTOR))[0]
|
||||
TEST_CONFIGURATIONS = [
|
||||
('spleeter:2stems', ('vocals', 'accompaniment')),
|
||||
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other')),
|
||||
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'))
|
||||
('spleeter:2stems', ('vocals', 'accompaniment'), 'tensorflow'),
|
||||
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'tensorflow'),
|
||||
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'tensorflow'),
|
||||
('spleeter:2stems', ('vocals', 'accompaniment'), 'librosa'),
|
||||
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'librosa'),
|
||||
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'librosa')
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS)
|
||||
def test_separate(configuration, instruments):
|
||||
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
|
||||
def test_separate(configuration, instruments, backend):
|
||||
""" Test separation from raw data. """
|
||||
adapter = get_default_audio_adapter()
|
||||
waveform, _ = adapter.load(TEST_AUDIO_DESCRIPTOR)
|
||||
separator = Separator(configuration)
|
||||
prediction = separator.separate(waveform)
|
||||
separator = Separator(configuration, stft_backend=backend)
|
||||
prediction = separator.separate(waveform, TEST_AUDIO_DESCRIPTOR)
|
||||
assert len(prediction) == len(instruments)
|
||||
for instrument in instruments:
|
||||
assert instrument in prediction
|
||||
for instrument in instruments:
|
||||
track = prediction[instrument]
|
||||
assert not (waveform == track).all()
|
||||
assert waveform.shape == track.shape
|
||||
assert not np.allclose(waveform, track)
|
||||
for compared in instruments:
|
||||
if instrument != compared:
|
||||
assert not (track == prediction[compared]).all()
|
||||
assert not np.allclose(track, prediction[compared])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS)
|
||||
def test_separate_to_file(configuration, instruments):
|
||||
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
|
||||
def test_separate_to_file(configuration, instruments, backend):
|
||||
""" Test file based separation. """
|
||||
separator = Separator(configuration)
|
||||
separator = Separator(configuration, stft_backend=backend)
|
||||
with TemporaryDirectory() as directory:
|
||||
separator.separate_to_file(
|
||||
TEST_AUDIO_DESCRIPTOR,
|
||||
@@ -59,10 +64,10 @@ def test_separate_to_file(configuration, instruments):
|
||||
'{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument)))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS)
|
||||
def test_filename_format(configuration, instruments):
|
||||
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
|
||||
def test_filename_format(configuration, instruments, backend):
|
||||
""" Test custom filename format. """
|
||||
separator = Separator(configuration)
|
||||
separator = Separator(configuration, stft_backend=backend)
|
||||
with TemporaryDirectory() as directory:
|
||||
separator.separate_to_file(
|
||||
TEST_AUDIO_DESCRIPTOR,
|
||||
@@ -74,7 +79,7 @@ def test_filename_format(configuration, instruments):
|
||||
'export/{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument)))
|
||||
|
||||
|
||||
def test_filename_confilct():
|
||||
def test_filename_conflict():
|
||||
""" Test error handling with static pattern. """
|
||||
separator = Separator(TEST_CONFIGURATIONS[0][0])
|
||||
with TemporaryDirectory() as directory:
|
||||
|
||||
Reference in New Issue
Block a user