mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-06 04:32:43 +00:00
First draft implementing correct reconstruction
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
""" This modules provides spleeter command as well as CLI parsing methods. """
|
||||
|
||||
import json
|
||||
|
||||
import logging
|
||||
from argparse import ArgumentParser
|
||||
from tempfile import gettempdir
|
||||
from os.path import exists, join
|
||||
@@ -13,6 +13,9 @@ __email__ = 'research@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
logging.basicConfig()
|
||||
|
||||
|
||||
# -i opt specification (separate).
|
||||
OPT_INPUT = {
|
||||
'dest': 'inputs',
|
||||
|
||||
@@ -35,6 +35,7 @@ def entrypoint(arguments, params):
|
||||
filename,
|
||||
arguments.output_path,
|
||||
audio_adapter=audio_adapter,
|
||||
chunk_duration=arguments.chunk_duration,
|
||||
offset=arguments.offset,
|
||||
duration=arguments.duration,
|
||||
codec=arguments.codec,
|
||||
|
||||
@@ -106,7 +106,7 @@ class EstimatorSpecBuilder(object):
|
||||
self._frame_length = params['frame_length']
|
||||
self._frame_step = params['frame_step']
|
||||
|
||||
def _build_output_dict(self):
|
||||
def _build_output_dict(self, input_tensor=None):
|
||||
""" Created a batch_sizexTxFxn_channels input tensor containing
|
||||
mix magnitude spectrogram, then an output dict from it according
|
||||
to the selected model in internal parameters.
|
||||
@@ -114,7 +114,8 @@ class EstimatorSpecBuilder(object):
|
||||
:returns: Build output dict.
|
||||
:raise ValueError: If required model_type is not supported.
|
||||
"""
|
||||
input_tensor = self._features[f'{self._mix_name}_spectrogram']
|
||||
if input_tensor is None:
|
||||
input_tensor = self._features[f'{self._mix_name}_spectrogram']
|
||||
model = self._params.get('model', None)
|
||||
if model is not None:
|
||||
model_type = model.get('type', self.DEFAULT_MODEL)
|
||||
@@ -194,6 +195,12 @@ class EstimatorSpecBuilder(object):
|
||||
self._features[f'{self._mix_name}_spectrogram'] = tf.abs(
|
||||
pad_and_partition(stft_feature, self._T))[:, :, :self._F, :]
|
||||
|
||||
def get_stft_feature(self):
|
||||
return self._features[f'{self._mix_name}_stft']
|
||||
|
||||
def get_spectrogram_feature(self):
|
||||
return self._features[f'{self._mix_name}_spectrogram']
|
||||
|
||||
def _inverse_stft(self, stft):
|
||||
""" Inverse and reshape the given STFT
|
||||
|
||||
@@ -269,26 +276,18 @@ class EstimatorSpecBuilder(object):
|
||||
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
|
||||
return tf.concat([mask, extension], axis=2)
|
||||
|
||||
def _build_manual_output_waveform(self, output_dict):
|
||||
""" Perform ratio mask separation
|
||||
|
||||
:param output_dict: dictionary of estimated spectrogram (key: instrument
|
||||
name, value: estimated spectrogram of the instrument)
|
||||
:returns: dictionary of separated waveforms (key: instrument name,
|
||||
value: estimated waveform of the instrument)
|
||||
"""
|
||||
def _build_masks(self, output_dict):
|
||||
separation_exponent = self._params['separation_exponent']
|
||||
output_sum = tf.reduce_sum(
|
||||
[e ** separation_exponent for e in output_dict.values()],
|
||||
axis=0
|
||||
) + self.EPSILON
|
||||
output_waveform = {}
|
||||
out = {}
|
||||
for instrument in self._instruments:
|
||||
output = output_dict[f'{instrument}_spectrogram']
|
||||
# Compute mask with the model.
|
||||
instrument_mask = (
|
||||
output ** separation_exponent
|
||||
+ (self.EPSILON / len(output_dict))) / output_sum
|
||||
instrument_mask = (output ** separation_exponent
|
||||
+ (self.EPSILON / len(output_dict))) / output_sum
|
||||
# Extend mask;
|
||||
instrument_mask = self._extend_mask(instrument_mask)
|
||||
# Stack back mask.
|
||||
@@ -300,10 +299,31 @@ class EstimatorSpecBuilder(object):
|
||||
# Remove padded part (for mask having the same size as STFT);
|
||||
stft_feature = self._features[f'{self._mix_name}_stft']
|
||||
instrument_mask = instrument_mask[
|
||||
:tf.shape(stft_feature)[0], ...]
|
||||
# Compute masked STFT and normalize it.
|
||||
output_waveform[instrument] = self._inverse_stft(
|
||||
tf.cast(instrument_mask, dtype=tf.complex64) * stft_feature)
|
||||
:tf.shape(stft_feature)[0], ...]
|
||||
out[instrument] = instrument_mask
|
||||
return out
|
||||
|
||||
def _build_masked_stft(self, mask_dict, input_stft=None):
|
||||
if input_stft is None:
|
||||
input_stft = self._features[f'{self._mix_name}_stft']
|
||||
out = {}
|
||||
for instrument, mask in mask_dict.items():
|
||||
out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft
|
||||
return out
|
||||
|
||||
def _build_manual_output_waveform(self, output_dict):
|
||||
""" Perform ratio mask separation
|
||||
|
||||
:param output_dict: dictionary of estimated spectrogram (key: instrument
|
||||
name, value: estimated spectrogram of the instrument)
|
||||
:returns: dictionary of separated waveforms (key: instrument name,
|
||||
value: estimated waveform of the instrument)
|
||||
"""
|
||||
|
||||
output_waveform = {}
|
||||
masked_stft = self._build_masked_stft(self._build_masks(output_dict))
|
||||
for instrument, stft_data in masked_stft.items():
|
||||
output_waveform[instrument] = self._inverse_stft(stft_data)
|
||||
return output_waveform
|
||||
|
||||
def _build_output_waveform(self, output_dict):
|
||||
|
||||
@@ -13,22 +13,31 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
|
||||
from time import time
|
||||
from multiprocessing import Pool
|
||||
from os.path import basename, join, splitext
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
from . import SpleeterError
|
||||
from .audio.adapter import get_default_audio_adapter
|
||||
from .audio.convertor import to_stereo
|
||||
from .utils.configuration import load_configuration
|
||||
from .utils.estimator import create_estimator, to_predictor
|
||||
from .utils.estimator import create_estimator, to_predictor, get_input_dict_placeholders, get_default_model_dir
|
||||
from .model import EstimatorSpecBuilder
|
||||
|
||||
|
||||
__email__ = 'research@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
|
||||
logger = logging.getLogger("spleeter")
|
||||
|
||||
|
||||
class Separator(object):
|
||||
""" A wrapper class for performing separation. """
|
||||
|
||||
@@ -87,16 +96,79 @@ class Separator(object):
|
||||
prediction.pop('audio_id')
|
||||
return prediction
|
||||
|
||||
def separate_chunked(self, waveform, sample_rate, chunk_duration=-1):
|
||||
chunk_size = waveform.shape[0] if chunk_duration == -1 else chunk_duration*sample_rate
|
||||
n_chunks = int(waveform.shape[0]/chunk_size)
|
||||
def get_valid_chunk_size(self, sample_rate: int, chunk_max_duration: float) -> int:
|
||||
"""
|
||||
Given a sample rate, and a maximal duration that a chunk can represent, return the maximum chunk
|
||||
size in samples. The chunk size must be a non-zero multiple of T (temporal dimension of the input spectrogram)
|
||||
times F (number of frequency bins in the input spectrogram). If no such value exist, we return T*F.
|
||||
|
||||
:param sample_rate: sample rate of the pcm data
|
||||
:param chunk_max_duration: maximal duration in seconds of a chunk
|
||||
:return: highest non-zero chunk size of duration less than chunk_max_duration or minimal valid chunk size.
|
||||
"""
|
||||
assert chunk_max_duration > 0
|
||||
chunk_size = chunk_max_duration * sample_rate
|
||||
min_sample_size = self._params["T"] * self._params["F"]
|
||||
if chunk_size < min_sample_size:
|
||||
min_duration = min_sample_size / sample_rate
|
||||
logger.warning("chunk_duration must be at least {:.2f} seconds. Ignoring parameter".format(min_duration))
|
||||
chunk_size = min_sample_size
|
||||
return min_sample_size*int(chunk_size/min_sample_size)
|
||||
|
||||
def get_batch_size_for_chunk_size(self, chunk_size):
|
||||
d = self._params["T"] * self._params["F"]
|
||||
assert chunk_size % d == 0
|
||||
return chunk_size//d
|
||||
|
||||
def separate_chunked(self, waveform, sample_rate, chunk_max_duration):
|
||||
chunk_size = self.get_valid_chunk_size(sample_rate, chunk_max_duration)
|
||||
print(f"chunk size is {chunk_size}")
|
||||
batch_size = self.get_batch_size_for_chunk_size(chunk_size)
|
||||
print(f"batch size {batch_size}")
|
||||
T, F = self._params["T"], self._params["F"]
|
||||
out = {}
|
||||
for i in range(n_chunks):
|
||||
sources = self.separate(waveform)
|
||||
for inst, data in sources.items():
|
||||
out.setdefault(inst, []).append(data)
|
||||
for inst, data in out.items():
|
||||
out[inst] = np.concatenate(data, axis=0)
|
||||
n_batches = (waveform.shape[0]+batch_size*T*F-1)//(batch_size*T*F)
|
||||
print(f"{n_batches} to compute")
|
||||
features = get_input_dict_placeholders(self._params)
|
||||
spectrogram_input_t = tf.placeholder(tf.float32, shape=(None, T, F, 2), name="spectrogram_input")
|
||||
istft_input_t = tf.placeholder(tf.complex64, shape=(None, F, 2), name="istft_input")
|
||||
start_t = tf.placeholder(tf.int32, shape=(), name="start")
|
||||
end_t = tf.placeholder(tf.int32, shape=(), name="end")
|
||||
builder = EstimatorSpecBuilder(features, self._params)
|
||||
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
|
||||
|
||||
# TODO: fix the logic, build sometimes return, sometimes set attribute
|
||||
builder._build_stft_feature()
|
||||
stft_t = builder.get_stft_feature()
|
||||
output_dict_t = builder._build_output_dict(input_tensor=spectrogram_input_t)
|
||||
masked_stft_t = builder._build_masked_stft(builder._build_masks(output_dict_t),
|
||||
input_stft=stft_t[start_t:end_t, :, :])
|
||||
output_waveform_t = builder._inverse_stft(istft_input_t)
|
||||
waveform_t = features["waveform"]
|
||||
masked_stfts = {}
|
||||
saver = tf.train.Saver()
|
||||
|
||||
with tf.Session() as sess:
|
||||
print("restoring weights {}".format(time()))
|
||||
saver.restore(sess, latest_checkpoint)
|
||||
print("computing spectrogram {}".format(time()))
|
||||
spectrogram, stft = sess.run([builder.get_spectrogram_feature(), stft_t], feed_dict={waveform_t: waveform})
|
||||
print(spectrogram.shape)
|
||||
print(stft.shape)
|
||||
for i in range(n_batches):
|
||||
print("computing batch {} {}".format(i, time()))
|
||||
start = i*batch_size
|
||||
end = (i+1)*batch_size
|
||||
tmp = sess.run(masked_stft_t,
|
||||
feed_dict={spectrogram_input_t: spectrogram[start:end, ...],
|
||||
start_t: start*T, end_t: end*T, stft_t: stft})
|
||||
for instrument, masked_stft in tmp.items():
|
||||
masked_stfts.setdefault(instrument, []).append(masked_stft)
|
||||
|
||||
print("inverting spectrogram {}".format(time()))
|
||||
for instrument, masked_stft in masked_stfts.items():
|
||||
out[instrument] = sess.run(output_waveform_t, {istft_input_t: np.concatenate(masked_stft, axis=0)})
|
||||
print("done separating {}".format(time()))
|
||||
return out
|
||||
|
||||
def separate_to_file(
|
||||
@@ -126,12 +198,15 @@ class Separator(object):
|
||||
:param filename_format: (Optional) Filename format.
|
||||
:param synchronous: (Optional) True is should by synchronous.
|
||||
"""
|
||||
print("loading audio {}".format(time()))
|
||||
waveform, sample_rate = audio_adapter.load(
|
||||
audio_descriptor,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
sample_rate=self._sample_rate)
|
||||
sources = self.separate_chunked(waveform, sample_rate, chunk_duration=chunk_duration)
|
||||
print("done loading audio {}".format(time()))
|
||||
sources = self.separate_chunked(waveform, sample_rate, chunk_duration)
|
||||
print("saving to file {}".format(time()))
|
||||
self.save_to_file(sources, audio_descriptor, destination, filename_format, codec,
|
||||
audio_adapter, bitrate, synchronous)
|
||||
|
||||
|
||||
@@ -20,20 +20,30 @@ from ..model.provider import get_default_model_provider
|
||||
DEFAULT_EXPORT_DIRECTORY = join(gettempdir(), 'serving')
|
||||
|
||||
|
||||
|
||||
def get_default_model_dir(model_dir):
|
||||
"""
|
||||
Transforms a string like 'spleeter:2stems' into an actual path.
|
||||
:param model_dir:
|
||||
:return:
|
||||
"""
|
||||
model_provider = get_default_model_provider()
|
||||
return model_provider.get(model_dir)
|
||||
|
||||
def create_estimator(params, MWF):
|
||||
"""
|
||||
Initialize tensorflow estimator that will perform separation
|
||||
|
||||
Params:
|
||||
- params: a dictionnary of parameters for building the model
|
||||
- params: a dictionary of parameters for building the model
|
||||
|
||||
Returns:
|
||||
a tensorflow estimator
|
||||
"""
|
||||
# Load model.
|
||||
model_directory = params['model_dir']
|
||||
model_provider = get_default_model_provider()
|
||||
params['model_dir'] = model_provider.get(model_directory)
|
||||
|
||||
|
||||
params['model_dir'] = get_default_model_dir(params['model_dir'])
|
||||
params['MWF'] = MWF
|
||||
# Setup config
|
||||
session_config = tf.compat.v1.ConfigProto()
|
||||
@@ -49,6 +59,14 @@ def create_estimator(params, MWF):
|
||||
return estimator
|
||||
|
||||
|
||||
def get_input_dict_placeholders(params):
|
||||
shape = (None, params['n_channels'])
|
||||
features = {
|
||||
'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape, name="waveform"),
|
||||
'audio_id': tf.compat.v1.placeholder(tf.string, name="audio_id")}
|
||||
return features
|
||||
|
||||
|
||||
def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
|
||||
""" Exports given estimator as predictor into the given directory
|
||||
and returns associated tf.predictor instance.
|
||||
@@ -57,10 +75,7 @@ def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
|
||||
:param directory: (Optional) path to write exported model into.
|
||||
"""
|
||||
def receiver():
|
||||
shape = (None, estimator.params['n_channels'])
|
||||
features = {
|
||||
'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape),
|
||||
'audio_id': tf.compat.v1.placeholder(tf.string)}
|
||||
features = get_input_dict_placeholders(estimator.params)
|
||||
return tf.estimator.export.ServingInputReceiver(features, features)
|
||||
|
||||
estimator.export_saved_model(directory, receiver)
|
||||
|
||||
Reference in New Issue
Block a user