Files
spleeter/spleeter/model/__init__.py
2020-02-27 15:38:46 +01:00

550 lines
20 KiB
Python

#!/usr/bin/env python
# coding: utf8
""" This package provide an estimator builder as well as model functions. """
import importlib
# pylint: disable=import-error
import tensorflow as tf
from tensorflow.contrib.signal import stft, inverse_stft, hann_window
# pylint: enable=import-error
from ..utils.tensor import pad_and_partition, pad_and_reshape
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
placeholder = tf.compat.v1.placeholder
def get_model_function(model_type):
"""
Get tensorflow function of the model to be applied to the input tensor.
For instance "unet.softmax_unet" will return the softmax_unet function
in the "unet.py" submodule of the current module (spleeter.model).
Params:
- model_type: str
the relative module path to the model function.
Returns:
A tensorflow function to be applied to the input tensor to get the
multitrack output.
"""
relative_path_to_module = '.'.join(model_type.split('.')[:-1])
model_name = model_type.split('.')[-1]
main_module = '.'.join((__name__, 'functions'))
path_to_module = f'{main_module}.{relative_path_to_module}'
module = importlib.import_module(path_to_module)
model_function = getattr(module, model_name)
return model_function
class InputProvider(object):
def __init__(self, params):
self.params = params
def get_input_dict_placeholders(self):
raise NotImplementedError()
@property
def input_names(self):
raise NotImplementedError()
def get_feed_dict(self, features, *args):
raise NotImplementedError()
class WaveformInputProvider(InputProvider):
@property
def input_names(self):
return ["audio_id", "waveform"]
def get_input_dict_placeholders(self):
shape = (None, self.params['n_channels'])
features = {
'waveform': placeholder(tf.float32, shape=shape, name="waveform"),
'audio_id': placeholder(tf.string, name="audio_id")}
return features
def get_feed_dict(self, features, waveform, audio_id):
return {features["audio_id"]: audio_id, features["waveform"]: waveform}
class SpectralInputProvider(InputProvider):
def __init__(self, params):
super().__init__(params)
self.stft_input_name = "{}_stft".format(self.params["mix_name"])
@property
def input_names(self):
return ["audio_id", self.stft_input_name]
def get_input_dict_placeholders(self):
features = {
self.stft_input_name: placeholder(tf.complex64,
shape=(None, self.params["frame_length"]//2+1,
self.params['n_channels']),
name=self.stft_input_name),
'audio_id': placeholder(tf.string, name="audio_id")}
return features
def get_feed_dict(self, features, stft, audio_id):
return {features["audio_id"]: audio_id, features[self.stft_input_name]: stft}
class InputProviderFactory(object):
@staticmethod
def get(params):
stft_backend = params["stft_backend"]
assert stft_backend in ("tensorflow", "librosa"), "Unexpected backend {}".format(stft_backend)
if stft_backend == "tensorflow":
return WaveformInputProvider(params)
else:
return SpectralInputProvider(params)
class EstimatorSpecBuilder(object):
""" A builder class that allows to builds a multitrack unet model
estimator. The built model estimator has a different behaviour when
used in a train/eval mode and in predict mode.
* In train/eval mode: it takes as input and outputs magnitude spectrogram
* In predict mode: it takes as input and outputs waveform. The whole
separation process is then done in this function
for performance reason: it makes it possible to run
the whole spearation process (including STFT and
inverse STFT) on GPU.
:Example:
>>> from spleeter.model import EstimatorSpecBuilder
>>> builder = EstimatorSpecBuilder()
>>> builder.build_predict_model()
>>> builder.build_evaluation_model()
>>> builder.build_train_model()
>>> from spleeter.model import model_fn
>>> estimator = tf.estimator.Estimator(model_fn=model_fn, ...)
"""
# Supported model functions.
DEFAULT_MODEL = 'unet.unet'
# Supported loss functions.
L1_MASK = 'L1_mask'
WEIGHTED_L1_MASK = 'weighted_L1_mask'
# Supported optimizers.
ADADELTA = 'Adadelta'
SGD = 'SGD'
# Math constants.
WINDOW_COMPENSATION_FACTOR = 2./3.
EPSILON = 1e-10
def __init__(self, features, params):
""" Default constructor. Depending on built model
usage, the provided features should be different:
* In train/eval mode: features is a dictionary with a
"mix_spectrogram" key, associated to the
mix magnitude spectrogram.
* In predict mode: features is a dictionary with a "waveform"
key, associated to the waveform of the sound
to be separated.
:param features: The input features for the estimator.
:param params: Some hyperparameters as a dictionary.
"""
self._features = features
self._params = params
# Get instrument name.
self._mix_name = params['mix_name']
self._instruments = params['instrument_list']
# Get STFT/signals parameters
self._n_channels = params['n_channels']
self._T = params['T']
self._F = params['F']
self._frame_length = params['frame_length']
self._frame_step = params['frame_step']
def include_stft_computations(self):
return self._params["stft_backend"] == "tensorflow"
def _build_model_outputs(self):
""" Created a batch_sizexTxFxn_channels input tensor containing
mix magnitude spectrogram, then an output dict from it according
to the selected model in internal parameters.
:returns: Build output dict.
:raise ValueError: If required model_type is not supported.
"""
input_tensor = self.spectrogram_feature
model = self._params.get('model', None)
if model is not None:
model_type = model.get('type', self.DEFAULT_MODEL)
else:
model_type = self.DEFAULT_MODEL
try:
apply_model = get_model_function(model_type)
except ModuleNotFoundError:
raise ValueError(f'No model function {model_type} found')
self._model_outputs = apply_model(
input_tensor,
self._instruments,
self._params['model']['params'])
def _build_loss(self, labels):
""" Construct tensorflow loss and metrics
:param output_dict: dictionary of network outputs (key: instrument
name, value: estimated spectrogram of the instrument)
:param labels: dictionary of target outputs (key: instrument
name, value: ground truth spectrogram of the instrument)
:returns: tensorflow (loss, metrics) tuple.
"""
output_dict = self.model_outputs
loss_type = self._params.get('loss_type', self.L1_MASK)
if loss_type == self.L1_MASK:
losses = {
name: tf.reduce_mean(tf.abs(output - labels[name]))
for name, output in output_dict.items()
}
elif loss_type == self.WEIGHTED_L1_MASK:
losses = {
name: tf.reduce_mean(
tf.reduce_mean(
labels[name],
axis=[1, 2, 3],
keep_dims=True) *
tf.abs(output - labels[name]))
for name, output in output_dict.items()
}
else:
raise ValueError(f"Unkwnown loss type: {loss_type}")
loss = tf.reduce_sum(list(losses.values()))
# Add metrics for monitoring each instrument.
metrics = {k: tf.compat.v1.metrics.mean(v) for k, v in losses.items()}
metrics['absolute_difference'] = tf.compat.v1.metrics.mean(loss)
return loss, metrics
def _build_optimizer(self):
""" Builds an optimizer instance from internal parameter values.
Default to AdamOptimizer if not specified.
:returns: Optimizer instance from internal configuration.
"""
name = self._params.get('optimizer')
if name == self.ADADELTA:
return tf.compat.v1.train.AdadeltaOptimizer()
rate = self._params['learning_rate']
if name == self.SGD:
return tf.compat.v1.train.GradientDescentOptimizer(rate)
return tf.compat.v1.train.AdamOptimizer(rate)
@property
def instruments(self):
return self._instruments
@property
def stft_name(self):
return f'{self._mix_name}_stft'
@property
def spectrogram_name(self):
return f'{self._mix_name}_spectrogram'
def _build_stft_feature(self):
""" Compute STFT of waveform and slice the STFT in segment
with the right length to feed the network.
"""
stft_name = self.stft_name
spec_name = self.spectrogram_name
if stft_name not in self._features:
stft_feature = tf.transpose(
stft(
tf.transpose(self._features['waveform']),
self._frame_length,
self._frame_step,
window_fn=lambda frame_length, dtype: (
hann_window(frame_length, periodic=True, dtype=dtype)),
pad_end=True),
perm=[1, 2, 0])
self._features[f'{self._mix_name}_stft'] = stft_feature
if spec_name not in self._features:
self._features[spec_name] = tf.abs(
pad_and_partition(self._features[stft_name], self._T))[:, :, :self._F, :]
@property
def model_outputs(self):
if not hasattr(self, "_model_outputs"):
self._build_model_outputs()
return self._model_outputs
@property
def outputs(self):
if not hasattr(self, "_outputs"):
self._build_outputs()
return self._outputs
@property
def stft_feature(self):
if self.stft_name not in self._features:
self._build_stft_feature()
return self._features[self.stft_name]
@property
def spectrogram_feature(self):
if self.spectrogram_name not in self._features:
self._build_stft_feature()
return self._features[self.spectrogram_name]
@property
def masks(self):
if not hasattr(self, "_masks"):
self._build_masks()
return self._masks
@property
def masked_stfts(self):
if not hasattr(self, "_masked_stfts"):
self._build_masked_stfts()
return self._masked_stfts
def _inverse_stft(self, stft_t, time_crop=None):
""" Inverse and reshape the given STFT
:param stft_t: input STFT
:returns: inverse STFT (waveform)
"""
inversed = inverse_stft(
tf.transpose(stft_t, perm=[2, 0, 1]),
self._frame_length,
self._frame_step,
window_fn=lambda frame_length, dtype: (
hann_window(frame_length, periodic=True, dtype=dtype))
) * self.WINDOW_COMPENSATION_FACTOR
reshaped = tf.transpose(inversed)
if time_crop is None:
time_crop = tf.shape(self._features['waveform'])[0]
return reshaped[:time_crop, :]
def _build_mwf_output_waveform(self):
""" Perform separation with multichannel Wiener Filtering using Norbert.
Note: multichannel Wiener Filtering is not coded in Tensorflow and thus
may be quite slow.
:returns: dictionary of separated waveforms (key: instrument name,
value: estimated waveform of the instrument)
"""
import norbert # pylint: disable=import-error
output_dict = self.model_outputs
x = self.stft_feature
v = tf.stack(
[
pad_and_reshape(
output_dict[f'{instrument}_spectrogram'],
self._frame_length,
self._F)[:tf.shape(x)[0], ...]
for instrument in self._instruments
],
axis=3)
input_args = [v, x]
stft_function = tf.py_function(
lambda v, x: norbert.wiener(v.numpy(), x.numpy()),
input_args,
tf.complex64),
return {
instrument: self._inverse_stft(stft_function[0][:, :, :, k])
for k, instrument in enumerate(self._instruments)
}
def _extend_mask(self, mask):
""" Extend mask, from reduced number of frequency bin to the number of
frequency bin in the STFT.
:param mask: restricted mask
:returns: extended mask
:raise ValueError: If invalid mask_extension parameter is set.
"""
extension = self._params['mask_extension']
# Extend with average
# (dispatch according to energy in the processed band)
if extension == "average":
extension_row = tf.reduce_mean(mask, axis=2, keepdims=True)
# Extend with 0
# (avoid extension artifacts but not conservative separation)
elif extension == "zeros":
mask_shape = tf.shape(mask)
extension_row = tf.zeros((
mask_shape[0],
mask_shape[1],
1,
mask_shape[-1]))
else:
raise ValueError(f'Invalid mask_extension parameter {extension}')
n_extra_row = self._frame_length // 2 + 1 - self._F
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
return tf.concat([mask, extension], axis=2)
def _build_masks(self):
"""
Compute masks from the output spectrograms of the model.
:return:
"""
output_dict = self.model_outputs
stft_feature = self.stft_feature
separation_exponent = self._params['separation_exponent']
output_sum = tf.reduce_sum(
[e ** separation_exponent for e in output_dict.values()],
axis=0
) + self.EPSILON
out = {}
for instrument in self._instruments:
output = output_dict[f'{instrument}_spectrogram']
# Compute mask with the model.
instrument_mask = (output ** separation_exponent
+ (self.EPSILON / len(output_dict))) / output_sum
# Extend mask;
instrument_mask = self._extend_mask(instrument_mask)
# Stack back mask.
old_shape = tf.shape(instrument_mask)
new_shape = tf.concat(
[[old_shape[0] * old_shape[1]], old_shape[2:]],
axis=0)
instrument_mask = tf.reshape(instrument_mask, new_shape)
# Remove padded part (for mask having the same size as STFT);
instrument_mask = instrument_mask[
:tf.shape(stft_feature)[0], ...]
out[instrument] = instrument_mask
self._masks = out
def _build_masked_stfts(self):
input_stft = self.stft_feature
out = {}
for instrument, mask in self.masks.items():
out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft
self._masked_stfts = out
def _build_manual_output_waveform(self, masked_stft):
""" Perform ratio mask separation
:param output_dict: dictionary of estimated spectrogram (key: instrument
name, value: estimated spectrogram of the instrument)
:returns: dictionary of separated waveforms (key: instrument name,
value: estimated waveform of the instrument)
"""
output_waveform = {}
for instrument, stft_data in masked_stft.items():
output_waveform[instrument] = self._inverse_stft(stft_data)
return output_waveform
def _build_output_waveform(self, masked_stft):
""" Build output waveform from given output dict in order to be used in
prediction context. Regarding of the configuration building method will
be using MWF.
:returns: Built output waveform.
"""
if self._params.get('MWF', False):
output_waveform = self._build_mwf_output_waveform()
else:
output_waveform = self._build_manual_output_waveform(masked_stft)
return output_waveform
def _build_outputs(self):
if self.include_stft_computations():
self._outputs = self._build_output_waveform(self.masked_stfts)
else:
self._outputs = self.masked_stfts
if 'audio_id' in self._features:
self._outputs['audio_id'] = self._features['audio_id']
def build_predict_model(self):
""" Builder interface for creating model instance that aims to perform
prediction / inference over given track. The output of such estimator
will be a dictionary with a "<instrument>" key per separated instrument
, associated to the estimated separated waveform of the instrument.
:returns: An estimator for performing prediction.
"""
return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.PREDICT,
predictions=self.outputs)
def build_evaluation_model(self, labels):
""" Builder interface for creating model instance that aims to perform
model evaluation. The output of such estimator will be a dictionary
with a key "<instrument>_spectrogram" per separated instrument,
associated to the estimated separated instrument magnitude spectrogram.
:param labels: Model labels.
:returns: An estimator for performing model evaluation.
"""
loss, metrics = self._build_loss(labels)
return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.EVAL,
loss=loss,
eval_metric_ops=metrics)
def build_train_model(self, labels):
""" Builder interface for creating model instance that aims to perform
model training. The output of such estimator will be a dictionary
with a key "<instrument>_spectrogram" per separated instrument,
associated to the estimated separated instrument magnitude spectrogram.
:param labels: Model labels.
:returns: An estimator for performing model training.
"""
loss, metrics = self._build_loss(labels)
optimizer = self._build_optimizer()
train_operation = optimizer.minimize(
loss=loss,
global_step=tf.compat.v1.train.get_global_step())
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN,
loss=loss,
train_op=train_operation,
eval_metric_ops=metrics,
)
def model_fn(features, labels, mode, params, config):
"""
:param features:
:param labels:
:param mode: Estimator mode.
:param params:
:param config: TF configuration (not used).
:returns: Built EstimatorSpec.
:raise ValueError: If estimator mode is not supported.
"""
builder = EstimatorSpecBuilder(features, params)
if mode == tf.estimator.ModeKeys.PREDICT:
return builder.build_predict_model()
elif mode == tf.estimator.ModeKeys.EVAL:
return builder.build_evaluation_model(labels)
elif mode == tf.estimator.ModeKeys.TRAIN:
return builder.build_train_model(labels)
raise ValueError(f'Unknown mode {mode}')