mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
Merge pull request #282 from alreadytaikeune/librosa_backend
Librosa backend
This commit is contained in:
@@ -2,6 +2,7 @@ importlib_resources; python_version<'3.7'
|
|||||||
requests
|
requests
|
||||||
setuptools>=41.0.0
|
setuptools>=41.0.0
|
||||||
pandas==0.25.1
|
pandas==0.25.1
|
||||||
tensorflow==1.14.0
|
tensorflow==1.15.0
|
||||||
ffmpeg-python
|
ffmpeg-python
|
||||||
norbert==0.2.1
|
norbert==0.2.1
|
||||||
|
librosa==0.7.2
|
||||||
5
setup.py
5
setup.py
@@ -14,9 +14,9 @@ __license__ = 'MIT License'
|
|||||||
|
|
||||||
# Default project values.
|
# Default project values.
|
||||||
project_name = 'spleeter'
|
project_name = 'spleeter'
|
||||||
project_version = '1.4.9'
|
project_version = '1.5.0'
|
||||||
tensorflow_dependency = 'tensorflow'
|
tensorflow_dependency = 'tensorflow'
|
||||||
tensorflow_version = '1.14.0'
|
tensorflow_version = '1.15.0'
|
||||||
here = path.abspath(path.dirname(__file__))
|
here = path.abspath(path.dirname(__file__))
|
||||||
readme_path = path.join(here, 'README.md')
|
readme_path = path.join(here, 'README.md')
|
||||||
with open(readme_path, 'r') as stream:
|
with open(readme_path, 'r') as stream:
|
||||||
@@ -56,6 +56,7 @@ setup(
|
|||||||
'pandas==0.25.1',
|
'pandas==0.25.1',
|
||||||
'requests',
|
'requests',
|
||||||
'setuptools>=41.0.0',
|
'setuptools>=41.0.0',
|
||||||
|
'librosa==0.7.2',
|
||||||
'{}=={}'.format(tensorflow_dependency, tensorflow_version),
|
'{}=={}'.format(tensorflow_dependency, tensorflow_version),
|
||||||
],
|
],
|
||||||
extras_require={
|
extras_require={
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
""" This modules provides spleeter command as well as CLI parsing methods. """
|
""" This modules provides spleeter command as well as CLI parsing methods. """
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from tempfile import gettempdir
|
from tempfile import gettempdir
|
||||||
from os.path import exists, join
|
from os.path import exists, join
|
||||||
@@ -13,6 +13,8 @@ __email__ = 'research@deezer.com'
|
|||||||
__author__ = 'Deezer Research'
|
__author__ = 'Deezer Research'
|
||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# -i opt specification (separate).
|
# -i opt specification (separate).
|
||||||
OPT_INPUT = {
|
OPT_INPUT = {
|
||||||
'dest': 'inputs',
|
'dest': 'inputs',
|
||||||
@@ -68,6 +70,17 @@ OPT_DURATION = {
|
|||||||
'the input file)')
|
'the input file)')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# -w opt specification (separate)
|
||||||
|
OPT_STFT_BACKEND = {
|
||||||
|
'dest': 'stft_backend',
|
||||||
|
'type': str,
|
||||||
|
'choices' : ["tensorflow", "librosa", "auto"],
|
||||||
|
'default': "auto",
|
||||||
|
'help': 'Who should be in charge of computing the stfts. Librosa is faster than tensorflow on CPU and uses'
|
||||||
|
' less memory. "auto" will use tensorflow when GPU acceleration is available and librosa when not.'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# -c opt specification (separate).
|
# -c opt specification (separate).
|
||||||
OPT_CODEC = {
|
OPT_CODEC = {
|
||||||
'dest': 'codec',
|
'dest': 'codec',
|
||||||
@@ -176,6 +189,7 @@ def _create_separate_parser(parser_factory):
|
|||||||
parser.add_argument('-c', '--codec', **OPT_CODEC)
|
parser.add_argument('-c', '--codec', **OPT_CODEC)
|
||||||
parser.add_argument('-b', '--birate', **OPT_BITRATE)
|
parser.add_argument('-b', '--birate', **OPT_BITRATE)
|
||||||
parser.add_argument('-m', '--mwf', **OPT_MWF)
|
parser.add_argument('-m', '--mwf', **OPT_MWF)
|
||||||
|
parser.add_argument('-B', '--stft-backend', **OPT_STFT_BACKEND)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ __author__ = 'Deezer Research'
|
|||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def entrypoint(arguments, params):
|
def entrypoint(arguments, params):
|
||||||
""" Command entrypoint.
|
""" Command entrypoint.
|
||||||
|
|
||||||
@@ -29,7 +30,8 @@ def entrypoint(arguments, params):
|
|||||||
audio_adapter = get_audio_adapter(arguments.audio_adapter)
|
audio_adapter = get_audio_adapter(arguments.audio_adapter)
|
||||||
separator = Separator(
|
separator = Separator(
|
||||||
arguments.configuration,
|
arguments.configuration,
|
||||||
arguments.MWF)
|
MWF=arguments.MWF,
|
||||||
|
stft_backend=arguments.stft_backend)
|
||||||
for filename in arguments.inputs:
|
for filename in arguments.inputs:
|
||||||
separator.separate_to_file(
|
separator.separate_to_file(
|
||||||
filename,
|
filename,
|
||||||
|
|||||||
@@ -18,6 +18,9 @@ __author__ = 'Deezer Research'
|
|||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
placeholder = tf.compat.v1.placeholder
|
||||||
|
|
||||||
|
|
||||||
def get_model_function(model_type):
|
def get_model_function(model_type):
|
||||||
"""
|
"""
|
||||||
Get tensorflow function of the model to be applied to the input tensor.
|
Get tensorflow function of the model to be applied to the input tensor.
|
||||||
@@ -41,6 +44,74 @@ def get_model_function(model_type):
|
|||||||
return model_function
|
return model_function
|
||||||
|
|
||||||
|
|
||||||
|
class InputProvider(object):
|
||||||
|
|
||||||
|
def __init__(self, params):
|
||||||
|
self.params = params
|
||||||
|
|
||||||
|
def get_input_dict_placeholders(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_names(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_feed_dict(self, features, *args):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class WaveformInputProvider(InputProvider):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_names(self):
|
||||||
|
return ["audio_id", "waveform"]
|
||||||
|
|
||||||
|
def get_input_dict_placeholders(self):
|
||||||
|
shape = (None, self.params['n_channels'])
|
||||||
|
features = {
|
||||||
|
'waveform': placeholder(tf.float32, shape=shape, name="waveform"),
|
||||||
|
'audio_id': placeholder(tf.string, name="audio_id")}
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_feed_dict(self, features, waveform, audio_id):
|
||||||
|
return {features["audio_id"]: audio_id, features["waveform"]: waveform}
|
||||||
|
|
||||||
|
|
||||||
|
class SpectralInputProvider(InputProvider):
|
||||||
|
|
||||||
|
def __init__(self, params):
|
||||||
|
super().__init__(params)
|
||||||
|
self.stft_input_name = "{}_stft".format(self.params["mix_name"])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_names(self):
|
||||||
|
return ["audio_id", self.stft_input_name]
|
||||||
|
|
||||||
|
def get_input_dict_placeholders(self):
|
||||||
|
features = {
|
||||||
|
self.stft_input_name: placeholder(tf.complex64,
|
||||||
|
shape=(None, self.params["frame_length"]//2+1,
|
||||||
|
self.params['n_channels']),
|
||||||
|
name=self.stft_input_name),
|
||||||
|
'audio_id': placeholder(tf.string, name="audio_id")}
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_feed_dict(self, features, stft, audio_id):
|
||||||
|
return {features["audio_id"]: audio_id, features[self.stft_input_name]: stft}
|
||||||
|
|
||||||
|
|
||||||
|
class InputProviderFactory(object):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get(params):
|
||||||
|
stft_backend = params["stft_backend"]
|
||||||
|
assert stft_backend in ("tensorflow", "librosa"), "Unexpected backend {}".format(stft_backend)
|
||||||
|
if stft_backend == "tensorflow":
|
||||||
|
return WaveformInputProvider(params)
|
||||||
|
else:
|
||||||
|
return SpectralInputProvider(params)
|
||||||
|
|
||||||
|
|
||||||
class EstimatorSpecBuilder(object):
|
class EstimatorSpecBuilder(object):
|
||||||
""" A builder class that allows to builds a multitrack unet model
|
""" A builder class that allows to builds a multitrack unet model
|
||||||
estimator. The built model estimator has a different behaviour when
|
estimator. The built model estimator has a different behaviour when
|
||||||
@@ -57,9 +128,9 @@ class EstimatorSpecBuilder(object):
|
|||||||
|
|
||||||
>>> from spleeter.model import EstimatorSpecBuilder
|
>>> from spleeter.model import EstimatorSpecBuilder
|
||||||
>>> builder = EstimatorSpecBuilder()
|
>>> builder = EstimatorSpecBuilder()
|
||||||
>>> builder.build_prediction_model()
|
>>> builder.build_predict_model()
|
||||||
>>> builder.build_evaluation_model()
|
>>> builder.build_evaluation_model()
|
||||||
>>> builder.build_training_model()
|
>>> builder.build_train_model()
|
||||||
|
|
||||||
>>> from spleeter.model import model_fn
|
>>> from spleeter.model import model_fn
|
||||||
>>> estimator = tf.estimator.Estimator(model_fn=model_fn, ...)
|
>>> estimator = tf.estimator.Estimator(model_fn=model_fn, ...)
|
||||||
@@ -94,6 +165,7 @@ class EstimatorSpecBuilder(object):
|
|||||||
:param features: The input features for the estimator.
|
:param features: The input features for the estimator.
|
||||||
:param params: Some hyperparameters as a dictionary.
|
:param params: Some hyperparameters as a dictionary.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._features = features
|
self._features = features
|
||||||
self._params = params
|
self._params = params
|
||||||
# Get instrument name.
|
# Get instrument name.
|
||||||
@@ -106,7 +178,10 @@ class EstimatorSpecBuilder(object):
|
|||||||
self._frame_length = params['frame_length']
|
self._frame_length = params['frame_length']
|
||||||
self._frame_step = params['frame_step']
|
self._frame_step = params['frame_step']
|
||||||
|
|
||||||
def _build_output_dict(self):
|
def include_stft_computations(self):
|
||||||
|
return self._params["stft_backend"] == "tensorflow"
|
||||||
|
|
||||||
|
def _build_model_outputs(self):
|
||||||
""" Created a batch_sizexTxFxn_channels input tensor containing
|
""" Created a batch_sizexTxFxn_channels input tensor containing
|
||||||
mix magnitude spectrogram, then an output dict from it according
|
mix magnitude spectrogram, then an output dict from it according
|
||||||
to the selected model in internal parameters.
|
to the selected model in internal parameters.
|
||||||
@@ -114,7 +189,8 @@ class EstimatorSpecBuilder(object):
|
|||||||
:returns: Build output dict.
|
:returns: Build output dict.
|
||||||
:raise ValueError: If required model_type is not supported.
|
:raise ValueError: If required model_type is not supported.
|
||||||
"""
|
"""
|
||||||
input_tensor = self._features[f'{self._mix_name}_spectrogram']
|
|
||||||
|
input_tensor = self.spectrogram_feature
|
||||||
model = self._params.get('model', None)
|
model = self._params.get('model', None)
|
||||||
if model is not None:
|
if model is not None:
|
||||||
model_type = model.get('type', self.DEFAULT_MODEL)
|
model_type = model.get('type', self.DEFAULT_MODEL)
|
||||||
@@ -124,12 +200,12 @@ class EstimatorSpecBuilder(object):
|
|||||||
apply_model = get_model_function(model_type)
|
apply_model = get_model_function(model_type)
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
raise ValueError(f'No model function {model_type} found')
|
raise ValueError(f'No model function {model_type} found')
|
||||||
return apply_model(
|
self._model_outputs = apply_model(
|
||||||
input_tensor,
|
input_tensor,
|
||||||
self._instruments,
|
self._instruments,
|
||||||
self._params['model']['params'])
|
self._params['model']['params'])
|
||||||
|
|
||||||
def _build_loss(self, output_dict, labels):
|
def _build_loss(self, labels):
|
||||||
""" Construct tensorflow loss and metrics
|
""" Construct tensorflow loss and metrics
|
||||||
|
|
||||||
:param output_dict: dictionary of network outputs (key: instrument
|
:param output_dict: dictionary of network outputs (key: instrument
|
||||||
@@ -138,6 +214,7 @@ class EstimatorSpecBuilder(object):
|
|||||||
name, value: ground truth spectrogram of the instrument)
|
name, value: ground truth spectrogram of the instrument)
|
||||||
:returns: tensorflow (loss, metrics) tuple.
|
:returns: tensorflow (loss, metrics) tuple.
|
||||||
"""
|
"""
|
||||||
|
output_dict = self.model_outputs
|
||||||
loss_type = self._params.get('loss_type', self.L1_MASK)
|
loss_type = self._params.get('loss_type', self.L1_MASK)
|
||||||
if loss_type == self.L1_MASK:
|
if loss_type == self.L1_MASK:
|
||||||
losses = {
|
losses = {
|
||||||
@@ -177,51 +254,106 @@ class EstimatorSpecBuilder(object):
|
|||||||
return tf.compat.v1.train.GradientDescentOptimizer(rate)
|
return tf.compat.v1.train.GradientDescentOptimizer(rate)
|
||||||
return tf.compat.v1.train.AdamOptimizer(rate)
|
return tf.compat.v1.train.AdamOptimizer(rate)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def instruments(self):
|
||||||
|
return self._instruments
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stft_name(self):
|
||||||
|
return f'{self._mix_name}_stft'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def spectrogram_name(self):
|
||||||
|
return f'{self._mix_name}_spectrogram'
|
||||||
|
|
||||||
def _build_stft_feature(self):
|
def _build_stft_feature(self):
|
||||||
""" Compute STFT of waveform and slice the STFT in segment
|
""" Compute STFT of waveform and slice the STFT in segment
|
||||||
with the right length to feed the network.
|
with the right length to feed the network.
|
||||||
"""
|
"""
|
||||||
stft_feature = tf.transpose(
|
|
||||||
stft(
|
|
||||||
tf.transpose(self._features['waveform']),
|
|
||||||
self._frame_length,
|
|
||||||
self._frame_step,
|
|
||||||
window_fn=lambda frame_length, dtype: (
|
|
||||||
hann_window(frame_length, periodic=True, dtype=dtype)),
|
|
||||||
pad_end=True),
|
|
||||||
perm=[1, 2, 0])
|
|
||||||
self._features[f'{self._mix_name}_stft'] = stft_feature
|
|
||||||
self._features[f'{self._mix_name}_spectrogram'] = tf.abs(
|
|
||||||
pad_and_partition(stft_feature, self._T))[:, :, :self._F, :]
|
|
||||||
|
|
||||||
def _inverse_stft(self, stft):
|
stft_name = self.stft_name
|
||||||
|
spec_name = self.spectrogram_name
|
||||||
|
|
||||||
|
if stft_name not in self._features:
|
||||||
|
stft_feature = tf.transpose(
|
||||||
|
stft(
|
||||||
|
tf.transpose(self._features['waveform']),
|
||||||
|
self._frame_length,
|
||||||
|
self._frame_step,
|
||||||
|
window_fn=lambda frame_length, dtype: (
|
||||||
|
hann_window(frame_length, periodic=True, dtype=dtype)),
|
||||||
|
pad_end=True),
|
||||||
|
perm=[1, 2, 0])
|
||||||
|
self._features[f'{self._mix_name}_stft'] = stft_feature
|
||||||
|
if spec_name not in self._features:
|
||||||
|
self._features[spec_name] = tf.abs(
|
||||||
|
pad_and_partition(self._features[stft_name], self._T))[:, :, :self._F, :]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_outputs(self):
|
||||||
|
if not hasattr(self, "_model_outputs"):
|
||||||
|
self._build_model_outputs()
|
||||||
|
return self._model_outputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self):
|
||||||
|
if not hasattr(self, "_outputs"):
|
||||||
|
self._build_outputs()
|
||||||
|
return self._outputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stft_feature(self):
|
||||||
|
if self.stft_name not in self._features:
|
||||||
|
self._build_stft_feature()
|
||||||
|
return self._features[self.stft_name]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def spectrogram_feature(self):
|
||||||
|
if self.spectrogram_name not in self._features:
|
||||||
|
self._build_stft_feature()
|
||||||
|
return self._features[self.spectrogram_name]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def masks(self):
|
||||||
|
if not hasattr(self, "_masks"):
|
||||||
|
self._build_masks()
|
||||||
|
return self._masks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def masked_stfts(self):
|
||||||
|
if not hasattr(self, "_masked_stfts"):
|
||||||
|
self._build_masked_stfts()
|
||||||
|
return self._masked_stfts
|
||||||
|
|
||||||
|
def _inverse_stft(self, stft_t, time_crop=None):
|
||||||
""" Inverse and reshape the given STFT
|
""" Inverse and reshape the given STFT
|
||||||
|
|
||||||
:param stft: input STFT
|
:param stft_t: input STFT
|
||||||
:returns: inverse STFT (waveform)
|
:returns: inverse STFT (waveform)
|
||||||
"""
|
"""
|
||||||
inversed = inverse_stft(
|
inversed = inverse_stft(
|
||||||
tf.transpose(stft, perm=[2, 0, 1]),
|
tf.transpose(stft_t, perm=[2, 0, 1]),
|
||||||
self._frame_length,
|
self._frame_length,
|
||||||
self._frame_step,
|
self._frame_step,
|
||||||
window_fn=lambda frame_length, dtype: (
|
window_fn=lambda frame_length, dtype: (
|
||||||
hann_window(frame_length, periodic=True, dtype=dtype))
|
hann_window(frame_length, periodic=True, dtype=dtype))
|
||||||
) * self.WINDOW_COMPENSATION_FACTOR
|
) * self.WINDOW_COMPENSATION_FACTOR
|
||||||
reshaped = tf.transpose(inversed)
|
reshaped = tf.transpose(inversed)
|
||||||
return reshaped[:tf.shape(self._features['waveform'])[0], :]
|
if time_crop is None:
|
||||||
|
time_crop = tf.shape(self._features['waveform'])[0]
|
||||||
|
return reshaped[:time_crop, :]
|
||||||
|
|
||||||
def _build_mwf_output_waveform(self, output_dict):
|
def _build_mwf_output_waveform(self):
|
||||||
""" Perform separation with multichannel Wiener Filtering using Norbert.
|
""" Perform separation with multichannel Wiener Filtering using Norbert.
|
||||||
Note: multichannel Wiener Filtering is not coded in Tensorflow and thus
|
Note: multichannel Wiener Filtering is not coded in Tensorflow and thus
|
||||||
may be quite slow.
|
may be quite slow.
|
||||||
|
|
||||||
:param output_dict: dictionary of estimated spectrogram (key: instrument
|
|
||||||
name, value: estimated spectrogram of the instrument)
|
|
||||||
:returns: dictionary of separated waveforms (key: instrument name,
|
:returns: dictionary of separated waveforms (key: instrument name,
|
||||||
value: estimated waveform of the instrument)
|
value: estimated waveform of the instrument)
|
||||||
"""
|
"""
|
||||||
import norbert # pylint: disable=import-error
|
import norbert # pylint: disable=import-error
|
||||||
x = self._features[f'{self._mix_name}_stft']
|
output_dict = self.model_outputs
|
||||||
|
x = self.stft_feature
|
||||||
v = tf.stack(
|
v = tf.stack(
|
||||||
[
|
[
|
||||||
pad_and_reshape(
|
pad_and_reshape(
|
||||||
@@ -265,30 +397,28 @@ class EstimatorSpecBuilder(object):
|
|||||||
mask_shape[-1]))
|
mask_shape[-1]))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Invalid mask_extension parameter {extension}')
|
raise ValueError(f'Invalid mask_extension parameter {extension}')
|
||||||
n_extra_row = (self._frame_length) // 2 + 1 - self._F
|
n_extra_row = self._frame_length // 2 + 1 - self._F
|
||||||
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
|
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
|
||||||
return tf.concat([mask, extension], axis=2)
|
return tf.concat([mask, extension], axis=2)
|
||||||
|
|
||||||
def _build_manual_output_waveform(self, output_dict):
|
def _build_masks(self):
|
||||||
""" Perform ratio mask separation
|
|
||||||
|
|
||||||
:param output_dict: dictionary of estimated spectrogram (key: instrument
|
|
||||||
name, value: estimated spectrogram of the instrument)
|
|
||||||
:returns: dictionary of separated waveforms (key: instrument name,
|
|
||||||
value: estimated waveform of the instrument)
|
|
||||||
"""
|
"""
|
||||||
|
Compute masks from the output spectrograms of the model.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
output_dict = self.model_outputs
|
||||||
|
stft_feature = self.stft_feature
|
||||||
separation_exponent = self._params['separation_exponent']
|
separation_exponent = self._params['separation_exponent']
|
||||||
output_sum = tf.reduce_sum(
|
output_sum = tf.reduce_sum(
|
||||||
[e ** separation_exponent for e in output_dict.values()],
|
[e ** separation_exponent for e in output_dict.values()],
|
||||||
axis=0
|
axis=0
|
||||||
) + self.EPSILON
|
) + self.EPSILON
|
||||||
output_waveform = {}
|
out = {}
|
||||||
for instrument in self._instruments:
|
for instrument in self._instruments:
|
||||||
output = output_dict[f'{instrument}_spectrogram']
|
output = output_dict[f'{instrument}_spectrogram']
|
||||||
# Compute mask with the model.
|
# Compute mask with the model.
|
||||||
instrument_mask = (
|
instrument_mask = (output ** separation_exponent
|
||||||
output ** separation_exponent
|
+ (self.EPSILON / len(output_dict))) / output_sum
|
||||||
+ (self.EPSILON / len(output_dict))) / output_sum
|
|
||||||
# Extend mask;
|
# Extend mask;
|
||||||
instrument_mask = self._extend_mask(instrument_mask)
|
instrument_mask = self._extend_mask(instrument_mask)
|
||||||
# Stack back mask.
|
# Stack back mask.
|
||||||
@@ -298,30 +428,56 @@ class EstimatorSpecBuilder(object):
|
|||||||
axis=0)
|
axis=0)
|
||||||
instrument_mask = tf.reshape(instrument_mask, new_shape)
|
instrument_mask = tf.reshape(instrument_mask, new_shape)
|
||||||
# Remove padded part (for mask having the same size as STFT);
|
# Remove padded part (for mask having the same size as STFT);
|
||||||
stft_feature = self._features[f'{self._mix_name}_stft']
|
|
||||||
instrument_mask = instrument_mask[
|
instrument_mask = instrument_mask[
|
||||||
:tf.shape(stft_feature)[0], ...]
|
:tf.shape(stft_feature)[0], ...]
|
||||||
# Compute masked STFT and normalize it.
|
out[instrument] = instrument_mask
|
||||||
output_waveform[instrument] = self._inverse_stft(
|
self._masks = out
|
||||||
tf.cast(instrument_mask, dtype=tf.complex64) * stft_feature)
|
|
||||||
|
def _build_masked_stfts(self):
|
||||||
|
input_stft = self.stft_feature
|
||||||
|
out = {}
|
||||||
|
for instrument, mask in self.masks.items():
|
||||||
|
out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft
|
||||||
|
self._masked_stfts = out
|
||||||
|
|
||||||
|
def _build_manual_output_waveform(self, masked_stft):
|
||||||
|
""" Perform ratio mask separation
|
||||||
|
|
||||||
|
:param output_dict: dictionary of estimated spectrogram (key: instrument
|
||||||
|
name, value: estimated spectrogram of the instrument)
|
||||||
|
:returns: dictionary of separated waveforms (key: instrument name,
|
||||||
|
value: estimated waveform of the instrument)
|
||||||
|
"""
|
||||||
|
|
||||||
|
output_waveform = {}
|
||||||
|
for instrument, stft_data in masked_stft.items():
|
||||||
|
output_waveform[instrument] = self._inverse_stft(stft_data)
|
||||||
return output_waveform
|
return output_waveform
|
||||||
|
|
||||||
def _build_output_waveform(self, output_dict):
|
def _build_output_waveform(self, masked_stft):
|
||||||
""" Build output waveform from given output dict in order to be used in
|
""" Build output waveform from given output dict in order to be used in
|
||||||
prediction context. Regarding of the configuration building method will
|
prediction context. Regarding of the configuration building method will
|
||||||
be using MWF.
|
be using MWF.
|
||||||
|
|
||||||
:param output_dict: Output dict to build output waveform from.
|
|
||||||
:returns: Built output waveform.
|
:returns: Built output waveform.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self._params.get('MWF', False):
|
if self._params.get('MWF', False):
|
||||||
output_waveform = self._build_mwf_output_waveform(output_dict)
|
output_waveform = self._build_mwf_output_waveform()
|
||||||
else:
|
else:
|
||||||
output_waveform = self._build_manual_output_waveform(output_dict)
|
output_waveform = self._build_manual_output_waveform(masked_stft)
|
||||||
if 'audio_id' in self._features:
|
|
||||||
output_waveform['audio_id'] = self._features['audio_id']
|
|
||||||
return output_waveform
|
return output_waveform
|
||||||
|
|
||||||
|
def _build_outputs(self):
|
||||||
|
if self.include_stft_computations():
|
||||||
|
self._outputs = self._build_output_waveform(self.masked_stfts)
|
||||||
|
else:
|
||||||
|
self._outputs = self.masked_stfts
|
||||||
|
|
||||||
|
if 'audio_id' in self._features:
|
||||||
|
self._outputs['audio_id'] = self._features['audio_id']
|
||||||
|
|
||||||
def build_predict_model(self):
|
def build_predict_model(self):
|
||||||
""" Builder interface for creating model instance that aims to perform
|
""" Builder interface for creating model instance that aims to perform
|
||||||
prediction / inference over given track. The output of such estimator
|
prediction / inference over given track. The output of such estimator
|
||||||
@@ -330,12 +486,10 @@ class EstimatorSpecBuilder(object):
|
|||||||
|
|
||||||
:returns: An estimator for performing prediction.
|
:returns: An estimator for performing prediction.
|
||||||
"""
|
"""
|
||||||
self._build_stft_feature()
|
|
||||||
output_dict = self._build_output_dict()
|
|
||||||
output_waveform = self._build_output_waveform(output_dict)
|
|
||||||
return tf.estimator.EstimatorSpec(
|
return tf.estimator.EstimatorSpec(
|
||||||
tf.estimator.ModeKeys.PREDICT,
|
tf.estimator.ModeKeys.PREDICT,
|
||||||
predictions=output_waveform)
|
predictions=self.outputs)
|
||||||
|
|
||||||
def build_evaluation_model(self, labels):
|
def build_evaluation_model(self, labels):
|
||||||
""" Builder interface for creating model instance that aims to perform
|
""" Builder interface for creating model instance that aims to perform
|
||||||
@@ -346,8 +500,7 @@ class EstimatorSpecBuilder(object):
|
|||||||
:param labels: Model labels.
|
:param labels: Model labels.
|
||||||
:returns: An estimator for performing model evaluation.
|
:returns: An estimator for performing model evaluation.
|
||||||
"""
|
"""
|
||||||
output_dict = self._build_output_dict()
|
loss, metrics = self._build_loss(labels)
|
||||||
loss, metrics = self._build_loss(output_dict, labels)
|
|
||||||
return tf.estimator.EstimatorSpec(
|
return tf.estimator.EstimatorSpec(
|
||||||
tf.estimator.ModeKeys.EVAL,
|
tf.estimator.ModeKeys.EVAL,
|
||||||
loss=loss,
|
loss=loss,
|
||||||
@@ -362,8 +515,7 @@ class EstimatorSpecBuilder(object):
|
|||||||
:param labels: Model labels.
|
:param labels: Model labels.
|
||||||
:returns: An estimator for performing model training.
|
:returns: An estimator for performing model training.
|
||||||
"""
|
"""
|
||||||
output_dict = self._build_output_dict()
|
loss, metrics = self._build_loss(labels)
|
||||||
loss, metrics = self._build_loss(output_dict, labels)
|
|
||||||
optimizer = self._build_optimizer()
|
optimizer = self._build_optimizer()
|
||||||
train_operation = optimizer.minimize(
|
train_operation = optimizer.minimize(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
|
|||||||
@@ -13,40 +13,57 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import json
|
import logging
|
||||||
|
|
||||||
from functools import partial
|
from time import time
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from pathlib import Path
|
|
||||||
from os.path import basename, join, splitext
|
from os.path import basename, join, splitext
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
from librosa.core import stft, istft
|
||||||
|
from scipy.signal.windows import hann
|
||||||
|
|
||||||
from . import SpleeterError
|
from . import SpleeterError
|
||||||
from .audio.adapter import get_default_audio_adapter
|
from .audio.adapter import get_default_audio_adapter
|
||||||
from .audio.convertor import to_stereo
|
from .audio.convertor import to_stereo
|
||||||
from .model import model_fn
|
|
||||||
from .utils.configuration import load_configuration
|
from .utils.configuration import load_configuration
|
||||||
from .utils.estimator import create_estimator, to_predictor
|
from .utils.estimator import create_estimator, to_predictor, get_default_model_dir
|
||||||
|
from .model import EstimatorSpecBuilder, InputProviderFactory
|
||||||
|
|
||||||
|
|
||||||
__email__ = 'research@deezer.com'
|
__email__ = 'research@deezer.com'
|
||||||
__author__ = 'Deezer Research'
|
__author__ = 'Deezer Research'
|
||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("spleeter")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_backend(backend):
|
||||||
|
assert backend in ["auto", "tensorflow", "librosa"]
|
||||||
|
if backend == "auto":
|
||||||
|
return "tensorflow" if tf.test.is_gpu_available() else "librosa"
|
||||||
|
return backend
|
||||||
|
|
||||||
|
|
||||||
class Separator(object):
|
class Separator(object):
|
||||||
""" A wrapper class for performing separation. """
|
""" A wrapper class for performing separation. """
|
||||||
|
|
||||||
def __init__(self, params_descriptor, MWF=False, multiprocess=True):
|
def __init__(self, params_descriptor, MWF=False, stft_backend="auto", multiprocess=True):
|
||||||
""" Default constructor.
|
""" Default constructor.
|
||||||
|
|
||||||
:param params_descriptor: Descriptor for TF params to be used.
|
:param params_descriptor: Descriptor for TF params to be used.
|
||||||
:param MWF: (Optional) True if MWF should be used, False otherwise.
|
:param MWF: (Optional) True if MWF should be used, False otherwise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._params = load_configuration(params_descriptor)
|
self._params = load_configuration(params_descriptor)
|
||||||
self._sample_rate = self._params['sample_rate']
|
self._sample_rate = self._params['sample_rate']
|
||||||
self._MWF = MWF
|
self._MWF = MWF
|
||||||
self._predictor = None
|
self._predictor = None
|
||||||
self._pool = Pool() if multiprocess else None
|
self._pool = Pool() if multiprocess else None
|
||||||
self._tasks = []
|
self._tasks = []
|
||||||
|
self._params["stft_backend"] = get_backend(stft_backend)
|
||||||
|
|
||||||
def _get_predictor(self):
|
def _get_predictor(self):
|
||||||
""" Lazy loading access method for internal predictor instance.
|
""" Lazy loading access method for internal predictor instance.
|
||||||
@@ -68,7 +85,7 @@ class Separator(object):
|
|||||||
task.get()
|
task.get()
|
||||||
task.wait(timeout=timeout)
|
task.wait(timeout=timeout)
|
||||||
|
|
||||||
def separate(self, waveform):
|
def separate_tensorflow(self, waveform, audio_descriptor):
|
||||||
""" Performs source separation over the given waveform.
|
""" Performs source separation over the given waveform.
|
||||||
|
|
||||||
The separation is performed synchronously but the result
|
The separation is performed synchronously but the result
|
||||||
@@ -86,10 +103,59 @@ class Separator(object):
|
|||||||
predictor = self._get_predictor()
|
predictor = self._get_predictor()
|
||||||
prediction = predictor({
|
prediction = predictor({
|
||||||
'waveform': waveform,
|
'waveform': waveform,
|
||||||
'audio_id': ''})
|
'audio_id': audio_descriptor})
|
||||||
prediction.pop('audio_id')
|
prediction.pop('audio_id')
|
||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
|
def stft(self, data, inverse=False, length=None):
|
||||||
|
"""
|
||||||
|
Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two
|
||||||
|
channels are processed separately and are concatenated together in the result. The expected input formats are:
|
||||||
|
(n_samples, 2) for stft and (T, F, 2) for istft.
|
||||||
|
:param data: np.array with either the waveform or the complex spectrogram depending on the parameter inverse
|
||||||
|
:param inverse: should a stft or an istft be computed.
|
||||||
|
:return: Stereo data as numpy array for the transform. The channels are stored in the last dimension
|
||||||
|
"""
|
||||||
|
assert not (inverse and length is None)
|
||||||
|
data = np.asfortranarray(data)
|
||||||
|
N = self._params["frame_length"]
|
||||||
|
H = self._params["frame_step"]
|
||||||
|
win = hann(N, sym=False)
|
||||||
|
fstft = istft if inverse else stft
|
||||||
|
win_len_arg = {"win_length": None, "length": length} if inverse else {"n_fft": N}
|
||||||
|
dl, dr = (data[:, :, 0].T, data[:, :, 1].T) if inverse else (data[:, 0], data[:, 1])
|
||||||
|
s1 = fstft(dl, hop_length=H, window=win, center=False, **win_len_arg)
|
||||||
|
s2 = fstft(dr, hop_length=H, window=win, center=False, **win_len_arg)
|
||||||
|
s1 = np.expand_dims(s1.T, 2-inverse)
|
||||||
|
s2 = np.expand_dims(s2.T, 2-inverse)
|
||||||
|
return np.concatenate([s1, s2], axis=2-inverse)
|
||||||
|
|
||||||
|
def separate_librosa(self, waveform, audio_id):
|
||||||
|
out = {}
|
||||||
|
input_provider = InputProviderFactory.get(self._params)
|
||||||
|
features = input_provider.get_input_dict_placeholders()
|
||||||
|
|
||||||
|
builder = EstimatorSpecBuilder(features, self._params)
|
||||||
|
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
|
||||||
|
|
||||||
|
# TODO: fix the logic, build sometimes return, sometimes set attribute
|
||||||
|
outputs = builder.outputs
|
||||||
|
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
stft = self.stft(waveform)
|
||||||
|
with tf.Session() as sess:
|
||||||
|
saver.restore(sess, latest_checkpoint)
|
||||||
|
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))
|
||||||
|
for inst in builder.instruments:
|
||||||
|
out[inst] = self.stft(outputs[inst], inverse=True, length=waveform.shape[0])
|
||||||
|
return out
|
||||||
|
|
||||||
|
def separate(self, waveform, audio_descriptor):
|
||||||
|
if self._params["stft_backend"] == "tensorflow":
|
||||||
|
return self.separate_tensorflow(waveform, audio_descriptor)
|
||||||
|
else:
|
||||||
|
return self.separate_librosa(waveform, audio_descriptor)
|
||||||
|
|
||||||
def separate_to_file(
|
def separate_to_file(
|
||||||
self, audio_descriptor, destination,
|
self, audio_descriptor, destination,
|
||||||
audio_adapter=get_default_audio_adapter(),
|
audio_adapter=get_default_audio_adapter(),
|
||||||
@@ -108,6 +174,8 @@ class Separator(object):
|
|||||||
descriptor would be a file path.
|
descriptor would be a file path.
|
||||||
:param destination: Target directory to write output to.
|
:param destination: Target directory to write output to.
|
||||||
:param audio_adapter: (Optional) Audio adapter to use for I/O.
|
:param audio_adapter: (Optional) Audio adapter to use for I/O.
|
||||||
|
:param chunk_duration: (Optional) Maximum signal duration that is processed
|
||||||
|
in one pass. Default: all signal.
|
||||||
:param offset: (Optional) Offset of loaded song.
|
:param offset: (Optional) Offset of loaded song.
|
||||||
:param duration: (Optional) Duration of loaded song.
|
:param duration: (Optional) Duration of loaded song.
|
||||||
:param codec: (Optional) Export codec.
|
:param codec: (Optional) Export codec.
|
||||||
@@ -115,12 +183,17 @@ class Separator(object):
|
|||||||
:param filename_format: (Optional) Filename format.
|
:param filename_format: (Optional) Filename format.
|
||||||
:param synchronous: (Optional) True is should by synchronous.
|
:param synchronous: (Optional) True is should by synchronous.
|
||||||
"""
|
"""
|
||||||
waveform, _ = audio_adapter.load(
|
waveform, sample_rate = audio_adapter.load(
|
||||||
audio_descriptor,
|
audio_descriptor,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
sample_rate=self._sample_rate)
|
sample_rate=self._sample_rate)
|
||||||
sources = self.separate(waveform)
|
sources = self.separate(waveform, audio_descriptor)
|
||||||
|
self.save_to_file(sources, audio_descriptor, destination, filename_format, codec,
|
||||||
|
audio_adapter, bitrate, synchronous)
|
||||||
|
|
||||||
|
def save_to_file(self, sources, audio_descriptor, destination, filename_format, codec,
|
||||||
|
audio_adapter, bitrate, synchronous):
|
||||||
filename = splitext(basename(audio_descriptor))[0]
|
filename = splitext(basename(audio_descriptor))[0]
|
||||||
generated = []
|
generated = []
|
||||||
for instrument, data in sources.items():
|
for instrument, data in sources.items():
|
||||||
|
|||||||
@@ -13,27 +13,37 @@ import tensorflow as tf
|
|||||||
from tensorflow.contrib import predictor
|
from tensorflow.contrib import predictor
|
||||||
# pylint: enable=import-error
|
# pylint: enable=import-error
|
||||||
|
|
||||||
from ..model import model_fn
|
from ..model import model_fn, InputProviderFactory
|
||||||
from ..model.provider import get_default_model_provider
|
from ..model.provider import get_default_model_provider
|
||||||
|
|
||||||
# Default exporting directory for predictor.
|
# Default exporting directory for predictor.
|
||||||
DEFAULT_EXPORT_DIRECTORY = join(gettempdir(), 'serving')
|
DEFAULT_EXPORT_DIRECTORY = join(gettempdir(), 'serving')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_model_dir(model_dir):
|
||||||
|
"""
|
||||||
|
Transforms a string like 'spleeter:2stems' into an actual path.
|
||||||
|
:param model_dir:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
model_provider = get_default_model_provider()
|
||||||
|
return model_provider.get(model_dir)
|
||||||
|
|
||||||
def create_estimator(params, MWF):
|
def create_estimator(params, MWF):
|
||||||
"""
|
"""
|
||||||
Initialize tensorflow estimator that will perform separation
|
Initialize tensorflow estimator that will perform separation
|
||||||
|
|
||||||
Params:
|
Params:
|
||||||
- params: a dictionnary of parameters for building the model
|
- params: a dictionary of parameters for building the model
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a tensorflow estimator
|
a tensorflow estimator
|
||||||
"""
|
"""
|
||||||
# Load model.
|
# Load model.
|
||||||
model_directory = params['model_dir']
|
|
||||||
model_provider = get_default_model_provider()
|
|
||||||
params['model_dir'] = model_provider.get(model_directory)
|
params['model_dir'] = get_default_model_dir(params['model_dir'])
|
||||||
params['MWF'] = MWF
|
params['MWF'] = MWF
|
||||||
# Setup config
|
# Setup config
|
||||||
session_config = tf.compat.v1.ConfigProto()
|
session_config = tf.compat.v1.ConfigProto()
|
||||||
@@ -56,11 +66,10 @@ def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
|
|||||||
:param estimator: Estimator to export.
|
:param estimator: Estimator to export.
|
||||||
:param directory: (Optional) path to write exported model into.
|
:param directory: (Optional) path to write exported model into.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
input_provider = InputProviderFactory.get(estimator.params)
|
||||||
def receiver():
|
def receiver():
|
||||||
shape = (None, estimator.params['n_channels'])
|
features = input_provider.get_input_dict_placeholders()
|
||||||
features = {
|
|
||||||
'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape),
|
|
||||||
'audio_id': tf.compat.v1.placeholder(tf.string)}
|
|
||||||
return tf.estimator.export.ServingInputReceiver(features, features)
|
return tf.estimator.export.ServingInputReceiver(features, features)
|
||||||
|
|
||||||
estimator.export_saved_model(directory, receiver)
|
estimator.export_saved_model(directory, receiver)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from os.path import splitext, basename, exists, join
|
|||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from spleeter import SpleeterError
|
from spleeter import SpleeterError
|
||||||
from spleeter.audio.adapter import get_default_audio_adapter
|
from spleeter.audio.adapter import get_default_audio_adapter
|
||||||
@@ -21,34 +22,38 @@ from spleeter.separator import Separator
|
|||||||
TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3'
|
TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3'
|
||||||
TEST_AUDIO_BASENAME = splitext(basename(TEST_AUDIO_DESCRIPTOR))[0]
|
TEST_AUDIO_BASENAME = splitext(basename(TEST_AUDIO_DESCRIPTOR))[0]
|
||||||
TEST_CONFIGURATIONS = [
|
TEST_CONFIGURATIONS = [
|
||||||
('spleeter:2stems', ('vocals', 'accompaniment')),
|
('spleeter:2stems', ('vocals', 'accompaniment'), 'tensorflow'),
|
||||||
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other')),
|
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'tensorflow'),
|
||||||
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'))
|
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'tensorflow'),
|
||||||
|
('spleeter:2stems', ('vocals', 'accompaniment'), 'librosa'),
|
||||||
|
('spleeter:4stems', ('vocals', 'drums', 'bass', 'other'), 'librosa'),
|
||||||
|
('spleeter:5stems', ('vocals', 'drums', 'bass', 'piano', 'other'), 'librosa')
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS)
|
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
|
||||||
def test_separate(configuration, instruments):
|
def test_separate(configuration, instruments, backend):
|
||||||
""" Test separation from raw data. """
|
""" Test separation from raw data. """
|
||||||
adapter = get_default_audio_adapter()
|
adapter = get_default_audio_adapter()
|
||||||
waveform, _ = adapter.load(TEST_AUDIO_DESCRIPTOR)
|
waveform, _ = adapter.load(TEST_AUDIO_DESCRIPTOR)
|
||||||
separator = Separator(configuration)
|
separator = Separator(configuration, stft_backend=backend)
|
||||||
prediction = separator.separate(waveform)
|
prediction = separator.separate(waveform, TEST_AUDIO_DESCRIPTOR)
|
||||||
assert len(prediction) == len(instruments)
|
assert len(prediction) == len(instruments)
|
||||||
for instrument in instruments:
|
for instrument in instruments:
|
||||||
assert instrument in prediction
|
assert instrument in prediction
|
||||||
for instrument in instruments:
|
for instrument in instruments:
|
||||||
track = prediction[instrument]
|
track = prediction[instrument]
|
||||||
assert not (waveform == track).all()
|
assert waveform.shape == track.shape
|
||||||
|
assert not np.allclose(waveform, track)
|
||||||
for compared in instruments:
|
for compared in instruments:
|
||||||
if instrument != compared:
|
if instrument != compared:
|
||||||
assert not (track == prediction[compared]).all()
|
assert not np.allclose(track, prediction[compared])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS)
|
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
|
||||||
def test_separate_to_file(configuration, instruments):
|
def test_separate_to_file(configuration, instruments, backend):
|
||||||
""" Test file based separation. """
|
""" Test file based separation. """
|
||||||
separator = Separator(configuration)
|
separator = Separator(configuration, stft_backend=backend)
|
||||||
with TemporaryDirectory() as directory:
|
with TemporaryDirectory() as directory:
|
||||||
separator.separate_to_file(
|
separator.separate_to_file(
|
||||||
TEST_AUDIO_DESCRIPTOR,
|
TEST_AUDIO_DESCRIPTOR,
|
||||||
@@ -59,10 +64,10 @@ def test_separate_to_file(configuration, instruments):
|
|||||||
'{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument)))
|
'{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument)))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('configuration, instruments', TEST_CONFIGURATIONS)
|
@pytest.mark.parametrize('configuration, instruments, backend', TEST_CONFIGURATIONS)
|
||||||
def test_filename_format(configuration, instruments):
|
def test_filename_format(configuration, instruments, backend):
|
||||||
""" Test custom filename format. """
|
""" Test custom filename format. """
|
||||||
separator = Separator(configuration)
|
separator = Separator(configuration, stft_backend=backend)
|
||||||
with TemporaryDirectory() as directory:
|
with TemporaryDirectory() as directory:
|
||||||
separator.separate_to_file(
|
separator.separate_to_file(
|
||||||
TEST_AUDIO_DESCRIPTOR,
|
TEST_AUDIO_DESCRIPTOR,
|
||||||
@@ -74,7 +79,7 @@ def test_filename_format(configuration, instruments):
|
|||||||
'export/{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument)))
|
'export/{}/{}.wav'.format(TEST_AUDIO_BASENAME, instrument)))
|
||||||
|
|
||||||
|
|
||||||
def test_filename_confilct():
|
def test_filename_conflict():
|
||||||
""" Test error handling with static pattern. """
|
""" Test error handling with static pattern. """
|
||||||
separator = Separator(TEST_CONFIGURATIONS[0][0])
|
separator = Separator(TEST_CONFIGURATIONS[0][0])
|
||||||
with TemporaryDirectory() as directory:
|
with TemporaryDirectory() as directory:
|
||||||
|
|||||||
Reference in New Issue
Block a user