First draft implementing correct reconstruction

This commit is contained in:
akhlif
2020-02-19 23:11:16 +01:00
parent f93dfbc235
commit aa7c208b39
5 changed files with 152 additions and 38 deletions

View File

@@ -13,22 +13,31 @@
"""
import os
import logging
from time import time
from multiprocessing import Pool
from os.path import basename, join, splitext
import numpy as np
import tensorflow as tf
from . import SpleeterError
from .audio.adapter import get_default_audio_adapter
from .audio.convertor import to_stereo
from .utils.configuration import load_configuration
from .utils.estimator import create_estimator, to_predictor
from .utils.estimator import create_estimator, to_predictor, get_input_dict_placeholders, get_default_model_dir
from .model import EstimatorSpecBuilder
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
logger = logging.getLogger("spleeter")
class Separator(object):
""" A wrapper class for performing separation. """
@@ -87,16 +96,79 @@ class Separator(object):
prediction.pop('audio_id')
return prediction
def separate_chunked(self, waveform, sample_rate, chunk_duration=-1):
chunk_size = waveform.shape[0] if chunk_duration == -1 else chunk_duration*sample_rate
n_chunks = int(waveform.shape[0]/chunk_size)
def get_valid_chunk_size(self, sample_rate: int, chunk_max_duration: float) -> int:
"""
Given a sample rate, and a maximal duration that a chunk can represent, return the maximum chunk
size in samples. The chunk size must be a non-zero multiple of T (temporal dimension of the input spectrogram)
times F (number of frequency bins in the input spectrogram). If no such value exist, we return T*F.
:param sample_rate: sample rate of the pcm data
:param chunk_max_duration: maximal duration in seconds of a chunk
:return: highest non-zero chunk size of duration less than chunk_max_duration or minimal valid chunk size.
"""
assert chunk_max_duration > 0
chunk_size = chunk_max_duration * sample_rate
min_sample_size = self._params["T"] * self._params["F"]
if chunk_size < min_sample_size:
min_duration = min_sample_size / sample_rate
logger.warning("chunk_duration must be at least {:.2f} seconds. Ignoring parameter".format(min_duration))
chunk_size = min_sample_size
return min_sample_size*int(chunk_size/min_sample_size)
def get_batch_size_for_chunk_size(self, chunk_size):
d = self._params["T"] * self._params["F"]
assert chunk_size % d == 0
return chunk_size//d
def separate_chunked(self, waveform, sample_rate, chunk_max_duration):
chunk_size = self.get_valid_chunk_size(sample_rate, chunk_max_duration)
print(f"chunk size is {chunk_size}")
batch_size = self.get_batch_size_for_chunk_size(chunk_size)
print(f"batch size {batch_size}")
T, F = self._params["T"], self._params["F"]
out = {}
for i in range(n_chunks):
sources = self.separate(waveform)
for inst, data in sources.items():
out.setdefault(inst, []).append(data)
for inst, data in out.items():
out[inst] = np.concatenate(data, axis=0)
n_batches = (waveform.shape[0]+batch_size*T*F-1)//(batch_size*T*F)
print(f"{n_batches} to compute")
features = get_input_dict_placeholders(self._params)
spectrogram_input_t = tf.placeholder(tf.float32, shape=(None, T, F, 2), name="spectrogram_input")
istft_input_t = tf.placeholder(tf.complex64, shape=(None, F, 2), name="istft_input")
start_t = tf.placeholder(tf.int32, shape=(), name="start")
end_t = tf.placeholder(tf.int32, shape=(), name="end")
builder = EstimatorSpecBuilder(features, self._params)
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
# TODO: fix the logic, build sometimes return, sometimes set attribute
builder._build_stft_feature()
stft_t = builder.get_stft_feature()
output_dict_t = builder._build_output_dict(input_tensor=spectrogram_input_t)
masked_stft_t = builder._build_masked_stft(builder._build_masks(output_dict_t),
input_stft=stft_t[start_t:end_t, :, :])
output_waveform_t = builder._inverse_stft(istft_input_t)
waveform_t = features["waveform"]
masked_stfts = {}
saver = tf.train.Saver()
with tf.Session() as sess:
print("restoring weights {}".format(time()))
saver.restore(sess, latest_checkpoint)
print("computing spectrogram {}".format(time()))
spectrogram, stft = sess.run([builder.get_spectrogram_feature(), stft_t], feed_dict={waveform_t: waveform})
print(spectrogram.shape)
print(stft.shape)
for i in range(n_batches):
print("computing batch {} {}".format(i, time()))
start = i*batch_size
end = (i+1)*batch_size
tmp = sess.run(masked_stft_t,
feed_dict={spectrogram_input_t: spectrogram[start:end, ...],
start_t: start*T, end_t: end*T, stft_t: stft})
for instrument, masked_stft in tmp.items():
masked_stfts.setdefault(instrument, []).append(masked_stft)
print("inverting spectrogram {}".format(time()))
for instrument, masked_stft in masked_stfts.items():
out[instrument] = sess.run(output_waveform_t, {istft_input_t: np.concatenate(masked_stft, axis=0)})
print("done separating {}".format(time()))
return out
def separate_to_file(
@@ -126,12 +198,15 @@ class Separator(object):
:param filename_format: (Optional) Filename format.
:param synchronous: (Optional) True is should by synchronous.
"""
print("loading audio {}".format(time()))
waveform, sample_rate = audio_adapter.load(
audio_descriptor,
offset=offset,
duration=duration,
sample_rate=self._sample_rate)
sources = self.separate_chunked(waveform, sample_rate, chunk_duration=chunk_duration)
print("done loading audio {}".format(time()))
sources = self.separate_chunked(waveform, sample_rate, chunk_duration)
print("saving to file {}".format(time()))
self.save_to_file(sources, audio_descriptor, destination, filename_format, codec,
audio_adapter, bitrate, synchronous)