Adding option to use librosa backend.

Changes in the EstimatorBuilder to set attributes instead of returning tensors for the _build methods.
InputProvider classes to handle the different backend cases.
New method in Separator.
This commit is contained in:
akhlif
2020-02-26 16:31:24 +01:00
parent aa7c208b39
commit fe4634afa6
5 changed files with 232 additions and 124 deletions

View File

@@ -72,14 +72,13 @@ OPT_DURATION = {
} }
# -w opt specification (separate) # -w opt specification (separate)
OPT_CHUNKED = { OPT_STFT_BACKEND = {
'dest': 'chunk_duration', 'dest': 'stft_backend',
'type': float, 'type': str,
'default': -1, 'choices' : ["tensorflow", "librosa", "auto"],
'help': 'Maximum duration of the segments that are fed to' 'default': "auto",
' the network. Use this parameter to limit ' 'help': 'Who should be in charge of computing the stfts. Librosa is faster than tensorflow on CPU and uses'
'memory usage. Use -1 to process the whole signal' ' less memory. "auto" will use tensorflow when GPU acceleration is available and librosa when not.'
' in one pass.'
} }
@@ -191,7 +190,7 @@ def _create_separate_parser(parser_factory):
parser.add_argument('-c', '--codec', **OPT_CODEC) parser.add_argument('-c', '--codec', **OPT_CODEC)
parser.add_argument('-b', '--birate', **OPT_BITRATE) parser.add_argument('-b', '--birate', **OPT_BITRATE)
parser.add_argument('-m', '--mwf', **OPT_MWF) parser.add_argument('-m', '--mwf', **OPT_MWF)
parser.add_argument('-w', '--chunk', **OPT_CHUNKED) parser.add_argument('-B', '--stft-backend', **OPT_STFT_BACKEND)
return parser return parser

View File

@@ -11,6 +11,8 @@
-i /path/to/audio1.wav /path/to/audio2.mp3 -i /path/to/audio1.wav /path/to/audio2.mp3
""" """
import tensorflow as tf
from ..audio.adapter import get_audio_adapter from ..audio.adapter import get_audio_adapter
from ..separator import Separator from ..separator import Separator
@@ -19,6 +21,12 @@ __author__ = 'Deezer Research'
__license__ = 'MIT License' __license__ = 'MIT License'
def get_backend(backend):
if backend == "auto":
return "tensorflow" if tf.test.is_gpu_available() else "librosa"
return backend
def entrypoint(arguments, params): def entrypoint(arguments, params):
""" Command entrypoint. """ Command entrypoint.
@@ -29,13 +37,13 @@ def entrypoint(arguments, params):
audio_adapter = get_audio_adapter(arguments.audio_adapter) audio_adapter = get_audio_adapter(arguments.audio_adapter)
separator = Separator( separator = Separator(
arguments.configuration, arguments.configuration,
arguments.MWF) MWF=arguments.MWF,
stft_backend=get_backend(arguments.stft_backend))
for filename in arguments.inputs: for filename in arguments.inputs:
separator.separate_to_file( separator.separate_to_file(
filename, filename,
arguments.output_path, arguments.output_path,
audio_adapter=audio_adapter, audio_adapter=audio_adapter,
chunk_duration=arguments.chunk_duration,
offset=arguments.offset, offset=arguments.offset,
duration=arguments.duration, duration=arguments.duration,
codec=arguments.codec, codec=arguments.codec,

View File

@@ -18,6 +18,10 @@ __author__ = 'Deezer Research'
__license__ = 'MIT License' __license__ = 'MIT License'
placeholder = tf.compat.v1.placeholder
def get_model_function(model_type): def get_model_function(model_type):
""" """
Get tensorflow function of the model to be applied to the input tensor. Get tensorflow function of the model to be applied to the input tensor.
@@ -41,6 +45,74 @@ def get_model_function(model_type):
return model_function return model_function
class InputProvider(object):
def __init__(self, params):
self.params = params
def get_input_dict_placeholders(self):
raise NotImplementedError()
@property
def input_names(self):
raise NotImplementedError()
def get_feed_dict(self, features, *args):
raise NotImplementedError()
class WaveformInputProvider(InputProvider):
@property
def input_names(self):
return ["audio_id", "waveform"]
def get_input_dict_placeholders(self):
shape = (None, self.params['n_channels'])
features = {
'waveform': placeholder(tf.float32, shape=shape, name="waveform"),
'audio_id': placeholder(tf.string, name="audio_id")}
return features
def get_feed_dict(self, features, waveform, audio_id):
return {features["audio_id"]: audio_id, features["waveform"]: waveform}
class SpectralInputProvider(InputProvider):
def __init__(self, params):
super().__init__(params)
self.stft_input_name = "{}_stft".format(self.params["mix_name"])
@property
def input_names(self):
return ["audio_id", self.stft_input_name]
def get_input_dict_placeholders(self):
features = {
self.stft_input_name: placeholder(tf.complex64,
shape=(None, self.params["frame_length"]//2+1,
self.params['n_channels']),
name=self.stft_input_name),
'audio_id': placeholder(tf.string, name="audio_id")}
return features
def get_feed_dict(self, features, stft, audio_id):
return {features["audio_id"]: audio_id, features[self.stft_input_name]: stft}
class InputProviderFactory(object):
@staticmethod
def get(params):
stft_backend = params["stft_backend"]
assert stft_backend in ("tensorflow", "librosa")
if stft_backend == "tensorflow":
return WaveformInputProvider(params)
else:
return SpectralInputProvider(params)
class EstimatorSpecBuilder(object): class EstimatorSpecBuilder(object):
""" A builder class that allows to builds a multitrack unet model """ A builder class that allows to builds a multitrack unet model
estimator. The built model estimator has a different behaviour when estimator. The built model estimator has a different behaviour when
@@ -57,9 +129,9 @@ class EstimatorSpecBuilder(object):
>>> from spleeter.model import EstimatorSpecBuilder >>> from spleeter.model import EstimatorSpecBuilder
>>> builder = EstimatorSpecBuilder() >>> builder = EstimatorSpecBuilder()
>>> builder.build_prediction_model() >>> builder.build_predict_model()
>>> builder.build_evaluation_model() >>> builder.build_evaluation_model()
>>> builder.build_training_model() >>> builder.build_train_model()
>>> from spleeter.model import model_fn >>> from spleeter.model import model_fn
>>> estimator = tf.estimator.Estimator(model_fn=model_fn, ...) >>> estimator = tf.estimator.Estimator(model_fn=model_fn, ...)
@@ -94,6 +166,7 @@ class EstimatorSpecBuilder(object):
:param features: The input features for the estimator. :param features: The input features for the estimator.
:param params: Some hyperparameters as a dictionary. :param params: Some hyperparameters as a dictionary.
""" """
self._features = features self._features = features
self._params = params self._params = params
# Get instrument name. # Get instrument name.
@@ -106,7 +179,10 @@ class EstimatorSpecBuilder(object):
self._frame_length = params['frame_length'] self._frame_length = params['frame_length']
self._frame_step = params['frame_step'] self._frame_step = params['frame_step']
def _build_output_dict(self, input_tensor=None): def include_stft_computations(self):
return self._params["stft_backend"] == "tensorflow"
def _build_model_outputs(self):
""" Created a batch_sizexTxFxn_channels input tensor containing """ Created a batch_sizexTxFxn_channels input tensor containing
mix magnitude spectrogram, then an output dict from it according mix magnitude spectrogram, then an output dict from it according
to the selected model in internal parameters. to the selected model in internal parameters.
@@ -114,8 +190,8 @@ class EstimatorSpecBuilder(object):
:returns: Build output dict. :returns: Build output dict.
:raise ValueError: If required model_type is not supported. :raise ValueError: If required model_type is not supported.
""" """
if input_tensor is None:
input_tensor = self._features[f'{self._mix_name}_spectrogram'] input_tensor = self.spectrogram_feature
model = self._params.get('model', None) model = self._params.get('model', None)
if model is not None: if model is not None:
model_type = model.get('type', self.DEFAULT_MODEL) model_type = model.get('type', self.DEFAULT_MODEL)
@@ -125,12 +201,12 @@ class EstimatorSpecBuilder(object):
apply_model = get_model_function(model_type) apply_model = get_model_function(model_type)
except ModuleNotFoundError: except ModuleNotFoundError:
raise ValueError(f'No model function {model_type} found') raise ValueError(f'No model function {model_type} found')
return apply_model( self._model_outputs = apply_model(
input_tensor, input_tensor,
self._instruments, self._instruments,
self._params['model']['params']) self._params['model']['params'])
def _build_loss(self, output_dict, labels): def _build_loss(self, labels):
""" Construct tensorflow loss and metrics """ Construct tensorflow loss and metrics
:param output_dict: dictionary of network outputs (key: instrument :param output_dict: dictionary of network outputs (key: instrument
@@ -139,6 +215,7 @@ class EstimatorSpecBuilder(object):
name, value: ground truth spectrogram of the instrument) name, value: ground truth spectrogram of the instrument)
:returns: tensorflow (loss, metrics) tuple. :returns: tensorflow (loss, metrics) tuple.
""" """
output_dict = self.model_outputs
loss_type = self._params.get('loss_type', self.L1_MASK) loss_type = self._params.get('loss_type', self.L1_MASK)
if loss_type == self.L1_MASK: if loss_type == self.L1_MASK:
losses = { losses = {
@@ -178,10 +255,27 @@ class EstimatorSpecBuilder(object):
return tf.compat.v1.train.GradientDescentOptimizer(rate) return tf.compat.v1.train.GradientDescentOptimizer(rate)
return tf.compat.v1.train.AdamOptimizer(rate) return tf.compat.v1.train.AdamOptimizer(rate)
@property
def instruments(self):
return self._instruments
@property
def stft_name(self):
return f'{self._mix_name}_stft'
@property
def spectrogram_name(self):
return f'{self._mix_name}_spectrogram'
def _build_stft_feature(self): def _build_stft_feature(self):
""" Compute STFT of waveform and slice the STFT in segment """ Compute STFT of waveform and slice the STFT in segment
with the right length to feed the network. with the right length to feed the network.
""" """
stft_name = self.stft_name
spec_name = self.spectrogram_name
if stft_name not in self._features:
stft_feature = tf.transpose( stft_feature = tf.transpose(
stft( stft(
tf.transpose(self._features['waveform']), tf.transpose(self._features['waveform']),
@@ -192,16 +286,41 @@ class EstimatorSpecBuilder(object):
pad_end=True), pad_end=True),
perm=[1, 2, 0]) perm=[1, 2, 0])
self._features[f'{self._mix_name}_stft'] = stft_feature self._features[f'{self._mix_name}_stft'] = stft_feature
self._features[f'{self._mix_name}_spectrogram'] = tf.abs( if spec_name not in self._features:
pad_and_partition(stft_feature, self._T))[:, :, :self._F, :] self._features[spec_name] = tf.abs(
pad_and_partition(self._features[stft_name], self._T))[:, :, :self._F, :]
def get_stft_feature(self): @property
return self._features[f'{self._mix_name}_stft'] def model_outputs(self):
if not hasattr(self, "_model_outputs"):
self._build_model_outputs()
return self._model_outputs
def get_spectrogram_feature(self): @property
return self._features[f'{self._mix_name}_spectrogram'] def outputs(self):
if not hasattr(self, "_outputs"):
self._build_outputs()
return self._outputs
def _inverse_stft(self, stft): @property
def stft_feature(self):
if self.stft_name not in self._features:
self._build_stft_feature()
return self._features[self.stft_name]
@property
def spectrogram_feature(self):
if self.spectrogram_name not in self._features:
self._build_stft_feature()
return self._features[self.spectrogram_name]
@property
def masks(self):
if not hasattr(self, "_masks"):
self._build_masks()
return self._masks
def _inverse_stft(self, stft, time_crop=None):
""" Inverse and reshape the given STFT """ Inverse and reshape the given STFT
:param stft: input STFT :param stft: input STFT
@@ -215,20 +334,21 @@ class EstimatorSpecBuilder(object):
hann_window(frame_length, periodic=True, dtype=dtype)) hann_window(frame_length, periodic=True, dtype=dtype))
) * self.WINDOW_COMPENSATION_FACTOR ) * self.WINDOW_COMPENSATION_FACTOR
reshaped = tf.transpose(inversed) reshaped = tf.transpose(inversed)
return reshaped[:tf.shape(self._features['waveform'])[0], :] if time_crop is None:
time_crop = tf.shape(self._features['waveform'])[0]
return reshaped[:time_crop, :]
def _build_mwf_output_waveform(self, output_dict): def _build_mwf_output_waveform(self):
""" Perform separation with multichannel Wiener Filtering using Norbert. """ Perform separation with multichannel Wiener Filtering using Norbert.
Note: multichannel Wiener Filtering is not coded in Tensorflow and thus Note: multichannel Wiener Filtering is not coded in Tensorflow and thus
may be quite slow. may be quite slow.
:param output_dict: dictionary of estimated spectrogram (key: instrument
name, value: estimated spectrogram of the instrument)
:returns: dictionary of separated waveforms (key: instrument name, :returns: dictionary of separated waveforms (key: instrument name,
value: estimated waveform of the instrument) value: estimated waveform of the instrument)
""" """
import norbert # pylint: disable=import-error import norbert # pylint: disable=import-error
x = self._features[f'{self._mix_name}_stft'] output_dict = self.model_outputs
x = self.stft_feature
v = tf.stack( v = tf.stack(
[ [
pad_and_reshape( pad_and_reshape(
@@ -272,11 +392,13 @@ class EstimatorSpecBuilder(object):
mask_shape[-1])) mask_shape[-1]))
else: else:
raise ValueError(f'Invalid mask_extension parameter {extension}') raise ValueError(f'Invalid mask_extension parameter {extension}')
n_extra_row = (self._frame_length) // 2 + 1 - self._F n_extra_row = self._frame_length // 2 + 1 - self._F
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1]) extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
return tf.concat([mask, extension], axis=2) return tf.concat([mask, extension], axis=2)
def _build_masks(self, output_dict): def _build_masks(self):
output_dict = self.model_outputs
stft_feature = self.stft_feature
separation_exponent = self._params['separation_exponent'] separation_exponent = self._params['separation_exponent']
output_sum = tf.reduce_sum( output_sum = tf.reduce_sum(
[e ** separation_exponent for e in output_dict.values()], [e ** separation_exponent for e in output_dict.values()],
@@ -297,21 +419,20 @@ class EstimatorSpecBuilder(object):
axis=0) axis=0)
instrument_mask = tf.reshape(instrument_mask, new_shape) instrument_mask = tf.reshape(instrument_mask, new_shape)
# Remove padded part (for mask having the same size as STFT); # Remove padded part (for mask having the same size as STFT);
stft_feature = self._features[f'{self._mix_name}_stft']
instrument_mask = instrument_mask[ instrument_mask = instrument_mask[
:tf.shape(stft_feature)[0], ...] :tf.shape(stft_feature)[0], ...]
out[instrument] = instrument_mask out[instrument] = instrument_mask
return out self._masks = out
def _build_masked_stft(self, mask_dict, input_stft=None): def _build_masked_stft(self):
if input_stft is None: input_stft = self.stft_feature
input_stft = self._features[f'{self._mix_name}_stft']
out = {} out = {}
for instrument, mask in mask_dict.items(): for instrument, mask in self.masks.items():
out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft
return out return out
def _build_manual_output_waveform(self, output_dict): def _build_manual_output_waveform(self, masked_stft):
""" Perform ratio mask separation """ Perform ratio mask separation
:param output_dict: dictionary of estimated spectrogram (key: instrument :param output_dict: dictionary of estimated spectrogram (key: instrument
@@ -321,27 +442,34 @@ class EstimatorSpecBuilder(object):
""" """
output_waveform = {} output_waveform = {}
masked_stft = self._build_masked_stft(self._build_masks(output_dict))
for instrument, stft_data in masked_stft.items(): for instrument, stft_data in masked_stft.items():
output_waveform[instrument] = self._inverse_stft(stft_data) output_waveform[instrument] = self._inverse_stft(stft_data)
return output_waveform return output_waveform
def _build_output_waveform(self, output_dict): def _build_output_waveform(self, masked_stft):
""" Build output waveform from given output dict in order to be used in """ Build output waveform from given output dict in order to be used in
prediction context. Regarding of the configuration building method will prediction context. Regarding of the configuration building method will
be using MWF. be using MWF.
:param output_dict: Output dict to build output waveform from.
:returns: Built output waveform. :returns: Built output waveform.
""" """
if self._params.get('MWF', False): if self._params.get('MWF', False):
output_waveform = self._build_mwf_output_waveform(output_dict) output_waveform = self._build_mwf_output_waveform()
else: else:
output_waveform = self._build_manual_output_waveform(output_dict) output_waveform = self._build_manual_output_waveform(masked_stft)
if 'audio_id' in self._features:
output_waveform['audio_id'] = self._features['audio_id']
return output_waveform return output_waveform
def _build_outputs(self):
masked_stft = self._build_masked_stft()
if self.include_stft_computations():
self._outputs = self._build_output_waveform(masked_stft)
else:
self._outputs = masked_stft
if 'audio_id' in self._features:
self._outputs['audio_id'] = self._features['audio_id']
def build_predict_model(self): def build_predict_model(self):
""" Builder interface for creating model instance that aims to perform """ Builder interface for creating model instance that aims to perform
prediction / inference over given track. The output of such estimator prediction / inference over given track. The output of such estimator
@@ -350,12 +478,10 @@ class EstimatorSpecBuilder(object):
:returns: An estimator for performing prediction. :returns: An estimator for performing prediction.
""" """
self._build_stft_feature()
output_dict = self._build_output_dict()
output_waveform = self._build_output_waveform(output_dict)
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.PREDICT, tf.estimator.ModeKeys.PREDICT,
predictions=output_waveform) predictions=self.outputs)
def build_evaluation_model(self, labels): def build_evaluation_model(self, labels):
""" Builder interface for creating model instance that aims to perform """ Builder interface for creating model instance that aims to perform
@@ -366,8 +492,7 @@ class EstimatorSpecBuilder(object):
:param labels: Model labels. :param labels: Model labels.
:returns: An estimator for performing model evaluation. :returns: An estimator for performing model evaluation.
""" """
output_dict = self._build_output_dict() loss, metrics = self._build_loss(labels)
loss, metrics = self._build_loss(output_dict, labels)
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.EVAL,
loss=loss, loss=loss,
@@ -382,8 +507,7 @@ class EstimatorSpecBuilder(object):
:param labels: Model labels. :param labels: Model labels.
:returns: An estimator for performing model training. :returns: An estimator for performing model training.
""" """
output_dict = self._build_output_dict() loss, metrics = self._build_loss(labels)
loss, metrics = self._build_loss(output_dict, labels)
optimizer = self._build_optimizer() optimizer = self._build_optimizer()
train_operation = optimizer.minimize( train_operation = optimizer.minimize(
loss=loss, loss=loss,

View File

@@ -20,14 +20,15 @@ from multiprocessing import Pool
from os.path import basename, join, splitext from os.path import basename, join, splitext
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from librosa.core import stft, istft
from scipy.signal.windows import hann
from . import SpleeterError from . import SpleeterError
from .audio.adapter import get_default_audio_adapter from .audio.adapter import get_default_audio_adapter
from .audio.convertor import to_stereo from .audio.convertor import to_stereo
from .utils.configuration import load_configuration from .utils.configuration import load_configuration
from .utils.estimator import create_estimator, to_predictor, get_input_dict_placeholders, get_default_model_dir from .utils.estimator import create_estimator, to_predictor, get_default_model_dir
from .model import EstimatorSpecBuilder from .model import EstimatorSpecBuilder, InputProviderFactory
__email__ = 'research@deezer.com' __email__ = 'research@deezer.com'
@@ -41,7 +42,7 @@ logger = logging.getLogger("spleeter")
class Separator(object): class Separator(object):
""" A wrapper class for performing separation. """ """ A wrapper class for performing separation. """
def __init__(self, params_descriptor, MWF=False, multiprocess=True): def __init__(self, params_descriptor, MWF=False, stft_backend="auto", multiprocess=True):
""" Default constructor. """ Default constructor.
:param params_descriptor: Descriptor for TF params to be used. :param params_descriptor: Descriptor for TF params to be used.
@@ -53,6 +54,7 @@ class Separator(object):
self._predictor = None self._predictor = None
self._pool = Pool() if multiprocess else None self._pool = Pool() if multiprocess else None
self._tasks = [] self._tasks = []
self._params["stft_backend"] = stft_backend
def _get_predictor(self): def _get_predictor(self):
""" Lazy loading access method for internal predictor instance. """ Lazy loading access method for internal predictor instance.
@@ -120,60 +122,41 @@ class Separator(object):
assert chunk_size % d == 0 assert chunk_size % d == 0
return chunk_size//d return chunk_size//d
def separate_chunked(self, waveform, sample_rate, chunk_max_duration): def stft(self, waveform, inverse=False):
chunk_size = self.get_valid_chunk_size(sample_rate, chunk_max_duration) N = self._params["frame_length"]
print(f"chunk size is {chunk_size}") H = self._params["frame_step"]
batch_size = self.get_batch_size_for_chunk_size(chunk_size) win = hann(N, sym=False)
print(f"batch size {batch_size}") fstft = istft if inverse else stft
T, F = self._params["T"], self._params["F"] win_len_arg = "win_length" if inverse else "n_fft"
s1 = fstft(waveform[:, 0], hop_length=H, window=win, center=False, **{win_len_arg: N})
s2 = fstft(waveform[:, 1], hop_length=H, window=win, center=False, **{win_len_arg: N})
s1 = np.expand_dims(s1.T, 2-inverse)
s2 = np.expand_dims(s2.T, 2-inverse)
return np.concatenate([s1, s2], axis=2-inverse)
def separate_librosa(self, waveform, audio_id):
out = {} out = {}
n_batches = (waveform.shape[0]+batch_size*T*F-1)//(batch_size*T*F) input_provider = InputProviderFactory.get(self._params)
print(f"{n_batches} to compute") features = input_provider.get_input_dict_placeholders()
features = get_input_dict_placeholders(self._params)
spectrogram_input_t = tf.placeholder(tf.float32, shape=(None, T, F, 2), name="spectrogram_input")
istft_input_t = tf.placeholder(tf.complex64, shape=(None, F, 2), name="istft_input")
start_t = tf.placeholder(tf.int32, shape=(), name="start")
end_t = tf.placeholder(tf.int32, shape=(), name="end")
builder = EstimatorSpecBuilder(features, self._params) builder = EstimatorSpecBuilder(features, self._params)
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir'])) latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
# TODO: fix the logic, build sometimes return, sometimes set attribute # TODO: fix the logic, build sometimes return, sometimes set attribute
builder._build_stft_feature() outputs = builder.outputs
stft_t = builder.get_stft_feature()
output_dict_t = builder._build_output_dict(input_tensor=spectrogram_input_t)
masked_stft_t = builder._build_masked_stft(builder._build_masks(output_dict_t),
input_stft=stft_t[start_t:end_t, :, :])
output_waveform_t = builder._inverse_stft(istft_input_t)
waveform_t = features["waveform"]
masked_stfts = {}
saver = tf.train.Saver() saver = tf.train.Saver()
stft = self.stft(waveform)
with tf.Session() as sess: with tf.Session() as sess:
print("restoring weights {}".format(time()))
saver.restore(sess, latest_checkpoint) saver.restore(sess, latest_checkpoint)
print("computing spectrogram {}".format(time())) outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))
spectrogram, stft = sess.run([builder.get_spectrogram_feature(), stft_t], feed_dict={waveform_t: waveform}) for inst in builder.instruments:
print(spectrogram.shape) out[inst] = self.stft(outputs[inst], inverse=True)
print(stft.shape)
for i in range(n_batches):
print("computing batch {} {}".format(i, time()))
start = i*batch_size
end = (i+1)*batch_size
tmp = sess.run(masked_stft_t,
feed_dict={spectrogram_input_t: spectrogram[start:end, ...],
start_t: start*T, end_t: end*T, stft_t: stft})
for instrument, masked_stft in tmp.items():
masked_stfts.setdefault(instrument, []).append(masked_stft)
print("inverting spectrogram {}".format(time()))
for instrument, masked_stft in masked_stfts.items():
out[instrument] = sess.run(output_waveform_t, {istft_input_t: np.concatenate(masked_stft, axis=0)})
print("done separating {}".format(time()))
return out return out
def separate_to_file( def separate_to_file(
self, audio_descriptor, destination, self, audio_descriptor, destination,
audio_adapter=get_default_audio_adapter(), chunk_duration=-1, audio_adapter=get_default_audio_adapter(),
offset=0, duration=600., codec='wav', bitrate='128k', offset=0, duration=600., codec='wav', bitrate='128k',
filename_format='{filename}/{instrument}.{codec}', filename_format='{filename}/{instrument}.{codec}',
synchronous=True): synchronous=True):
@@ -198,15 +181,15 @@ class Separator(object):
:param filename_format: (Optional) Filename format. :param filename_format: (Optional) Filename format.
:param synchronous: (Optional) True is should by synchronous. :param synchronous: (Optional) True is should by synchronous.
""" """
print("loading audio {}".format(time()))
waveform, sample_rate = audio_adapter.load( waveform, sample_rate = audio_adapter.load(
audio_descriptor, audio_descriptor,
offset=offset, offset=offset,
duration=duration, duration=duration,
sample_rate=self._sample_rate) sample_rate=self._sample_rate)
print("done loading audio {}".format(time())) if self._params["stft_backend"] == "tensorflow":
sources = self.separate_chunked(waveform, sample_rate, chunk_duration) sources = self.separate(waveform)
print("saving to file {}".format(time())) else:
sources = self.separate_librosa(waveform, audio_descriptor)
self.save_to_file(sources, audio_descriptor, destination, filename_format, codec, self.save_to_file(sources, audio_descriptor, destination, filename_format, codec,
audio_adapter, bitrate, synchronous) audio_adapter, bitrate, synchronous)

View File

@@ -13,7 +13,7 @@ import tensorflow as tf
from tensorflow.contrib import predictor from tensorflow.contrib import predictor
# pylint: enable=import-error # pylint: enable=import-error
from ..model import model_fn from ..model import model_fn, InputProviderFactory
from ..model.provider import get_default_model_provider from ..model.provider import get_default_model_provider
# Default exporting directory for predictor. # Default exporting directory for predictor.
@@ -59,14 +59,6 @@ def create_estimator(params, MWF):
return estimator return estimator
def get_input_dict_placeholders(params):
shape = (None, params['n_channels'])
features = {
'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape, name="waveform"),
'audio_id': tf.compat.v1.placeholder(tf.string, name="audio_id")}
return features
def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY): def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
""" Exports given estimator as predictor into the given directory """ Exports given estimator as predictor into the given directory
and returns associated tf.predictor instance. and returns associated tf.predictor instance.
@@ -74,8 +66,10 @@ def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
:param estimator: Estimator to export. :param estimator: Estimator to export.
:param directory: (Optional) path to write exported model into. :param directory: (Optional) path to write exported model into.
""" """
input_provider = InputProviderFactory.get(estimator.params)
def receiver(): def receiver():
features = get_input_dict_placeholders(estimator.params) features = input_provider.get_input_dict_placeholders()
return tf.estimator.export.ServingInputReceiver(features, features) return tf.estimator.export.ServingInputReceiver(features, features)
estimator.export_saved_model(directory, receiver) estimator.export_saved_model(directory, receiver)