mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-06 04:32:43 +00:00
Merge pull request #282 from alreadytaikeune/librosa_backend
Librosa backend
This commit is contained in:
@@ -2,6 +2,7 @@ importlib_resources; python_version<'3.7'
|
||||
requests
|
||||
setuptools>=41.0.0
|
||||
pandas==0.25.1
|
||||
tensorflow==1.14.0
|
||||
tensorflow==1.15.0
|
||||
ffmpeg-python
|
||||
norbert==0.2.1
|
||||
norbert==0.2.1
|
||||
librosa==0.7.2
|
||||
5
setup.py
5
setup.py
@@ -14,9 +14,9 @@ __license__ = 'MIT License'
|
||||
|
||||
# Default project values.
|
||||
project_name = 'spleeter'
|
||||
project_version = '1.4.9'
|
||||
project_version = '1.5.0'
|
||||
tensorflow_dependency = 'tensorflow'
|
||||
tensorflow_version = '1.14.0'
|
||||
tensorflow_version = '1.15.0'
|
||||
here = path.abspath(path.dirname(__file__))
|
||||
readme_path = path.join(here, 'README.md')
|
||||
with open(readme_path, 'r') as stream:
|
||||
@@ -56,6 +56,7 @@ setup(
|
||||
'pandas==0.25.1',
|
||||
'requests',
|
||||
'setuptools>=41.0.0',
|
||||
'librosa==0.7.2',
|
||||
'{}=={}'.format(tensorflow_dependency, tensorflow_version),
|
||||
],
|
||||
extras_require={
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
""" This modules provides spleeter command as well as CLI parsing methods. """
|
||||
|
||||
import json
|
||||
|
||||
import logging
|
||||
from argparse import ArgumentParser
|
||||
from tempfile import gettempdir
|
||||
from os.path import exists, join
|
||||
@@ -13,6 +13,8 @@ __email__ = 'research@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
|
||||
|
||||
# -i opt specification (separate).
|
||||
OPT_INPUT = {
|
||||
'dest': 'inputs',
|
||||
@@ -68,6 +70,17 @@ OPT_DURATION = {
|
||||
'the input file)')
|
||||
}
|
||||
|
||||
# -w opt specification (separate)
|
||||
OPT_STFT_BACKEND = {
|
||||
'dest': 'stft_backend',
|
||||
'type': str,
|
||||
'choices' : ["tensorflow", "librosa", "auto"],
|
||||
'default': "auto",
|
||||
'help': 'Who should be in charge of computing the stfts. Librosa is faster than tensorflow on CPU and uses'
|
||||
' less memory. "auto" will use tensorflow when GPU acceleration is available and librosa when not.'
|
||||
}
|
||||
|
||||
|
||||
# -c opt specification (separate).
|
||||
OPT_CODEC = {
|
||||
'dest': 'codec',
|
||||
@@ -176,6 +189,7 @@ def _create_separate_parser(parser_factory):
|
||||
parser.add_argument('-c', '--codec', **OPT_CODEC)
|
||||
parser.add_argument('-b', '--birate', **OPT_BITRATE)
|
||||
parser.add_argument('-m', '--mwf', **OPT_MWF)
|
||||
parser.add_argument('-B', '--stft-backend', **OPT_STFT_BACKEND)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ __author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
|
||||
|
||||
def entrypoint(arguments, params):
|
||||
""" Command entrypoint.
|
||||
|
||||
@@ -29,7 +30,8 @@ def entrypoint(arguments, params):
|
||||
audio_adapter = get_audio_adapter(arguments.audio_adapter)
|
||||
separator = Separator(
|
||||
arguments.configuration,
|
||||
arguments.MWF)
|
||||
MWF=arguments.MWF,
|
||||
stft_backend=arguments.stft_backend)
|
||||
for filename in arguments.inputs:
|
||||
separator.separate_to_file(
|
||||
filename,
|
||||
|
||||
@@ -18,6 +18,9 @@ __author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
|
||||
placeholder = tf.compat.v1.placeholder
|
||||
|
||||
|
||||
def get_model_function(model_type):
|
||||
"""
|
||||
Get tensorflow function of the model to be applied to the input tensor.
|
||||
@@ -41,6 +44,74 @@ def get_model_function(model_type):
|
||||
return model_function
|
||||
|
||||
|
||||
class InputProvider(object):
|
||||
|
||||
def __init__(self, params):
|
||||
self.params = params
|
||||
|
||||
def get_input_dict_placeholders(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def input_names(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_feed_dict(self, features, *args):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class WaveformInputProvider(InputProvider):
|
||||
|
||||
@property
|
||||
def input_names(self):
|
||||
return ["audio_id", "waveform"]
|
||||
|
||||
def get_input_dict_placeholders(self):
|
||||
shape = (None, self.params['n_channels'])
|
||||
features = {
|
||||
'waveform': placeholder(tf.float32, shape=shape, name="waveform"),
|
||||
'audio_id': placeholder(tf.string, name="audio_id")}
|
||||
return features
|
||||
|
||||
def get_feed_dict(self, features, waveform, audio_id):
|
||||
return {features["audio_id"]: audio_id, features["waveform"]: waveform}
|
||||
|
||||
|
||||
class SpectralInputProvider(InputProvider):
|
||||
|
||||
def __init__(self, params):
|
||||
super().__init__(params)
|
||||
self.stft_input_name = "{}_stft".format(self.params["mix_name"])
|
||||
|
||||
@property
|
||||
def input_names(self):
|
||||
return ["audio_id", self.stft_input_name]
|
||||
|
||||
def get_input_dict_placeholders(self):
|
||||
features = {
|
||||
self.stft_input_name: placeholder(tf.complex64,
|
||||
shape=(None, self.params["frame_length"]//2+1,
|
||||
self.params['n_channels']),
|
||||
name=self.stft_input_name),
|
||||
'audio_id': placeholder(tf.string, name="audio_id")}
|
||||
return features
|
||||
|
||||
def get_feed_dict(self, features, stft, audio_id):
|
||||
return {features["audio_id"]: audio_id, features[self.stft_input_name]: stft}
|
||||
|
||||
|
||||
class InputProviderFactory(object):
|
||||
|
||||
@staticmethod
|
||||
def get(params):
|
||||
stft_backend = params["stft_backend"]
|
||||
assert stft_backend in ("tensorflow", "librosa"), "Unexpected backend {}".format(stft_backend)
|
||||
if stft_backend == "tensorflow":
|
||||
return WaveformInputProvider(params)
|
||||
else:
|
||||
return SpectralInputProvider(params)
|
||||
|
||||
|
||||
class EstimatorSpecBuilder(object):
|
||||
""" A builder class that allows to builds a multitrack unet model
|
||||
estimator. The built model estimator has a different behaviour when
|
||||
@@ -57,9 +128,9 @@ class EstimatorSpecBuilder(object):
|
||||
|
||||
>>> from spleeter.model import EstimatorSpecBuilder
|
||||
>>> builder = EstimatorSpecBuilder()
|
||||
>>> builder.build_prediction_model()
|
||||
>>> builder.build_predict_model()
|
||||
>>> builder.build_evaluation_model()
|
||||
>>> builder.build_training_model()
|
||||
>>> builder.build_train_model()
|
||||
|
||||
>>> from spleeter.model import model_fn
|
||||
>>> estimator = tf.estimator.Estimator(model_fn=model_fn, ...)
|
||||
@@ -94,6 +165,7 @@ class EstimatorSpecBuilder(object):
|
||||
:param features: The input features for the estimator.
|
||||
:param params: Some hyperparameters as a dictionary.
|
||||
"""
|
||||
|
||||
self._features = features
|
||||
self._params = params
|
||||
# Get instrument name.
|
||||
@@ -106,7 +178,10 @@ class EstimatorSpecBuilder(object):
|
||||
self._frame_length = params['frame_length']
|
||||
self._frame_step = params['frame_step']
|
||||
|
||||
def _build_output_dict(self):
|
||||
def include_stft_computations(self):
|
||||
return self._params["stft_backend"] == "tensorflow"
|
||||
|
||||
def _build_model_outputs(self):
|
||||
""" Created a batch_sizexTxFxn_channels input tensor containing
|
||||
mix magnitude spectrogram, then an output dict from it according
|
||||
to the selected model in internal parameters.
|
||||
@@ -114,7 +189,8 @@ class EstimatorSpecBuilder(object):
|
||||
:returns: Build output dict.
|
||||
:raise ValueError: If required model_type is not supported.
|
||||
"""
|
||||
input_tensor = self._features[f'{self._mix_name}_spectrogram']
|
||||
|
||||
input_tensor = self.spectrogram_feature
|
||||
model = self._params.get('model', None)
|
||||
if model is not None:
|
||||
model_type = model.get('type', self.DEFAULT_MODEL)
|
||||
@@ -124,12 +200,12 @@ class EstimatorSpecBuilder(object):
|
||||
apply_model = get_model_function(model_type)
|
||||
except ModuleNotFoundError:
|
||||
raise ValueError(f'No model function {model_type} found')
|
||||
return apply_model(
|
||||
self._model_outputs = apply_model(
|
||||
input_tensor,
|
||||
self._instruments,
|
||||
self._params['model']['params'])
|
||||
|
||||
def _build_loss(self, output_dict, labels):
|
||||
def _build_loss(self, labels):
|
||||
""" Construct tensorflow loss and metrics
|
||||
|
||||
:param output_dict: dictionary of network outputs (key: instrument
|
||||
@@ -138,6 +214,7 @@ class EstimatorSpecBuilder(object):
|
||||
name, value: ground truth spectrogram of the instrument)
|
||||
:returns: tensorflow (loss, metrics) tuple.
|
||||
"""
|
||||
output_dict = self.model_outputs
|
||||
loss_type = self._params.get('loss_type', self.L1_MASK)
|
||||
if loss_type == self.L1_MASK:
|
||||
losses = {
|
||||
@@ -177,51 +254,106 @@ class EstimatorSpecBuilder(object):
|
||||
return tf.compat.v1.train.GradientDescentOptimizer(rate)
|
||||
return tf.compat.v1.train.AdamOptimizer(rate)
|
||||
|
||||
@property
|
||||
def instruments(self):
|
||||
return self._instruments
|
||||
|
||||
@property
|
||||
def stft_name(self):
|
||||
return f'{self._mix_name}_stft'
|
||||
|
||||
@property
|
||||
def spectrogram_name(self):
|
||||
return f'{self._mix_name}_spectrogram'
|
||||
|
||||
def _build_stft_feature(self):
|
||||
""" Compute STFT of waveform and slice the STFT in segment
|
||||
with the right length to feed the network.
|
||||
"""
|
||||
stft_feature = tf.transpose(
|
||||
stft(
|
||||
tf.transpose(self._features['waveform']),
|
||||
self._frame_length,
|
||||
self._frame_step,
|
||||
window_fn=lambda frame_length, dtype: (
|
||||
hann_window(frame_length, periodic=True, dtype=dtype)),
|
||||
pad_end=True),
|
||||
perm=[1, 2, 0])
|
||||
self._features[f'{self._mix_name}_stft'] = stft_feature
|
||||
self._features[f'{self._mix_name}_spectrogram'] = tf.abs(
|
||||
pad_and_partition(stft_feature, self._T))[:, :, :self._F, :]
|
||||
|
||||
def _inverse_stft(self, stft):
|
||||
stft_name = self.stft_name
|
||||
spec_name = self.spectrogram_name
|
||||
|
||||
if stft_name not in self._features:
|
||||
stft_feature = tf.transpose(
|
||||
stft(
|
||||
tf.transpose(self._features['waveform']),
|
||||
self._frame_length,
|
||||
self._frame_step,
|
||||
window_fn=lambda frame_length, dtype: (
|
||||
hann_window(frame_length, periodic=True, dtype=dtype)),
|
||||
pad_end=True),
|
||||
perm=[1, 2, 0])
|
||||
self._features[f'{self._mix_name}_stft'] = stft_feature
|
||||
if spec_name not in self._features:
|
||||
self._features[spec_name] = tf.abs(
|
||||
pad_and_partition(self._features[stft_name], self._T))[:, :, :self._F, :]
|
||||
|
||||
@property
|
||||
def model_outputs(self):
|
||||
if not hasattr(self, "_model_outputs"):
|
||||
self._build_model_outputs()
|
||||
return self._model_outputs
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
if not hasattr(self, "_outputs"):
|
||||
self._build_outputs()
|
||||
return self._outputs
|
||||
|
||||
@property
|
||||
def stft_feature(self):
|
||||
if self.stft_name not in self._features:
|
||||
self._build_stft_feature()
|
||||
return self._features[self.stft_name]
|
||||
|
||||
@property
|
||||
def spectrogram_feature(self):
|
||||
if self.spectrogram_name not in self._features:
|
||||
self._build_stft_feature()
|
||||
return self._features[self.spectrogram_name]
|
||||
|
||||
@property
|
||||
def masks(self):
|
||||
if not hasattr(self, "_masks"):
|
||||
self._build_masks()
|
||||
return self._masks
|
||||
|
||||
@property
|
||||
def masked_stfts(self):
|
||||
if not hasattr(self, "_masked_stfts"):
|
||||
self._build_masked_stfts()
|
||||
return self._masked_stfts
|
||||
|
||||
def _inverse_stft(self, stft_t, time_crop=None):
|
||||
""" Inverse and reshape the given STFT
|
||||
|
||||
:param stft: input STFT
|
||||
:param stft_t: input STFT
|
||||
:returns: inverse STFT (waveform)
|
||||
"""
|
||||
inversed = inverse_stft(
|
||||
tf.transpose(stft, perm=[2, 0, 1]),
|
||||
tf.transpose(stft_t, perm=[2, 0, 1]),
|
||||
self._frame_length,
|
||||
self._frame_step,
|
||||
window_fn=lambda frame_length, dtype: (
|
||||
hann_window(frame_length, periodic=True, dtype=dtype))
|
||||
) * self.WINDOW_COMPENSATION_FACTOR
|
||||
reshaped = tf.transpose(inversed)
|
||||
return reshaped[:tf.shape(self._features['waveform'])[0], :]
|
||||
if time_crop is None:
|
||||
time_crop = tf.shape(self._features['waveform'])[0]
|
||||
return reshaped[:time_crop, :]
|
||||
|
||||
def _build_mwf_output_waveform(self, output_dict):
|
||||
def _build_mwf_output_waveform(self):
|
||||
""" Perform separation with multichannel Wiener Filtering using Norbert.
|
||||
Note: multichannel Wiener Filtering is not coded in Tensorflow and thus
|
||||
may be quite slow.
|
||||
|
||||
:param output_dict: dictionary of estimated spectrogram (key: instrument
|
||||
name, value: estimated spectrogram of the instrument)
|
||||
:returns: dictionary of separated waveforms (key: instrument name,
|
||||
value: estimated waveform of the instrument)
|
||||
"""
|
||||
import norbert # pylint: disable=import-error
|
||||
x = self._features[f'{self._mix_name}_stft']
|
||||
output_dict = self.model_outputs
|
||||
x = self.stft_feature
|
||||
v = tf.stack(
|
||||
[
|
||||
pad_and_reshape(
|
||||
@@ -265,30 +397,28 @@ class EstimatorSpecBuilder(object):
|
||||
mask_shape[-1]))
|
||||
else:
|
||||
raise ValueError(f'Invalid mask_extension parameter {extension}')
|
||||
n_extra_row = (self._frame_length) // 2 + 1 - self._F
|
||||
n_extra_row = self._frame_length // 2 + 1 - self._F
|
||||
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
|
||||
return tf.concat([mask, extension], axis=2)
|
||||
|
||||
def _build_manual_output_waveform(self, output_dict):
|
||||
""" Perform ratio mask separation
|
||||
|
||||
:param output_dict: dictionary of estimated spectrogram (key: instrument
|
||||
name, value: estimated spectrogram of the instrument)
|
||||
:returns: dictionary of separated waveforms (key: instrument name,
|
||||
value: estimated waveform of the instrument)
|
||||
def _build_masks(self):
|
||||
"""
|
||||
Compute masks from the output spectrograms of the model.
|
||||
:return:
|
||||
"""
|
||||
output_dict = self.model_outputs
|
||||
stft_feature = self.stft_feature
|
||||
separation_exponent = self._params['separation_exponent']
|
||||
output_sum = tf.reduce_sum(
|
||||
[e ** separation_exponent for e in output_dict.values()],
|
||||
axis=0
|
||||
) + self.EPSILON
|
||||
output_waveform = {}
|
||||
out = {}
|
||||
for instrument in self._instruments:
|
||||
output = output_dict[f'{instrument}_spectrogram']
|
||||
# Compute mask with the model.
|
||||
instrument_mask = (
|
||||
output ** separation_exponent
|
||||
+ (self.EPSILON / len(output_dict))) / output_sum
|
||||
instrument_mask = (output ** separation_exponent
|
||||
+ (self.EPSILON / len(output_dict))) / output_sum
|
||||
# Extend mask;
|
||||
instrument_mask = self._extend_mask(instrument_mask)
|
||||
# Stack back mask.
|
||||
@@ -298,30 +428,56 @@ class EstimatorSpecBuilder(object):
|
||||
axis=0)
|
||||
instrument_mask = tf.reshape(instrument_mask, new_shape)
|
||||
# Remove padded part (for mask having the same size as STFT);
|
||||
stft_feature = self._features[f'{self._mix_name}_stft']
|
||||
|
||||
instrument_mask = instrument_mask[
|
||||
:tf.shape(stft_feature)[0], ...]
|
||||
# Compute masked STFT and normalize it.
|
||||
output_waveform[instrument] = self._inverse_stft(
|
||||
tf.cast(instrument_mask, dtype=tf.complex64) * stft_feature)
|
||||
:tf.shape(stft_feature)[0], ...]
|
||||
out[instrument] = instrument_mask
|
||||
self._masks = out
|
||||
|
||||
def _build_masked_stfts(self):
|
||||
input_stft = self.stft_feature
|
||||
out = {}
|
||||
for instrument, mask in self.masks.items():
|
||||
out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft
|
||||
self._masked_stfts = out
|
||||
|
||||
def _build_manual_output_waveform(self, masked_stft):
|
||||
""" Perform ratio mask separation
|
||||
|
||||
:param output_dict: dictionary of estimated spectrogram (key: instrument
|
||||
name, value: estimated spectrogram of the instrument)
|
||||
:returns: dictionary of separated waveforms (key: instrument name,
|
||||
value: estimated waveform of the instrument)
|
||||
"""
|
||||
|
||||
output_waveform = {}
|
||||
for instrument, stft_data in masked_stft.items():
|
||||
output_waveform[instrument] = self._inverse_stft(stft_data)
|
||||
return output_waveform
|
||||
|
||||
def _build_output_waveform(self, output_dict):
|
||||
def _build_output_waveform(self, masked_stft):
|
||||
""" Build output waveform from given output dict in order to be used in
|
||||
prediction context. Regarding of the configuration building method will
|
||||
be using MWF.
|
||||
|
||||
:param output_dict: Output dict to build output waveform from.
|
||||
:returns: Built output waveform.
|
||||
"""
|
||||
|
||||
if self._params.get('MWF', False):
|
||||
output_waveform = self._build_mwf_output_waveform(output_dict)
|
||||
output_waveform = self._build_mwf_output_waveform()
|
||||
else:
|
||||
output_waveform = self._build_manual_output_waveform(output_dict)
|
||||
if 'audio_id' in self._features:
|
||||
output_waveform['audio_id'] = self._features['audio_id']
|
||||
output_waveform = self._build_manual_output_waveform(masked_stft)
|
||||
return output_waveform
|
||||
|
||||
def _build_outputs(self):
|
||||
if self.include_stft_computations():
|
||||
self._outputs = self._build_output_waveform(self.masked_stfts)
|
||||
else:
|
||||
self._outputs = self.masked_stfts
|
||||
|
||||
if 'audio_id' in self._features:
|
||||
self._outputs['audio_id'] = self._features['audio_id']
|
||||
|
||||
def build_predict_model(self):
|
||||
""" Builder interface for creating model instance that aims to perform
|
||||
prediction / inference over given track. The output of such estimator
|
||||
@@ -330,12 +486,10 @@ class EstimatorSpecBuilder(object):
|
||||
|
||||
:returns: An estimator for performing prediction.
|
||||
"""
|
||||
self._build_stft_feature()
|
||||
output_dict = self._build_output_dict()
|
||||
output_waveform = self._build_output_waveform(output_dict)
|
||||
|
||||
return tf.estimator.EstimatorSpec(
|
||||
tf.estimator.ModeKeys.PREDICT,
|
||||
predictions=output_waveform)
|
||||
predictions=self.outputs)
|
||||
|
||||
def build_evaluation_model(self, labels):
|
||||
""" Builder interface for creating model instance that aims to perform
|
||||
@@ -346,8 +500,7 @@ class EstimatorSpecBuilder(object):
|
||||
:param labels: Model labels.
|
||||
:returns: An estimator for performing model evaluation.
|
||||
"""
|
||||
output_dict = self._build_output_dict()
|
||||
loss, metrics = self._build_loss(output_dict, labels)
|
||||
loss, metrics = self._build_loss(labels)
|
||||
return tf.estimator.EstimatorSpec(
|
||||
tf.estimator.ModeKeys.EVAL,
|
||||
loss=loss,
|
||||
@@ -362,8 +515,7 @@ class EstimatorSpecBuilder(object):
|
||||
:param labels: Model labels.
|
||||
:returns: An estimator for performing model training.
|
||||
"""
|
||||
output_dict = self._build_output_dict()
|
||||
loss, metrics = self._build_loss(output_dict, labels)
|
||||
loss, metrics = self._build_loss(labels)
|
||||
optimizer = self._build_optimizer()
|
||||
train_operation = optimizer.minimize(
|
||||
loss=loss,
|
||||
|
||||
@@ -13,40 +13,57 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
|
||||
from functools import partial
|
||||
from time import time
|
||||
from multiprocessing import Pool
|
||||
from pathlib import Path
|
||||
from os.path import basename, join, splitext
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from librosa.core import stft, istft
|
||||
from scipy.signal.windows import hann
|
||||
|
||||
from . import SpleeterError
|
||||
from .audio.adapter import get_default_audio_adapter
|
||||
from .audio.convertor import to_stereo
|
||||
from .model import model_fn
|
||||
from .utils.configuration import load_configuration
|
||||
from .utils.estimator import create_estimator, to_predictor
|
||||
from .utils.estimator import create_estimator, to_predictor, get_default_model_dir
|
||||
from .model import EstimatorSpecBuilder, InputProviderFactory
|
||||
|
||||
|
||||
__email__ = 'research@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
|
||||
logger = logging.getLogger("spleeter")
|
||||
|
||||
|
||||
|
||||
def get_backend(backend):
|
||||
assert backend in ["auto", "tensorflow", "librosa"]
|
||||
if backend == "auto":
|
||||
return "tensorflow" if tf.test.is_gpu_available() else "librosa"
|
||||
return backend
|
||||
|
||||
|
||||
class Separator(object):
|
||||
""" A wrapper class for performing separation. """
|
||||
|
||||
def __init__(self, params_descriptor, MWF=False, multiprocess=True):
|
||||
def __init__(self, params_descriptor, MWF=False, stft_backend="auto", multiprocess=True):
|
||||
""" Default constructor.
|
||||
|
||||
:param params_descriptor: Descriptor for TF params to be used.
|
||||
:param MWF: (Optional) True if MWF should be used, False otherwise.
|
||||
"""
|
||||
|
||||
self._params = load_configuration(params_descriptor)
|
||||
self._sample_rate = self._params['sample_rate']
|
||||
self._MWF = MWF
|
||||
self._predictor = None
|
||||
self._pool = Pool() if multiprocess else None
|
||||
self._tasks = []
|
||||
self._params["stft_backend"] = get_backend(stft_backend)
|
||||
|
||||
def _get_predictor(self):
|
||||
""" Lazy loading access method for internal predictor instance.
|
||||
@@ -68,7 +85,7 @@ class Separator(object):
|
||||
task.get()
|
||||
task.wait(timeout=timeout)
|
||||
|
||||
def separate(self, waveform):
|
||||
def separate_tensorflow(self, waveform, audio_descriptor):
|
||||
""" Performs source separation over the given waveform.
|
||||
|
||||
The separation is performed synchronously but the result
|
||||
@@ -86,10 +103,59 @@ class Separator(object):
|
||||
predictor = self._get_predictor()
|
||||
prediction = predictor({
|
||||
'waveform': waveform,
|
||||
'audio_id': ''})
|
||||
'audio_id': audio_descriptor})
|
||||
prediction.pop('audio_id')
|
||||
return prediction
|
||||
|
||||
def stft(self, data, inverse=False, length=None):
|
||||
"""
|
||||
Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two
|
||||
channels are processed separately and are concatenated together in the result. The expected input formats are:
|
||||
(n_samples, 2) for stft and (T, F, 2) for istft.
|
||||
:param data: np.array with either the waveform or the complex spectrogram depending on the parameter inverse
|
||||
:param inverse: should a stft or an istft be computed.
|
||||
:return: Stereo data as numpy array for the transform. The channels are stored in the last dimension
|
||||
"""
|
||||
assert not (inverse and length is None)
|
||||
data = np.asfortranarray(data)
|
||||
N = self._params["frame_length"]
|
||||
H = self._params["frame_step"]
|
||||
win = hann(N, sym=False)
|
||||
fstft = istft if inverse else stft
|
||||
win_len_arg = {"win_length": None, "length": length} if inverse else {"n_fft": N}
|
||||
dl, dr = (data[:, :, 0].T, data[:, :, 1].T) if inverse else (data[:, 0], data[:, 1])
|
||||
s1 = fstft(dl, hop_length=H, window=win, center=False, **win_len_arg)
|
||||
s2 = fstft(dr, hop_length=H, window=win, center=False, **win_len_arg)
|
||||
s1 = np.expand_dims(s1.T, 2-inverse)
|
||||
s2 = np.expand_dims(s2.T, 2-inverse)
|
||||
return np.concatenate([s1, s2], axis=2-inverse)
|
||||
|
||||
def separate_librosa(self, waveform, audio_id):
|
||||
out = {}
|
||||
input_provider = InputProviderFactory.get(self._params)
|
||||
features = input_provider.get_input_dict_placeholders()
|
||||
|
||||
builder = EstimatorSpecBuilder(features, self._params)
|
||||
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
|
||||
|
||||
# TODO: fix the logic, build sometimes return, sometimes set attribute
|
||||
outputs = builder.outputs
|
||||
|
||||
saver = tf.train.Saver()
|
||||
stft = self.stft(waveform)
|
||||
with tf.Session() as sess:
|
||||
saver.restore(sess, latest_checkpoint)
|
||||
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))
|
||||
for inst in builder.instruments:
|
||||
out[inst] = self.stft(outputs[inst], inverse=True, length=waveform.shape[0])
|
||||
return out
|
||||
|
||||
def separate(self, waveform, audio_descriptor):
|
||||
if self._params["stft_backend"] == "tensorflow":
|
||||
return self.separate_tensorflow(waveform, audio_descriptor)
|
||||
else:
|
||||
return self.separate_librosa(waveform, audio_descriptor)
|
||||
|
||||
def separate_to_file(
|
||||
self, audio_descriptor, destination,
|
||||
audio_adapter=get_default_audio_adapter(),
|
||||
@@ -108,6 +174,8 @@ class Separator(object):
|
||||
descriptor would be a file path.
|
||||
:param destination: Target directory to write output to.
|
||||
:param audio_adapter: (Optional) Audio adapter to use for I/O.
|
||||
:param chunk_duration: (Optional) Maximum signal duration that is processed
|
||||
in one pass. Default: all signal.
|
||||
:param offset: (Optional) Offset of loaded song.
|
||||
:param duration: (Optional) Duration of loaded song.
|
||||
:param codec: (Optional) Export codec.
|
||||
@@ -115,12 +183,17 @@ class Separator(object):
|
||||
:param filename_format: (Optional) Filename format.
|
||||
:param synchronous: (Optional) True is should by synchronous.
|
||||
"""
|
||||
waveform, _ = audio_adapter.load(
|
||||
waveform, sample_rate = audio_adapter.load(
|
||||
audio_descriptor,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
sample_rate=self._sample_rate)
|
||||
sources = self.separate(waveform)
|
||||
sources = self.separate(waveform, audio_descriptor)
|
||||
self.save_to_file(sources, audio_descriptor, destination, filename_format, codec,
|
||||
audio_adapter, bitrate, synchronous)
|
||||
|
||||
def save_to_file(self, sources, audio_descriptor, destination, filename_format, codec,
|
||||
audio_adapter, bitrate, synchronous):
|
||||
filename = splitext(basename(audio_descriptor))[0]
|
||||
generated = []
|
||||
for instrument, data in sources.items():
|
||||
|
||||
@@ -13,27 +13,37 @@ import tensorflow as tf
|
||||
from tensorflow.contrib import predictor
|
||||
# pylint: enable=import-error
|
||||
|
||||
from ..model import model_fn
|
||||
from ..model import model_fn, InputProviderFactory
|
||||
from ..model.provider import get_default_model_provider
|
||||
|
||||
# Default exporting directory for predictor.
|
||||
DEFAULT_EXPORT_DIRECTORY = join(gettempdir(), 'serving')
|
||||
|
||||
|
||||
|
||||
def get_default_model_dir(model_dir):
|
||||
"""
|
||||
Transforms a string like 'spleeter:2stems' into an actual path.
|
||||
:param model_dir:
|
||||
:return:
|
||||
"""
|
||||
model_provider = get_default_model_provider()
|
||||
return model_provider.get(model_dir)
|
||||
|
||||
def create_estimator(params, MWF):
|
||||
"""
|
||||
Initialize tensorflow estimator that will perform separation
|
||||
|
||||
Params:
|
||||
- params: a dictionnary of parameters for building the model
|
||||
- params: a dictionary of parameters for building the model
|
||||
|
||||
Returns:
|
||||
a tensorflow estimator
|
||||
"""
|
||||
# Load model.
|
||||
model_directory = params['model_dir']
|
||||
model_provider = get_default_model_provider()
|
||||
params['model_dir'] = model_provider.get(model_directory)
|
||||
|
||||
|
||||
params['model_dir'] = get_default_model_dir(params['model_dir'])
|
||||
params['MWF'] = MWF
|
||||
# Setup config
|
||||
session_config = tf.compat.v1.ConfigProto()
|
||||
@@ -56,11 +66,10 @@ def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
|
||||
:param estimator: Estimator to export.
|
||||
:param directory: (Optional) path to write exported model into.
|
||||
"""
|
||||
|
||||
input_provider = InputProviderFactory.get(estimator.params)
|
||||
def receiver():
|
||||
shape = (None, estimator.params['n_channels'])
|
||||
features = {
|
||||
'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape),
|
||||
'audio_id': tf.compat.v1.placeholder(tf.string)}
|
||||
features = input_provider.get_input_dict_placeholders()
|
||||
return tf.estimator.export.ServingInputReceiver(features, features)
|
||||
|
||||
estimator.export_saved_model(directory, receiver)
|
||||
|
||||
@@ -13,6 +13,7 @@ from os.path import splitext, basename, exists, join
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from spleeter import SpleeterError
|
||||
from spleeter.audio.adapter import get_default_audio_adapter
|
||||
@@ -21,34 +22,38 @@ from spleeter.separator import Separator
|
||||
TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3'
|
||||
TEST_AUDIO_BASENAME = splitext(basename(TEST_AUDIO_DESCRIPTOR))[0]
|
||||
TEST_CONFIGURATIONS = [
|
||||
('spleeter:2stems', ('vocals', 'accompaniment')),
|
||||
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other')),
|
||||
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'))
|
||||
('spleeter:2stems', ('vocals', 'accompaniment'), 'tensorflow'),
|
||||
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'tensorflow'),
|
||||
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'tensorflow'),
|
||||
('spleeter:2stems', ('vocals', 'accompaniment'), 'librosa'),
|
||||
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'librosa'),
|
||||
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'librosa')
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS)
|
||||
def test_separate(configuration, instruments):
|
||||
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
|
||||
def test_separate(configuration, instruments, backend):
|
||||
""" Test separation from raw data. """
|
||||
adapter = get_default_audio_adapter()
|
||||
waveform, _ = adapter.load(TEST_AUDIO_DESCRIPTOR)
|
||||
separator = Separator(configuration)
|
||||
prediction = separator.separate(waveform)
|
||||
separator = Separator(configuration, stft_backend=backend)
|
||||
prediction = separator.separate(waveform, TEST_AUDIO_DESCRIPTOR)
|
||||
assert len(prediction) == len(instruments)
|
||||
for instrument in instruments:
|
||||
assert instrument in prediction
|
||||
for instrument in instruments:
|
||||
track = prediction[instrument]
|
||||
assert not (waveform == track).all()
|
||||
assert waveform.shape == track.shape
|
||||
assert not np.allclose(waveform, track)
|
||||
for compared in instruments:
|
||||
if instrument != compared:
|
||||
assert not (track == prediction[compared]).all()
|
||||
assert not np.allclose(track, prediction[compared])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS)
|
||||
def test_separate_to_file(configuration, instruments):
|
||||
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
|
||||
def test_separate_to_file(configuration, instruments, backend):
|
||||
""" Test file based separation. """
|
||||
separator = Separator(configuration)
|
||||
separator = Separator(configuration, stft_backend=backend)
|
||||
with TemporaryDirectory() as directory:
|
||||
separator.separate_to_file(
|
||||
TEST_AUDIO_DESCRIPTOR,
|
||||
@@ -59,10 +64,10 @@ def test_separate_to_file(configuration, instruments):
|
||||
'{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument)))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS)
|
||||
def test_filename_format(configuration, instruments):
|
||||
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
|
||||
def test_filename_format(configuration, instruments, backend):
|
||||
""" Test custom filename format. """
|
||||
separator = Separator(configuration)
|
||||
separator = Separator(configuration, stft_backend=backend)
|
||||
with TemporaryDirectory() as directory:
|
||||
separator.separate_to_file(
|
||||
TEST_AUDIO_DESCRIPTOR,
|
||||
@@ -74,7 +79,7 @@ def test_filename_format(configuration, instruments):
|
||||
'export/{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument)))
|
||||
|
||||
|
||||
def test_filename_confilct():
|
||||
def test_filename_conflict():
|
||||
""" Test error handling with static pattern. """
|
||||
separator = Separator(TEST_CONFIGURATIONS[0][0])
|
||||
with TemporaryDirectory() as directory:
|
||||
|
||||
Reference in New Issue
Block a user