Merge pull request #282 from alreadytaikeune/librosa_backend

Librosa backend
This commit is contained in:
alreadytaikeune
2020-03-20 17:10:41 +01:00
committed by GitHub
8 changed files with 355 additions and 98 deletions

View File

@@ -2,6 +2,7 @@ importlib_resources; python_version<'3.7'
requests
setuptools>=41.0.0
pandas==0.25.1
tensorflow==1.14.0
tensorflow==1.15.0
ffmpeg-python
norbert==0.2.1
norbert==0.2.1
librosa==0.7.2

View File

@@ -14,9 +14,9 @@ __license__ = 'MIT License'
# Default project values.
project_name = 'spleeter'
project_version = '1.4.9'
project_version = '1.5.0'
tensorflow_dependency = 'tensorflow'
tensorflow_version = '1.14.0'
tensorflow_version = '1.15.0'
here = path.abspath(path.dirname(__file__))
readme_path = path.join(here, 'README.md')
with open(readme_path, 'r') as stream:
@@ -56,6 +56,7 @@ setup(
'pandas==0.25.1',
'requests',
'setuptools>=41.0.0',
'librosa==0.7.2',
'{}=={}'.format(tensorflow_dependency, tensorflow_version),
],
extras_require={

View File

@@ -4,7 +4,7 @@
""" This modules provides spleeter command as well as CLI parsing methods. """
import json
import logging
from argparse import ArgumentParser
from tempfile import gettempdir
from os.path import exists, join
@@ -13,6 +13,8 @@ __email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# -i opt specification (separate).
OPT_INPUT = {
'dest': 'inputs',
@@ -68,6 +70,17 @@ OPT_DURATION = {
'the input file)')
}
# -w opt specification (separate)
OPT_STFT_BACKEND = {
'dest': 'stft_backend',
'type': str,
'choices' : ["tensorflow", "librosa", "auto"],
'default': "auto",
'help': 'Who should be in charge of computing the stfts. Librosa is faster than tensorflow on CPU and uses'
' less memory. "auto" will use tensorflow when GPU acceleration is available and librosa when not.'
}
# -c opt specification (separate).
OPT_CODEC = {
'dest': 'codec',
@@ -176,6 +189,7 @@ def _create_separate_parser(parser_factory):
parser.add_argument('-c', '--codec', **OPT_CODEC)
parser.add_argument('-b', '--birate', **OPT_BITRATE)
parser.add_argument('-m', '--mwf', **OPT_MWF)
parser.add_argument('-B', '--stft-backend', **OPT_STFT_BACKEND)
return parser

View File

@@ -19,6 +19,7 @@ __author__ = 'Deezer Research'
__license__ = 'MIT License'
def entrypoint(arguments, params):
""" Command entrypoint.
@@ -29,7 +30,8 @@ def entrypoint(arguments, params):
audio_adapter = get_audio_adapter(arguments.audio_adapter)
separator = Separator(
arguments.configuration,
arguments.MWF)
MWF=arguments.MWF,
stft_backend=arguments.stft_backend)
for filename in arguments.inputs:
separator.separate_to_file(
filename,

View File

@@ -18,6 +18,9 @@ __author__ = 'Deezer Research'
__license__ = 'MIT License'
placeholder = tf.compat.v1.placeholder
def get_model_function(model_type):
"""
Get tensorflow function of the model to be applied to the input tensor.
@@ -41,6 +44,74 @@ def get_model_function(model_type):
return model_function
class InputProvider(object):
def __init__(self, params):
self.params = params
def get_input_dict_placeholders(self):
raise NotImplementedError()
@property
def input_names(self):
raise NotImplementedError()
def get_feed_dict(self, features, *args):
raise NotImplementedError()
class WaveformInputProvider(InputProvider):
@property
def input_names(self):
return ["audio_id", "waveform"]
def get_input_dict_placeholders(self):
shape = (None, self.params['n_channels'])
features = {
'waveform': placeholder(tf.float32, shape=shape, name="waveform"),
'audio_id': placeholder(tf.string, name="audio_id")}
return features
def get_feed_dict(self, features, waveform, audio_id):
return {features["audio_id"]: audio_id, features["waveform"]: waveform}
class SpectralInputProvider(InputProvider):
def __init__(self, params):
super().__init__(params)
self.stft_input_name = "{}_stft".format(self.params["mix_name"])
@property
def input_names(self):
return ["audio_id", self.stft_input_name]
def get_input_dict_placeholders(self):
features = {
self.stft_input_name: placeholder(tf.complex64,
shape=(None, self.params["frame_length"]//2+1,
self.params['n_channels']),
name=self.stft_input_name),
'audio_id': placeholder(tf.string, name="audio_id")}
return features
def get_feed_dict(self, features, stft, audio_id):
return {features["audio_id"]: audio_id, features[self.stft_input_name]: stft}
class InputProviderFactory(object):
@staticmethod
def get(params):
stft_backend = params["stft_backend"]
assert stft_backend in ("tensorflow", "librosa"), "Unexpected backend {}".format(stft_backend)
if stft_backend == "tensorflow":
return WaveformInputProvider(params)
else:
return SpectralInputProvider(params)
class EstimatorSpecBuilder(object):
""" A builder class that allows to builds a multitrack unet model
estimator. The built model estimator has a different behaviour when
@@ -57,9 +128,9 @@ class EstimatorSpecBuilder(object):
>>> from spleeter.model import EstimatorSpecBuilder
>>> builder = EstimatorSpecBuilder()
>>> builder.build_prediction_model()
>>> builder.build_predict_model()
>>> builder.build_evaluation_model()
>>> builder.build_training_model()
>>> builder.build_train_model()
>>> from spleeter.model import model_fn
>>> estimator = tf.estimator.Estimator(model_fn=model_fn, ...)
@@ -94,6 +165,7 @@ class EstimatorSpecBuilder(object):
:param features: The input features for the estimator.
:param params: Some hyperparameters as a dictionary.
"""
self._features = features
self._params = params
# Get instrument name.
@@ -106,7 +178,10 @@ class EstimatorSpecBuilder(object):
self._frame_length = params['frame_length']
self._frame_step = params['frame_step']
def _build_output_dict(self):
def include_stft_computations(self):
return self._params["stft_backend"] == "tensorflow"
def _build_model_outputs(self):
""" Created a batch_sizexTxFxn_channels input tensor containing
mix magnitude spectrogram, then an output dict from it according
to the selected model in internal parameters.
@@ -114,7 +189,8 @@ class EstimatorSpecBuilder(object):
:returns: Build output dict.
:raise ValueError: If required model_type is not supported.
"""
input_tensor = self._features[f'{self._mix_name}_spectrogram']
input_tensor = self.spectrogram_feature
model = self._params.get('model', None)
if model is not None:
model_type = model.get('type', self.DEFAULT_MODEL)
@@ -124,12 +200,12 @@ class EstimatorSpecBuilder(object):
apply_model = get_model_function(model_type)
except ModuleNotFoundError:
raise ValueError(f'No model function {model_type} found')
return apply_model(
self._model_outputs = apply_model(
input_tensor,
self._instruments,
self._params['model']['params'])
def _build_loss(self, output_dict, labels):
def _build_loss(self, labels):
""" Construct tensorflow loss and metrics
:param output_dict: dictionary of network outputs (key: instrument
@@ -138,6 +214,7 @@ class EstimatorSpecBuilder(object):
name, value: ground truth spectrogram of the instrument)
:returns: tensorflow (loss, metrics) tuple.
"""
output_dict = self.model_outputs
loss_type = self._params.get('loss_type', self.L1_MASK)
if loss_type == self.L1_MASK:
losses = {
@@ -177,51 +254,106 @@ class EstimatorSpecBuilder(object):
return tf.compat.v1.train.GradientDescentOptimizer(rate)
return tf.compat.v1.train.AdamOptimizer(rate)
@property
def instruments(self):
return self._instruments
@property
def stft_name(self):
return f'{self._mix_name}_stft'
@property
def spectrogram_name(self):
return f'{self._mix_name}_spectrogram'
def _build_stft_feature(self):
""" Compute STFT of waveform and slice the STFT in segment
with the right length to feed the network.
"""
stft_feature = tf.transpose(
stft(
tf.transpose(self._features['waveform']),
self._frame_length,
self._frame_step,
window_fn=lambda frame_length, dtype: (
hann_window(frame_length, periodic=True, dtype=dtype)),
pad_end=True),
perm=[1, 2, 0])
self._features[f'{self._mix_name}_stft'] = stft_feature
self._features[f'{self._mix_name}_spectrogram'] = tf.abs(
pad_and_partition(stft_feature, self._T))[:, :, :self._F, :]
def _inverse_stft(self, stft):
stft_name = self.stft_name
spec_name = self.spectrogram_name
if stft_name not in self._features:
stft_feature = tf.transpose(
stft(
tf.transpose(self._features['waveform']),
self._frame_length,
self._frame_step,
window_fn=lambda frame_length, dtype: (
hann_window(frame_length, periodic=True, dtype=dtype)),
pad_end=True),
perm=[1, 2, 0])
self._features[f'{self._mix_name}_stft'] = stft_feature
if spec_name not in self._features:
self._features[spec_name] = tf.abs(
pad_and_partition(self._features[stft_name], self._T))[:, :, :self._F, :]
@property
def model_outputs(self):
if not hasattr(self, "_model_outputs"):
self._build_model_outputs()
return self._model_outputs
@property
def outputs(self):
if not hasattr(self, "_outputs"):
self._build_outputs()
return self._outputs
@property
def stft_feature(self):
if self.stft_name not in self._features:
self._build_stft_feature()
return self._features[self.stft_name]
@property
def spectrogram_feature(self):
if self.spectrogram_name not in self._features:
self._build_stft_feature()
return self._features[self.spectrogram_name]
@property
def masks(self):
if not hasattr(self, "_masks"):
self._build_masks()
return self._masks
@property
def masked_stfts(self):
if not hasattr(self, "_masked_stfts"):
self._build_masked_stfts()
return self._masked_stfts
def _inverse_stft(self, stft_t, time_crop=None):
""" Inverse and reshape the given STFT
:param stft: input STFT
:param stft_t: input STFT
:returns: inverse STFT (waveform)
"""
inversed = inverse_stft(
tf.transpose(stft, perm=[2, 0, 1]),
tf.transpose(stft_t, perm=[2, 0, 1]),
self._frame_length,
self._frame_step,
window_fn=lambda frame_length, dtype: (
hann_window(frame_length, periodic=True, dtype=dtype))
) * self.WINDOW_COMPENSATION_FACTOR
reshaped = tf.transpose(inversed)
return reshaped[:tf.shape(self._features['waveform'])[0], :]
if time_crop is None:
time_crop = tf.shape(self._features['waveform'])[0]
return reshaped[:time_crop, :]
def _build_mwf_output_waveform(self, output_dict):
def _build_mwf_output_waveform(self):
""" Perform separation with multichannel Wiener Filtering using Norbert.
Note: multichannel Wiener Filtering is not coded in Tensorflow and thus
may be quite slow.
:param output_dict: dictionary of estimated spectrogram (key: instrument
name, value: estimated spectrogram of the instrument)
:returns: dictionary of separated waveforms (key: instrument name,
value: estimated waveform of the instrument)
"""
import norbert # pylint: disable=import-error
x = self._features[f'{self._mix_name}_stft']
output_dict = self.model_outputs
x = self.stft_feature
v = tf.stack(
[
pad_and_reshape(
@@ -265,30 +397,28 @@ class EstimatorSpecBuilder(object):
mask_shape[-1]))
else:
raise ValueError(f'Invalid mask_extension parameter {extension}')
n_extra_row = (self._frame_length) // 2 + 1 - self._F
n_extra_row = self._frame_length // 2 + 1 - self._F
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
return tf.concat([mask, extension], axis=2)
def _build_manual_output_waveform(self, output_dict):
""" Perform ratio mask separation
:param output_dict: dictionary of estimated spectrogram (key: instrument
name, value: estimated spectrogram of the instrument)
:returns: dictionary of separated waveforms (key: instrument name,
value: estimated waveform of the instrument)
def _build_masks(self):
"""
Compute masks from the output spectrograms of the model.
:return:
"""
output_dict = self.model_outputs
stft_feature = self.stft_feature
separation_exponent = self._params['separation_exponent']
output_sum = tf.reduce_sum(
[e ** separation_exponent for e in output_dict.values()],
axis=0
) + self.EPSILON
output_waveform = {}
out = {}
for instrument in self._instruments:
output = output_dict[f'{instrument}_spectrogram']
# Compute mask with the model.
instrument_mask = (
output ** separation_exponent
+ (self.EPSILON / len(output_dict))) / output_sum
instrument_mask = (output ** separation_exponent
+ (self.EPSILON / len(output_dict))) / output_sum
# Extend mask;
instrument_mask = self._extend_mask(instrument_mask)
# Stack back mask.
@@ -298,30 +428,56 @@ class EstimatorSpecBuilder(object):
axis=0)
instrument_mask = tf.reshape(instrument_mask, new_shape)
# Remove padded part (for mask having the same size as STFT);
stft_feature = self._features[f'{self._mix_name}_stft']
instrument_mask = instrument_mask[
:tf.shape(stft_feature)[0], ...]
# Compute masked STFT and normalize it.
output_waveform[instrument] = self._inverse_stft(
tf.cast(instrument_mask, dtype=tf.complex64) * stft_feature)
:tf.shape(stft_feature)[0], ...]
out[instrument] = instrument_mask
self._masks = out
def _build_masked_stfts(self):
input_stft = self.stft_feature
out = {}
for instrument, mask in self.masks.items():
out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft
self._masked_stfts = out
def _build_manual_output_waveform(self, masked_stft):
""" Perform ratio mask separation
:param output_dict: dictionary of estimated spectrogram (key: instrument
name, value: estimated spectrogram of the instrument)
:returns: dictionary of separated waveforms (key: instrument name,
value: estimated waveform of the instrument)
"""
output_waveform = {}
for instrument, stft_data in masked_stft.items():
output_waveform[instrument] = self._inverse_stft(stft_data)
return output_waveform
def _build_output_waveform(self, output_dict):
def _build_output_waveform(self, masked_stft):
""" Build output waveform from given output dict in order to be used in
prediction context. Regarding of the configuration building method will
be using MWF.
:param output_dict: Output dict to build output waveform from.
:returns: Built output waveform.
"""
if self._params.get('MWF', False):
output_waveform = self._build_mwf_output_waveform(output_dict)
output_waveform = self._build_mwf_output_waveform()
else:
output_waveform = self._build_manual_output_waveform(output_dict)
if 'audio_id' in self._features:
output_waveform['audio_id'] = self._features['audio_id']
output_waveform = self._build_manual_output_waveform(masked_stft)
return output_waveform
def _build_outputs(self):
if self.include_stft_computations():
self._outputs = self._build_output_waveform(self.masked_stfts)
else:
self._outputs = self.masked_stfts
if 'audio_id' in self._features:
self._outputs['audio_id'] = self._features['audio_id']
def build_predict_model(self):
""" Builder interface for creating model instance that aims to perform
prediction / inference over given track. The output of such estimator
@@ -330,12 +486,10 @@ class EstimatorSpecBuilder(object):
:returns: An estimator for performing prediction.
"""
self._build_stft_feature()
output_dict = self._build_output_dict()
output_waveform = self._build_output_waveform(output_dict)
return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.PREDICT,
predictions=output_waveform)
predictions=self.outputs)
def build_evaluation_model(self, labels):
""" Builder interface for creating model instance that aims to perform
@@ -346,8 +500,7 @@ class EstimatorSpecBuilder(object):
:param labels: Model labels.
:returns: An estimator for performing model evaluation.
"""
output_dict = self._build_output_dict()
loss, metrics = self._build_loss(output_dict, labels)
loss, metrics = self._build_loss(labels)
return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.EVAL,
loss=loss,
@@ -362,8 +515,7 @@ class EstimatorSpecBuilder(object):
:param labels: Model labels.
:returns: An estimator for performing model training.
"""
output_dict = self._build_output_dict()
loss, metrics = self._build_loss(output_dict, labels)
loss, metrics = self._build_loss(labels)
optimizer = self._build_optimizer()
train_operation = optimizer.minimize(
loss=loss,

View File

@@ -13,40 +13,57 @@
"""
import os
import json
import logging
from functools import partial
from time import time
from multiprocessing import Pool
from pathlib import Path
from os.path import basename, join, splitext
import numpy as np
import tensorflow as tf
from librosa.core import stft, istft
from scipy.signal.windows import hann
from . import SpleeterError
from .audio.adapter import get_default_audio_adapter
from .audio.convertor import to_stereo
from .model import model_fn
from .utils.configuration import load_configuration
from .utils.estimator import create_estimator, to_predictor
from .utils.estimator import create_estimator, to_predictor, get_default_model_dir
from .model import EstimatorSpecBuilder, InputProviderFactory
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
logger = logging.getLogger("spleeter")
def get_backend(backend):
assert backend in ["auto", "tensorflow", "librosa"]
if backend == "auto":
return "tensorflow" if tf.test.is_gpu_available() else "librosa"
return backend
class Separator(object):
""" A wrapper class for performing separation. """
def __init__(self, params_descriptor, MWF=False, multiprocess=True):
def __init__(self, params_descriptor, MWF=False, stft_backend="auto", multiprocess=True):
""" Default constructor.
:param params_descriptor: Descriptor for TF params to be used.
:param MWF: (Optional) True if MWF should be used, False otherwise.
"""
self._params = load_configuration(params_descriptor)
self._sample_rate = self._params['sample_rate']
self._MWF = MWF
self._predictor = None
self._pool = Pool() if multiprocess else None
self._tasks = []
self._params["stft_backend"] = get_backend(stft_backend)
def _get_predictor(self):
""" Lazy loading access method for internal predictor instance.
@@ -68,7 +85,7 @@ class Separator(object):
task.get()
task.wait(timeout=timeout)
def separate(self, waveform):
def separate_tensorflow(self, waveform, audio_descriptor):
""" Performs source separation over the given waveform.
The separation is performed synchronously but the result
@@ -86,10 +103,59 @@ class Separator(object):
predictor = self._get_predictor()
prediction = predictor({
'waveform': waveform,
'audio_id': ''})
'audio_id': audio_descriptor})
prediction.pop('audio_id')
return prediction
def stft(self, data, inverse=False, length=None):
"""
Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two
channels are processed separately and are concatenated together in the result. The expected input formats are:
(n_samples, 2) for stft and (T, F, 2) for istft.
:param data: np.array with either the waveform or the complex spectrogram depending on the parameter inverse
:param inverse: should a stft or an istft be computed.
:return: Stereo data as numpy array for the transform. The channels are stored in the last dimension
"""
assert not (inverse and length is None)
data = np.asfortranarray(data)
N = self._params["frame_length"]
H = self._params["frame_step"]
win = hann(N, sym=False)
fstft = istft if inverse else stft
win_len_arg = {"win_length": None, "length": length} if inverse else {"n_fft": N}
dl, dr = (data[:, :, 0].T, data[:, :, 1].T) if inverse else (data[:, 0], data[:, 1])
s1 = fstft(dl, hop_length=H, window=win, center=False, **win_len_arg)
s2 = fstft(dr, hop_length=H, window=win, center=False, **win_len_arg)
s1 = np.expand_dims(s1.T, 2-inverse)
s2 = np.expand_dims(s2.T, 2-inverse)
return np.concatenate([s1, s2], axis=2-inverse)
def separate_librosa(self, waveform, audio_id):
out = {}
input_provider = InputProviderFactory.get(self._params)
features = input_provider.get_input_dict_placeholders()
builder = EstimatorSpecBuilder(features, self._params)
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
# TODO: fix the logic, build sometimes return, sometimes set attribute
outputs = builder.outputs
saver = tf.train.Saver()
stft = self.stft(waveform)
with tf.Session() as sess:
saver.restore(sess, latest_checkpoint)
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))
for inst in builder.instruments:
out[inst] = self.stft(outputs[inst], inverse=True, length=waveform.shape[0])
return out
def separate(self, waveform, audio_descriptor):
if self._params["stft_backend"] == "tensorflow":
return self.separate_tensorflow(waveform, audio_descriptor)
else:
return self.separate_librosa(waveform, audio_descriptor)
def separate_to_file(
self, audio_descriptor, destination,
audio_adapter=get_default_audio_adapter(),
@@ -108,6 +174,8 @@ class Separator(object):
descriptor would be a file path.
:param destination: Target directory to write output to.
:param audio_adapter: (Optional) Audio adapter to use for I/O.
:param chunk_duration: (Optional) Maximum signal duration that is processed
in one pass. Default: all signal.
:param offset: (Optional) Offset of loaded song.
:param duration: (Optional) Duration of loaded song.
:param codec: (Optional) Export codec.
@@ -115,12 +183,17 @@ class Separator(object):
:param filename_format: (Optional) Filename format.
:param synchronous: (Optional) True is should by synchronous.
"""
waveform, _ = audio_adapter.load(
waveform, sample_rate = audio_adapter.load(
audio_descriptor,
offset=offset,
duration=duration,
sample_rate=self._sample_rate)
sources = self.separate(waveform)
sources = self.separate(waveform, audio_descriptor)
self.save_to_file(sources, audio_descriptor, destination, filename_format, codec,
audio_adapter, bitrate, synchronous)
def save_to_file(self, sources, audio_descriptor, destination, filename_format, codec,
audio_adapter, bitrate, synchronous):
filename = splitext(basename(audio_descriptor))[0]
generated = []
for instrument, data in sources.items():

View File

@@ -13,27 +13,37 @@ import tensorflow as tf
from tensorflow.contrib import predictor
# pylint: enable=import-error
from ..model import model_fn
from ..model import model_fn, InputProviderFactory
from ..model.provider import get_default_model_provider
# Default exporting directory for predictor.
DEFAULT_EXPORT_DIRECTORY = join(gettempdir(), 'serving')
def get_default_model_dir(model_dir):
"""
Transforms a string like 'spleeter:2stems' into an actual path.
:param model_dir:
:return:
"""
model_provider = get_default_model_provider()
return model_provider.get(model_dir)
def create_estimator(params, MWF):
"""
Initialize tensorflow estimator that will perform separation
Params:
- params: a dictionnary of parameters for building the model
- params: a dictionary of parameters for building the model
Returns:
a tensorflow estimator
"""
# Load model.
model_directory = params['model_dir']
model_provider = get_default_model_provider()
params['model_dir'] = model_provider.get(model_directory)
params['model_dir'] = get_default_model_dir(params['model_dir'])
params['MWF'] = MWF
# Setup config
session_config = tf.compat.v1.ConfigProto()
@@ -56,11 +66,10 @@ def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
:param estimator: Estimator to export.
:param directory: (Optional) path to write exported model into.
"""
input_provider = InputProviderFactory.get(estimator.params)
def receiver():
shape = (None, estimator.params['n_channels'])
features = {
'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape),
'audio_id': tf.compat.v1.placeholder(tf.string)}
features = input_provider.get_input_dict_placeholders()
return tf.estimator.export.ServingInputReceiver(features, features)
estimator.export_saved_model(directory, receiver)

View File

@@ -13,6 +13,7 @@ from os.path import splitext, basename, exists, join
from tempfile import TemporaryDirectory
import pytest
import numpy as np
from spleeter import SpleeterError
from spleeter.audio.adapter import get_default_audio_adapter
@@ -21,34 +22,38 @@ from spleeter.separator import Separator
TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3'
TEST_AUDIO_BASENAME = splitext(basename(TEST_AUDIO_DESCRIPTOR))[0]
TEST_CONFIGURATIONS = [
('spleeter:2stems', ('vocals', 'accompaniment')),
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other')),
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'))
('spleeter:2stems', ('vocals', 'accompaniment'), 'tensorflow'),
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'tensorflow'),
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'tensorflow'),
('spleeter:2stems', ('vocals', 'accompaniment'), 'librosa'),
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'librosa'),
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'librosa')
]
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS)
def test_separate(configuration, instruments):
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
def test_separate(configuration, instruments, backend):
""" Test separation from raw data. """
adapter = get_default_audio_adapter()
waveform, _ = adapter.load(TEST_AUDIO_DESCRIPTOR)
separator = Separator(configuration)
prediction = separator.separate(waveform)
separator = Separator(configuration, stft_backend=backend)
prediction = separator.separate(waveform, TEST_AUDIO_DESCRIPTOR)
assert len(prediction) == len(instruments)
for instrument in instruments:
assert instrument in prediction
for instrument in instruments:
track = prediction[instrument]
assert not (waveform == track).all()
assert waveform.shape == track.shape
assert not np.allclose(waveform, track)
for compared in instruments:
if instrument != compared:
assert not (track == prediction[compared]).all()
assert not np.allclose(track, prediction[compared])
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS)
def test_separate_to_file(configuration, instruments):
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
def test_separate_to_file(configuration, instruments, backend):
""" Test file based separation. """
separator = Separator(configuration)
separator = Separator(configuration, stft_backend=backend)
with TemporaryDirectory() as directory:
separator.separate_to_file(
TEST_AUDIO_DESCRIPTOR,
@@ -59,10 +64,10 @@ def test_separate_to_file(configuration, instruments):
'{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument)))
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS)
def test_filename_format(configuration, instruments):
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
def test_filename_format(configuration, instruments, backend):
""" Test custom filename format. """
separator = Separator(configuration)
separator = Separator(configuration, stft_backend=backend)
with TemporaryDirectory() as directory:
separator.separate_to_file(
TEST_AUDIO_DESCRIPTOR,
@@ -74,7 +79,7 @@ def test_filename_format(configuration, instruments):
'export/{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument)))
def test_filename_confilct():
def test_filename_conflict():
""" Test error handling with static pattern. """
separator = Separator(TEST_CONFIGURATIONS[0][0])
with TemporaryDirectory() as directory: