From fe4634afa68623034423ce256d886c32d0e27d45 Mon Sep 17 00:00:00 2001 From: akhlif Date: Wed, 26 Feb 2020 16:31:24 +0100 Subject: [PATCH] Adding option to use librosa backend. Changes in the EstimatorBuilder to set attributes instead of returning tensors for the _build methods. InputProvider classes to handle the different backend cases. New method in Separator. --- spleeter/commands/__init__.py | 17 ++- spleeter/commands/separate.py | 12 +- spleeter/model/__init__.py | 230 ++++++++++++++++++++++++++-------- spleeter/separator.py | 83 +++++------- spleeter/utils/estimator.py | 14 +-- 5 files changed, 232 insertions(+), 124 deletions(-) diff --git a/spleeter/commands/__init__.py b/spleeter/commands/__init__.py index 21b70ef..452b3a6 100644 --- a/spleeter/commands/__init__.py +++ b/spleeter/commands/__init__.py @@ -72,14 +72,13 @@ OPT_DURATION = { } # -w opt specification (separate) -OPT_CHUNKED = { - 'dest': 'chunk_duration', - 'type': float, - 'default': -1, - 'help': 'Maximum duration of the segments that are fed to' - ' the network. Use this parameter to limit ' - 'memory usage. Use -1 to process the whole signal' - ' in one pass.' +OPT_STFT_BACKEND = { + 'dest': 'stft_backend', + 'type': str, + 'choices' : ["tensorflow", "librosa", "auto"], + 'default': "auto", + 'help': 'Who should be in charge of computing the stfts. Librosa is faster than tensorflow on CPU and uses' + ' less memory. "auto" will use tensorflow when GPU acceleration is available and librosa when not.' } @@ -191,7 +190,7 @@ def _create_separate_parser(parser_factory): parser.add_argument('-c', '--codec', **OPT_CODEC) parser.add_argument('-b', '--birate', **OPT_BITRATE) parser.add_argument('-m', '--mwf', **OPT_MWF) - parser.add_argument('-w', '--chunk', **OPT_CHUNKED) + parser.add_argument('-B', '--stft-backend', **OPT_STFT_BACKEND) return parser diff --git a/spleeter/commands/separate.py b/spleeter/commands/separate.py index b3de4b6..2fac8aa 100644 --- a/spleeter/commands/separate.py +++ b/spleeter/commands/separate.py @@ -11,6 +11,8 @@ -i /path/to/audio1.wav /path/to/audio2.mp3 """ +import tensorflow as tf + from ..audio.adapter import get_audio_adapter from ..separator import Separator @@ -19,6 +21,12 @@ __author__ = 'Deezer Research' __license__ = 'MIT License' +def get_backend(backend): + if backend == "auto": + return "tensorflow" if tf.test.is_gpu_available() else "librosa" + return backend + + def entrypoint(arguments, params): """ Command entrypoint. @@ -29,13 +37,13 @@ def entrypoint(arguments, params): audio_adapter = get_audio_adapter(arguments.audio_adapter) separator = Separator( arguments.configuration, - arguments.MWF) + MWF=arguments.MWF, + stft_backend=get_backend(arguments.stft_backend)) for filename in arguments.inputs: separator.separate_to_file( filename, arguments.output_path, audio_adapter=audio_adapter, - chunk_duration=arguments.chunk_duration, offset=arguments.offset, duration=arguments.duration, codec=arguments.codec, diff --git a/spleeter/model/__init__.py b/spleeter/model/__init__.py index 4d18eb5..4c3e200 100644 --- a/spleeter/model/__init__.py +++ b/spleeter/model/__init__.py @@ -18,6 +18,10 @@ __author__ = 'Deezer Research' __license__ = 'MIT License' + +placeholder = tf.compat.v1.placeholder + + def get_model_function(model_type): """ Get tensorflow function of the model to be applied to the input tensor. @@ -41,6 +45,74 @@ def get_model_function(model_type): return model_function +class InputProvider(object): + + def __init__(self, params): + self.params = params + + def get_input_dict_placeholders(self): + raise NotImplementedError() + + @property + def input_names(self): + raise NotImplementedError() + + def get_feed_dict(self, features, *args): + raise NotImplementedError() + + +class WaveformInputProvider(InputProvider): + + @property + def input_names(self): + return ["audio_id", "waveform"] + + def get_input_dict_placeholders(self): + shape = (None, self.params['n_channels']) + features = { + 'waveform': placeholder(tf.float32, shape=shape, name="waveform"), + 'audio_id': placeholder(tf.string, name="audio_id")} + return features + + def get_feed_dict(self, features, waveform, audio_id): + return {features["audio_id"]: audio_id, features["waveform"]: waveform} + + +class SpectralInputProvider(InputProvider): + + def __init__(self, params): + super().__init__(params) + self.stft_input_name = "{}_stft".format(self.params["mix_name"]) + + @property + def input_names(self): + return ["audio_id", self.stft_input_name] + + def get_input_dict_placeholders(self): + features = { + self.stft_input_name: placeholder(tf.complex64, + shape=(None, self.params["frame_length"]//2+1, + self.params['n_channels']), + name=self.stft_input_name), + 'audio_id': placeholder(tf.string, name="audio_id")} + return features + + def get_feed_dict(self, features, stft, audio_id): + return {features["audio_id"]: audio_id, features[self.stft_input_name]: stft} + + +class InputProviderFactory(object): + + @staticmethod + def get(params): + stft_backend = params["stft_backend"] + assert stft_backend in ("tensorflow", "librosa") + if stft_backend == "tensorflow": + return WaveformInputProvider(params) + else: + return SpectralInputProvider(params) + + class EstimatorSpecBuilder(object): """ A builder class that allows to builds a multitrack unet model estimator. The built model estimator has a different behaviour when @@ -57,9 +129,9 @@ class EstimatorSpecBuilder(object): >>> from spleeter.model import EstimatorSpecBuilder >>> builder = EstimatorSpecBuilder() - >>> builder.build_prediction_model() + >>> builder.build_predict_model() >>> builder.build_evaluation_model() - >>> builder.build_training_model() + >>> builder.build_train_model() >>> from spleeter.model import model_fn >>> estimator = tf.estimator.Estimator(model_fn=model_fn, ...) @@ -94,6 +166,7 @@ class EstimatorSpecBuilder(object): :param features: The input features for the estimator. :param params: Some hyperparameters as a dictionary. """ + self._features = features self._params = params # Get instrument name. @@ -106,7 +179,10 @@ class EstimatorSpecBuilder(object): self._frame_length = params['frame_length'] self._frame_step = params['frame_step'] - def _build_output_dict(self, input_tensor=None): + def include_stft_computations(self): + return self._params["stft_backend"] == "tensorflow" + + def _build_model_outputs(self): """ Created a batch_sizexTxFxn_channels input tensor containing mix magnitude spectrogram, then an output dict from it according to the selected model in internal parameters. @@ -114,8 +190,8 @@ class EstimatorSpecBuilder(object): :returns: Build output dict. :raise ValueError: If required model_type is not supported. """ - if input_tensor is None: - input_tensor = self._features[f'{self._mix_name}_spectrogram'] + + input_tensor = self.spectrogram_feature model = self._params.get('model', None) if model is not None: model_type = model.get('type', self.DEFAULT_MODEL) @@ -125,12 +201,12 @@ class EstimatorSpecBuilder(object): apply_model = get_model_function(model_type) except ModuleNotFoundError: raise ValueError(f'No model function {model_type} found') - return apply_model( + self._model_outputs = apply_model( input_tensor, self._instruments, self._params['model']['params']) - def _build_loss(self, output_dict, labels): + def _build_loss(self, labels): """ Construct tensorflow loss and metrics :param output_dict: dictionary of network outputs (key: instrument @@ -139,6 +215,7 @@ class EstimatorSpecBuilder(object): name, value: ground truth spectrogram of the instrument) :returns: tensorflow (loss, metrics) tuple. """ + output_dict = self.model_outputs loss_type = self._params.get('loss_type', self.L1_MASK) if loss_type == self.L1_MASK: losses = { @@ -178,30 +255,72 @@ class EstimatorSpecBuilder(object): return tf.compat.v1.train.GradientDescentOptimizer(rate) return tf.compat.v1.train.AdamOptimizer(rate) + @property + def instruments(self): + return self._instruments + + @property + def stft_name(self): + return f'{self._mix_name}_stft' + + @property + def spectrogram_name(self): + return f'{self._mix_name}_spectrogram' + def _build_stft_feature(self): """ Compute STFT of waveform and slice the STFT in segment with the right length to feed the network. """ - stft_feature = tf.transpose( - stft( - tf.transpose(self._features['waveform']), - self._frame_length, - self._frame_step, - window_fn=lambda frame_length, dtype: ( - hann_window(frame_length, periodic=True, dtype=dtype)), - pad_end=True), - perm=[1, 2, 0]) - self._features[f'{self._mix_name}_stft'] = stft_feature - self._features[f'{self._mix_name}_spectrogram'] = tf.abs( - pad_and_partition(stft_feature, self._T))[:, :, :self._F, :] - def get_stft_feature(self): - return self._features[f'{self._mix_name}_stft'] + stft_name = self.stft_name + spec_name = self.spectrogram_name - def get_spectrogram_feature(self): - return self._features[f'{self._mix_name}_spectrogram'] + if stft_name not in self._features: + stft_feature = tf.transpose( + stft( + tf.transpose(self._features['waveform']), + self._frame_length, + self._frame_step, + window_fn=lambda frame_length, dtype: ( + hann_window(frame_length, periodic=True, dtype=dtype)), + pad_end=True), + perm=[1, 2, 0]) + self._features[f'{self._mix_name}_stft'] = stft_feature + if spec_name not in self._features: + self._features[spec_name] = tf.abs( + pad_and_partition(self._features[stft_name], self._T))[:, :, :self._F, :] - def _inverse_stft(self, stft): + @property + def model_outputs(self): + if not hasattr(self, "_model_outputs"): + self._build_model_outputs() + return self._model_outputs + + @property + def outputs(self): + if not hasattr(self, "_outputs"): + self._build_outputs() + return self._outputs + + @property + def stft_feature(self): + if self.stft_name not in self._features: + self._build_stft_feature() + return self._features[self.stft_name] + + @property + def spectrogram_feature(self): + if self.spectrogram_name not in self._features: + self._build_stft_feature() + return self._features[self.spectrogram_name] + + @property + def masks(self): + if not hasattr(self, "_masks"): + self._build_masks() + return self._masks + + def _inverse_stft(self, stft, time_crop=None): """ Inverse and reshape the given STFT :param stft: input STFT @@ -215,20 +334,21 @@ class EstimatorSpecBuilder(object): hann_window(frame_length, periodic=True, dtype=dtype)) ) * self.WINDOW_COMPENSATION_FACTOR reshaped = tf.transpose(inversed) - return reshaped[:tf.shape(self._features['waveform'])[0], :] + if time_crop is None: + time_crop = tf.shape(self._features['waveform'])[0] + return reshaped[:time_crop, :] - def _build_mwf_output_waveform(self, output_dict): + def _build_mwf_output_waveform(self): """ Perform separation with multichannel Wiener Filtering using Norbert. Note: multichannel Wiener Filtering is not coded in Tensorflow and thus may be quite slow. - :param output_dict: dictionary of estimated spectrogram (key: instrument - name, value: estimated spectrogram of the instrument) :returns: dictionary of separated waveforms (key: instrument name, value: estimated waveform of the instrument) """ import norbert # pylint: disable=import-error - x = self._features[f'{self._mix_name}_stft'] + output_dict = self.model_outputs + x = self.stft_feature v = tf.stack( [ pad_and_reshape( @@ -272,11 +392,13 @@ class EstimatorSpecBuilder(object): mask_shape[-1])) else: raise ValueError(f'Invalid mask_extension parameter {extension}') - n_extra_row = (self._frame_length) // 2 + 1 - self._F + n_extra_row = self._frame_length // 2 + 1 - self._F extension = tf.tile(extension_row, [1, 1, n_extra_row, 1]) return tf.concat([mask, extension], axis=2) - def _build_masks(self, output_dict): + def _build_masks(self): + output_dict = self.model_outputs + stft_feature = self.stft_feature separation_exponent = self._params['separation_exponent'] output_sum = tf.reduce_sum( [e ** separation_exponent for e in output_dict.values()], @@ -297,21 +419,20 @@ class EstimatorSpecBuilder(object): axis=0) instrument_mask = tf.reshape(instrument_mask, new_shape) # Remove padded part (for mask having the same size as STFT); - stft_feature = self._features[f'{self._mix_name}_stft'] + instrument_mask = instrument_mask[ :tf.shape(stft_feature)[0], ...] out[instrument] = instrument_mask - return out + self._masks = out - def _build_masked_stft(self, mask_dict, input_stft=None): - if input_stft is None: - input_stft = self._features[f'{self._mix_name}_stft'] + def _build_masked_stft(self): + input_stft = self.stft_feature out = {} - for instrument, mask in mask_dict.items(): + for instrument, mask in self.masks.items(): out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft return out - def _build_manual_output_waveform(self, output_dict): + def _build_manual_output_waveform(self, masked_stft): """ Perform ratio mask separation :param output_dict: dictionary of estimated spectrogram (key: instrument @@ -321,27 +442,34 @@ class EstimatorSpecBuilder(object): """ output_waveform = {} - masked_stft = self._build_masked_stft(self._build_masks(output_dict)) for instrument, stft_data in masked_stft.items(): output_waveform[instrument] = self._inverse_stft(stft_data) return output_waveform - def _build_output_waveform(self, output_dict): + def _build_output_waveform(self, masked_stft): """ Build output waveform from given output dict in order to be used in prediction context. Regarding of the configuration building method will be using MWF. - :param output_dict: Output dict to build output waveform from. :returns: Built output waveform. """ + if self._params.get('MWF', False): - output_waveform = self._build_mwf_output_waveform(output_dict) + output_waveform = self._build_mwf_output_waveform() else: - output_waveform = self._build_manual_output_waveform(output_dict) - if 'audio_id' in self._features: - output_waveform['audio_id'] = self._features['audio_id'] + output_waveform = self._build_manual_output_waveform(masked_stft) return output_waveform + def _build_outputs(self): + masked_stft = self._build_masked_stft() + if self.include_stft_computations(): + self._outputs = self._build_output_waveform(masked_stft) + else: + self._outputs = masked_stft + + if 'audio_id' in self._features: + self._outputs['audio_id'] = self._features['audio_id'] + def build_predict_model(self): """ Builder interface for creating model instance that aims to perform prediction / inference over given track. The output of such estimator @@ -350,12 +478,10 @@ class EstimatorSpecBuilder(object): :returns: An estimator for performing prediction. """ - self._build_stft_feature() - output_dict = self._build_output_dict() - output_waveform = self._build_output_waveform(output_dict) + return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.PREDICT, - predictions=output_waveform) + predictions=self.outputs) def build_evaluation_model(self, labels): """ Builder interface for creating model instance that aims to perform @@ -366,8 +492,7 @@ class EstimatorSpecBuilder(object): :param labels: Model labels. :returns: An estimator for performing model evaluation. """ - output_dict = self._build_output_dict() - loss, metrics = self._build_loss(output_dict, labels) + loss, metrics = self._build_loss(labels) return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.EVAL, loss=loss, @@ -382,8 +507,7 @@ class EstimatorSpecBuilder(object): :param labels: Model labels. :returns: An estimator for performing model training. """ - output_dict = self._build_output_dict() - loss, metrics = self._build_loss(output_dict, labels) + loss, metrics = self._build_loss(labels) optimizer = self._build_optimizer() train_operation = optimizer.minimize( loss=loss, diff --git a/spleeter/separator.py b/spleeter/separator.py index f7eac1b..1f4d177 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -20,14 +20,15 @@ from multiprocessing import Pool from os.path import basename, join, splitext import numpy as np import tensorflow as tf - +from librosa.core import stft, istft +from scipy.signal.windows import hann from . import SpleeterError from .audio.adapter import get_default_audio_adapter from .audio.convertor import to_stereo from .utils.configuration import load_configuration -from .utils.estimator import create_estimator, to_predictor, get_input_dict_placeholders, get_default_model_dir -from .model import EstimatorSpecBuilder +from .utils.estimator import create_estimator, to_predictor, get_default_model_dir +from .model import EstimatorSpecBuilder, InputProviderFactory __email__ = 'research@deezer.com' @@ -41,7 +42,7 @@ logger = logging.getLogger("spleeter") class Separator(object): """ A wrapper class for performing separation. """ - def __init__(self, params_descriptor, MWF=False, multiprocess=True): + def __init__(self, params_descriptor, MWF=False, stft_backend="auto", multiprocess=True): """ Default constructor. :param params_descriptor: Descriptor for TF params to be used. @@ -53,6 +54,7 @@ class Separator(object): self._predictor = None self._pool = Pool() if multiprocess else None self._tasks = [] + self._params["stft_backend"] = stft_backend def _get_predictor(self): """ Lazy loading access method for internal predictor instance. @@ -120,60 +122,41 @@ class Separator(object): assert chunk_size % d == 0 return chunk_size//d - def separate_chunked(self, waveform, sample_rate, chunk_max_duration): - chunk_size = self.get_valid_chunk_size(sample_rate, chunk_max_duration) - print(f"chunk size is {chunk_size}") - batch_size = self.get_batch_size_for_chunk_size(chunk_size) - print(f"batch size {batch_size}") - T, F = self._params["T"], self._params["F"] + def stft(self, waveform, inverse=False): + N = self._params["frame_length"] + H = self._params["frame_step"] + win = hann(N, sym=False) + fstft = istft if inverse else stft + win_len_arg = "win_length" if inverse else "n_fft" + s1 = fstft(waveform[:, 0], hop_length=H, window=win, center=False, **{win_len_arg: N}) + s2 = fstft(waveform[:, 1], hop_length=H, window=win, center=False, **{win_len_arg: N}) + s1 = np.expand_dims(s1.T, 2-inverse) + s2 = np.expand_dims(s2.T, 2-inverse) + return np.concatenate([s1, s2], axis=2-inverse) + + def separate_librosa(self, waveform, audio_id): out = {} - n_batches = (waveform.shape[0]+batch_size*T*F-1)//(batch_size*T*F) - print(f"{n_batches} to compute") - features = get_input_dict_placeholders(self._params) - spectrogram_input_t = tf.placeholder(tf.float32, shape=(None, T, F, 2), name="spectrogram_input") - istft_input_t = tf.placeholder(tf.complex64, shape=(None, F, 2), name="istft_input") - start_t = tf.placeholder(tf.int32, shape=(), name="start") - end_t = tf.placeholder(tf.int32, shape=(), name="end") + input_provider = InputProviderFactory.get(self._params) + features = input_provider.get_input_dict_placeholders() + builder = EstimatorSpecBuilder(features, self._params) latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir'])) # TODO: fix the logic, build sometimes return, sometimes set attribute - builder._build_stft_feature() - stft_t = builder.get_stft_feature() - output_dict_t = builder._build_output_dict(input_tensor=spectrogram_input_t) - masked_stft_t = builder._build_masked_stft(builder._build_masks(output_dict_t), - input_stft=stft_t[start_t:end_t, :, :]) - output_waveform_t = builder._inverse_stft(istft_input_t) - waveform_t = features["waveform"] - masked_stfts = {} + outputs = builder.outputs + saver = tf.train.Saver() - + stft = self.stft(waveform) with tf.Session() as sess: - print("restoring weights {}".format(time())) saver.restore(sess, latest_checkpoint) - print("computing spectrogram {}".format(time())) - spectrogram, stft = sess.run([builder.get_spectrogram_feature(), stft_t], feed_dict={waveform_t: waveform}) - print(spectrogram.shape) - print(stft.shape) - for i in range(n_batches): - print("computing batch {} {}".format(i, time())) - start = i*batch_size - end = (i+1)*batch_size - tmp = sess.run(masked_stft_t, - feed_dict={spectrogram_input_t: spectrogram[start:end, ...], - start_t: start*T, end_t: end*T, stft_t: stft}) - for instrument, masked_stft in tmp.items(): - masked_stfts.setdefault(instrument, []).append(masked_stft) - - print("inverting spectrogram {}".format(time())) - for instrument, masked_stft in masked_stfts.items(): - out[instrument] = sess.run(output_waveform_t, {istft_input_t: np.concatenate(masked_stft, axis=0)}) - print("done separating {}".format(time())) + outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id)) + for inst in builder.instruments: + out[inst] = self.stft(outputs[inst], inverse=True) return out def separate_to_file( self, audio_descriptor, destination, - audio_adapter=get_default_audio_adapter(), chunk_duration=-1, + audio_adapter=get_default_audio_adapter(), offset=0, duration=600., codec='wav', bitrate='128k', filename_format='{filename}/{instrument}.{codec}', synchronous=True): @@ -198,15 +181,15 @@ class Separator(object): :param filename_format: (Optional) Filename format. :param synchronous: (Optional) True is should by synchronous. """ - print("loading audio {}".format(time())) waveform, sample_rate = audio_adapter.load( audio_descriptor, offset=offset, duration=duration, sample_rate=self._sample_rate) - print("done loading audio {}".format(time())) - sources = self.separate_chunked(waveform, sample_rate, chunk_duration) - print("saving to file {}".format(time())) + if self._params["stft_backend"] == "tensorflow": + sources = self.separate(waveform) + else: + sources = self.separate_librosa(waveform, audio_descriptor) self.save_to_file(sources, audio_descriptor, destination, filename_format, codec, audio_adapter, bitrate, synchronous) diff --git a/spleeter/utils/estimator.py b/spleeter/utils/estimator.py index 2b17e96..a9aa736 100644 --- a/spleeter/utils/estimator.py +++ b/spleeter/utils/estimator.py @@ -13,7 +13,7 @@ import tensorflow as tf from tensorflow.contrib import predictor # pylint: enable=import-error -from ..model import model_fn +from ..model import model_fn, InputProviderFactory from ..model.provider import get_default_model_provider # Default exporting directory for predictor. @@ -59,14 +59,6 @@ def create_estimator(params, MWF): return estimator -def get_input_dict_placeholders(params): - shape = (None, params['n_channels']) - features = { - 'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape, name="waveform"), - 'audio_id': tf.compat.v1.placeholder(tf.string, name="audio_id")} - return features - - def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY): """ Exports given estimator as predictor into the given directory and returns associated tf.predictor instance. @@ -74,8 +66,10 @@ def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY): :param estimator: Estimator to export. :param directory: (Optional) path to write exported model into. """ + + input_provider = InputProviderFactory.get(estimator.params) def receiver(): - features = get_input_dict_placeholders(estimator.params) + features = input_provider.get_input_dict_placeholders() return tf.estimator.export.ServingInputReceiver(features, features) estimator.export_saved_model(directory, receiver)