Merge pull request #305 from alreadytaikeune/fix_librosa_mono

Fix librosa mono
This commit is contained in:
Moussallam
2020-04-03 16:15:59 +02:00
committed by GitHub
5 changed files with 62 additions and 36 deletions

BIN
audio_example_mono.mp3 Normal file

Binary file not shown.

View File

@@ -2,7 +2,7 @@ importlib_resources; python_version<'3.7'
requests requests
setuptools>=41.0.0 setuptools>=41.0.0
pandas==0.25.1 pandas==0.25.1
tensorflow==1.15 tensorflow==1.15.2
ffmpeg-python ffmpeg-python
norbert==0.2.1 norbert==0.2.1
librosa==0.7.2 librosa==0.7.2

View File

@@ -16,7 +16,7 @@ __license__ = 'MIT License'
project_name = 'spleeter' project_name = 'spleeter'
project_version = '1.5.0' project_version = '1.5.0'
tensorflow_dependency = 'tensorflow' tensorflow_dependency = 'tensorflow'
tensorflow_version = '1.15' tensorflow_version = '1.15.2'
here = path.abspath(path.dirname(__file__)) here = path.abspath(path.dirname(__file__))
readme_path = path.join(here, 'README.md') readme_path = path.join(here, 'README.md')
with open(readme_path, 'r') as stream: with open(readme_path, 'r') as stream:

View File

@@ -123,12 +123,16 @@ class Separator(object):
win = hann(N, sym=False) win = hann(N, sym=False)
fstft = istft if inverse else stft fstft = istft if inverse else stft
win_len_arg = {"win_length": None, "length": length} if inverse else {"n_fft": N} win_len_arg = {"win_length": None, "length": length} if inverse else {"n_fft": N}
dl, dr = (data[:, :, 0].T, data[:, :, 1].T) if inverse else (data[:, 0], data[:, 1]) n_channels = data.shape[-1]
s1 = fstft(dl, hop_length=H, window=win, center=False, **win_len_arg) out = []
s2 = fstft(dr, hop_length=H, window=win, center=False, **win_len_arg) for c in range(n_channels):
s1 = np.expand_dims(s1.T, 2-inverse) d = data[:, :, c].T if inverse else data[:, c]
s2 = np.expand_dims(s2.T, 2-inverse) s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg)
return np.concatenate([s1, s2], axis=2-inverse) s = np.expand_dims(s.T, 2-inverse)
out.append(s)
if len(out) == 1:
return out[0]
return np.concatenate(out, axis=2-inverse)
def separate_librosa(self, waveform, audio_id): def separate_librosa(self, waveform, audio_id):
out = {} out = {}
@@ -140,9 +144,13 @@ class Separator(object):
# TODO: fix the logic, build sometimes return, sometimes set attribute # TODO: fix the logic, build sometimes return, sometimes set attribute
outputs = builder.outputs outputs = builder.outputs
stft = self.stft(waveform)
if stft.shape[-1] == 1:
stft = np.concatenate([stft, stft], axis=-1)
elif stft.shape[-1] > 2:
stft = stft[:, :2]
saver = tf.train.Saver() saver = tf.train.Saver()
stft = self.stft(waveform)
with tf.Session() as sess: with tf.Session() as sess:
saver.restore(sess, latest_checkpoint) saver.restore(sess, latest_checkpoint)
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id)) outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))

View File

@@ -8,83 +8,101 @@ __author__ = 'Deezer Research'
__license__ = 'MIT License' __license__ = 'MIT License'
import filecmp import filecmp
import itertools
from os.path import splitext, basename, exists, join from os.path import splitext, basename, exists, join
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import pytest import pytest
import numpy as np import numpy as np
import tensorflow as tf
from spleeter import SpleeterError from spleeter import SpleeterError
from spleeter.audio.adapter import get_default_audio_adapter from spleeter.audio.adapter import get_default_audio_adapter
from spleeter.separator import Separator from spleeter.separator import Separator
TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3' TEST_AUDIO_DESCRIPTORS = ['audio_example.mp3', 'audio_example_mono.mp3']
TEST_AUDIO_BASENAME = splitext(basename(TEST_AUDIO_DESCRIPTOR))[0] BACKENDS = ["tensorflow", "librosa"]
TEST_CONFIGURATIONS = [ MODELS = ['spleeter:2stems', 'spleeter:4stems', 'spleeter:5stems']
('spleeter:2stems', ('vocals', 'accompaniment'), 'tensorflow'),
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'tensorflow'), MODEL_TO_INST = {
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'tensorflow'), 'spleeter:2stems': ('vocals', 'accompaniment'),
('spleeter:2stems', ('vocals', 'accompaniment'), 'librosa'), 'spleeter:4stems': ('vocals', 'drums', 'bass', 'other'),
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'librosa'), 'spleeter:5stems': ('vocals', 'drums', 'bass', 'piano', 'other'),
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'librosa') }
]
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS) MODELS_AND_TEST_FILES = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS))
def test_separate(configuration, instruments, backend): TEST_CONFIGURATIONS = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS, BACKENDS))
print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__))
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
def test_separate(test_file, configuration, backend):
""" Test separation from raw data. """ """ Test separation from raw data. """
tf.reset_default_graph()
instruments = MODEL_TO_INST[configuration]
adapter = get_default_audio_adapter() adapter = get_default_audio_adapter()
waveform, _ = adapter.load(TEST_AUDIO_DESCRIPTOR) waveform, _ = adapter.load(test_file)
separator = Separator(configuration, stft_backend=backend) separator = Separator(configuration, stft_backend=backend)
prediction = separator.separate(waveform, TEST_AUDIO_DESCRIPTOR) prediction = separator.separate(waveform, test_file)
assert len(prediction) == len(instruments) assert len(prediction) == len(instruments)
for instrument in instruments: for instrument in instruments:
assert instrument in prediction assert instrument in prediction
for instrument in instruments: for instrument in instruments:
track = prediction[instrument] track = prediction[instrument]
assert waveform.shape == track.shape assert waveform.shape[:-1] == track.shape[:-1]
assert not np.allclose(waveform, track) assert not np.allclose(waveform, track)
for compared in instruments: for compared in instruments:
if instrument != compared: if instrument != compared:
assert not np.allclose(track, prediction[compared]) assert not np.allclose(track, prediction[compared])
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS) @pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
def test_separate_to_file(configuration, instruments, backend): def test_separate_to_file(test_file, configuration, backend):
""" Test file based separation. """ """ Test file based separation. """
tf.reset_default_graph()
instruments = MODEL_TO_INST[configuration]
separator = Separator(configuration, stft_backend=backend) separator = Separator(configuration, stft_backend=backend)
name = splitext(basename(test_file))[0]
with TemporaryDirectory() as directory: with TemporaryDirectory() as directory:
separator.separate_to_file( separator.separate_to_file(
TEST_AUDIO_DESCRIPTOR, test_file,
directory) directory)
for instrument in instruments: for instrument in instruments:
assert exists(join( assert exists(join(
directory, directory,
'{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument))) '{}/{}.wav'.format(name, instrument)))
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS) @pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
def test_filename_format(configuration, instruments, backend): def test_filename_format(test_file, configuration, backend):
""" Test custom filename format. """ """ Test custom filename format. """
tf.reset_default_graph()
instruments = MODEL_TO_INST[configuration]
separator = Separator(configuration, stft_backend=backend) separator = Separator(configuration, stft_backend=backend)
name = splitext(basename(test_file))[0]
with TemporaryDirectory() as directory: with TemporaryDirectory() as directory:
separator.separate_to_file( separator.separate_to_file(
TEST_AUDIO_DESCRIPTOR, test_file,
directory, directory,
filename_format='export/{filename}/{instrument}.{codec}') filename_format='export/{filename}/{instrument}.{codec}')
for instrument in instruments: for instrument in instruments:
assert exists(join( assert exists(join(
directory, directory,
'export/{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument))) 'export/{}/{}.wav'.format(name, instrument)))
def test_filename_conflict(): @pytest.mark.parametrize('test_file, configuration', MODELS_AND_TEST_FILES)
def test_filename_conflict(test_file, configuration):
""" Test error handling with static pattern. """ """ Test error handling with static pattern. """
separator = Separator(TEST_CONFIGURATIONS[0][0]) tf.reset_default_graph()
separator = Separator(configuration)
with TemporaryDirectory() as directory: with TemporaryDirectory() as directory:
with pytest.raises(SpleeterError): with pytest.raises(SpleeterError):
separator.separate_to_file( separator.separate_to_file(
TEST_AUDIO_DESCRIPTOR, test_file,
directory, directory,
filename_format='I wanna be your lover') filename_format='I wanna be your lover')