mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
Adding option to use librosa backend.
Changes in the EstimatorBuilder to set attributes instead of returning tensors for the _build methods. InputProvider classes to handle the different backend cases. New method in Separator.
This commit is contained in:
@@ -20,14 +20,15 @@ from multiprocessing import Pool
|
||||
from os.path import basename, join, splitext
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from librosa.core import stft, istft
|
||||
from scipy.signal.windows import hann
|
||||
|
||||
from . import SpleeterError
|
||||
from .audio.adapter import get_default_audio_adapter
|
||||
from .audio.convertor import to_stereo
|
||||
from .utils.configuration import load_configuration
|
||||
from .utils.estimator import create_estimator, to_predictor, get_input_dict_placeholders, get_default_model_dir
|
||||
from .model import EstimatorSpecBuilder
|
||||
from .utils.estimator import create_estimator, to_predictor, get_default_model_dir
|
||||
from .model import EstimatorSpecBuilder, InputProviderFactory
|
||||
|
||||
|
||||
__email__ = 'research@deezer.com'
|
||||
@@ -41,7 +42,7 @@ logger = logging.getLogger("spleeter")
|
||||
class Separator(object):
|
||||
""" A wrapper class for performing separation. """
|
||||
|
||||
def __init__(self, params_descriptor, MWF=False, multiprocess=True):
|
||||
def __init__(self, params_descriptor, MWF=False, stft_backend="auto", multiprocess=True):
|
||||
""" Default constructor.
|
||||
|
||||
:param params_descriptor: Descriptor for TF params to be used.
|
||||
@@ -53,6 +54,7 @@ class Separator(object):
|
||||
self._predictor = None
|
||||
self._pool = Pool() if multiprocess else None
|
||||
self._tasks = []
|
||||
self._params["stft_backend"] = stft_backend
|
||||
|
||||
def _get_predictor(self):
|
||||
""" Lazy loading access method for internal predictor instance.
|
||||
@@ -120,60 +122,41 @@ class Separator(object):
|
||||
assert chunk_size % d == 0
|
||||
return chunk_size//d
|
||||
|
||||
def separate_chunked(self, waveform, sample_rate, chunk_max_duration):
|
||||
chunk_size = self.get_valid_chunk_size(sample_rate, chunk_max_duration)
|
||||
print(f"chunk size is {chunk_size}")
|
||||
batch_size = self.get_batch_size_for_chunk_size(chunk_size)
|
||||
print(f"batch size {batch_size}")
|
||||
T, F = self._params["T"], self._params["F"]
|
||||
def stft(self, waveform, inverse=False):
|
||||
N = self._params["frame_length"]
|
||||
H = self._params["frame_step"]
|
||||
win = hann(N, sym=False)
|
||||
fstft = istft if inverse else stft
|
||||
win_len_arg = "win_length" if inverse else "n_fft"
|
||||
s1 = fstft(waveform[:, 0], hop_length=H, window=win, center=False, **{win_len_arg: N})
|
||||
s2 = fstft(waveform[:, 1], hop_length=H, window=win, center=False, **{win_len_arg: N})
|
||||
s1 = np.expand_dims(s1.T, 2-inverse)
|
||||
s2 = np.expand_dims(s2.T, 2-inverse)
|
||||
return np.concatenate([s1, s2], axis=2-inverse)
|
||||
|
||||
def separate_librosa(self, waveform, audio_id):
|
||||
out = {}
|
||||
n_batches = (waveform.shape[0]+batch_size*T*F-1)//(batch_size*T*F)
|
||||
print(f"{n_batches} to compute")
|
||||
features = get_input_dict_placeholders(self._params)
|
||||
spectrogram_input_t = tf.placeholder(tf.float32, shape=(None, T, F, 2), name="spectrogram_input")
|
||||
istft_input_t = tf.placeholder(tf.complex64, shape=(None, F, 2), name="istft_input")
|
||||
start_t = tf.placeholder(tf.int32, shape=(), name="start")
|
||||
end_t = tf.placeholder(tf.int32, shape=(), name="end")
|
||||
input_provider = InputProviderFactory.get(self._params)
|
||||
features = input_provider.get_input_dict_placeholders()
|
||||
|
||||
builder = EstimatorSpecBuilder(features, self._params)
|
||||
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
|
||||
|
||||
# TODO: fix the logic, build sometimes return, sometimes set attribute
|
||||
builder._build_stft_feature()
|
||||
stft_t = builder.get_stft_feature()
|
||||
output_dict_t = builder._build_output_dict(input_tensor=spectrogram_input_t)
|
||||
masked_stft_t = builder._build_masked_stft(builder._build_masks(output_dict_t),
|
||||
input_stft=stft_t[start_t:end_t, :, :])
|
||||
output_waveform_t = builder._inverse_stft(istft_input_t)
|
||||
waveform_t = features["waveform"]
|
||||
masked_stfts = {}
|
||||
outputs = builder.outputs
|
||||
|
||||
saver = tf.train.Saver()
|
||||
|
||||
stft = self.stft(waveform)
|
||||
with tf.Session() as sess:
|
||||
print("restoring weights {}".format(time()))
|
||||
saver.restore(sess, latest_checkpoint)
|
||||
print("computing spectrogram {}".format(time()))
|
||||
spectrogram, stft = sess.run([builder.get_spectrogram_feature(), stft_t], feed_dict={waveform_t: waveform})
|
||||
print(spectrogram.shape)
|
||||
print(stft.shape)
|
||||
for i in range(n_batches):
|
||||
print("computing batch {} {}".format(i, time()))
|
||||
start = i*batch_size
|
||||
end = (i+1)*batch_size
|
||||
tmp = sess.run(masked_stft_t,
|
||||
feed_dict={spectrogram_input_t: spectrogram[start:end, ...],
|
||||
start_t: start*T, end_t: end*T, stft_t: stft})
|
||||
for instrument, masked_stft in tmp.items():
|
||||
masked_stfts.setdefault(instrument, []).append(masked_stft)
|
||||
|
||||
print("inverting spectrogram {}".format(time()))
|
||||
for instrument, masked_stft in masked_stfts.items():
|
||||
out[instrument] = sess.run(output_waveform_t, {istft_input_t: np.concatenate(masked_stft, axis=0)})
|
||||
print("done separating {}".format(time()))
|
||||
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))
|
||||
for inst in builder.instruments:
|
||||
out[inst] = self.stft(outputs[inst], inverse=True)
|
||||
return out
|
||||
|
||||
def separate_to_file(
|
||||
self, audio_descriptor, destination,
|
||||
audio_adapter=get_default_audio_adapter(), chunk_duration=-1,
|
||||
audio_adapter=get_default_audio_adapter(),
|
||||
offset=0, duration=600., codec='wav', bitrate='128k',
|
||||
filename_format='{filename}/{instrument}.{codec}',
|
||||
synchronous=True):
|
||||
@@ -198,15 +181,15 @@ class Separator(object):
|
||||
:param filename_format: (Optional) Filename format.
|
||||
:param synchronous: (Optional) True is should by synchronous.
|
||||
"""
|
||||
print("loading audio {}".format(time()))
|
||||
waveform, sample_rate = audio_adapter.load(
|
||||
audio_descriptor,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
sample_rate=self._sample_rate)
|
||||
print("done loading audio {}".format(time()))
|
||||
sources = self.separate_chunked(waveform, sample_rate, chunk_duration)
|
||||
print("saving to file {}".format(time()))
|
||||
if self._params["stft_backend"] == "tensorflow":
|
||||
sources = self.separate(waveform)
|
||||
else:
|
||||
sources = self.separate_librosa(waveform, audio_descriptor)
|
||||
self.save_to_file(sources, audio_descriptor, destination, filename_format, codec,
|
||||
audio_adapter, bitrate, synchronous)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user