First draft implementing correct reconstruction

This commit is contained in:
akhlif
2020-02-19 23:11:16 +01:00
parent f93dfbc235
commit aa7c208b39
5 changed files with 152 additions and 38 deletions

View File

@@ -4,7 +4,7 @@
""" This modules provides spleeter command as well as CLI parsing methods. """ """ This modules provides spleeter command as well as CLI parsing methods. """
import json import json
import logging
from argparse import ArgumentParser from argparse import ArgumentParser
from tempfile import gettempdir from tempfile import gettempdir
from os.path import exists, join from os.path import exists, join
@@ -13,6 +13,9 @@ __email__ = 'research@deezer.com'
__author__ = 'Deezer Research' __author__ = 'Deezer Research'
__license__ = 'MIT License' __license__ = 'MIT License'
logging.basicConfig()
# -i opt specification (separate). # -i opt specification (separate).
OPT_INPUT = { OPT_INPUT = {
'dest': 'inputs', 'dest': 'inputs',

View File

@@ -35,6 +35,7 @@ def entrypoint(arguments, params):
filename, filename,
arguments.output_path, arguments.output_path,
audio_adapter=audio_adapter, audio_adapter=audio_adapter,
chunk_duration=arguments.chunk_duration,
offset=arguments.offset, offset=arguments.offset,
duration=arguments.duration, duration=arguments.duration,
codec=arguments.codec, codec=arguments.codec,

View File

@@ -106,7 +106,7 @@ class EstimatorSpecBuilder(object):
self._frame_length = params['frame_length'] self._frame_length = params['frame_length']
self._frame_step = params['frame_step'] self._frame_step = params['frame_step']
def _build_output_dict(self): def _build_output_dict(self, input_tensor=None):
""" Created a batch_sizexTxFxn_channels input tensor containing """ Created a batch_sizexTxFxn_channels input tensor containing
mix magnitude spectrogram, then an output dict from it according mix magnitude spectrogram, then an output dict from it according
to the selected model in internal parameters. to the selected model in internal parameters.
@@ -114,7 +114,8 @@ class EstimatorSpecBuilder(object):
:returns: Build output dict. :returns: Build output dict.
:raise ValueError: If required model_type is not supported. :raise ValueError: If required model_type is not supported.
""" """
input_tensor = self._features[f'{self._mix_name}_spectrogram'] if input_tensor is None:
input_tensor = self._features[f'{self._mix_name}_spectrogram']
model = self._params.get('model', None) model = self._params.get('model', None)
if model is not None: if model is not None:
model_type = model.get('type', self.DEFAULT_MODEL) model_type = model.get('type', self.DEFAULT_MODEL)
@@ -194,6 +195,12 @@ class EstimatorSpecBuilder(object):
self._features[f'{self._mix_name}_spectrogram'] = tf.abs( self._features[f'{self._mix_name}_spectrogram'] = tf.abs(
pad_and_partition(stft_feature, self._T))[:, :, :self._F, :] pad_and_partition(stft_feature, self._T))[:, :, :self._F, :]
def get_stft_feature(self):
return self._features[f'{self._mix_name}_stft']
def get_spectrogram_feature(self):
return self._features[f'{self._mix_name}_spectrogram']
def _inverse_stft(self, stft): def _inverse_stft(self, stft):
""" Inverse and reshape the given STFT """ Inverse and reshape the given STFT
@@ -269,26 +276,18 @@ class EstimatorSpecBuilder(object):
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1]) extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
return tf.concat([mask, extension], axis=2) return tf.concat([mask, extension], axis=2)
def _build_manual_output_waveform(self, output_dict): def _build_masks(self, output_dict):
""" Perform ratio mask separation
:param output_dict: dictionary of estimated spectrogram (key: instrument
name, value: estimated spectrogram of the instrument)
:returns: dictionary of separated waveforms (key: instrument name,
value: estimated waveform of the instrument)
"""
separation_exponent = self._params['separation_exponent'] separation_exponent = self._params['separation_exponent']
output_sum = tf.reduce_sum( output_sum = tf.reduce_sum(
[e ** separation_exponent for e in output_dict.values()], [e ** separation_exponent for e in output_dict.values()],
axis=0 axis=0
) + self.EPSILON ) + self.EPSILON
output_waveform = {} out = {}
for instrument in self._instruments: for instrument in self._instruments:
output = output_dict[f'{instrument}_spectrogram'] output = output_dict[f'{instrument}_spectrogram']
# Compute mask with the model. # Compute mask with the model.
instrument_mask = ( instrument_mask = (output ** separation_exponent
output ** separation_exponent + (self.EPSILON / len(output_dict))) / output_sum
+ (self.EPSILON / len(output_dict))) / output_sum
# Extend mask; # Extend mask;
instrument_mask = self._extend_mask(instrument_mask) instrument_mask = self._extend_mask(instrument_mask)
# Stack back mask. # Stack back mask.
@@ -300,10 +299,31 @@ class EstimatorSpecBuilder(object):
# Remove padded part (for mask having the same size as STFT); # Remove padded part (for mask having the same size as STFT);
stft_feature = self._features[f'{self._mix_name}_stft'] stft_feature = self._features[f'{self._mix_name}_stft']
instrument_mask = instrument_mask[ instrument_mask = instrument_mask[
:tf.shape(stft_feature)[0], ...] :tf.shape(stft_feature)[0], ...]
# Compute masked STFT and normalize it. out[instrument] = instrument_mask
output_waveform[instrument] = self._inverse_stft( return out
tf.cast(instrument_mask, dtype=tf.complex64) * stft_feature)
def _build_masked_stft(self, mask_dict, input_stft=None):
if input_stft is None:
input_stft = self._features[f'{self._mix_name}_stft']
out = {}
for instrument, mask in mask_dict.items():
out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft
return out
def _build_manual_output_waveform(self, output_dict):
""" Perform ratio mask separation
:param output_dict: dictionary of estimated spectrogram (key: instrument
name, value: estimated spectrogram of the instrument)
:returns: dictionary of separated waveforms (key: instrument name,
value: estimated waveform of the instrument)
"""
output_waveform = {}
masked_stft = self._build_masked_stft(self._build_masks(output_dict))
for instrument, stft_data in masked_stft.items():
output_waveform[instrument] = self._inverse_stft(stft_data)
return output_waveform return output_waveform
def _build_output_waveform(self, output_dict): def _build_output_waveform(self, output_dict):

View File

@@ -13,22 +13,31 @@
""" """
import os import os
import logging
from time import time
from multiprocessing import Pool from multiprocessing import Pool
from os.path import basename, join, splitext from os.path import basename, join, splitext
import numpy as np import numpy as np
import tensorflow as tf
from . import SpleeterError from . import SpleeterError
from .audio.adapter import get_default_audio_adapter from .audio.adapter import get_default_audio_adapter
from .audio.convertor import to_stereo from .audio.convertor import to_stereo
from .utils.configuration import load_configuration from .utils.configuration import load_configuration
from .utils.estimator import create_estimator, to_predictor from .utils.estimator import create_estimator, to_predictor, get_input_dict_placeholders, get_default_model_dir
from .model import EstimatorSpecBuilder
__email__ = 'research@deezer.com' __email__ = 'research@deezer.com'
__author__ = 'Deezer Research' __author__ = 'Deezer Research'
__license__ = 'MIT License' __license__ = 'MIT License'
logger = logging.getLogger("spleeter")
class Separator(object): class Separator(object):
""" A wrapper class for performing separation. """ """ A wrapper class for performing separation. """
@@ -87,16 +96,79 @@ class Separator(object):
prediction.pop('audio_id') prediction.pop('audio_id')
return prediction return prediction
def separate_chunked(self, waveform, sample_rate, chunk_duration=-1): def get_valid_chunk_size(self, sample_rate: int, chunk_max_duration: float) -> int:
chunk_size = waveform.shape[0] if chunk_duration == -1 else chunk_duration*sample_rate """
n_chunks = int(waveform.shape[0]/chunk_size) Given a sample rate, and a maximal duration that a chunk can represent, return the maximum chunk
size in samples. The chunk size must be a non-zero multiple of T (temporal dimension of the input spectrogram)
times F (number of frequency bins in the input spectrogram). If no such value exist, we return T*F.
:param sample_rate: sample rate of the pcm data
:param chunk_max_duration: maximal duration in seconds of a chunk
:return: highest non-zero chunk size of duration less than chunk_max_duration or minimal valid chunk size.
"""
assert chunk_max_duration > 0
chunk_size = chunk_max_duration * sample_rate
min_sample_size = self._params["T"] * self._params["F"]
if chunk_size < min_sample_size:
min_duration = min_sample_size / sample_rate
logger.warning("chunk_duration must be at least {:.2f} seconds. Ignoring parameter".format(min_duration))
chunk_size = min_sample_size
return min_sample_size*int(chunk_size/min_sample_size)
def get_batch_size_for_chunk_size(self, chunk_size):
d = self._params["T"] * self._params["F"]
assert chunk_size % d == 0
return chunk_size//d
def separate_chunked(self, waveform, sample_rate, chunk_max_duration):
chunk_size = self.get_valid_chunk_size(sample_rate, chunk_max_duration)
print(f"chunk size is {chunk_size}")
batch_size = self.get_batch_size_for_chunk_size(chunk_size)
print(f"batch size {batch_size}")
T, F = self._params["T"], self._params["F"]
out = {} out = {}
for i in range(n_chunks): n_batches = (waveform.shape[0]+batch_size*T*F-1)//(batch_size*T*F)
sources = self.separate(waveform) print(f"{n_batches} to compute")
for inst, data in sources.items(): features = get_input_dict_placeholders(self._params)
out.setdefault(inst, []).append(data) spectrogram_input_t = tf.placeholder(tf.float32, shape=(None, T, F, 2), name="spectrogram_input")
for inst, data in out.items(): istft_input_t = tf.placeholder(tf.complex64, shape=(None, F, 2), name="istft_input")
out[inst] = np.concatenate(data, axis=0) start_t = tf.placeholder(tf.int32, shape=(), name="start")
end_t = tf.placeholder(tf.int32, shape=(), name="end")
builder = EstimatorSpecBuilder(features, self._params)
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
# TODO: fix the logic, build sometimes return, sometimes set attribute
builder._build_stft_feature()
stft_t = builder.get_stft_feature()
output_dict_t = builder._build_output_dict(input_tensor=spectrogram_input_t)
masked_stft_t = builder._build_masked_stft(builder._build_masks(output_dict_t),
input_stft=stft_t[start_t:end_t, :, :])
output_waveform_t = builder._inverse_stft(istft_input_t)
waveform_t = features["waveform"]
masked_stfts = {}
saver = tf.train.Saver()
with tf.Session() as sess:
print("restoring weights {}".format(time()))
saver.restore(sess, latest_checkpoint)
print("computing spectrogram {}".format(time()))
spectrogram, stft = sess.run([builder.get_spectrogram_feature(), stft_t], feed_dict={waveform_t: waveform})
print(spectrogram.shape)
print(stft.shape)
for i in range(n_batches):
print("computing batch {} {}".format(i, time()))
start = i*batch_size
end = (i+1)*batch_size
tmp = sess.run(masked_stft_t,
feed_dict={spectrogram_input_t: spectrogram[start:end, ...],
start_t: start*T, end_t: end*T, stft_t: stft})
for instrument, masked_stft in tmp.items():
masked_stfts.setdefault(instrument, []).append(masked_stft)
print("inverting spectrogram {}".format(time()))
for instrument, masked_stft in masked_stfts.items():
out[instrument] = sess.run(output_waveform_t, {istft_input_t: np.concatenate(masked_stft, axis=0)})
print("done separating {}".format(time()))
return out return out
def separate_to_file( def separate_to_file(
@@ -126,12 +198,15 @@ class Separator(object):
:param filename_format: (Optional) Filename format. :param filename_format: (Optional) Filename format.
:param synchronous: (Optional) True is should by synchronous. :param synchronous: (Optional) True is should by synchronous.
""" """
print("loading audio {}".format(time()))
waveform, sample_rate = audio_adapter.load( waveform, sample_rate = audio_adapter.load(
audio_descriptor, audio_descriptor,
offset=offset, offset=offset,
duration=duration, duration=duration,
sample_rate=self._sample_rate) sample_rate=self._sample_rate)
sources = self.separate_chunked(waveform, sample_rate, chunk_duration=chunk_duration) print("done loading audio {}".format(time()))
sources = self.separate_chunked(waveform, sample_rate, chunk_duration)
print("saving to file {}".format(time()))
self.save_to_file(sources, audio_descriptor, destination, filename_format, codec, self.save_to_file(sources, audio_descriptor, destination, filename_format, codec,
audio_adapter, bitrate, synchronous) audio_adapter, bitrate, synchronous)

View File

@@ -20,20 +20,30 @@ from ..model.provider import get_default_model_provider
DEFAULT_EXPORT_DIRECTORY = join(gettempdir(), 'serving') DEFAULT_EXPORT_DIRECTORY = join(gettempdir(), 'serving')
def get_default_model_dir(model_dir):
"""
Transforms a string like 'spleeter:2stems' into an actual path.
:param model_dir:
:return:
"""
model_provider = get_default_model_provider()
return model_provider.get(model_dir)
def create_estimator(params, MWF): def create_estimator(params, MWF):
""" """
Initialize tensorflow estimator that will perform separation Initialize tensorflow estimator that will perform separation
Params: Params:
- params: a dictionnary of parameters for building the model - params: a dictionary of parameters for building the model
Returns: Returns:
a tensorflow estimator a tensorflow estimator
""" """
# Load model. # Load model.
model_directory = params['model_dir']
model_provider = get_default_model_provider()
params['model_dir'] = model_provider.get(model_directory) params['model_dir'] = get_default_model_dir(params['model_dir'])
params['MWF'] = MWF params['MWF'] = MWF
# Setup config # Setup config
session_config = tf.compat.v1.ConfigProto() session_config = tf.compat.v1.ConfigProto()
@@ -49,6 +59,14 @@ def create_estimator(params, MWF):
return estimator return estimator
def get_input_dict_placeholders(params):
shape = (None, params['n_channels'])
features = {
'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape, name="waveform"),
'audio_id': tf.compat.v1.placeholder(tf.string, name="audio_id")}
return features
def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY): def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
""" Exports given estimator as predictor into the given directory """ Exports given estimator as predictor into the given directory
and returns associated tf.predictor instance. and returns associated tf.predictor instance.
@@ -57,10 +75,7 @@ def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
:param directory: (Optional) path to write exported model into. :param directory: (Optional) path to write exported model into.
""" """
def receiver(): def receiver():
shape = (None, estimator.params['n_channels']) features = get_input_dict_placeholders(estimator.params)
features = {
'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape),
'audio_id': tf.compat.v1.placeholder(tf.string)}
return tf.estimator.export.ServingInputReceiver(features, features) return tf.estimator.export.ServingInputReceiver(features, features)
estimator.export_saved_model(directory, receiver) estimator.export_saved_model(directory, receiver)