diff --git a/requirements.txt b/requirements.txt index 8e1f7be..7f3bbf0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ importlib_resources; python_version<'3.7' requests setuptools>=41.0.0 pandas==0.25.1 -tensorflow==1.14.0 +tensorflow==1.15.0 ffmpeg-python -norbert==0.2.1 \ No newline at end of file +norbert==0.2.1 +librosa==0.7.2 \ No newline at end of file diff --git a/setup.py b/setup.py index b70dcdc..569ef49 100644 --- a/setup.py +++ b/setup.py @@ -14,9 +14,9 @@ __license__ = 'MIT License' # Default project values. project_name = 'spleeter' -project_version = '1.4.9' +project_version = '1.5.0' tensorflow_dependency = 'tensorflow' -tensorflow_version = '1.14.0' +tensorflow_version = '1.15.0' here = path.abspath(path.dirname(__file__)) readme_path = path.join(here, 'README.md') with open(readme_path, 'r') as stream: @@ -56,6 +56,7 @@ setup( 'pandas==0.25.1', 'requests', 'setuptools>=41.0.0', + 'librosa==0.7.2', '{}=={}'.format(tensorflow_dependency, tensorflow_version), ], extras_require={ diff --git a/spleeter/commands/__init__.py b/spleeter/commands/__init__.py index 2bbc974..995e866 100644 --- a/spleeter/commands/__init__.py +++ b/spleeter/commands/__init__.py @@ -4,7 +4,7 @@ """ This modules provides spleeter command as well as CLI parsing methods. """ import json - +import logging from argparse import ArgumentParser from tempfile import gettempdir from os.path import exists, join @@ -13,6 +13,8 @@ __email__ = 'research@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' + + # -i opt specification (separate). OPT_INPUT = { 'dest': 'inputs', @@ -68,6 +70,17 @@ OPT_DURATION = { 'the input file)') } +# -w opt specification (separate) +OPT_STFT_BACKEND = { + 'dest': 'stft_backend', + 'type': str, + 'choices' : ["tensorflow", "librosa", "auto"], + 'default': "auto", + 'help': 'Who should be in charge of computing the stfts. Librosa is faster than tensorflow on CPU and uses' + ' less memory. "auto" will use tensorflow when GPU acceleration is available and librosa when not.' +} + + # -c opt specification (separate). OPT_CODEC = { 'dest': 'codec', @@ -176,6 +189,7 @@ def _create_separate_parser(parser_factory): parser.add_argument('-c', '--codec', **OPT_CODEC) parser.add_argument('-b', '--birate', **OPT_BITRATE) parser.add_argument('-m', '--mwf', **OPT_MWF) + parser.add_argument('-B', '--stft-backend', **OPT_STFT_BACKEND) return parser diff --git a/spleeter/commands/separate.py b/spleeter/commands/separate.py index b1f41ab..158190c 100644 --- a/spleeter/commands/separate.py +++ b/spleeter/commands/separate.py @@ -19,6 +19,7 @@ __author__ = 'Deezer Research' __license__ = 'MIT License' + def entrypoint(arguments, params): """ Command entrypoint. @@ -29,7 +30,8 @@ def entrypoint(arguments, params): audio_adapter = get_audio_adapter(arguments.audio_adapter) separator = Separator( arguments.configuration, - arguments.MWF) + MWF=arguments.MWF, + stft_backend=arguments.stft_backend) for filename in arguments.inputs: separator.separate_to_file( filename, diff --git a/spleeter/model/__init__.py b/spleeter/model/__init__.py index 384e838..531f84c 100644 --- a/spleeter/model/__init__.py +++ b/spleeter/model/__init__.py @@ -18,6 +18,9 @@ __author__ = 'Deezer Research' __license__ = 'MIT License' +placeholder = tf.compat.v1.placeholder + + def get_model_function(model_type): """ Get tensorflow function of the model to be applied to the input tensor. @@ -41,6 +44,74 @@ def get_model_function(model_type): return model_function +class InputProvider(object): + + def __init__(self, params): + self.params = params + + def get_input_dict_placeholders(self): + raise NotImplementedError() + + @property + def input_names(self): + raise NotImplementedError() + + def get_feed_dict(self, features, *args): + raise NotImplementedError() + + +class WaveformInputProvider(InputProvider): + + @property + def input_names(self): + return ["audio_id", "waveform"] + + def get_input_dict_placeholders(self): + shape = (None, self.params['n_channels']) + features = { + 'waveform': placeholder(tf.float32, shape=shape, name="waveform"), + 'audio_id': placeholder(tf.string, name="audio_id")} + return features + + def get_feed_dict(self, features, waveform, audio_id): + return {features["audio_id"]: audio_id, features["waveform"]: waveform} + + +class SpectralInputProvider(InputProvider): + + def __init__(self, params): + super().__init__(params) + self.stft_input_name = "{}_stft".format(self.params["mix_name"]) + + @property + def input_names(self): + return ["audio_id", self.stft_input_name] + + def get_input_dict_placeholders(self): + features = { + self.stft_input_name: placeholder(tf.complex64, + shape=(None, self.params["frame_length"]//2+1, + self.params['n_channels']), + name=self.stft_input_name), + 'audio_id': placeholder(tf.string, name="audio_id")} + return features + + def get_feed_dict(self, features, stft, audio_id): + return {features["audio_id"]: audio_id, features[self.stft_input_name]: stft} + + +class InputProviderFactory(object): + + @staticmethod + def get(params): + stft_backend = params["stft_backend"] + assert stft_backend in ("tensorflow", "librosa"), "Unexpected backend {}".format(stft_backend) + if stft_backend == "tensorflow": + return WaveformInputProvider(params) + else: + return SpectralInputProvider(params) + + class EstimatorSpecBuilder(object): """ A builder class that allows to builds a multitrack unet model estimator. The built model estimator has a different behaviour when @@ -57,9 +128,9 @@ class EstimatorSpecBuilder(object): >>> from spleeter.model import EstimatorSpecBuilder >>> builder = EstimatorSpecBuilder() - >>> builder.build_prediction_model() + >>> builder.build_predict_model() >>> builder.build_evaluation_model() - >>> builder.build_training_model() + >>> builder.build_train_model() >>> from spleeter.model import model_fn >>> estimator = tf.estimator.Estimator(model_fn=model_fn, ...) @@ -94,6 +165,7 @@ class EstimatorSpecBuilder(object): :param features: The input features for the estimator. :param params: Some hyperparameters as a dictionary. """ + self._features = features self._params = params # Get instrument name. @@ -106,7 +178,10 @@ class EstimatorSpecBuilder(object): self._frame_length = params['frame_length'] self._frame_step = params['frame_step'] - def _build_output_dict(self): + def include_stft_computations(self): + return self._params["stft_backend"] == "tensorflow" + + def _build_model_outputs(self): """ Created a batch_sizexTxFxn_channels input tensor containing mix magnitude spectrogram, then an output dict from it according to the selected model in internal parameters. @@ -114,7 +189,8 @@ class EstimatorSpecBuilder(object): :returns: Build output dict. :raise ValueError: If required model_type is not supported. """ - input_tensor = self._features[f'{self._mix_name}_spectrogram'] + + input_tensor = self.spectrogram_feature model = self._params.get('model', None) if model is not None: model_type = model.get('type', self.DEFAULT_MODEL) @@ -124,12 +200,12 @@ class EstimatorSpecBuilder(object): apply_model = get_model_function(model_type) except ModuleNotFoundError: raise ValueError(f'No model function {model_type} found') - return apply_model( + self._model_outputs = apply_model( input_tensor, self._instruments, self._params['model']['params']) - def _build_loss(self, output_dict, labels): + def _build_loss(self, labels): """ Construct tensorflow loss and metrics :param output_dict: dictionary of network outputs (key: instrument @@ -138,6 +214,7 @@ class EstimatorSpecBuilder(object): name, value: ground truth spectrogram of the instrument) :returns: tensorflow (loss, metrics) tuple. """ + output_dict = self.model_outputs loss_type = self._params.get('loss_type', self.L1_MASK) if loss_type == self.L1_MASK: losses = { @@ -177,51 +254,106 @@ class EstimatorSpecBuilder(object): return tf.compat.v1.train.GradientDescentOptimizer(rate) return tf.compat.v1.train.AdamOptimizer(rate) + @property + def instruments(self): + return self._instruments + + @property + def stft_name(self): + return f'{self._mix_name}_stft' + + @property + def spectrogram_name(self): + return f'{self._mix_name}_spectrogram' + def _build_stft_feature(self): """ Compute STFT of waveform and slice the STFT in segment with the right length to feed the network. """ - stft_feature = tf.transpose( - stft( - tf.transpose(self._features['waveform']), - self._frame_length, - self._frame_step, - window_fn=lambda frame_length, dtype: ( - hann_window(frame_length, periodic=True, dtype=dtype)), - pad_end=True), - perm=[1, 2, 0]) - self._features[f'{self._mix_name}_stft'] = stft_feature - self._features[f'{self._mix_name}_spectrogram'] = tf.abs( - pad_and_partition(stft_feature, self._T))[:, :, :self._F, :] - def _inverse_stft(self, stft): + stft_name = self.stft_name + spec_name = self.spectrogram_name + + if stft_name not in self._features: + stft_feature = tf.transpose( + stft( + tf.transpose(self._features['waveform']), + self._frame_length, + self._frame_step, + window_fn=lambda frame_length, dtype: ( + hann_window(frame_length, periodic=True, dtype=dtype)), + pad_end=True), + perm=[1, 2, 0]) + self._features[f'{self._mix_name}_stft'] = stft_feature + if spec_name not in self._features: + self._features[spec_name] = tf.abs( + pad_and_partition(self._features[stft_name], self._T))[:, :, :self._F, :] + + @property + def model_outputs(self): + if not hasattr(self, "_model_outputs"): + self._build_model_outputs() + return self._model_outputs + + @property + def outputs(self): + if not hasattr(self, "_outputs"): + self._build_outputs() + return self._outputs + + @property + def stft_feature(self): + if self.stft_name not in self._features: + self._build_stft_feature() + return self._features[self.stft_name] + + @property + def spectrogram_feature(self): + if self.spectrogram_name not in self._features: + self._build_stft_feature() + return self._features[self.spectrogram_name] + + @property + def masks(self): + if not hasattr(self, "_masks"): + self._build_masks() + return self._masks + + @property + def masked_stfts(self): + if not hasattr(self, "_masked_stfts"): + self._build_masked_stfts() + return self._masked_stfts + + def _inverse_stft(self, stft_t, time_crop=None): """ Inverse and reshape the given STFT - :param stft: input STFT + :param stft_t: input STFT :returns: inverse STFT (waveform) """ inversed = inverse_stft( - tf.transpose(stft, perm=[2, 0, 1]), + tf.transpose(stft_t, perm=[2, 0, 1]), self._frame_length, self._frame_step, window_fn=lambda frame_length, dtype: ( hann_window(frame_length, periodic=True, dtype=dtype)) ) * self.WINDOW_COMPENSATION_FACTOR reshaped = tf.transpose(inversed) - return reshaped[:tf.shape(self._features['waveform'])[0], :] + if time_crop is None: + time_crop = tf.shape(self._features['waveform'])[0] + return reshaped[:time_crop, :] - def _build_mwf_output_waveform(self, output_dict): + def _build_mwf_output_waveform(self): """ Perform separation with multichannel Wiener Filtering using Norbert. Note: multichannel Wiener Filtering is not coded in Tensorflow and thus may be quite slow. - :param output_dict: dictionary of estimated spectrogram (key: instrument - name, value: estimated spectrogram of the instrument) :returns: dictionary of separated waveforms (key: instrument name, value: estimated waveform of the instrument) """ import norbert # pylint: disable=import-error - x = self._features[f'{self._mix_name}_stft'] + output_dict = self.model_outputs + x = self.stft_feature v = tf.stack( [ pad_and_reshape( @@ -265,30 +397,28 @@ class EstimatorSpecBuilder(object): mask_shape[-1])) else: raise ValueError(f'Invalid mask_extension parameter {extension}') - n_extra_row = (self._frame_length) // 2 + 1 - self._F + n_extra_row = self._frame_length // 2 + 1 - self._F extension = tf.tile(extension_row, [1, 1, n_extra_row, 1]) return tf.concat([mask, extension], axis=2) - def _build_manual_output_waveform(self, output_dict): - """ Perform ratio mask separation - - :param output_dict: dictionary of estimated spectrogram (key: instrument - name, value: estimated spectrogram of the instrument) - :returns: dictionary of separated waveforms (key: instrument name, - value: estimated waveform of the instrument) + def _build_masks(self): """ + Compute masks from the output spectrograms of the model. + :return: + """ + output_dict = self.model_outputs + stft_feature = self.stft_feature separation_exponent = self._params['separation_exponent'] output_sum = tf.reduce_sum( [e ** separation_exponent for e in output_dict.values()], axis=0 ) + self.EPSILON - output_waveform = {} + out = {} for instrument in self._instruments: output = output_dict[f'{instrument}_spectrogram'] # Compute mask with the model. - instrument_mask = ( - output ** separation_exponent - + (self.EPSILON / len(output_dict))) / output_sum + instrument_mask = (output ** separation_exponent + + (self.EPSILON / len(output_dict))) / output_sum # Extend mask; instrument_mask = self._extend_mask(instrument_mask) # Stack back mask. @@ -298,30 +428,56 @@ class EstimatorSpecBuilder(object): axis=0) instrument_mask = tf.reshape(instrument_mask, new_shape) # Remove padded part (for mask having the same size as STFT); - stft_feature = self._features[f'{self._mix_name}_stft'] + instrument_mask = instrument_mask[ - :tf.shape(stft_feature)[0], ...] - # Compute masked STFT and normalize it. - output_waveform[instrument] = self._inverse_stft( - tf.cast(instrument_mask, dtype=tf.complex64) * stft_feature) + :tf.shape(stft_feature)[0], ...] + out[instrument] = instrument_mask + self._masks = out + + def _build_masked_stfts(self): + input_stft = self.stft_feature + out = {} + for instrument, mask in self.masks.items(): + out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft + self._masked_stfts = out + + def _build_manual_output_waveform(self, masked_stft): + """ Perform ratio mask separation + + :param output_dict: dictionary of estimated spectrogram (key: instrument + name, value: estimated spectrogram of the instrument) + :returns: dictionary of separated waveforms (key: instrument name, + value: estimated waveform of the instrument) + """ + + output_waveform = {} + for instrument, stft_data in masked_stft.items(): + output_waveform[instrument] = self._inverse_stft(stft_data) return output_waveform - def _build_output_waveform(self, output_dict): + def _build_output_waveform(self, masked_stft): """ Build output waveform from given output dict in order to be used in prediction context. Regarding of the configuration building method will be using MWF. - :param output_dict: Output dict to build output waveform from. :returns: Built output waveform. """ + if self._params.get('MWF', False): - output_waveform = self._build_mwf_output_waveform(output_dict) + output_waveform = self._build_mwf_output_waveform() else: - output_waveform = self._build_manual_output_waveform(output_dict) - if 'audio_id' in self._features: - output_waveform['audio_id'] = self._features['audio_id'] + output_waveform = self._build_manual_output_waveform(masked_stft) return output_waveform + def _build_outputs(self): + if self.include_stft_computations(): + self._outputs = self._build_output_waveform(self.masked_stfts) + else: + self._outputs = self.masked_stfts + + if 'audio_id' in self._features: + self._outputs['audio_id'] = self._features['audio_id'] + def build_predict_model(self): """ Builder interface for creating model instance that aims to perform prediction / inference over given track. The output of such estimator @@ -330,12 +486,10 @@ class EstimatorSpecBuilder(object): :returns: An estimator for performing prediction. """ - self._build_stft_feature() - output_dict = self._build_output_dict() - output_waveform = self._build_output_waveform(output_dict) + return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.PREDICT, - predictions=output_waveform) + predictions=self.outputs) def build_evaluation_model(self, labels): """ Builder interface for creating model instance that aims to perform @@ -346,8 +500,7 @@ class EstimatorSpecBuilder(object): :param labels: Model labels. :returns: An estimator for performing model evaluation. """ - output_dict = self._build_output_dict() - loss, metrics = self._build_loss(output_dict, labels) + loss, metrics = self._build_loss(labels) return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.EVAL, loss=loss, @@ -362,8 +515,7 @@ class EstimatorSpecBuilder(object): :param labels: Model labels. :returns: An estimator for performing model training. """ - output_dict = self._build_output_dict() - loss, metrics = self._build_loss(output_dict, labels) + loss, metrics = self._build_loss(labels) optimizer = self._build_optimizer() train_operation = optimizer.minimize( loss=loss, diff --git a/spleeter/separator.py b/spleeter/separator.py index 3ef3ecf..174c73f 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -13,40 +13,57 @@ """ import os -import json +import logging -from functools import partial +from time import time from multiprocessing import Pool -from pathlib import Path from os.path import basename, join, splitext +import numpy as np +import tensorflow as tf +from librosa.core import stft, istft +from scipy.signal.windows import hann from . import SpleeterError from .audio.adapter import get_default_audio_adapter from .audio.convertor import to_stereo -from .model import model_fn from .utils.configuration import load_configuration -from .utils.estimator import create_estimator, to_predictor +from .utils.estimator import create_estimator, to_predictor, get_default_model_dir +from .model import EstimatorSpecBuilder, InputProviderFactory + __email__ = 'research@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' +logger = logging.getLogger("spleeter") + + + +def get_backend(backend): + assert backend in ["auto", "tensorflow", "librosa"] + if backend == "auto": + return "tensorflow" if tf.test.is_gpu_available() else "librosa" + return backend + + class Separator(object): """ A wrapper class for performing separation. """ - def __init__(self, params_descriptor, MWF=False, multiprocess=True): + def __init__(self, params_descriptor, MWF=False, stft_backend="auto", multiprocess=True): """ Default constructor. :param params_descriptor: Descriptor for TF params to be used. :param MWF: (Optional) True if MWF should be used, False otherwise. """ + self._params = load_configuration(params_descriptor) self._sample_rate = self._params['sample_rate'] self._MWF = MWF self._predictor = None self._pool = Pool() if multiprocess else None self._tasks = [] + self._params["stft_backend"] = get_backend(stft_backend) def _get_predictor(self): """ Lazy loading access method for internal predictor instance. @@ -68,7 +85,7 @@ class Separator(object): task.get() task.wait(timeout=timeout) - def separate(self, waveform): + def separate_tensorflow(self, waveform, audio_descriptor): """ Performs source separation over the given waveform. The separation is performed synchronously but the result @@ -86,10 +103,59 @@ class Separator(object): predictor = self._get_predictor() prediction = predictor({ 'waveform': waveform, - 'audio_id': ''}) + 'audio_id': audio_descriptor}) prediction.pop('audio_id') return prediction + def stft(self, data, inverse=False, length=None): + """ + Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two + channels are processed separately and are concatenated together in the result. The expected input formats are: + (n_samples, 2) for stft and (T, F, 2) for istft. + :param data: np.array with either the waveform or the complex spectrogram depending on the parameter inverse + :param inverse: should a stft or an istft be computed. + :return: Stereo data as numpy array for the transform. The channels are stored in the last dimension + """ + assert not (inverse and length is None) + data = np.asfortranarray(data) + N = self._params["frame_length"] + H = self._params["frame_step"] + win = hann(N, sym=False) + fstft = istft if inverse else stft + win_len_arg = {"win_length": None, "length": length} if inverse else {"n_fft": N} + dl, dr = (data[:, :, 0].T, data[:, :, 1].T) if inverse else (data[:, 0], data[:, 1]) + s1 = fstft(dl, hop_length=H, window=win, center=False, **win_len_arg) + s2 = fstft(dr, hop_length=H, window=win, center=False, **win_len_arg) + s1 = np.expand_dims(s1.T, 2-inverse) + s2 = np.expand_dims(s2.T, 2-inverse) + return np.concatenate([s1, s2], axis=2-inverse) + + def separate_librosa(self, waveform, audio_id): + out = {} + input_provider = InputProviderFactory.get(self._params) + features = input_provider.get_input_dict_placeholders() + + builder = EstimatorSpecBuilder(features, self._params) + latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir'])) + + # TODO: fix the logic, build sometimes return, sometimes set attribute + outputs = builder.outputs + + saver = tf.train.Saver() + stft = self.stft(waveform) + with tf.Session() as sess: + saver.restore(sess, latest_checkpoint) + outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id)) + for inst in builder.instruments: + out[inst] = self.stft(outputs[inst], inverse=True, length=waveform.shape[0]) + return out + + def separate(self, waveform, audio_descriptor): + if self._params["stft_backend"] == "tensorflow": + return self.separate_tensorflow(waveform, audio_descriptor) + else: + return self.separate_librosa(waveform, audio_descriptor) + def separate_to_file( self, audio_descriptor, destination, audio_adapter=get_default_audio_adapter(), @@ -108,6 +174,8 @@ class Separator(object): descriptor would be a file path. :param destination: Target directory to write output to. :param audio_adapter: (Optional) Audio adapter to use for I/O. + :param chunk_duration: (Optional) Maximum signal duration that is processed + in one pass. Default: all signal. :param offset: (Optional) Offset of loaded song. :param duration: (Optional) Duration of loaded song. :param codec: (Optional) Export codec. @@ -115,12 +183,17 @@ class Separator(object): :param filename_format: (Optional) Filename format. :param synchronous: (Optional) True is should by synchronous. """ - waveform, _ = audio_adapter.load( + waveform, sample_rate = audio_adapter.load( audio_descriptor, offset=offset, duration=duration, sample_rate=self._sample_rate) - sources = self.separate(waveform) + sources = self.separate(waveform, audio_descriptor) + self.save_to_file(sources, audio_descriptor, destination, filename_format, codec, + audio_adapter, bitrate, synchronous) + + def save_to_file(self, sources, audio_descriptor, destination, filename_format, codec, + audio_adapter, bitrate, synchronous): filename = splitext(basename(audio_descriptor))[0] generated = [] for instrument, data in sources.items(): diff --git a/spleeter/utils/estimator.py b/spleeter/utils/estimator.py index 95c4219..a9aa736 100644 --- a/spleeter/utils/estimator.py +++ b/spleeter/utils/estimator.py @@ -13,27 +13,37 @@ import tensorflow as tf from tensorflow.contrib import predictor # pylint: enable=import-error -from ..model import model_fn +from ..model import model_fn, InputProviderFactory from ..model.provider import get_default_model_provider # Default exporting directory for predictor. DEFAULT_EXPORT_DIRECTORY = join(gettempdir(), 'serving') + +def get_default_model_dir(model_dir): + """ + Transforms a string like 'spleeter:2stems' into an actual path. + :param model_dir: + :return: + """ + model_provider = get_default_model_provider() + return model_provider.get(model_dir) + def create_estimator(params, MWF): """ Initialize tensorflow estimator that will perform separation Params: - - params: a dictionnary of parameters for building the model + - params: a dictionary of parameters for building the model Returns: a tensorflow estimator """ # Load model. - model_directory = params['model_dir'] - model_provider = get_default_model_provider() - params['model_dir'] = model_provider.get(model_directory) + + + params['model_dir'] = get_default_model_dir(params['model_dir']) params['MWF'] = MWF # Setup config session_config = tf.compat.v1.ConfigProto() @@ -56,11 +66,10 @@ def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY): :param estimator: Estimator to export. :param directory: (Optional) path to write exported model into. """ + + input_provider = InputProviderFactory.get(estimator.params) def receiver(): - shape = (None, estimator.params['n_channels']) - features = { - 'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape), - 'audio_id': tf.compat.v1.placeholder(tf.string)} + features = input_provider.get_input_dict_placeholders() return tf.estimator.export.ServingInputReceiver(features, features) estimator.export_saved_model(directory, receiver) diff --git a/tests/test_separator.py b/tests/test_separator.py index 9235731..271fdfb 100644 --- a/tests/test_separator.py +++ b/tests/test_separator.py @@ -13,6 +13,7 @@ from os.path import splitext, basename, exists, join from tempfile import TemporaryDirectory import pytest +import numpy as np from spleeter import SpleeterError from spleeter.audio.adapter import get_default_audio_adapter @@ -21,34 +22,38 @@ from spleeter.separator import Separator TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3' TEST_AUDIO_BASENAME = splitext(basename(TEST_AUDIO_DESCRIPTOR))[0] TEST_CONFIGURATIONS = [ - ('spleeter:2stems', ('vocals', 'accompaniment')), - ('spleeter:4stems', ('vocals', 'drums', 'bass', 'other')), - ('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other')) + ('spleeter:2stems', ('vocals', 'accompaniment'), 'tensorflow'), + ('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'tensorflow'), + ('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'tensorflow'), + ('spleeter:2stems', ('vocals', 'accompaniment'), 'librosa'), + ('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'librosa'), + ('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'librosa') ] -@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS) -def test_separate(configuration, instruments): +@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS) +def test_separate(configuration, instruments, backend): """ Test separation from raw data. """ adapter = get_default_audio_adapter() waveform, _ = adapter.load(TEST_AUDIO_DESCRIPTOR) - separator = Separator(configuration) - prediction = separator.separate(waveform) + separator = Separator(configuration, stft_backend=backend) + prediction = separator.separate(waveform, TEST_AUDIO_DESCRIPTOR) assert len(prediction) == len(instruments) for instrument in instruments: assert instrument in prediction for instrument in instruments: track = prediction[instrument] - assert not (waveform == track).all() + assert waveform.shape == track.shape + assert not np.allclose(waveform, track) for compared in instruments: if instrument != compared: - assert not (track == prediction[compared]).all() + assert not np.allclose(track, prediction[compared]) -@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS) -def test_separate_to_file(configuration, instruments): +@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS) +def test_separate_to_file(configuration, instruments, backend): """ Test file based separation. """ - separator = Separator(configuration) + separator = Separator(configuration, stft_backend=backend) with TemporaryDirectory() as directory: separator.separate_to_file( TEST_AUDIO_DESCRIPTOR, @@ -59,10 +64,10 @@ def test_separate_to_file(configuration, instruments): '{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument))) -@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS) -def test_filename_format(configuration, instruments): +@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS) +def test_filename_format(configuration, instruments, backend): """ Test custom filename format. """ - separator = Separator(configuration) + separator = Separator(configuration, stft_backend=backend) with TemporaryDirectory() as directory: separator.separate_to_file( TEST_AUDIO_DESCRIPTOR, @@ -74,7 +79,7 @@ def test_filename_format(configuration, instruments): 'export/{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument))) -def test_filename_confilct(): +def test_filename_conflict(): """ Test error handling with static pattern. """ separator = Separator(TEST_CONFIGURATIONS[0][0]) with TemporaryDirectory() as directory: