mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
Updating tests to test for librosa backend
This commit is contained in:
@@ -105,7 +105,7 @@ class InputProviderFactory(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get(params):
|
def get(params):
|
||||||
stft_backend = params["stft_backend"]
|
stft_backend = params["stft_backend"]
|
||||||
assert stft_backend in ("tensorflow", "librosa")
|
assert stft_backend in ("tensorflow", "librosa"), "Unexpected backend {}".format(stft_backend)
|
||||||
if stft_backend == "tensorflow":
|
if stft_backend == "tensorflow":
|
||||||
return WaveformInputProvider(params)
|
return WaveformInputProvider(params)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ class Separator(object):
|
|||||||
task.get()
|
task.get()
|
||||||
task.wait(timeout=timeout)
|
task.wait(timeout=timeout)
|
||||||
|
|
||||||
def separate(self, waveform):
|
def separate_tensorflow(self, waveform, audio_descriptor):
|
||||||
""" Performs source separation over the given waveform.
|
""" Performs source separation over the given waveform.
|
||||||
|
|
||||||
The separation is performed synchronously but the result
|
The separation is performed synchronously but the result
|
||||||
@@ -103,11 +103,11 @@ class Separator(object):
|
|||||||
predictor = self._get_predictor()
|
predictor = self._get_predictor()
|
||||||
prediction = predictor({
|
prediction = predictor({
|
||||||
'waveform': waveform,
|
'waveform': waveform,
|
||||||
'audio_id': ''})
|
'audio_id': audio_descriptor})
|
||||||
prediction.pop('audio_id')
|
prediction.pop('audio_id')
|
||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
def stft(self, data, inverse=False):
|
def stft(self, data, inverse=False, length=None):
|
||||||
"""
|
"""
|
||||||
Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two
|
Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two
|
||||||
channels are processed separately and are concatenated together in the result. The expected input formats are:
|
channels are processed separately and are concatenated together in the result. The expected input formats are:
|
||||||
@@ -116,11 +116,13 @@ class Separator(object):
|
|||||||
:param inverse: should a stft or an istft be computed.
|
:param inverse: should a stft or an istft be computed.
|
||||||
:return: Stereo data as numpy array for the transform. The channels are stored in the last dimension
|
:return: Stereo data as numpy array for the transform. The channels are stored in the last dimension
|
||||||
"""
|
"""
|
||||||
|
assert not (inverse and length is None)
|
||||||
|
data = np.asfortranarray(data)
|
||||||
N = self._params["frame_length"]
|
N = self._params["frame_length"]
|
||||||
H = self._params["frame_step"]
|
H = self._params["frame_step"]
|
||||||
win = hann(N, sym=False)
|
win = hann(N, sym=False)
|
||||||
fstft = istft if inverse else stft
|
fstft = istft if inverse else stft
|
||||||
win_len_arg = {"win_length": None} if inverse else {"n_fft": N}
|
win_len_arg = {"win_length": None, "length": length} if inverse else {"n_fft": N}
|
||||||
dl, dr = (data[:, :, 0].T, data[:, :, 1].T) if inverse else (data[:, 0], data[:, 1])
|
dl, dr = (data[:, :, 0].T, data[:, :, 1].T) if inverse else (data[:, 0], data[:, 1])
|
||||||
s1 = fstft(dl, hop_length=H, window=win, center=False, **win_len_arg)
|
s1 = fstft(dl, hop_length=H, window=win, center=False, **win_len_arg)
|
||||||
s2 = fstft(dr, hop_length=H, window=win, center=False, **win_len_arg)
|
s2 = fstft(dr, hop_length=H, window=win, center=False, **win_len_arg)
|
||||||
@@ -145,9 +147,15 @@ class Separator(object):
|
|||||||
saver.restore(sess, latest_checkpoint)
|
saver.restore(sess, latest_checkpoint)
|
||||||
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))
|
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))
|
||||||
for inst in builder.instruments:
|
for inst in builder.instruments:
|
||||||
out[inst] = self.stft(outputs[inst], inverse=True)[:waveform.shape[0], :]
|
out[inst] = self.stft(outputs[inst], inverse=True, length=waveform.shape[0])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def separate(self, waveform, audio_descriptor):
|
||||||
|
if self._params["stft_backend"] == "tensorflow":
|
||||||
|
return self.separate_tensorflow(waveform, audio_descriptor)
|
||||||
|
else:
|
||||||
|
return self.separate_librosa(waveform, audio_descriptor)
|
||||||
|
|
||||||
def separate_to_file(
|
def separate_to_file(
|
||||||
self, audio_descriptor, destination,
|
self, audio_descriptor, destination,
|
||||||
audio_adapter=get_default_audio_adapter(),
|
audio_adapter=get_default_audio_adapter(),
|
||||||
@@ -180,10 +188,7 @@ class Separator(object):
|
|||||||
offset=offset,
|
offset=offset,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
sample_rate=self._sample_rate)
|
sample_rate=self._sample_rate)
|
||||||
if self._params["stft_backend"] == "tensorflow":
|
sources = self.separate(waveform, audio_descriptor)
|
||||||
sources = self.separate(waveform)
|
|
||||||
else:
|
|
||||||
sources = self.separate_librosa(waveform, audio_descriptor)
|
|
||||||
self.save_to_file(sources, audio_descriptor, destination, filename_format, codec,
|
self.save_to_file(sources, audio_descriptor, destination, filename_format, codec,
|
||||||
audio_adapter, bitrate, synchronous)
|
audio_adapter, bitrate, synchronous)
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from os.path import splitext, basename, exists, join
|
|||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from spleeter import SpleeterError
|
from spleeter import SpleeterError
|
||||||
from spleeter.audio.adapter import get_default_audio_adapter
|
from spleeter.audio.adapter import get_default_audio_adapter
|
||||||
@@ -21,34 +22,38 @@ from spleeter.separator import Separator
|
|||||||
TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3'
|
TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3'
|
||||||
TEST_AUDIO_BASENAME = splitext(basename(TEST_AUDIO_DESCRIPTOR))[0]
|
TEST_AUDIO_BASENAME = splitext(basename(TEST_AUDIO_DESCRIPTOR))[0]
|
||||||
TEST_CONFIGURATIONS = [
|
TEST_CONFIGURATIONS = [
|
||||||
('spleeter:2stems', ('vocals', 'accompaniment')),
|
('spleeter:2stems', ('vocals', 'accompaniment'), 'tensorflow'),
|
||||||
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other')),
|
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'tensorflow'),
|
||||||
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'))
|
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'tensorflow'),
|
||||||
|
('spleeter:2stems', ('vocals', 'accompaniment'), 'librosa'),
|
||||||
|
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'librosa'),
|
||||||
|
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'librosa')
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS)
|
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
|
||||||
def test_separate(configuration, instruments):
|
def test_separate(configuration, instruments, backend):
|
||||||
""" Test separation from raw data. """
|
""" Test separation from raw data. """
|
||||||
adapter = get_default_audio_adapter()
|
adapter = get_default_audio_adapter()
|
||||||
waveform, _ = adapter.load(TEST_AUDIO_DESCRIPTOR)
|
waveform, _ = adapter.load(TEST_AUDIO_DESCRIPTOR)
|
||||||
separator = Separator(configuration)
|
separator = Separator(configuration, stft_backend=backend)
|
||||||
prediction = separator.separate(waveform)
|
prediction = separator.separate(waveform, TEST_AUDIO_DESCRIPTOR)
|
||||||
assert len(prediction) == len(instruments)
|
assert len(prediction) == len(instruments)
|
||||||
for instrument in instruments:
|
for instrument in instruments:
|
||||||
assert instrument in prediction
|
assert instrument in prediction
|
||||||
for instrument in instruments:
|
for instrument in instruments:
|
||||||
track = prediction[instrument]
|
track = prediction[instrument]
|
||||||
assert not (waveform == track).all()
|
assert waveform.shape == track.shape
|
||||||
|
assert not np.allclose(waveform, track)
|
||||||
for compared in instruments:
|
for compared in instruments:
|
||||||
if instrument != compared:
|
if instrument != compared:
|
||||||
assert not (track == prediction[compared]).all()
|
assert not np.allclose(track, prediction[compared])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS)
|
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
|
||||||
def test_separate_to_file(configuration, instruments):
|
def test_separate_to_file(configuration, instruments, backend):
|
||||||
""" Test file based separation. """
|
""" Test file based separation. """
|
||||||
separator = Separator(configuration)
|
separator = Separator(configuration, stft_backend=backend)
|
||||||
with TemporaryDirectory() as directory:
|
with TemporaryDirectory() as directory:
|
||||||
separator.separate_to_file(
|
separator.separate_to_file(
|
||||||
TEST_AUDIO_DESCRIPTOR,
|
TEST_AUDIO_DESCRIPTOR,
|
||||||
@@ -59,10 +64,10 @@ def test_separate_to_file(configuration, instruments):
|
|||||||
'{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument)))
|
'{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument)))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS)
|
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
|
||||||
def test_filename_format(configuration, instruments):
|
def test_filename_format(configuration, instruments, backend):
|
||||||
""" Test custom filename format. """
|
""" Test custom filename format. """
|
||||||
separator = Separator(configuration)
|
separator = Separator(configuration, stft_backend=backend)
|
||||||
with TemporaryDirectory() as directory:
|
with TemporaryDirectory() as directory:
|
||||||
separator.separate_to_file(
|
separator.separate_to_file(
|
||||||
TEST_AUDIO_DESCRIPTOR,
|
TEST_AUDIO_DESCRIPTOR,
|
||||||
@@ -74,7 +79,7 @@ def test_filename_format(configuration, instruments):
|
|||||||
'export/{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument)))
|
'export/{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument)))
|
||||||
|
|
||||||
|
|
||||||
def test_filename_confilct():
|
def test_filename_conflict():
|
||||||
""" Test error handling with static pattern. """
|
""" Test error handling with static pattern. """
|
||||||
separator = Separator(TEST_CONFIGURATIONS[0][0])
|
separator = Separator(TEST_CONFIGURATIONS[0][0])
|
||||||
with TemporaryDirectory() as directory:
|
with TemporaryDirectory() as directory:
|
||||||
|
|||||||
Reference in New Issue
Block a user