Adding option to use librosa backend.

Changes in the EstimatorBuilder to set attributes instead of returning tensors for the _build methods.
InputProvider classes to handle the different backend cases.
New method in Separator.
This commit is contained in:
akhlif
2020-02-26 16:31:24 +01:00
parent aa7c208b39
commit fe4634afa6
5 changed files with 232 additions and 124 deletions

View File

@@ -72,14 +72,13 @@ OPT_DURATION = {
}
# -w opt specification (separate)
OPT_CHUNKED = {
'dest': 'chunk_duration',
'type': float,
'default': -1,
'help': 'Maximum duration of the segments that are fed to'
' the network. Use this parameter to limit '
'memory usage. Use -1 to process the whole signal'
' in one pass.'
OPT_STFT_BACKEND = {
'dest': 'stft_backend',
'type': str,
'choices' : ["tensorflow", "librosa", "auto"],
'default': "auto",
'help': 'Who should be in charge of computing the stfts. Librosa is faster than tensorflow on CPU and uses'
' less memory. "auto" will use tensorflow when GPU acceleration is available and librosa when not.'
}
@@ -191,7 +190,7 @@ def _create_separate_parser(parser_factory):
parser.add_argument('-c', '--codec', **OPT_CODEC)
parser.add_argument('-b', '--birate', **OPT_BITRATE)
parser.add_argument('-m', '--mwf', **OPT_MWF)
parser.add_argument('-w', '--chunk', **OPT_CHUNKED)
parser.add_argument('-B', '--stft-backend', **OPT_STFT_BACKEND)
return parser

View File

@@ -11,6 +11,8 @@
-i /path/to/audio1.wav /path/to/audio2.mp3
"""
import tensorflow as tf
from ..audio.adapter import get_audio_adapter
from ..separator import Separator
@@ -19,6 +21,12 @@ __author__ = 'Deezer Research'
__license__ = 'MIT License'
def get_backend(backend):
if backend == "auto":
return "tensorflow" if tf.test.is_gpu_available() else "librosa"
return backend
def entrypoint(arguments, params):
""" Command entrypoint.
@@ -29,13 +37,13 @@ def entrypoint(arguments, params):
audio_adapter = get_audio_adapter(arguments.audio_adapter)
separator = Separator(
arguments.configuration,
arguments.MWF)
MWF=arguments.MWF,
stft_backend=get_backend(arguments.stft_backend))
for filename in arguments.inputs:
separator.separate_to_file(
filename,
arguments.output_path,
audio_adapter=audio_adapter,
chunk_duration=arguments.chunk_duration,
offset=arguments.offset,
duration=arguments.duration,
codec=arguments.codec,

View File

@@ -18,6 +18,10 @@ __author__ = 'Deezer Research'
__license__ = 'MIT License'
placeholder = tf.compat.v1.placeholder
def get_model_function(model_type):
"""
Get tensorflow function of the model to be applied to the input tensor.
@@ -41,6 +45,74 @@ def get_model_function(model_type):
return model_function
class InputProvider(object):
def __init__(self, params):
self.params = params
def get_input_dict_placeholders(self):
raise NotImplementedError()
@property
def input_names(self):
raise NotImplementedError()
def get_feed_dict(self, features, *args):
raise NotImplementedError()
class WaveformInputProvider(InputProvider):
@property
def input_names(self):
return ["audio_id", "waveform"]
def get_input_dict_placeholders(self):
shape = (None, self.params['n_channels'])
features = {
'waveform': placeholder(tf.float32, shape=shape, name="waveform"),
'audio_id': placeholder(tf.string, name="audio_id")}
return features
def get_feed_dict(self, features, waveform, audio_id):
return {features["audio_id"]: audio_id, features["waveform"]: waveform}
class SpectralInputProvider(InputProvider):
def __init__(self, params):
super().__init__(params)
self.stft_input_name = "{}_stft".format(self.params["mix_name"])
@property
def input_names(self):
return ["audio_id", self.stft_input_name]
def get_input_dict_placeholders(self):
features = {
self.stft_input_name: placeholder(tf.complex64,
shape=(None, self.params["frame_length"]//2+1,
self.params['n_channels']),
name=self.stft_input_name),
'audio_id': placeholder(tf.string, name="audio_id")}
return features
def get_feed_dict(self, features, stft, audio_id):
return {features["audio_id"]: audio_id, features[self.stft_input_name]: stft}
class InputProviderFactory(object):
@staticmethod
def get(params):
stft_backend = params["stft_backend"]
assert stft_backend in ("tensorflow", "librosa")
if stft_backend == "tensorflow":
return WaveformInputProvider(params)
else:
return SpectralInputProvider(params)
class EstimatorSpecBuilder(object):
""" A builder class that allows to builds a multitrack unet model
estimator. The built model estimator has a different behaviour when
@@ -57,9 +129,9 @@ class EstimatorSpecBuilder(object):
>>> from spleeter.model import EstimatorSpecBuilder
>>> builder = EstimatorSpecBuilder()
>>> builder.build_prediction_model()
>>> builder.build_predict_model()
>>> builder.build_evaluation_model()
>>> builder.build_training_model()
>>> builder.build_train_model()
>>> from spleeter.model import model_fn
>>> estimator = tf.estimator.Estimator(model_fn=model_fn, ...)
@@ -94,6 +166,7 @@ class EstimatorSpecBuilder(object):
:param features: The input features for the estimator.
:param params: Some hyperparameters as a dictionary.
"""
self._features = features
self._params = params
# Get instrument name.
@@ -106,7 +179,10 @@ class EstimatorSpecBuilder(object):
self._frame_length = params['frame_length']
self._frame_step = params['frame_step']
def _build_output_dict(self, input_tensor=None):
def include_stft_computations(self):
return self._params["stft_backend"] == "tensorflow"
def _build_model_outputs(self):
""" Created a batch_sizexTxFxn_channels input tensor containing
mix magnitude spectrogram, then an output dict from it according
to the selected model in internal parameters.
@@ -114,8 +190,8 @@ class EstimatorSpecBuilder(object):
:returns: Build output dict.
:raise ValueError: If required model_type is not supported.
"""
if input_tensor is None:
input_tensor = self._features[f'{self._mix_name}_spectrogram']
input_tensor = self.spectrogram_feature
model = self._params.get('model', None)
if model is not None:
model_type = model.get('type', self.DEFAULT_MODEL)
@@ -125,12 +201,12 @@ class EstimatorSpecBuilder(object):
apply_model = get_model_function(model_type)
except ModuleNotFoundError:
raise ValueError(f'No model function {model_type} found')
return apply_model(
self._model_outputs = apply_model(
input_tensor,
self._instruments,
self._params['model']['params'])
def _build_loss(self, output_dict, labels):
def _build_loss(self, labels):
""" Construct tensorflow loss and metrics
:param output_dict: dictionary of network outputs (key: instrument
@@ -139,6 +215,7 @@ class EstimatorSpecBuilder(object):
name, value: ground truth spectrogram of the instrument)
:returns: tensorflow (loss, metrics) tuple.
"""
output_dict = self.model_outputs
loss_type = self._params.get('loss_type', self.L1_MASK)
if loss_type == self.L1_MASK:
losses = {
@@ -178,30 +255,72 @@ class EstimatorSpecBuilder(object):
return tf.compat.v1.train.GradientDescentOptimizer(rate)
return tf.compat.v1.train.AdamOptimizer(rate)
@property
def instruments(self):
return self._instruments
@property
def stft_name(self):
return f'{self._mix_name}_stft'
@property
def spectrogram_name(self):
return f'{self._mix_name}_spectrogram'
def _build_stft_feature(self):
""" Compute STFT of waveform and slice the STFT in segment
with the right length to feed the network.
"""
stft_feature = tf.transpose(
stft(
tf.transpose(self._features['waveform']),
self._frame_length,
self._frame_step,
window_fn=lambda frame_length, dtype: (
hann_window(frame_length, periodic=True, dtype=dtype)),
pad_end=True),
perm=[1, 2, 0])
self._features[f'{self._mix_name}_stft'] = stft_feature
self._features[f'{self._mix_name}_spectrogram'] = tf.abs(
pad_and_partition(stft_feature, self._T))[:, :, :self._F, :]
def get_stft_feature(self):
return self._features[f'{self._mix_name}_stft']
stft_name = self.stft_name
spec_name = self.spectrogram_name
def get_spectrogram_feature(self):
return self._features[f'{self._mix_name}_spectrogram']
if stft_name not in self._features:
stft_feature = tf.transpose(
stft(
tf.transpose(self._features['waveform']),
self._frame_length,
self._frame_step,
window_fn=lambda frame_length, dtype: (
hann_window(frame_length, periodic=True, dtype=dtype)),
pad_end=True),
perm=[1, 2, 0])
self._features[f'{self._mix_name}_stft'] = stft_feature
if spec_name not in self._features:
self._features[spec_name] = tf.abs(
pad_and_partition(self._features[stft_name], self._T))[:, :, :self._F, :]
def _inverse_stft(self, stft):
@property
def model_outputs(self):
if not hasattr(self, "_model_outputs"):
self._build_model_outputs()
return self._model_outputs
@property
def outputs(self):
if not hasattr(self, "_outputs"):
self._build_outputs()
return self._outputs
@property
def stft_feature(self):
if self.stft_name not in self._features:
self._build_stft_feature()
return self._features[self.stft_name]
@property
def spectrogram_feature(self):
if self.spectrogram_name not in self._features:
self._build_stft_feature()
return self._features[self.spectrogram_name]
@property
def masks(self):
if not hasattr(self, "_masks"):
self._build_masks()
return self._masks
def _inverse_stft(self, stft, time_crop=None):
""" Inverse and reshape the given STFT
:param stft: input STFT
@@ -215,20 +334,21 @@ class EstimatorSpecBuilder(object):
hann_window(frame_length, periodic=True, dtype=dtype))
) * self.WINDOW_COMPENSATION_FACTOR
reshaped = tf.transpose(inversed)
return reshaped[:tf.shape(self._features['waveform'])[0], :]
if time_crop is None:
time_crop = tf.shape(self._features['waveform'])[0]
return reshaped[:time_crop, :]
def _build_mwf_output_waveform(self, output_dict):
def _build_mwf_output_waveform(self):
""" Perform separation with multichannel Wiener Filtering using Norbert.
Note: multichannel Wiener Filtering is not coded in Tensorflow and thus
may be quite slow.
:param output_dict: dictionary of estimated spectrogram (key: instrument
name, value: estimated spectrogram of the instrument)
:returns: dictionary of separated waveforms (key: instrument name,
value: estimated waveform of the instrument)
"""
import norbert # pylint: disable=import-error
x = self._features[f'{self._mix_name}_stft']
output_dict = self.model_outputs
x = self.stft_feature
v = tf.stack(
[
pad_and_reshape(
@@ -272,11 +392,13 @@ class EstimatorSpecBuilder(object):
mask_shape[-1]))
else:
raise ValueError(f'Invalid mask_extension parameter {extension}')
n_extra_row = (self._frame_length) // 2 + 1 - self._F
n_extra_row = self._frame_length // 2 + 1 - self._F
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
return tf.concat([mask, extension], axis=2)
def _build_masks(self, output_dict):
def _build_masks(self):
output_dict = self.model_outputs
stft_feature = self.stft_feature
separation_exponent = self._params['separation_exponent']
output_sum = tf.reduce_sum(
[e ** separation_exponent for e in output_dict.values()],
@@ -297,21 +419,20 @@ class EstimatorSpecBuilder(object):
axis=0)
instrument_mask = tf.reshape(instrument_mask, new_shape)
# Remove padded part (for mask having the same size as STFT);
stft_feature = self._features[f'{self._mix_name}_stft']
instrument_mask = instrument_mask[
:tf.shape(stft_feature)[0], ...]
out[instrument] = instrument_mask
return out
self._masks = out
def _build_masked_stft(self, mask_dict, input_stft=None):
if input_stft is None:
input_stft = self._features[f'{self._mix_name}_stft']
def _build_masked_stft(self):
input_stft = self.stft_feature
out = {}
for instrument, mask in mask_dict.items():
for instrument, mask in self.masks.items():
out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft
return out
def _build_manual_output_waveform(self, output_dict):
def _build_manual_output_waveform(self, masked_stft):
""" Perform ratio mask separation
:param output_dict: dictionary of estimated spectrogram (key: instrument
@@ -321,27 +442,34 @@ class EstimatorSpecBuilder(object):
"""
output_waveform = {}
masked_stft = self._build_masked_stft(self._build_masks(output_dict))
for instrument, stft_data in masked_stft.items():
output_waveform[instrument] = self._inverse_stft(stft_data)
return output_waveform
def _build_output_waveform(self, output_dict):
def _build_output_waveform(self, masked_stft):
""" Build output waveform from given output dict in order to be used in
prediction context. Regarding of the configuration building method will
be using MWF.
:param output_dict: Output dict to build output waveform from.
:returns: Built output waveform.
"""
if self._params.get('MWF', False):
output_waveform = self._build_mwf_output_waveform(output_dict)
output_waveform = self._build_mwf_output_waveform()
else:
output_waveform = self._build_manual_output_waveform(output_dict)
if 'audio_id' in self._features:
output_waveform['audio_id'] = self._features['audio_id']
output_waveform = self._build_manual_output_waveform(masked_stft)
return output_waveform
def _build_outputs(self):
masked_stft = self._build_masked_stft()
if self.include_stft_computations():
self._outputs = self._build_output_waveform(masked_stft)
else:
self._outputs = masked_stft
if 'audio_id' in self._features:
self._outputs['audio_id'] = self._features['audio_id']
def build_predict_model(self):
""" Builder interface for creating model instance that aims to perform
prediction / inference over given track. The output of such estimator
@@ -350,12 +478,10 @@ class EstimatorSpecBuilder(object):
:returns: An estimator for performing prediction.
"""
self._build_stft_feature()
output_dict = self._build_output_dict()
output_waveform = self._build_output_waveform(output_dict)
return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.PREDICT,
predictions=output_waveform)
predictions=self.outputs)
def build_evaluation_model(self, labels):
""" Builder interface for creating model instance that aims to perform
@@ -366,8 +492,7 @@ class EstimatorSpecBuilder(object):
:param labels: Model labels.
:returns: An estimator for performing model evaluation.
"""
output_dict = self._build_output_dict()
loss, metrics = self._build_loss(output_dict, labels)
loss, metrics = self._build_loss(labels)
return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.EVAL,
loss=loss,
@@ -382,8 +507,7 @@ class EstimatorSpecBuilder(object):
:param labels: Model labels.
:returns: An estimator for performing model training.
"""
output_dict = self._build_output_dict()
loss, metrics = self._build_loss(output_dict, labels)
loss, metrics = self._build_loss(labels)
optimizer = self._build_optimizer()
train_operation = optimizer.minimize(
loss=loss,

View File

@@ -20,14 +20,15 @@ from multiprocessing import Pool
from os.path import basename, join, splitext
import numpy as np
import tensorflow as tf
from librosa.core import stft, istft
from scipy.signal.windows import hann
from . import SpleeterError
from .audio.adapter import get_default_audio_adapter
from .audio.convertor import to_stereo
from .utils.configuration import load_configuration
from .utils.estimator import create_estimator, to_predictor, get_input_dict_placeholders, get_default_model_dir
from .model import EstimatorSpecBuilder
from .utils.estimator import create_estimator, to_predictor, get_default_model_dir
from .model import EstimatorSpecBuilder, InputProviderFactory
__email__ = 'research@deezer.com'
@@ -41,7 +42,7 @@ logger = logging.getLogger("spleeter")
class Separator(object):
""" A wrapper class for performing separation. """
def __init__(self, params_descriptor, MWF=False, multiprocess=True):
def __init__(self, params_descriptor, MWF=False, stft_backend="auto", multiprocess=True):
""" Default constructor.
:param params_descriptor: Descriptor for TF params to be used.
@@ -53,6 +54,7 @@ class Separator(object):
self._predictor = None
self._pool = Pool() if multiprocess else None
self._tasks = []
self._params["stft_backend"] = stft_backend
def _get_predictor(self):
""" Lazy loading access method for internal predictor instance.
@@ -120,60 +122,41 @@ class Separator(object):
assert chunk_size % d == 0
return chunk_size//d
def separate_chunked(self, waveform, sample_rate, chunk_max_duration):
chunk_size = self.get_valid_chunk_size(sample_rate, chunk_max_duration)
print(f"chunk size is {chunk_size}")
batch_size = self.get_batch_size_for_chunk_size(chunk_size)
print(f"batch size {batch_size}")
T, F = self._params["T"], self._params["F"]
def stft(self, waveform, inverse=False):
N = self._params["frame_length"]
H = self._params["frame_step"]
win = hann(N, sym=False)
fstft = istft if inverse else stft
win_len_arg = "win_length" if inverse else "n_fft"
s1 = fstft(waveform[:, 0], hop_length=H, window=win, center=False, **{win_len_arg: N})
s2 = fstft(waveform[:, 1], hop_length=H, window=win, center=False, **{win_len_arg: N})
s1 = np.expand_dims(s1.T, 2-inverse)
s2 = np.expand_dims(s2.T, 2-inverse)
return np.concatenate([s1, s2], axis=2-inverse)
def separate_librosa(self, waveform, audio_id):
out = {}
n_batches = (waveform.shape[0]+batch_size*T*F-1)//(batch_size*T*F)
print(f"{n_batches} to compute")
features = get_input_dict_placeholders(self._params)
spectrogram_input_t = tf.placeholder(tf.float32, shape=(None, T, F, 2), name="spectrogram_input")
istft_input_t = tf.placeholder(tf.complex64, shape=(None, F, 2), name="istft_input")
start_t = tf.placeholder(tf.int32, shape=(), name="start")
end_t = tf.placeholder(tf.int32, shape=(), name="end")
input_provider = InputProviderFactory.get(self._params)
features = input_provider.get_input_dict_placeholders()
builder = EstimatorSpecBuilder(features, self._params)
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
# TODO: fix the logic, build sometimes return, sometimes set attribute
builder._build_stft_feature()
stft_t = builder.get_stft_feature()
output_dict_t = builder._build_output_dict(input_tensor=spectrogram_input_t)
masked_stft_t = builder._build_masked_stft(builder._build_masks(output_dict_t),
input_stft=stft_t[start_t:end_t, :, :])
output_waveform_t = builder._inverse_stft(istft_input_t)
waveform_t = features["waveform"]
masked_stfts = {}
outputs = builder.outputs
saver = tf.train.Saver()
stft = self.stft(waveform)
with tf.Session() as sess:
print("restoring weights {}".format(time()))
saver.restore(sess, latest_checkpoint)
print("computing spectrogram {}".format(time()))
spectrogram, stft = sess.run([builder.get_spectrogram_feature(), stft_t], feed_dict={waveform_t: waveform})
print(spectrogram.shape)
print(stft.shape)
for i in range(n_batches):
print("computing batch {} {}".format(i, time()))
start = i*batch_size
end = (i+1)*batch_size
tmp = sess.run(masked_stft_t,
feed_dict={spectrogram_input_t: spectrogram[start:end, ...],
start_t: start*T, end_t: end*T, stft_t: stft})
for instrument, masked_stft in tmp.items():
masked_stfts.setdefault(instrument, []).append(masked_stft)
print("inverting spectrogram {}".format(time()))
for instrument, masked_stft in masked_stfts.items():
out[instrument] = sess.run(output_waveform_t, {istft_input_t: np.concatenate(masked_stft, axis=0)})
print("done separating {}".format(time()))
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))
for inst in builder.instruments:
out[inst] = self.stft(outputs[inst], inverse=True)
return out
def separate_to_file(
self, audio_descriptor, destination,
audio_adapter=get_default_audio_adapter(), chunk_duration=-1,
audio_adapter=get_default_audio_adapter(),
offset=0, duration=600., codec='wav', bitrate='128k',
filename_format='{filename}/{instrument}.{codec}',
synchronous=True):
@@ -198,15 +181,15 @@ class Separator(object):
:param filename_format: (Optional) Filename format.
:param synchronous: (Optional) True is should by synchronous.
"""
print("loading audio {}".format(time()))
waveform, sample_rate = audio_adapter.load(
audio_descriptor,
offset=offset,
duration=duration,
sample_rate=self._sample_rate)
print("done loading audio {}".format(time()))
sources = self.separate_chunked(waveform, sample_rate, chunk_duration)
print("saving to file {}".format(time()))
if self._params["stft_backend"] == "tensorflow":
sources = self.separate(waveform)
else:
sources = self.separate_librosa(waveform, audio_descriptor)
self.save_to_file(sources, audio_descriptor, destination, filename_format, codec,
audio_adapter, bitrate, synchronous)

View File

@@ -13,7 +13,7 @@ import tensorflow as tf
from tensorflow.contrib import predictor
# pylint: enable=import-error
from ..model import model_fn
from ..model import model_fn, InputProviderFactory
from ..model.provider import get_default_model_provider
# Default exporting directory for predictor.
@@ -59,14 +59,6 @@ def create_estimator(params, MWF):
return estimator
def get_input_dict_placeholders(params):
shape = (None, params['n_channels'])
features = {
'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape, name="waveform"),
'audio_id': tf.compat.v1.placeholder(tf.string, name="audio_id")}
return features
def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
""" Exports given estimator as predictor into the given directory
and returns associated tf.predictor instance.
@@ -74,8 +66,10 @@ def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
:param estimator: Estimator to export.
:param directory: (Optional) path to write exported model into.
"""
input_provider = InputProviderFactory.get(estimator.params)
def receiver():
features = get_input_dict_placeholders(estimator.params)
features = input_provider.get_input_dict_placeholders()
return tf.estimator.export.ServingInputReceiver(features, features)
estimator.export_saved_model(directory, receiver)