From aa7c208b3903192ce92584c42e99263eda764302 Mon Sep 17 00:00:00 2001 From: akhlif Date: Wed, 19 Feb 2020 23:11:16 +0100 Subject: [PATCH] First draft implementing correct reconstruction --- spleeter/commands/__init__.py | 5 +- spleeter/commands/separate.py | 1 + spleeter/model/__init__.py | 56 +++++++++++++------- spleeter/separator.py | 97 +++++++++++++++++++++++++++++++---- spleeter/utils/estimator.py | 31 ++++++++--- 5 files changed, 152 insertions(+), 38 deletions(-) diff --git a/spleeter/commands/__init__.py b/spleeter/commands/__init__.py index 9461a07..21b70ef 100644 --- a/spleeter/commands/__init__.py +++ b/spleeter/commands/__init__.py @@ -4,7 +4,7 @@ """ This modules provides spleeter command as well as CLI parsing methods. """ import json - +import logging from argparse import ArgumentParser from tempfile import gettempdir from os.path import exists, join @@ -13,6 +13,9 @@ __email__ = 'research@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' +logging.basicConfig() + + # -i opt specification (separate). OPT_INPUT = { 'dest': 'inputs', diff --git a/spleeter/commands/separate.py b/spleeter/commands/separate.py index b1f41ab..b3de4b6 100644 --- a/spleeter/commands/separate.py +++ b/spleeter/commands/separate.py @@ -35,6 +35,7 @@ def entrypoint(arguments, params): filename, arguments.output_path, audio_adapter=audio_adapter, + chunk_duration=arguments.chunk_duration, offset=arguments.offset, duration=arguments.duration, codec=arguments.codec, diff --git a/spleeter/model/__init__.py b/spleeter/model/__init__.py index 384e838..4d18eb5 100644 --- a/spleeter/model/__init__.py +++ b/spleeter/model/__init__.py @@ -106,7 +106,7 @@ class EstimatorSpecBuilder(object): self._frame_length = params['frame_length'] self._frame_step = params['frame_step'] - def _build_output_dict(self): + def _build_output_dict(self, input_tensor=None): """ Created a batch_sizexTxFxn_channels input tensor containing mix magnitude spectrogram, then an output dict from it according to the selected model in internal parameters. @@ -114,7 +114,8 @@ class EstimatorSpecBuilder(object): :returns: Build output dict. :raise ValueError: If required model_type is not supported. """ - input_tensor = self._features[f'{self._mix_name}_spectrogram'] + if input_tensor is None: + input_tensor = self._features[f'{self._mix_name}_spectrogram'] model = self._params.get('model', None) if model is not None: model_type = model.get('type', self.DEFAULT_MODEL) @@ -194,6 +195,12 @@ class EstimatorSpecBuilder(object): self._features[f'{self._mix_name}_spectrogram'] = tf.abs( pad_and_partition(stft_feature, self._T))[:, :, :self._F, :] + def get_stft_feature(self): + return self._features[f'{self._mix_name}_stft'] + + def get_spectrogram_feature(self): + return self._features[f'{self._mix_name}_spectrogram'] + def _inverse_stft(self, stft): """ Inverse and reshape the given STFT @@ -269,26 +276,18 @@ class EstimatorSpecBuilder(object): extension = tf.tile(extension_row, [1, 1, n_extra_row, 1]) return tf.concat([mask, extension], axis=2) - def _build_manual_output_waveform(self, output_dict): - """ Perform ratio mask separation - - :param output_dict: dictionary of estimated spectrogram (key: instrument - name, value: estimated spectrogram of the instrument) - :returns: dictionary of separated waveforms (key: instrument name, - value: estimated waveform of the instrument) - """ + def _build_masks(self, output_dict): separation_exponent = self._params['separation_exponent'] output_sum = tf.reduce_sum( [e ** separation_exponent for e in output_dict.values()], axis=0 ) + self.EPSILON - output_waveform = {} + out = {} for instrument in self._instruments: output = output_dict[f'{instrument}_spectrogram'] # Compute mask with the model. - instrument_mask = ( - output ** separation_exponent - + (self.EPSILON / len(output_dict))) / output_sum + instrument_mask = (output ** separation_exponent + + (self.EPSILON / len(output_dict))) / output_sum # Extend mask; instrument_mask = self._extend_mask(instrument_mask) # Stack back mask. @@ -300,10 +299,31 @@ class EstimatorSpecBuilder(object): # Remove padded part (for mask having the same size as STFT); stft_feature = self._features[f'{self._mix_name}_stft'] instrument_mask = instrument_mask[ - :tf.shape(stft_feature)[0], ...] - # Compute masked STFT and normalize it. - output_waveform[instrument] = self._inverse_stft( - tf.cast(instrument_mask, dtype=tf.complex64) * stft_feature) + :tf.shape(stft_feature)[0], ...] + out[instrument] = instrument_mask + return out + + def _build_masked_stft(self, mask_dict, input_stft=None): + if input_stft is None: + input_stft = self._features[f'{self._mix_name}_stft'] + out = {} + for instrument, mask in mask_dict.items(): + out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft + return out + + def _build_manual_output_waveform(self, output_dict): + """ Perform ratio mask separation + + :param output_dict: dictionary of estimated spectrogram (key: instrument + name, value: estimated spectrogram of the instrument) + :returns: dictionary of separated waveforms (key: instrument name, + value: estimated waveform of the instrument) + """ + + output_waveform = {} + masked_stft = self._build_masked_stft(self._build_masks(output_dict)) + for instrument, stft_data in masked_stft.items(): + output_waveform[instrument] = self._inverse_stft(stft_data) return output_waveform def _build_output_waveform(self, output_dict): diff --git a/spleeter/separator.py b/spleeter/separator.py index 4423888..f7eac1b 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -13,22 +13,31 @@ """ import os +import logging +from time import time from multiprocessing import Pool from os.path import basename, join, splitext import numpy as np +import tensorflow as tf + from . import SpleeterError from .audio.adapter import get_default_audio_adapter from .audio.convertor import to_stereo from .utils.configuration import load_configuration -from .utils.estimator import create_estimator, to_predictor +from .utils.estimator import create_estimator, to_predictor, get_input_dict_placeholders, get_default_model_dir +from .model import EstimatorSpecBuilder + __email__ = 'research@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' +logger = logging.getLogger("spleeter") + + class Separator(object): """ A wrapper class for performing separation. """ @@ -87,16 +96,79 @@ class Separator(object): prediction.pop('audio_id') return prediction - def separate_chunked(self, waveform, sample_rate, chunk_duration=-1): - chunk_size = waveform.shape[0] if chunk_duration == -1 else chunk_duration*sample_rate - n_chunks = int(waveform.shape[0]/chunk_size) + def get_valid_chunk_size(self, sample_rate: int, chunk_max_duration: float) -> int: + """ + Given a sample rate, and a maximal duration that a chunk can represent, return the maximum chunk + size in samples. The chunk size must be a non-zero multiple of T (temporal dimension of the input spectrogram) + times F (number of frequency bins in the input spectrogram). If no such value exist, we return T*F. + + :param sample_rate: sample rate of the pcm data + :param chunk_max_duration: maximal duration in seconds of a chunk + :return: highest non-zero chunk size of duration less than chunk_max_duration or minimal valid chunk size. + """ + assert chunk_max_duration > 0 + chunk_size = chunk_max_duration * sample_rate + min_sample_size = self._params["T"] * self._params["F"] + if chunk_size < min_sample_size: + min_duration = min_sample_size / sample_rate + logger.warning("chunk_duration must be at least {:.2f} seconds. Ignoring parameter".format(min_duration)) + chunk_size = min_sample_size + return min_sample_size*int(chunk_size/min_sample_size) + + def get_batch_size_for_chunk_size(self, chunk_size): + d = self._params["T"] * self._params["F"] + assert chunk_size % d == 0 + return chunk_size//d + + def separate_chunked(self, waveform, sample_rate, chunk_max_duration): + chunk_size = self.get_valid_chunk_size(sample_rate, chunk_max_duration) + print(f"chunk size is {chunk_size}") + batch_size = self.get_batch_size_for_chunk_size(chunk_size) + print(f"batch size {batch_size}") + T, F = self._params["T"], self._params["F"] out = {} - for i in range(n_chunks): - sources = self.separate(waveform) - for inst, data in sources.items(): - out.setdefault(inst, []).append(data) - for inst, data in out.items(): - out[inst] = np.concatenate(data, axis=0) + n_batches = (waveform.shape[0]+batch_size*T*F-1)//(batch_size*T*F) + print(f"{n_batches} to compute") + features = get_input_dict_placeholders(self._params) + spectrogram_input_t = tf.placeholder(tf.float32, shape=(None, T, F, 2), name="spectrogram_input") + istft_input_t = tf.placeholder(tf.complex64, shape=(None, F, 2), name="istft_input") + start_t = tf.placeholder(tf.int32, shape=(), name="start") + end_t = tf.placeholder(tf.int32, shape=(), name="end") + builder = EstimatorSpecBuilder(features, self._params) + latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir'])) + + # TODO: fix the logic, build sometimes return, sometimes set attribute + builder._build_stft_feature() + stft_t = builder.get_stft_feature() + output_dict_t = builder._build_output_dict(input_tensor=spectrogram_input_t) + masked_stft_t = builder._build_masked_stft(builder._build_masks(output_dict_t), + input_stft=stft_t[start_t:end_t, :, :]) + output_waveform_t = builder._inverse_stft(istft_input_t) + waveform_t = features["waveform"] + masked_stfts = {} + saver = tf.train.Saver() + + with tf.Session() as sess: + print("restoring weights {}".format(time())) + saver.restore(sess, latest_checkpoint) + print("computing spectrogram {}".format(time())) + spectrogram, stft = sess.run([builder.get_spectrogram_feature(), stft_t], feed_dict={waveform_t: waveform}) + print(spectrogram.shape) + print(stft.shape) + for i in range(n_batches): + print("computing batch {} {}".format(i, time())) + start = i*batch_size + end = (i+1)*batch_size + tmp = sess.run(masked_stft_t, + feed_dict={spectrogram_input_t: spectrogram[start:end, ...], + start_t: start*T, end_t: end*T, stft_t: stft}) + for instrument, masked_stft in tmp.items(): + masked_stfts.setdefault(instrument, []).append(masked_stft) + + print("inverting spectrogram {}".format(time())) + for instrument, masked_stft in masked_stfts.items(): + out[instrument] = sess.run(output_waveform_t, {istft_input_t: np.concatenate(masked_stft, axis=0)}) + print("done separating {}".format(time())) return out def separate_to_file( @@ -126,12 +198,15 @@ class Separator(object): :param filename_format: (Optional) Filename format. :param synchronous: (Optional) True is should by synchronous. """ + print("loading audio {}".format(time())) waveform, sample_rate = audio_adapter.load( audio_descriptor, offset=offset, duration=duration, sample_rate=self._sample_rate) - sources = self.separate_chunked(waveform, sample_rate, chunk_duration=chunk_duration) + print("done loading audio {}".format(time())) + sources = self.separate_chunked(waveform, sample_rate, chunk_duration) + print("saving to file {}".format(time())) self.save_to_file(sources, audio_descriptor, destination, filename_format, codec, audio_adapter, bitrate, synchronous) diff --git a/spleeter/utils/estimator.py b/spleeter/utils/estimator.py index 95c4219..2b17e96 100644 --- a/spleeter/utils/estimator.py +++ b/spleeter/utils/estimator.py @@ -20,20 +20,30 @@ from ..model.provider import get_default_model_provider DEFAULT_EXPORT_DIRECTORY = join(gettempdir(), 'serving') + +def get_default_model_dir(model_dir): + """ + Transforms a string like 'spleeter:2stems' into an actual path. + :param model_dir: + :return: + """ + model_provider = get_default_model_provider() + return model_provider.get(model_dir) + def create_estimator(params, MWF): """ Initialize tensorflow estimator that will perform separation Params: - - params: a dictionnary of parameters for building the model + - params: a dictionary of parameters for building the model Returns: a tensorflow estimator """ # Load model. - model_directory = params['model_dir'] - model_provider = get_default_model_provider() - params['model_dir'] = model_provider.get(model_directory) + + + params['model_dir'] = get_default_model_dir(params['model_dir']) params['MWF'] = MWF # Setup config session_config = tf.compat.v1.ConfigProto() @@ -49,6 +59,14 @@ def create_estimator(params, MWF): return estimator +def get_input_dict_placeholders(params): + shape = (None, params['n_channels']) + features = { + 'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape, name="waveform"), + 'audio_id': tf.compat.v1.placeholder(tf.string, name="audio_id")} + return features + + def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY): """ Exports given estimator as predictor into the given directory and returns associated tf.predictor instance. @@ -57,10 +75,7 @@ def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY): :param directory: (Optional) path to write exported model into. """ def receiver(): - shape = (None, estimator.params['n_channels']) - features = { - 'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape), - 'audio_id': tf.compat.v1.placeholder(tf.string)} + features = get_input_dict_placeholders(estimator.params) return tf.estimator.export.ServingInputReceiver(features, features) estimator.export_saved_model(directory, receiver)