From f93dfbc235138911dcc16a18eb513945e0de2be6 Mon Sep 17 00:00:00 2001 From: akhlif Date: Wed, 19 Feb 2020 10:55:48 +0100 Subject: [PATCH 01/10] Adding a new argument to support chunked inference --- spleeter/commands/__init__.py | 13 +++++++++++++ spleeter/separator.py | 30 +++++++++++++++++++++++------- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/spleeter/commands/__init__.py b/spleeter/commands/__init__.py index 2bbc974..9461a07 100644 --- a/spleeter/commands/__init__.py +++ b/spleeter/commands/__init__.py @@ -68,6 +68,18 @@ OPT_DURATION = { 'the input file)') } +# -w opt specification (separate) +OPT_CHUNKED = { + 'dest': 'chunk_duration', + 'type': float, + 'default': -1, + 'help': 'Maximum duration of the segments that are fed to' + ' the network. Use this parameter to limit ' + 'memory usage. Use -1 to process the whole signal' + ' in one pass.' +} + + # -c opt specification (separate). OPT_CODEC = { 'dest': 'codec', @@ -176,6 +188,7 @@ def _create_separate_parser(parser_factory): parser.add_argument('-c', '--codec', **OPT_CODEC) parser.add_argument('-b', '--birate', **OPT_BITRATE) parser.add_argument('-m', '--mwf', **OPT_MWF) + parser.add_argument('-w', '--chunk', **OPT_CHUNKED) return parser diff --git a/spleeter/separator.py b/spleeter/separator.py index 3ef3ecf..4423888 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -13,17 +13,14 @@ """ import os -import json -from functools import partial from multiprocessing import Pool -from pathlib import Path from os.path import basename, join, splitext +import numpy as np from . import SpleeterError from .audio.adapter import get_default_audio_adapter from .audio.convertor import to_stereo -from .model import model_fn from .utils.configuration import load_configuration from .utils.estimator import create_estimator, to_predictor @@ -90,9 +87,21 @@ class Separator(object): prediction.pop('audio_id') return prediction + def separate_chunked(self, waveform, sample_rate, chunk_duration=-1): + chunk_size = waveform.shape[0] if chunk_duration == -1 else chunk_duration*sample_rate + n_chunks = int(waveform.shape[0]/chunk_size) + out = {} + for i in range(n_chunks): + sources = self.separate(waveform) + for inst, data in sources.items(): + out.setdefault(inst, []).append(data) + for inst, data in out.items(): + out[inst] = np.concatenate(data, axis=0) + return out + def separate_to_file( self, audio_descriptor, destination, - audio_adapter=get_default_audio_adapter(), + audio_adapter=get_default_audio_adapter(), chunk_duration=-1, offset=0, duration=600., codec='wav', bitrate='128k', filename_format='{filename}/{instrument}.{codec}', synchronous=True): @@ -108,6 +117,8 @@ class Separator(object): descriptor would be a file path. :param destination: Target directory to write output to. :param audio_adapter: (Optional) Audio adapter to use for I/O. + :param chunk_duration: (Optional) Maximum signal duration that is processed + in one pass. Default: all signal. :param offset: (Optional) Offset of loaded song. :param duration: (Optional) Duration of loaded song. :param codec: (Optional) Export codec. @@ -115,12 +126,17 @@ class Separator(object): :param filename_format: (Optional) Filename format. :param synchronous: (Optional) True is should by synchronous. """ - waveform, _ = audio_adapter.load( + waveform, sample_rate = audio_adapter.load( audio_descriptor, offset=offset, duration=duration, sample_rate=self._sample_rate) - sources = self.separate(waveform) + sources = self.separate_chunked(waveform, sample_rate, chunk_duration=chunk_duration) + self.save_to_file(sources, audio_descriptor, destination, filename_format, codec, + audio_adapter, bitrate, synchronous) + + def save_to_file(self, sources, audio_descriptor, destination, filename_format, codec, + audio_adapter, bitrate, synchronous): filename = splitext(basename(audio_descriptor))[0] generated = [] for instrument, data in sources.items(): From aa7c208b3903192ce92584c42e99263eda764302 Mon Sep 17 00:00:00 2001 From: akhlif Date: Wed, 19 Feb 2020 23:11:16 +0100 Subject: [PATCH 02/10] First draft implementing correct reconstruction --- spleeter/commands/__init__.py | 5 +- spleeter/commands/separate.py | 1 + spleeter/model/__init__.py | 56 +++++++++++++------- spleeter/separator.py | 97 +++++++++++++++++++++++++++++++---- spleeter/utils/estimator.py | 31 ++++++++--- 5 files changed, 152 insertions(+), 38 deletions(-) diff --git a/spleeter/commands/__init__.py b/spleeter/commands/__init__.py index 9461a07..21b70ef 100644 --- a/spleeter/commands/__init__.py +++ b/spleeter/commands/__init__.py @@ -4,7 +4,7 @@ """ This modules provides spleeter command as well as CLI parsing methods. """ import json - +import logging from argparse import ArgumentParser from tempfile import gettempdir from os.path import exists, join @@ -13,6 +13,9 @@ __email__ = 'research@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' +logging.basicConfig() + + # -i opt specification (separate). OPT_INPUT = { 'dest': 'inputs', diff --git a/spleeter/commands/separate.py b/spleeter/commands/separate.py index b1f41ab..b3de4b6 100644 --- a/spleeter/commands/separate.py +++ b/spleeter/commands/separate.py @@ -35,6 +35,7 @@ def entrypoint(arguments, params): filename, arguments.output_path, audio_adapter=audio_adapter, + chunk_duration=arguments.chunk_duration, offset=arguments.offset, duration=arguments.duration, codec=arguments.codec, diff --git a/spleeter/model/__init__.py b/spleeter/model/__init__.py index 384e838..4d18eb5 100644 --- a/spleeter/model/__init__.py +++ b/spleeter/model/__init__.py @@ -106,7 +106,7 @@ class EstimatorSpecBuilder(object): self._frame_length = params['frame_length'] self._frame_step = params['frame_step'] - def _build_output_dict(self): + def _build_output_dict(self, input_tensor=None): """ Created a batch_sizexTxFxn_channels input tensor containing mix magnitude spectrogram, then an output dict from it according to the selected model in internal parameters. @@ -114,7 +114,8 @@ class EstimatorSpecBuilder(object): :returns: Build output dict. :raise ValueError: If required model_type is not supported. """ - input_tensor = self._features[f'{self._mix_name}_spectrogram'] + if input_tensor is None: + input_tensor = self._features[f'{self._mix_name}_spectrogram'] model = self._params.get('model', None) if model is not None: model_type = model.get('type', self.DEFAULT_MODEL) @@ -194,6 +195,12 @@ class EstimatorSpecBuilder(object): self._features[f'{self._mix_name}_spectrogram'] = tf.abs( pad_and_partition(stft_feature, self._T))[:, :, :self._F, :] + def get_stft_feature(self): + return self._features[f'{self._mix_name}_stft'] + + def get_spectrogram_feature(self): + return self._features[f'{self._mix_name}_spectrogram'] + def _inverse_stft(self, stft): """ Inverse and reshape the given STFT @@ -269,26 +276,18 @@ class EstimatorSpecBuilder(object): extension = tf.tile(extension_row, [1, 1, n_extra_row, 1]) return tf.concat([mask, extension], axis=2) - def _build_manual_output_waveform(self, output_dict): - """ Perform ratio mask separation - - :param output_dict: dictionary of estimated spectrogram (key: instrument - name, value: estimated spectrogram of the instrument) - :returns: dictionary of separated waveforms (key: instrument name, - value: estimated waveform of the instrument) - """ + def _build_masks(self, output_dict): separation_exponent = self._params['separation_exponent'] output_sum = tf.reduce_sum( [e ** separation_exponent for e in output_dict.values()], axis=0 ) + self.EPSILON - output_waveform = {} + out = {} for instrument in self._instruments: output = output_dict[f'{instrument}_spectrogram'] # Compute mask with the model. - instrument_mask = ( - output ** separation_exponent - + (self.EPSILON / len(output_dict))) / output_sum + instrument_mask = (output ** separation_exponent + + (self.EPSILON / len(output_dict))) / output_sum # Extend mask; instrument_mask = self._extend_mask(instrument_mask) # Stack back mask. @@ -300,10 +299,31 @@ class EstimatorSpecBuilder(object): # Remove padded part (for mask having the same size as STFT); stft_feature = self._features[f'{self._mix_name}_stft'] instrument_mask = instrument_mask[ - :tf.shape(stft_feature)[0], ...] - # Compute masked STFT and normalize it. - output_waveform[instrument] = self._inverse_stft( - tf.cast(instrument_mask, dtype=tf.complex64) * stft_feature) + :tf.shape(stft_feature)[0], ...] + out[instrument] = instrument_mask + return out + + def _build_masked_stft(self, mask_dict, input_stft=None): + if input_stft is None: + input_stft = self._features[f'{self._mix_name}_stft'] + out = {} + for instrument, mask in mask_dict.items(): + out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft + return out + + def _build_manual_output_waveform(self, output_dict): + """ Perform ratio mask separation + + :param output_dict: dictionary of estimated spectrogram (key: instrument + name, value: estimated spectrogram of the instrument) + :returns: dictionary of separated waveforms (key: instrument name, + value: estimated waveform of the instrument) + """ + + output_waveform = {} + masked_stft = self._build_masked_stft(self._build_masks(output_dict)) + for instrument, stft_data in masked_stft.items(): + output_waveform[instrument] = self._inverse_stft(stft_data) return output_waveform def _build_output_waveform(self, output_dict): diff --git a/spleeter/separator.py b/spleeter/separator.py index 4423888..f7eac1b 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -13,22 +13,31 @@ """ import os +import logging +from time import time from multiprocessing import Pool from os.path import basename, join, splitext import numpy as np +import tensorflow as tf + from . import SpleeterError from .audio.adapter import get_default_audio_adapter from .audio.convertor import to_stereo from .utils.configuration import load_configuration -from .utils.estimator import create_estimator, to_predictor +from .utils.estimator import create_estimator, to_predictor, get_input_dict_placeholders, get_default_model_dir +from .model import EstimatorSpecBuilder + __email__ = 'research@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' +logger = logging.getLogger("spleeter") + + class Separator(object): """ A wrapper class for performing separation. """ @@ -87,16 +96,79 @@ class Separator(object): prediction.pop('audio_id') return prediction - def separate_chunked(self, waveform, sample_rate, chunk_duration=-1): - chunk_size = waveform.shape[0] if chunk_duration == -1 else chunk_duration*sample_rate - n_chunks = int(waveform.shape[0]/chunk_size) + def get_valid_chunk_size(self, sample_rate: int, chunk_max_duration: float) -> int: + """ + Given a sample rate, and a maximal duration that a chunk can represent, return the maximum chunk + size in samples. The chunk size must be a non-zero multiple of T (temporal dimension of the input spectrogram) + times F (number of frequency bins in the input spectrogram). If no such value exist, we return T*F. + + :param sample_rate: sample rate of the pcm data + :param chunk_max_duration: maximal duration in seconds of a chunk + :return: highest non-zero chunk size of duration less than chunk_max_duration or minimal valid chunk size. + """ + assert chunk_max_duration > 0 + chunk_size = chunk_max_duration * sample_rate + min_sample_size = self._params["T"] * self._params["F"] + if chunk_size < min_sample_size: + min_duration = min_sample_size / sample_rate + logger.warning("chunk_duration must be at least {:.2f} seconds. Ignoring parameter".format(min_duration)) + chunk_size = min_sample_size + return min_sample_size*int(chunk_size/min_sample_size) + + def get_batch_size_for_chunk_size(self, chunk_size): + d = self._params["T"] * self._params["F"] + assert chunk_size % d == 0 + return chunk_size//d + + def separate_chunked(self, waveform, sample_rate, chunk_max_duration): + chunk_size = self.get_valid_chunk_size(sample_rate, chunk_max_duration) + print(f"chunk size is {chunk_size}") + batch_size = self.get_batch_size_for_chunk_size(chunk_size) + print(f"batch size {batch_size}") + T, F = self._params["T"], self._params["F"] out = {} - for i in range(n_chunks): - sources = self.separate(waveform) - for inst, data in sources.items(): - out.setdefault(inst, []).append(data) - for inst, data in out.items(): - out[inst] = np.concatenate(data, axis=0) + n_batches = (waveform.shape[0]+batch_size*T*F-1)//(batch_size*T*F) + print(f"{n_batches} to compute") + features = get_input_dict_placeholders(self._params) + spectrogram_input_t = tf.placeholder(tf.float32, shape=(None, T, F, 2), name="spectrogram_input") + istft_input_t = tf.placeholder(tf.complex64, shape=(None, F, 2), name="istft_input") + start_t = tf.placeholder(tf.int32, shape=(), name="start") + end_t = tf.placeholder(tf.int32, shape=(), name="end") + builder = EstimatorSpecBuilder(features, self._params) + latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir'])) + + # TODO: fix the logic, build sometimes return, sometimes set attribute + builder._build_stft_feature() + stft_t = builder.get_stft_feature() + output_dict_t = builder._build_output_dict(input_tensor=spectrogram_input_t) + masked_stft_t = builder._build_masked_stft(builder._build_masks(output_dict_t), + input_stft=stft_t[start_t:end_t, :, :]) + output_waveform_t = builder._inverse_stft(istft_input_t) + waveform_t = features["waveform"] + masked_stfts = {} + saver = tf.train.Saver() + + with tf.Session() as sess: + print("restoring weights {}".format(time())) + saver.restore(sess, latest_checkpoint) + print("computing spectrogram {}".format(time())) + spectrogram, stft = sess.run([builder.get_spectrogram_feature(), stft_t], feed_dict={waveform_t: waveform}) + print(spectrogram.shape) + print(stft.shape) + for i in range(n_batches): + print("computing batch {} {}".format(i, time())) + start = i*batch_size + end = (i+1)*batch_size + tmp = sess.run(masked_stft_t, + feed_dict={spectrogram_input_t: spectrogram[start:end, ...], + start_t: start*T, end_t: end*T, stft_t: stft}) + for instrument, masked_stft in tmp.items(): + masked_stfts.setdefault(instrument, []).append(masked_stft) + + print("inverting spectrogram {}".format(time())) + for instrument, masked_stft in masked_stfts.items(): + out[instrument] = sess.run(output_waveform_t, {istft_input_t: np.concatenate(masked_stft, axis=0)}) + print("done separating {}".format(time())) return out def separate_to_file( @@ -126,12 +198,15 @@ class Separator(object): :param filename_format: (Optional) Filename format. :param synchronous: (Optional) True is should by synchronous. """ + print("loading audio {}".format(time())) waveform, sample_rate = audio_adapter.load( audio_descriptor, offset=offset, duration=duration, sample_rate=self._sample_rate) - sources = self.separate_chunked(waveform, sample_rate, chunk_duration=chunk_duration) + print("done loading audio {}".format(time())) + sources = self.separate_chunked(waveform, sample_rate, chunk_duration) + print("saving to file {}".format(time())) self.save_to_file(sources, audio_descriptor, destination, filename_format, codec, audio_adapter, bitrate, synchronous) diff --git a/spleeter/utils/estimator.py b/spleeter/utils/estimator.py index 95c4219..2b17e96 100644 --- a/spleeter/utils/estimator.py +++ b/spleeter/utils/estimator.py @@ -20,20 +20,30 @@ from ..model.provider import get_default_model_provider DEFAULT_EXPORT_DIRECTORY = join(gettempdir(), 'serving') + +def get_default_model_dir(model_dir): + """ + Transforms a string like 'spleeter:2stems' into an actual path. + :param model_dir: + :return: + """ + model_provider = get_default_model_provider() + return model_provider.get(model_dir) + def create_estimator(params, MWF): """ Initialize tensorflow estimator that will perform separation Params: - - params: a dictionnary of parameters for building the model + - params: a dictionary of parameters for building the model Returns: a tensorflow estimator """ # Load model. - model_directory = params['model_dir'] - model_provider = get_default_model_provider() - params['model_dir'] = model_provider.get(model_directory) + + + params['model_dir'] = get_default_model_dir(params['model_dir']) params['MWF'] = MWF # Setup config session_config = tf.compat.v1.ConfigProto() @@ -49,6 +59,14 @@ def create_estimator(params, MWF): return estimator +def get_input_dict_placeholders(params): + shape = (None, params['n_channels']) + features = { + 'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape, name="waveform"), + 'audio_id': tf.compat.v1.placeholder(tf.string, name="audio_id")} + return features + + def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY): """ Exports given estimator as predictor into the given directory and returns associated tf.predictor instance. @@ -57,10 +75,7 @@ def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY): :param directory: (Optional) path to write exported model into. """ def receiver(): - shape = (None, estimator.params['n_channels']) - features = { - 'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape), - 'audio_id': tf.compat.v1.placeholder(tf.string)} + features = get_input_dict_placeholders(estimator.params) return tf.estimator.export.ServingInputReceiver(features, features) estimator.export_saved_model(directory, receiver) From fe4634afa68623034423ce256d886c32d0e27d45 Mon Sep 17 00:00:00 2001 From: akhlif Date: Wed, 26 Feb 2020 16:31:24 +0100 Subject: [PATCH 03/10] Adding option to use librosa backend. Changes in the EstimatorBuilder to set attributes instead of returning tensors for the _build methods. InputProvider classes to handle the different backend cases. New method in Separator. --- spleeter/commands/__init__.py | 17 ++- spleeter/commands/separate.py | 12 +- spleeter/model/__init__.py | 230 ++++++++++++++++++++++++++-------- spleeter/separator.py | 83 +++++------- spleeter/utils/estimator.py | 14 +-- 5 files changed, 232 insertions(+), 124 deletions(-) diff --git a/spleeter/commands/__init__.py b/spleeter/commands/__init__.py index 21b70ef..452b3a6 100644 --- a/spleeter/commands/__init__.py +++ b/spleeter/commands/__init__.py @@ -72,14 +72,13 @@ OPT_DURATION = { } # -w opt specification (separate) -OPT_CHUNKED = { - 'dest': 'chunk_duration', - 'type': float, - 'default': -1, - 'help': 'Maximum duration of the segments that are fed to' - ' the network. Use this parameter to limit ' - 'memory usage. Use -1 to process the whole signal' - ' in one pass.' +OPT_STFT_BACKEND = { + 'dest': 'stft_backend', + 'type': str, + 'choices' : ["tensorflow", "librosa", "auto"], + 'default': "auto", + 'help': 'Who should be in charge of computing the stfts. Librosa is faster than tensorflow on CPU and uses' + ' less memory. "auto" will use tensorflow when GPU acceleration is available and librosa when not.' } @@ -191,7 +190,7 @@ def _create_separate_parser(parser_factory): parser.add_argument('-c', '--codec', **OPT_CODEC) parser.add_argument('-b', '--birate', **OPT_BITRATE) parser.add_argument('-m', '--mwf', **OPT_MWF) - parser.add_argument('-w', '--chunk', **OPT_CHUNKED) + parser.add_argument('-B', '--stft-backend', **OPT_STFT_BACKEND) return parser diff --git a/spleeter/commands/separate.py b/spleeter/commands/separate.py index b3de4b6..2fac8aa 100644 --- a/spleeter/commands/separate.py +++ b/spleeter/commands/separate.py @@ -11,6 +11,8 @@ -i /path/to/audio1.wav /path/to/audio2.mp3 """ +import tensorflow as tf + from ..audio.adapter import get_audio_adapter from ..separator import Separator @@ -19,6 +21,12 @@ __author__ = 'Deezer Research' __license__ = 'MIT License' +def get_backend(backend): + if backend == "auto": + return "tensorflow" if tf.test.is_gpu_available() else "librosa" + return backend + + def entrypoint(arguments, params): """ Command entrypoint. @@ -29,13 +37,13 @@ def entrypoint(arguments, params): audio_adapter = get_audio_adapter(arguments.audio_adapter) separator = Separator( arguments.configuration, - arguments.MWF) + MWF=arguments.MWF, + stft_backend=get_backend(arguments.stft_backend)) for filename in arguments.inputs: separator.separate_to_file( filename, arguments.output_path, audio_adapter=audio_adapter, - chunk_duration=arguments.chunk_duration, offset=arguments.offset, duration=arguments.duration, codec=arguments.codec, diff --git a/spleeter/model/__init__.py b/spleeter/model/__init__.py index 4d18eb5..4c3e200 100644 --- a/spleeter/model/__init__.py +++ b/spleeter/model/__init__.py @@ -18,6 +18,10 @@ __author__ = 'Deezer Research' __license__ = 'MIT License' + +placeholder = tf.compat.v1.placeholder + + def get_model_function(model_type): """ Get tensorflow function of the model to be applied to the input tensor. @@ -41,6 +45,74 @@ def get_model_function(model_type): return model_function +class InputProvider(object): + + def __init__(self, params): + self.params = params + + def get_input_dict_placeholders(self): + raise NotImplementedError() + + @property + def input_names(self): + raise NotImplementedError() + + def get_feed_dict(self, features, *args): + raise NotImplementedError() + + +class WaveformInputProvider(InputProvider): + + @property + def input_names(self): + return ["audio_id", "waveform"] + + def get_input_dict_placeholders(self): + shape = (None, self.params['n_channels']) + features = { + 'waveform': placeholder(tf.float32, shape=shape, name="waveform"), + 'audio_id': placeholder(tf.string, name="audio_id")} + return features + + def get_feed_dict(self, features, waveform, audio_id): + return {features["audio_id"]: audio_id, features["waveform"]: waveform} + + +class SpectralInputProvider(InputProvider): + + def __init__(self, params): + super().__init__(params) + self.stft_input_name = "{}_stft".format(self.params["mix_name"]) + + @property + def input_names(self): + return ["audio_id", self.stft_input_name] + + def get_input_dict_placeholders(self): + features = { + self.stft_input_name: placeholder(tf.complex64, + shape=(None, self.params["frame_length"]//2+1, + self.params['n_channels']), + name=self.stft_input_name), + 'audio_id': placeholder(tf.string, name="audio_id")} + return features + + def get_feed_dict(self, features, stft, audio_id): + return {features["audio_id"]: audio_id, features[self.stft_input_name]: stft} + + +class InputProviderFactory(object): + + @staticmethod + def get(params): + stft_backend = params["stft_backend"] + assert stft_backend in ("tensorflow", "librosa") + if stft_backend == "tensorflow": + return WaveformInputProvider(params) + else: + return SpectralInputProvider(params) + + class EstimatorSpecBuilder(object): """ A builder class that allows to builds a multitrack unet model estimator. The built model estimator has a different behaviour when @@ -57,9 +129,9 @@ class EstimatorSpecBuilder(object): >>> from spleeter.model import EstimatorSpecBuilder >>> builder = EstimatorSpecBuilder() - >>> builder.build_prediction_model() + >>> builder.build_predict_model() >>> builder.build_evaluation_model() - >>> builder.build_training_model() + >>> builder.build_train_model() >>> from spleeter.model import model_fn >>> estimator = tf.estimator.Estimator(model_fn=model_fn, ...) @@ -94,6 +166,7 @@ class EstimatorSpecBuilder(object): :param features: The input features for the estimator. :param params: Some hyperparameters as a dictionary. """ + self._features = features self._params = params # Get instrument name. @@ -106,7 +179,10 @@ class EstimatorSpecBuilder(object): self._frame_length = params['frame_length'] self._frame_step = params['frame_step'] - def _build_output_dict(self, input_tensor=None): + def include_stft_computations(self): + return self._params["stft_backend"] == "tensorflow" + + def _build_model_outputs(self): """ Created a batch_sizexTxFxn_channels input tensor containing mix magnitude spectrogram, then an output dict from it according to the selected model in internal parameters. @@ -114,8 +190,8 @@ class EstimatorSpecBuilder(object): :returns: Build output dict. :raise ValueError: If required model_type is not supported. """ - if input_tensor is None: - input_tensor = self._features[f'{self._mix_name}_spectrogram'] + + input_tensor = self.spectrogram_feature model = self._params.get('model', None) if model is not None: model_type = model.get('type', self.DEFAULT_MODEL) @@ -125,12 +201,12 @@ class EstimatorSpecBuilder(object): apply_model = get_model_function(model_type) except ModuleNotFoundError: raise ValueError(f'No model function {model_type} found') - return apply_model( + self._model_outputs = apply_model( input_tensor, self._instruments, self._params['model']['params']) - def _build_loss(self, output_dict, labels): + def _build_loss(self, labels): """ Construct tensorflow loss and metrics :param output_dict: dictionary of network outputs (key: instrument @@ -139,6 +215,7 @@ class EstimatorSpecBuilder(object): name, value: ground truth spectrogram of the instrument) :returns: tensorflow (loss, metrics) tuple. """ + output_dict = self.model_outputs loss_type = self._params.get('loss_type', self.L1_MASK) if loss_type == self.L1_MASK: losses = { @@ -178,30 +255,72 @@ class EstimatorSpecBuilder(object): return tf.compat.v1.train.GradientDescentOptimizer(rate) return tf.compat.v1.train.AdamOptimizer(rate) + @property + def instruments(self): + return self._instruments + + @property + def stft_name(self): + return f'{self._mix_name}_stft' + + @property + def spectrogram_name(self): + return f'{self._mix_name}_spectrogram' + def _build_stft_feature(self): """ Compute STFT of waveform and slice the STFT in segment with the right length to feed the network. """ - stft_feature = tf.transpose( - stft( - tf.transpose(self._features['waveform']), - self._frame_length, - self._frame_step, - window_fn=lambda frame_length, dtype: ( - hann_window(frame_length, periodic=True, dtype=dtype)), - pad_end=True), - perm=[1, 2, 0]) - self._features[f'{self._mix_name}_stft'] = stft_feature - self._features[f'{self._mix_name}_spectrogram'] = tf.abs( - pad_and_partition(stft_feature, self._T))[:, :, :self._F, :] - def get_stft_feature(self): - return self._features[f'{self._mix_name}_stft'] + stft_name = self.stft_name + spec_name = self.spectrogram_name - def get_spectrogram_feature(self): - return self._features[f'{self._mix_name}_spectrogram'] + if stft_name not in self._features: + stft_feature = tf.transpose( + stft( + tf.transpose(self._features['waveform']), + self._frame_length, + self._frame_step, + window_fn=lambda frame_length, dtype: ( + hann_window(frame_length, periodic=True, dtype=dtype)), + pad_end=True), + perm=[1, 2, 0]) + self._features[f'{self._mix_name}_stft'] = stft_feature + if spec_name not in self._features: + self._features[spec_name] = tf.abs( + pad_and_partition(self._features[stft_name], self._T))[:, :, :self._F, :] - def _inverse_stft(self, stft): + @property + def model_outputs(self): + if not hasattr(self, "_model_outputs"): + self._build_model_outputs() + return self._model_outputs + + @property + def outputs(self): + if not hasattr(self, "_outputs"): + self._build_outputs() + return self._outputs + + @property + def stft_feature(self): + if self.stft_name not in self._features: + self._build_stft_feature() + return self._features[self.stft_name] + + @property + def spectrogram_feature(self): + if self.spectrogram_name not in self._features: + self._build_stft_feature() + return self._features[self.spectrogram_name] + + @property + def masks(self): + if not hasattr(self, "_masks"): + self._build_masks() + return self._masks + + def _inverse_stft(self, stft, time_crop=None): """ Inverse and reshape the given STFT :param stft: input STFT @@ -215,20 +334,21 @@ class EstimatorSpecBuilder(object): hann_window(frame_length, periodic=True, dtype=dtype)) ) * self.WINDOW_COMPENSATION_FACTOR reshaped = tf.transpose(inversed) - return reshaped[:tf.shape(self._features['waveform'])[0], :] + if time_crop is None: + time_crop = tf.shape(self._features['waveform'])[0] + return reshaped[:time_crop, :] - def _build_mwf_output_waveform(self, output_dict): + def _build_mwf_output_waveform(self): """ Perform separation with multichannel Wiener Filtering using Norbert. Note: multichannel Wiener Filtering is not coded in Tensorflow and thus may be quite slow. - :param output_dict: dictionary of estimated spectrogram (key: instrument - name, value: estimated spectrogram of the instrument) :returns: dictionary of separated waveforms (key: instrument name, value: estimated waveform of the instrument) """ import norbert # pylint: disable=import-error - x = self._features[f'{self._mix_name}_stft'] + output_dict = self.model_outputs + x = self.stft_feature v = tf.stack( [ pad_and_reshape( @@ -272,11 +392,13 @@ class EstimatorSpecBuilder(object): mask_shape[-1])) else: raise ValueError(f'Invalid mask_extension parameter {extension}') - n_extra_row = (self._frame_length) // 2 + 1 - self._F + n_extra_row = self._frame_length // 2 + 1 - self._F extension = tf.tile(extension_row, [1, 1, n_extra_row, 1]) return tf.concat([mask, extension], axis=2) - def _build_masks(self, output_dict): + def _build_masks(self): + output_dict = self.model_outputs + stft_feature = self.stft_feature separation_exponent = self._params['separation_exponent'] output_sum = tf.reduce_sum( [e ** separation_exponent for e in output_dict.values()], @@ -297,21 +419,20 @@ class EstimatorSpecBuilder(object): axis=0) instrument_mask = tf.reshape(instrument_mask, new_shape) # Remove padded part (for mask having the same size as STFT); - stft_feature = self._features[f'{self._mix_name}_stft'] + instrument_mask = instrument_mask[ :tf.shape(stft_feature)[0], ...] out[instrument] = instrument_mask - return out + self._masks = out - def _build_masked_stft(self, mask_dict, input_stft=None): - if input_stft is None: - input_stft = self._features[f'{self._mix_name}_stft'] + def _build_masked_stft(self): + input_stft = self.stft_feature out = {} - for instrument, mask in mask_dict.items(): + for instrument, mask in self.masks.items(): out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft return out - def _build_manual_output_waveform(self, output_dict): + def _build_manual_output_waveform(self, masked_stft): """ Perform ratio mask separation :param output_dict: dictionary of estimated spectrogram (key: instrument @@ -321,27 +442,34 @@ class EstimatorSpecBuilder(object): """ output_waveform = {} - masked_stft = self._build_masked_stft(self._build_masks(output_dict)) for instrument, stft_data in masked_stft.items(): output_waveform[instrument] = self._inverse_stft(stft_data) return output_waveform - def _build_output_waveform(self, output_dict): + def _build_output_waveform(self, masked_stft): """ Build output waveform from given output dict in order to be used in prediction context. Regarding of the configuration building method will be using MWF. - :param output_dict: Output dict to build output waveform from. :returns: Built output waveform. """ + if self._params.get('MWF', False): - output_waveform = self._build_mwf_output_waveform(output_dict) + output_waveform = self._build_mwf_output_waveform() else: - output_waveform = self._build_manual_output_waveform(output_dict) - if 'audio_id' in self._features: - output_waveform['audio_id'] = self._features['audio_id'] + output_waveform = self._build_manual_output_waveform(masked_stft) return output_waveform + def _build_outputs(self): + masked_stft = self._build_masked_stft() + if self.include_stft_computations(): + self._outputs = self._build_output_waveform(masked_stft) + else: + self._outputs = masked_stft + + if 'audio_id' in self._features: + self._outputs['audio_id'] = self._features['audio_id'] + def build_predict_model(self): """ Builder interface for creating model instance that aims to perform prediction / inference over given track. The output of such estimator @@ -350,12 +478,10 @@ class EstimatorSpecBuilder(object): :returns: An estimator for performing prediction. """ - self._build_stft_feature() - output_dict = self._build_output_dict() - output_waveform = self._build_output_waveform(output_dict) + return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.PREDICT, - predictions=output_waveform) + predictions=self.outputs) def build_evaluation_model(self, labels): """ Builder interface for creating model instance that aims to perform @@ -366,8 +492,7 @@ class EstimatorSpecBuilder(object): :param labels: Model labels. :returns: An estimator for performing model evaluation. """ - output_dict = self._build_output_dict() - loss, metrics = self._build_loss(output_dict, labels) + loss, metrics = self._build_loss(labels) return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.EVAL, loss=loss, @@ -382,8 +507,7 @@ class EstimatorSpecBuilder(object): :param labels: Model labels. :returns: An estimator for performing model training. """ - output_dict = self._build_output_dict() - loss, metrics = self._build_loss(output_dict, labels) + loss, metrics = self._build_loss(labels) optimizer = self._build_optimizer() train_operation = optimizer.minimize( loss=loss, diff --git a/spleeter/separator.py b/spleeter/separator.py index f7eac1b..1f4d177 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -20,14 +20,15 @@ from multiprocessing import Pool from os.path import basename, join, splitext import numpy as np import tensorflow as tf - +from librosa.core import stft, istft +from scipy.signal.windows import hann from . import SpleeterError from .audio.adapter import get_default_audio_adapter from .audio.convertor import to_stereo from .utils.configuration import load_configuration -from .utils.estimator import create_estimator, to_predictor, get_input_dict_placeholders, get_default_model_dir -from .model import EstimatorSpecBuilder +from .utils.estimator import create_estimator, to_predictor, get_default_model_dir +from .model import EstimatorSpecBuilder, InputProviderFactory __email__ = 'research@deezer.com' @@ -41,7 +42,7 @@ logger = logging.getLogger("spleeter") class Separator(object): """ A wrapper class for performing separation. """ - def __init__(self, params_descriptor, MWF=False, multiprocess=True): + def __init__(self, params_descriptor, MWF=False, stft_backend="auto", multiprocess=True): """ Default constructor. :param params_descriptor: Descriptor for TF params to be used. @@ -53,6 +54,7 @@ class Separator(object): self._predictor = None self._pool = Pool() if multiprocess else None self._tasks = [] + self._params["stft_backend"] = stft_backend def _get_predictor(self): """ Lazy loading access method for internal predictor instance. @@ -120,60 +122,41 @@ class Separator(object): assert chunk_size % d == 0 return chunk_size//d - def separate_chunked(self, waveform, sample_rate, chunk_max_duration): - chunk_size = self.get_valid_chunk_size(sample_rate, chunk_max_duration) - print(f"chunk size is {chunk_size}") - batch_size = self.get_batch_size_for_chunk_size(chunk_size) - print(f"batch size {batch_size}") - T, F = self._params["T"], self._params["F"] + def stft(self, waveform, inverse=False): + N = self._params["frame_length"] + H = self._params["frame_step"] + win = hann(N, sym=False) + fstft = istft if inverse else stft + win_len_arg = "win_length" if inverse else "n_fft" + s1 = fstft(waveform[:, 0], hop_length=H, window=win, center=False, **{win_len_arg: N}) + s2 = fstft(waveform[:, 1], hop_length=H, window=win, center=False, **{win_len_arg: N}) + s1 = np.expand_dims(s1.T, 2-inverse) + s2 = np.expand_dims(s2.T, 2-inverse) + return np.concatenate([s1, s2], axis=2-inverse) + + def separate_librosa(self, waveform, audio_id): out = {} - n_batches = (waveform.shape[0]+batch_size*T*F-1)//(batch_size*T*F) - print(f"{n_batches} to compute") - features = get_input_dict_placeholders(self._params) - spectrogram_input_t = tf.placeholder(tf.float32, shape=(None, T, F, 2), name="spectrogram_input") - istft_input_t = tf.placeholder(tf.complex64, shape=(None, F, 2), name="istft_input") - start_t = tf.placeholder(tf.int32, shape=(), name="start") - end_t = tf.placeholder(tf.int32, shape=(), name="end") + input_provider = InputProviderFactory.get(self._params) + features = input_provider.get_input_dict_placeholders() + builder = EstimatorSpecBuilder(features, self._params) latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir'])) # TODO: fix the logic, build sometimes return, sometimes set attribute - builder._build_stft_feature() - stft_t = builder.get_stft_feature() - output_dict_t = builder._build_output_dict(input_tensor=spectrogram_input_t) - masked_stft_t = builder._build_masked_stft(builder._build_masks(output_dict_t), - input_stft=stft_t[start_t:end_t, :, :]) - output_waveform_t = builder._inverse_stft(istft_input_t) - waveform_t = features["waveform"] - masked_stfts = {} + outputs = builder.outputs + saver = tf.train.Saver() - + stft = self.stft(waveform) with tf.Session() as sess: - print("restoring weights {}".format(time())) saver.restore(sess, latest_checkpoint) - print("computing spectrogram {}".format(time())) - spectrogram, stft = sess.run([builder.get_spectrogram_feature(), stft_t], feed_dict={waveform_t: waveform}) - print(spectrogram.shape) - print(stft.shape) - for i in range(n_batches): - print("computing batch {} {}".format(i, time())) - start = i*batch_size - end = (i+1)*batch_size - tmp = sess.run(masked_stft_t, - feed_dict={spectrogram_input_t: spectrogram[start:end, ...], - start_t: start*T, end_t: end*T, stft_t: stft}) - for instrument, masked_stft in tmp.items(): - masked_stfts.setdefault(instrument, []).append(masked_stft) - - print("inverting spectrogram {}".format(time())) - for instrument, masked_stft in masked_stfts.items(): - out[instrument] = sess.run(output_waveform_t, {istft_input_t: np.concatenate(masked_stft, axis=0)}) - print("done separating {}".format(time())) + outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id)) + for inst in builder.instruments: + out[inst] = self.stft(outputs[inst], inverse=True) return out def separate_to_file( self, audio_descriptor, destination, - audio_adapter=get_default_audio_adapter(), chunk_duration=-1, + audio_adapter=get_default_audio_adapter(), offset=0, duration=600., codec='wav', bitrate='128k', filename_format='{filename}/{instrument}.{codec}', synchronous=True): @@ -198,15 +181,15 @@ class Separator(object): :param filename_format: (Optional) Filename format. :param synchronous: (Optional) True is should by synchronous. """ - print("loading audio {}".format(time())) waveform, sample_rate = audio_adapter.load( audio_descriptor, offset=offset, duration=duration, sample_rate=self._sample_rate) - print("done loading audio {}".format(time())) - sources = self.separate_chunked(waveform, sample_rate, chunk_duration) - print("saving to file {}".format(time())) + if self._params["stft_backend"] == "tensorflow": + sources = self.separate(waveform) + else: + sources = self.separate_librosa(waveform, audio_descriptor) self.save_to_file(sources, audio_descriptor, destination, filename_format, codec, audio_adapter, bitrate, synchronous) diff --git a/spleeter/utils/estimator.py b/spleeter/utils/estimator.py index 2b17e96..a9aa736 100644 --- a/spleeter/utils/estimator.py +++ b/spleeter/utils/estimator.py @@ -13,7 +13,7 @@ import tensorflow as tf from tensorflow.contrib import predictor # pylint: enable=import-error -from ..model import model_fn +from ..model import model_fn, InputProviderFactory from ..model.provider import get_default_model_provider # Default exporting directory for predictor. @@ -59,14 +59,6 @@ def create_estimator(params, MWF): return estimator -def get_input_dict_placeholders(params): - shape = (None, params['n_channels']) - features = { - 'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape, name="waveform"), - 'audio_id': tf.compat.v1.placeholder(tf.string, name="audio_id")} - return features - - def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY): """ Exports given estimator as predictor into the given directory and returns associated tf.predictor instance. @@ -74,8 +66,10 @@ def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY): :param estimator: Estimator to export. :param directory: (Optional) path to write exported model into. """ + + input_provider = InputProviderFactory.get(estimator.params) def receiver(): - features = get_input_dict_placeholders(estimator.params) + features = input_provider.get_input_dict_placeholders() return tf.estimator.export.ServingInputReceiver(features, features) estimator.export_saved_model(directory, receiver) From 31f823aaa4b671ea55328ccf8551f2febe2d9aa9 Mon Sep 17 00:00:00 2001 From: akhlif Date: Wed, 26 Feb 2020 17:51:28 +0100 Subject: [PATCH 04/10] removing useless code --- spleeter/separator.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/spleeter/separator.py b/spleeter/separator.py index 1f4d177..87accd5 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -98,30 +98,6 @@ class Separator(object): prediction.pop('audio_id') return prediction - def get_valid_chunk_size(self, sample_rate: int, chunk_max_duration: float) -> int: - """ - Given a sample rate, and a maximal duration that a chunk can represent, return the maximum chunk - size in samples. The chunk size must be a non-zero multiple of T (temporal dimension of the input spectrogram) - times F (number of frequency bins in the input spectrogram). If no such value exist, we return T*F. - - :param sample_rate: sample rate of the pcm data - :param chunk_max_duration: maximal duration in seconds of a chunk - :return: highest non-zero chunk size of duration less than chunk_max_duration or minimal valid chunk size. - """ - assert chunk_max_duration > 0 - chunk_size = chunk_max_duration * sample_rate - min_sample_size = self._params["T"] * self._params["F"] - if chunk_size < min_sample_size: - min_duration = min_sample_size / sample_rate - logger.warning("chunk_duration must be at least {:.2f} seconds. Ignoring parameter".format(min_duration)) - chunk_size = min_sample_size - return min_sample_size*int(chunk_size/min_sample_size) - - def get_batch_size_for_chunk_size(self, chunk_size): - d = self._params["T"] * self._params["F"] - assert chunk_size % d == 0 - return chunk_size//d - def stft(self, waveform, inverse=False): N = self._params["frame_length"] H = self._params["frame_step"] From 6001ae12a97583481b0651ce0cc6a5a6984ec062 Mon Sep 17 00:00:00 2001 From: akhlif Date: Thu, 27 Feb 2020 11:05:06 +0100 Subject: [PATCH 05/10] Fixing the stft/istft computations --- spleeter/separator.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/spleeter/separator.py b/spleeter/separator.py index 87accd5..73a1a89 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -98,14 +98,23 @@ class Separator(object): prediction.pop('audio_id') return prediction - def stft(self, waveform, inverse=False): + def stft(self, data, inverse=False): + """ + Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two + channels are processed separately and are concatenated together in the result. The expected input formats are: + (n_samples, 2) for stft and (T, F, 2) for istft. + :param data: np.array with either the waveform or the complex spectrogram depending on the parameter inverse + :param inverse: should a stft or an istft be computed. + :return: Stereo data as numpy array for the transform. The channels are stored in the last dimension + """ N = self._params["frame_length"] H = self._params["frame_step"] win = hann(N, sym=False) fstft = istft if inverse else stft - win_len_arg = "win_length" if inverse else "n_fft" - s1 = fstft(waveform[:, 0], hop_length=H, window=win, center=False, **{win_len_arg: N}) - s2 = fstft(waveform[:, 1], hop_length=H, window=win, center=False, **{win_len_arg: N}) + win_len_arg = {"win_length": None} if inverse else {"n_fft": N} + dl, dr = (data[:, :, 0].T, data[:, :, 1].T) if inverse else (data[:, 0], data[:, 1]) + s1 = fstft(dl, hop_length=H, window=win, center=False, **win_len_arg) + s2 = fstft(dr, hop_length=H, window=win, center=False, **win_len_arg) s1 = np.expand_dims(s1.T, 2-inverse) s2 = np.expand_dims(s2.T, 2-inverse) return np.concatenate([s1, s2], axis=2-inverse) @@ -127,7 +136,7 @@ class Separator(object): saver.restore(sess, latest_checkpoint) outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id)) for inst in builder.instruments: - out[inst] = self.stft(outputs[inst], inverse=True) + out[inst] = self.stft(outputs[inst], inverse=True)[:waveform.shape[0], :] return out def separate_to_file( From d177525ea7ed73f816c75a96708bc470d3eb38ef Mon Sep 17 00:00:00 2001 From: akhlif Date: Thu, 27 Feb 2020 14:13:59 +0100 Subject: [PATCH 06/10] Moving get_backend to separator --- spleeter/commands/separate.py | 9 +-------- spleeter/model/__init__.py | 19 ++++++++++++++----- spleeter/separator.py | 11 ++++++++++- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/spleeter/commands/separate.py b/spleeter/commands/separate.py index 2fac8aa..158190c 100644 --- a/spleeter/commands/separate.py +++ b/spleeter/commands/separate.py @@ -11,8 +11,6 @@ -i /path/to/audio1.wav /path/to/audio2.mp3 """ -import tensorflow as tf - from ..audio.adapter import get_audio_adapter from ..separator import Separator @@ -21,11 +19,6 @@ __author__ = 'Deezer Research' __license__ = 'MIT License' -def get_backend(backend): - if backend == "auto": - return "tensorflow" if tf.test.is_gpu_available() else "librosa" - return backend - def entrypoint(arguments, params): """ Command entrypoint. @@ -38,7 +31,7 @@ def entrypoint(arguments, params): separator = Separator( arguments.configuration, MWF=arguments.MWF, - stft_backend=get_backend(arguments.stft_backend)) + stft_backend=arguments.stft_backend) for filename in arguments.inputs: separator.separate_to_file( filename, diff --git a/spleeter/model/__init__.py b/spleeter/model/__init__.py index 4c3e200..b99a9dd 100644 --- a/spleeter/model/__init__.py +++ b/spleeter/model/__init__.py @@ -320,6 +320,12 @@ class EstimatorSpecBuilder(object): self._build_masks() return self._masks + @property + def masked_stfts(self): + if not hasattr(self, "_masked_stfts"): + self._build_masked_stfts() + return self._masked_stfts + def _inverse_stft(self, stft, time_crop=None): """ Inverse and reshape the given STFT @@ -397,6 +403,10 @@ class EstimatorSpecBuilder(object): return tf.concat([mask, extension], axis=2) def _build_masks(self): + """ + Compute masks from the output spectrograms of the model. + :return: + """ output_dict = self.model_outputs stft_feature = self.stft_feature separation_exponent = self._params['separation_exponent'] @@ -425,12 +435,12 @@ class EstimatorSpecBuilder(object): out[instrument] = instrument_mask self._masks = out - def _build_masked_stft(self): + def _build_masked_stfts(self): input_stft = self.stft_feature out = {} for instrument, mask in self.masks.items(): out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft - return out + self._masked_stfts = out def _build_manual_output_waveform(self, masked_stft): """ Perform ratio mask separation @@ -461,11 +471,10 @@ class EstimatorSpecBuilder(object): return output_waveform def _build_outputs(self): - masked_stft = self._build_masked_stft() if self.include_stft_computations(): - self._outputs = self._build_output_waveform(masked_stft) + self._outputs = self._build_output_waveform(self.masked_stfts) else: - self._outputs = masked_stft + self._outputs = self.masked_stfts if 'audio_id' in self._features: self._outputs['audio_id'] = self._features['audio_id'] diff --git a/spleeter/separator.py b/spleeter/separator.py index 73a1a89..668b59d 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -39,6 +39,14 @@ __license__ = 'MIT License' logger = logging.getLogger("spleeter") + +def get_backend(backend): + assert backend in ["auto", "tensorflow", "librosa"] + if backend == "auto": + return "tensorflow" if tf.test.is_gpu_available() else "librosa" + return backend + + class Separator(object): """ A wrapper class for performing separation. """ @@ -48,13 +56,14 @@ class Separator(object): :param params_descriptor: Descriptor for TF params to be used. :param MWF: (Optional) True if MWF should be used, False otherwise. """ + self._params = load_configuration(params_descriptor) self._sample_rate = self._params['sample_rate'] self._MWF = MWF self._predictor = None self._pool = Pool() if multiprocess else None self._tasks = [] - self._params["stft_backend"] = stft_backend + self._params["stft_backend"] = get_backend(stft_backend) def _get_predictor(self): """ Lazy loading access method for internal predictor instance. From 922fcd85bbde235aac58098e7529f0466fc79fa3 Mon Sep 17 00:00:00 2001 From: akhlif Date: Thu, 27 Feb 2020 14:37:02 +0100 Subject: [PATCH 07/10] add librosa dependency --- requirements.txt | 3 ++- setup.py | 3 ++- spleeter/model/__init__.py | 9 ++++----- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/requirements.txt b/requirements.txt index 8e1f7be..cf289ee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ setuptools>=41.0.0 pandas==0.25.1 tensorflow==1.14.0 ffmpeg-python -norbert==0.2.1 \ No newline at end of file +norbert==0.2.1 +librosa==0.7.2 \ No newline at end of file diff --git a/setup.py b/setup.py index b70dcdc..f89042f 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ __license__ = 'MIT License' # Default project values. project_name = 'spleeter' -project_version = '1.4.9' +project_version = '1.5.0' tensorflow_dependency = 'tensorflow' tensorflow_version = '1.14.0' here = path.abspath(path.dirname(__file__)) @@ -56,6 +56,7 @@ setup( 'pandas==0.25.1', 'requests', 'setuptools>=41.0.0', + 'librosa==0.7.2', '{}=={}'.format(tensorflow_dependency, tensorflow_version), ], extras_require={ diff --git a/spleeter/model/__init__.py b/spleeter/model/__init__.py index b99a9dd..6c2fe3a 100644 --- a/spleeter/model/__init__.py +++ b/spleeter/model/__init__.py @@ -18,7 +18,6 @@ __author__ = 'Deezer Research' __license__ = 'MIT License' - placeholder = tf.compat.v1.placeholder @@ -326,14 +325,14 @@ class EstimatorSpecBuilder(object): self._build_masked_stfts() return self._masked_stfts - def _inverse_stft(self, stft, time_crop=None): + def _inverse_stft(self, stft_t, time_crop=None): """ Inverse and reshape the given STFT - :param stft: input STFT + :param stft_t: input STFT :returns: inverse STFT (waveform) """ inversed = inverse_stft( - tf.transpose(stft, perm=[2, 0, 1]), + tf.transpose(stft_t, perm=[2, 0, 1]), self._frame_length, self._frame_step, window_fn=lambda frame_length, dtype: ( @@ -419,7 +418,7 @@ class EstimatorSpecBuilder(object): output = output_dict[f'{instrument}_spectrogram'] # Compute mask with the model. instrument_mask = (output ** separation_exponent - + (self.EPSILON / len(output_dict))) / output_sum + + (self.EPSILON / len(output_dict))) / output_sum # Extend mask; instrument_mask = self._extend_mask(instrument_mask) # Stack back mask. From 3cba6985f410750f9704303668e4093dde29f7a3 Mon Sep 17 00:00:00 2001 From: akhlif Date: Thu, 27 Feb 2020 15:38:46 +0100 Subject: [PATCH 08/10] Updating tests to test for librosa backend --- spleeter/model/__init__.py | 2 +- spleeter/separator.py | 23 ++++++++++++++--------- tests/test_separator.py | 37 +++++++++++++++++++++---------------- 3 files changed, 36 insertions(+), 26 deletions(-) diff --git a/spleeter/model/__init__.py b/spleeter/model/__init__.py index 6c2fe3a..531f84c 100644 --- a/spleeter/model/__init__.py +++ b/spleeter/model/__init__.py @@ -105,7 +105,7 @@ class InputProviderFactory(object): @staticmethod def get(params): stft_backend = params["stft_backend"] - assert stft_backend in ("tensorflow", "librosa") + assert stft_backend in ("tensorflow", "librosa"), "Unexpected backend {}".format(stft_backend) if stft_backend == "tensorflow": return WaveformInputProvider(params) else: diff --git a/spleeter/separator.py b/spleeter/separator.py index 668b59d..174c73f 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -85,7 +85,7 @@ class Separator(object): task.get() task.wait(timeout=timeout) - def separate(self, waveform): + def separate_tensorflow(self, waveform, audio_descriptor): """ Performs source separation over the given waveform. The separation is performed synchronously but the result @@ -103,11 +103,11 @@ class Separator(object): predictor = self._get_predictor() prediction = predictor({ 'waveform': waveform, - 'audio_id': ''}) + 'audio_id': audio_descriptor}) prediction.pop('audio_id') return prediction - def stft(self, data, inverse=False): + def stft(self, data, inverse=False, length=None): """ Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two channels are processed separately and are concatenated together in the result. The expected input formats are: @@ -116,11 +116,13 @@ class Separator(object): :param inverse: should a stft or an istft be computed. :return: Stereo data as numpy array for the transform. The channels are stored in the last dimension """ + assert not (inverse and length is None) + data = np.asfortranarray(data) N = self._params["frame_length"] H = self._params["frame_step"] win = hann(N, sym=False) fstft = istft if inverse else stft - win_len_arg = {"win_length": None} if inverse else {"n_fft": N} + win_len_arg = {"win_length": None, "length": length} if inverse else {"n_fft": N} dl, dr = (data[:, :, 0].T, data[:, :, 1].T) if inverse else (data[:, 0], data[:, 1]) s1 = fstft(dl, hop_length=H, window=win, center=False, **win_len_arg) s2 = fstft(dr, hop_length=H, window=win, center=False, **win_len_arg) @@ -145,9 +147,15 @@ class Separator(object): saver.restore(sess, latest_checkpoint) outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id)) for inst in builder.instruments: - out[inst] = self.stft(outputs[inst], inverse=True)[:waveform.shape[0], :] + out[inst] = self.stft(outputs[inst], inverse=True, length=waveform.shape[0]) return out + def separate(self, waveform, audio_descriptor): + if self._params["stft_backend"] == "tensorflow": + return self.separate_tensorflow(waveform, audio_descriptor) + else: + return self.separate_librosa(waveform, audio_descriptor) + def separate_to_file( self, audio_descriptor, destination, audio_adapter=get_default_audio_adapter(), @@ -180,10 +188,7 @@ class Separator(object): offset=offset, duration=duration, sample_rate=self._sample_rate) - if self._params["stft_backend"] == "tensorflow": - sources = self.separate(waveform) - else: - sources = self.separate_librosa(waveform, audio_descriptor) + sources = self.separate(waveform, audio_descriptor) self.save_to_file(sources, audio_descriptor, destination, filename_format, codec, audio_adapter, bitrate, synchronous) diff --git a/tests/test_separator.py b/tests/test_separator.py index 9235731..271fdfb 100644 --- a/tests/test_separator.py +++ b/tests/test_separator.py @@ -13,6 +13,7 @@ from os.path import splitext, basename, exists, join from tempfile import TemporaryDirectory import pytest +import numpy as np from spleeter import SpleeterError from spleeter.audio.adapter import get_default_audio_adapter @@ -21,34 +22,38 @@ from spleeter.separator import Separator TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3' TEST_AUDIO_BASENAME = splitext(basename(TEST_AUDIO_DESCRIPTOR))[0] TEST_CONFIGURATIONS = [ - ('spleeter:2stems', ('vocals', 'accompaniment')), - ('spleeter:4stems', ('vocals', 'drums', 'bass', 'other')), - ('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other')) + ('spleeter:2stems', ('vocals', 'accompaniment'), 'tensorflow'), + ('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'tensorflow'), + ('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'tensorflow'), + ('spleeter:2stems', ('vocals', 'accompaniment'), 'librosa'), + ('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'librosa'), + ('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'librosa') ] -@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS) -def test_separate(configuration, instruments): +@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS) +def test_separate(configuration, instruments, backend): """ Test separation from raw data. """ adapter = get_default_audio_adapter() waveform, _ = adapter.load(TEST_AUDIO_DESCRIPTOR) - separator = Separator(configuration) - prediction = separator.separate(waveform) + separator = Separator(configuration, stft_backend=backend) + prediction = separator.separate(waveform, TEST_AUDIO_DESCRIPTOR) assert len(prediction) == len(instruments) for instrument in instruments: assert instrument in prediction for instrument in instruments: track = prediction[instrument] - assert not (waveform == track).all() + assert waveform.shape == track.shape + assert not np.allclose(waveform, track) for compared in instruments: if instrument != compared: - assert not (track == prediction[compared]).all() + assert not np.allclose(track, prediction[compared]) -@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS) -def test_separate_to_file(configuration, instruments): +@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS) +def test_separate_to_file(configuration, instruments, backend): """ Test file based separation. """ - separator = Separator(configuration) + separator = Separator(configuration, stft_backend=backend) with TemporaryDirectory() as directory: separator.separate_to_file( TEST_AUDIO_DESCRIPTOR, @@ -59,10 +64,10 @@ def test_separate_to_file(configuration, instruments): '{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument))) -@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS) -def test_filename_format(configuration, instruments): +@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS) +def test_filename_format(configuration, instruments, backend): """ Test custom filename format. """ - separator = Separator(configuration) + separator = Separator(configuration, stft_backend=backend) with TemporaryDirectory() as directory: separator.separate_to_file( TEST_AUDIO_DESCRIPTOR, @@ -74,7 +79,7 @@ def test_filename_format(configuration, instruments): 'export/{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument))) -def test_filename_confilct(): +def test_filename_conflict(): """ Test error handling with static pattern. """ separator = Separator(TEST_CONFIGURATIONS[0][0]) with TemporaryDirectory() as directory: From 2365a9b37fa27782f48a59bd7689f92c92cdec5d Mon Sep 17 00:00:00 2001 From: akhlif Date: Fri, 20 Mar 2020 15:39:11 +0100 Subject: [PATCH 09/10] Removing duplicate logs --- spleeter/commands/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/spleeter/commands/__init__.py b/spleeter/commands/__init__.py index 452b3a6..995e866 100644 --- a/spleeter/commands/__init__.py +++ b/spleeter/commands/__init__.py @@ -13,7 +13,6 @@ __email__ = 'research@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' -logging.basicConfig() # -i opt specification (separate). From 6539bf779b0fbe825718be140cb44138b40326ae Mon Sep 17 00:00:00 2001 From: akhlif Date: Fri, 20 Mar 2020 15:40:55 +0100 Subject: [PATCH 10/10] Bumping tensorflow version to do 1 PR merge instead of two --- requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index cf289ee..7f3bbf0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ importlib_resources; python_version<'3.7' requests setuptools>=41.0.0 pandas==0.25.1 -tensorflow==1.14.0 +tensorflow==1.15.0 ffmpeg-python norbert==0.2.1 librosa==0.7.2 \ No newline at end of file diff --git a/setup.py b/setup.py index f89042f..569ef49 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ __license__ = 'MIT License' project_name = 'spleeter' project_version = '1.5.0' tensorflow_dependency = 'tensorflow' -tensorflow_version = '1.14.0' +tensorflow_version = '1.15.0' here = path.abspath(path.dirname(__file__)) readme_path = path.join(here, 'README.md') with open(readme_path, 'r') as stream: