Initial commit from private spleeter

This commit is contained in:
Romain
2019-10-28 14:12:13 +01:00
parent dc39414ee9
commit 556ef21214
47 changed files with 3924 additions and 3 deletions

18
spleeter/__init__.py Normal file
View File

@@ -0,0 +1,18 @@
#!/usr/bin/env python
# coding: utf8
"""
Spleeter is the Deezer source separation library with pretrained models.
The library is based on Tensorflow:
- It provides already trained model for performing separation.
- It makes it easy to train source separation model with tensorflow
(provided you have a dataset of isolated sources).
This module allows to interact easily from command line with Spleeter
by providing train, evaluation and source separation action.
"""
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'

52
spleeter/__main__.py Normal file
View File

@@ -0,0 +1,52 @@
#!/usr/bin/env python
# coding: utf8
"""
Python oneliner script usage.
USAGE: python -m spleeter {train,evaluate,separate} ...
"""
import sys
import warnings
from .commands import create_argument_parser
from .utils.configuration import load_configuration
from .utils.logging import enable_logging, enable_verbose_logging
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def main(argv):
""" Spleeter runner. Parse provided command line arguments
and run entrypoint for required command (either train,
evaluate or separate).
:param argv: Provided command line arguments.
"""
parser = create_argument_parser()
arguments = parser.parse_args(argv[1:])
if arguments.verbose:
enable_verbose_logging()
else:
enable_logging()
if arguments.command == 'separate':
from .commands.separate import entrypoint
elif arguments.command == 'train':
from .commands.train import entrypoint
elif arguments.command == 'evaluate':
from .commands.evaluate import entrypoint
params = load_configuration(arguments.params_filename)
entrypoint(arguments, params)
def entrypoint():
""" Command line entrypoint. """
warnings.filterwarnings('ignore')
main(sys.argv)
if __name__ == '__main__':
entrypoint()

View File

@@ -0,0 +1,182 @@
#!/usr/bin/env python
# coding: utf8
""" This modules provides spleeter command as well as CLI parsing methods. """
import json
from argparse import ArgumentParser
from tempfile import gettempdir
from os.path import exists, join
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# -i opt specification.
OPT_INPUT = {
'dest': 'audio_filenames',
'nargs': '+',
'help': 'List of input audio filenames',
'required': True
}
# -o opt specification.
OPT_OUTPUT = {
'dest': 'output_path',
'default': join(gettempdir(), 'separated_audio'),
'help': 'Path of the output directory to write audio files in'
}
# -p opt specification.
OPT_PARAMS = {
'dest': 'params_filename',
'default': 'spleeter:2stems',
'type': str,
'action': 'store',
'help': 'JSON filename that contains params'
}
# -n opt specification.
OPT_OUTPUT_NAMING = {
'dest': 'output_naming',
'default': 'filename',
'choices': ('directory', 'filename'),
'help': (
'Choice for naming the output base path: '
'"filename" (use the input filename, i.e '
'/path/to/audio/mix.wav will be separated to '
'<output_path>/mix/<instument1>.wav, '
'<output_path>/mix/<instument2>.wav...) or '
'"directory" (use the name of the input last level'
' directory, for instance /path/to/audio/mix.wav '
'will be separated to <output_path>/audio/<instument1>.wav'
', <output_path>/audio/<instument2>.wav)')
}
# -d opt specification (separate).
OPT_DURATION = {
'dest': 'max_duration',
'type': float,
'default': 600.,
'help': (
'Set a maximum duration for processing audio '
'(only separate max_duration first seconds of '
'the input file)')
}
# -c opt specification.
OPT_CODEC = {
'dest': 'audio_codec',
'choices': ('wav', 'mp3', 'ogg', 'm4a', 'wma', 'flac'),
'default': 'wav',
'help': 'Audio codec to be used for the separated output'
}
# -m opt specification.
OPT_MWF = {
'dest': 'MWF',
'action': 'store_const',
'const': True,
'default': False,
'help': 'Whether to use multichannel Wiener filtering for separation',
}
# --mus_dir opt specification.
OPT_MUSDB = {
'dest': 'mus_dir',
'type': str,
'required': True,
'help': 'Path to folder with musDB'
}
# -d opt specification (train).
OPT_DATA = {
'dest': 'audio_path',
'type': str,
'required': True,
'help': 'Path of the folder containing audio data for training'
}
# -a opt specification.
OPT_ADAPTER = {
'dest': 'audio_adapter',
'type': str,
'help': 'Name of the audio adapter to use for audio I/O'
}
# -a opt specification.
OPT_VERBOSE = {
'action': 'store_true',
'help': 'Shows verbose logs'
}
def _add_common_options(parser):
""" Add common option to the given parser.
:param parser: Parser to add common opt to.
"""
parser.add_argument('-a', '--adapter', **OPT_ADAPTER)
parser.add_argument('-p', '--params_filename', **OPT_PARAMS)
parser.add_argument('--verbose', **OPT_VERBOSE)
def _create_train_parser(parser_factory):
""" Creates an argparser for training command
:param parser_factory: Factory to use to create parser instance.
:returns: Created and configured parser.
"""
parser = parser_factory('train', help='Train a source separation model')
_add_common_options(parser)
parser.add_argument('-d', '--data', **OPT_DATA)
return parser
def _create_evaluate_parser(parser_factory):
""" Creates an argparser for evaluation command
:param parser_factory: Factory to use to create parser instance.
:returns: Created and configured parser.
"""
parser = parser_factory(
'evaluate',
help='Evaluate a model on the musDB test dataset')
_add_common_options(parser)
parser.add_argument('-o', '--output_path', **OPT_OUTPUT)
parser.add_argument('--mus_dir', **OPT_MUSDB)
parser.add_argument('-m', '--mwf', **OPT_MWF)
return parser
def _create_separate_parser(parser_factory):
""" Creates an argparser for separation command
:param parser_factory: Factory to use to create parser instance.
:returns: Created and configured parser.
"""
parser = parser_factory('separate', help='Separate audio files')
_add_common_options(parser)
parser.add_argument('-i', '--audio_filenames', **OPT_INPUT)
parser.add_argument('-o', '--output_path', **OPT_OUTPUT)
parser.add_argument('-n', '--output_naming', **OPT_OUTPUT_NAMING)
parser.add_argument('-d', '--max_duration', **OPT_DURATION)
parser.add_argument('-c', '--audio_codec', **OPT_CODEC)
parser.add_argument('-m', '--mwf', **OPT_MWF)
return parser
def create_argument_parser():
""" Creates overall command line parser for Spleeter.
:returns: Created argument parser.
"""
parser = ArgumentParser(prog='python -m spleeter')
subparsers = parser.add_subparsers()
subparsers.dest = 'command'
subparsers.required = True
_create_separate_parser(subparsers.add_parser)
_create_train_parser(subparsers.add_parser)
_create_evaluate_parser(subparsers.add_parser)
return parser

View File

@@ -0,0 +1,154 @@
#!/usr/bin/env python
# coding: utf8
"""
Entrypoint provider for performing model evaluation.
Evaluation is performed against musDB dataset.
USAGE: python -m spleeter evaluate \
-p /path/to/params \
-o /path/to/output/dir \
[-m] \
--mus_dir /path/to/musdb dataset
"""
import json
from argparse import Namespace
from itertools import product
from glob import glob
from os.path import join, exists
# pylint: disable=import-error
import musdb
import museval
import numpy as np
import pandas as pd
# pylint: enable=import-error
from .separate import entrypoint as separate_entrypoint
from ..utils.logging import get_logger
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
_SPLIT = 'test'
_MIXTURE = 'mixture.wav'
_NAMING = 'directory'
_AUDIO_DIRECTORY = 'audio'
_METRICS_DIRECTORY = 'metrics'
_INSTRUMENTS = ('vocals', 'drums', 'bass', 'other')
_METRICS = ('SDR', 'SAR', 'SIR', 'ISR')
def _separate_evaluation_dataset(arguments, musdb_root_directory, params):
""" Performs audio separation on the musdb dataset from
the given directory and params.
:param arguments: Entrypoint arguments.
:param musdb_root_directory: Directory to retrieve dataset from.
:param params: Spleeter configuration to apply to separation.
:returns: Separation output directory path.
"""
songs = glob(join(musdb_root_directory, _SPLIT, '*/'))
mixtures = [join(song, _MIXTURE) for song in songs]
audio_output_directory = join(
arguments.output_path,
_AUDIO_DIRECTORY)
separate_entrypoint(
Namespace(
audio_adapter=arguments.audio_adapter,
audio_filenames=mixtures,
audio_codec='wav',
output_path=join(audio_output_directory, _SPLIT),
output_naming=_NAMING,
max_duration=600.,
MWF=arguments.MWF,
verbose=arguments.verbose),
params)
return audio_output_directory
def _compute_musdb_metrics(
arguments,
musdb_root_directory,
audio_output_directory):
""" Generates musdb metrics fro previsouly computed audio estimation.
:param arguments: Entrypoint arguments.
:param audio_output_directory: Directory to get audio estimation from.
:returns: Path of generated metrics directory.
"""
metrics_output_directory = join(
arguments.output_path,
_METRICS_DIRECTORY)
get_logger().info('Starting musdb evaluation (this could be long) ...')
dataset = musdb.DB(
root=musdb_root_directory,
is_wav=True,
subsets=[_SPLIT])
museval.eval_mus_dir(
dataset=dataset,
estimates_dir=audio_output_directory,
output_dir=metrics_output_directory)
get_logger().info('musdb evaluation done')
return metrics_output_directory
def _compile_metrics(metrics_output_directory):
""" Compiles metrics from given directory and returns
results as dict.
:param metrics_output_directory: Directory to get metrics from.
:returns: Compiled metrics as dict.
"""
songs = glob(join(metrics_output_directory, 'test/*.json'))
index = pd.MultiIndex.from_tuples(
product(_INSTRUMENTS, _METRICS),
names=['instrument', 'metric'])
pd.DataFrame([], index=['config1', 'config2'], columns=index)
metrics = {
instrument: {k: [] for k in _METRICS}
for instrument in _INSTRUMENTS}
for song in songs:
with open(song, 'r') as stream:
data = json.load(stream)
for target in data['targets']:
instrument = target['name']
for metric in _METRICS:
sdr_med = np.median([
frame['metrics'][metric]
for frame in target['frames']
if not np.isnan(frame['metrics'][metric])])
metrics[instrument][metric].append(sdr_med)
return metrics
def entrypoint(arguments, params):
""" Command entrypoint.
:param arguments: Command line parsed argument as argparse.Namespace.
:param params: Deserialized JSON configuration file provided in CLI args.
"""
# Parse and check musdb directory.
musdb_root_directory = arguments.mus_dir
if not exists(musdb_root_directory):
raise IOError(f'musdb directory {musdb_root_directory} not found')
# Separate musdb sources.
audio_output_directory = _separate_evaluation_dataset(
arguments,
musdb_root_directory,
params)
# Compute metrics with musdb.
metrics_output_directory = _compute_musdb_metrics(
arguments,
musdb_root_directory,
audio_output_directory)
# Compute and pretty print median metrics.
metrics = _compile_metrics(metrics_output_directory)
for instrument, metric in metrics.items():
get_logger().info('%s:', instrument)
for metric, value in metric.items():
get_logger().info('%s: %s', metric, f'{np.median(value):.3f}')

View File

@@ -0,0 +1,180 @@
#!/usr/bin/env python
# coding: utf8
"""
Entrypoint provider for performing source separation.
USAGE: python -m spleeter separate \
-p /path/to/params \
-i inputfile1 inputfile2 ... inputfilen
-o /path/to/output/dir \
-i /path/to/audio1.wav /path/to/audio2.mp3
"""
from multiprocessing import Pool
from os.path import isabs, join, split, splitext
from tempfile import gettempdir
# pylint: disable=import-error
import tensorflow as tf
import numpy as np
# pylint: enable=import-error
from ..utils.audio.adapter import get_audio_adapter
from ..utils.audio.convertor import to_n_channels
from ..utils.estimator import create_estimator
from ..utils.tensor import set_tensor_shape
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def get_dataset(audio_adapter, filenames_and_crops, sample_rate, n_channels):
""""
Build a tensorflow dataset of waveform from a filename list wit crop
information.
Params:
- audio_adapter: An AudioAdapter instance to load audio from.
- filenames_and_crops: list of (audio_filename, start, duration)
tuples separation is performed on each filaneme
from start (in seconds) to start + duration
(in seconds).
- sample_rate: audio sample_rate of the input and output audio
signals
- n_channels: int, number of channels of the input and output
audio signals
Returns
A tensorflow dataset of waveform to feed a tensorflow estimator in
predict mode.
"""
filenames, starts, ends = list(zip(*filenames_and_crops))
dataset = tf.data.Dataset.from_tensor_slices({
'audio_id': list(filenames),
'start': list(starts),
'end': list(ends)
})
# Load waveform.
dataset = dataset.map(
lambda sample: dict(
sample,
**audio_adapter.load_tf_waveform(
sample['audio_id'],
sample_rate=sample_rate,
offset=sample['start'],
duration=sample['end'] - sample['start'])),
num_parallel_calls=2)
# Filter out error.
dataset = dataset.filter(
lambda sample: tf.logical_not(sample['waveform_error']))
# Convert waveform to the right number of channels.
dataset = dataset.map(
lambda sample: dict(
sample,
waveform=to_n_channels(sample['waveform'], n_channels)))
# Set number of channels (required for the model).
dataset = dataset.map(
lambda sample: dict(
sample,
waveform=set_tensor_shape(sample['waveform'], (None, n_channels))))
return dataset
def process_audio(
audio_adapter,
filenames_and_crops, estimator, output_path,
sample_rate, n_channels, codec, output_naming):
"""
Perform separation on a list of audio ids.
Params:
- audio_adapter: Audio adapter to use for audio I/O.
- filenames_and_crops: list of (audio_filename, start, duration)
tuples separation is performed on each filaneme
from start (in seconds) to start + duration
(in seconds).
- estimator: the tensorflow estimator that performs the
source separation.
- output_path: output_path where to export separated files.
- sample_rate: audio sample_rate of the input and output audio
signals
- n_channels: int, number of channels of the input and output
audio signals
- codec: string codec to be used for export (could be
"wav", "mp3", "ogg", "m4a") could be anything
supported by ffmpeg.
- output_naming: string (= "filename" of "directory")
naming convention for output.
for an input file /path/to/audio/input_file.wav:
* if output_naming is equal to "filename":
output files will be put in the directory <output_path>/input_file
(<output_path>/input_file/<instrument1>.<codec>,
<output_path>/input_file/<instrument2>.<codec>...).
* if output_naming is equal to "directory":
output files will be put in the directory <output_path>/audio/
(<output_path>/audio/<instrument1>.<codec>,
<output_path>/audio/<instrument2>.<codec>...)
Use "directory" when separating the MusDB dataset.
"""
# Get estimator
prediction = estimator.predict(
lambda: get_dataset(
audio_adapter,
filenames_and_crops,
sample_rate,
n_channels),
yield_single_examples=False)
# initialize pool for audio export
pool = Pool(16)
tasks = []
for sample in prediction:
sample_filename = sample.pop('audio_id', 'unknown_filename').decode()
input_directory, input_filename = split(sample_filename)
if output_naming == 'directory':
output_dirname = split(input_directory)[1]
elif output_naming == 'filename':
output_dirname = splitext(input_filename)[0]
else:
raise ValueError(f'Unknown output naming {output_naming}')
for instrument, waveform in sample.items():
filename = join(
output_path,
output_dirname,
f'{instrument}.{codec}')
tasks.append(
pool.apply_async(
audio_adapter.save,
(filename, waveform, sample_rate, codec)))
# Wait for everything to be written
for task in tasks:
task.wait(timeout=20)
def entrypoint(arguments, params):
""" Command entrypoint.
:param arguments: Command line parsed argument as argparse.Namespace.
:param params: Deserialized JSON configuration file provided in CLI args.
"""
audio_adapter = get_audio_adapter(arguments.audio_adapter)
filenames = arguments.audio_filenames
output_path = arguments.output_path
max_duration = arguments.max_duration
audio_codec = arguments.audio_codec
output_naming = arguments.output_naming
estimator = create_estimator(params, arguments.MWF)
filenames_and_crops = [
(filename, 0., max_duration)
for filename in filenames]
process_audio(
audio_adapter,
filenames_and_crops,
estimator,
output_path,
params['sample_rate'],
params['n_channels'],
codec=audio_codec,
output_naming=output_naming)

View File

@@ -0,0 +1,98 @@
#!/usr/bin/env python
# coding: utf8
"""
Entrypoint provider for performing model training.
USAGE: python -m spleeter train -p /path/to/params
"""
from functools import partial
# pylint: disable=import-error
import tensorflow as tf
# pylint: enable=import-error
from ..dataset import get_training_dataset, get_validation_dataset
from ..model import model_fn
from ..utils.audio.adapter import get_audio_adapter
from ..utils.logging import get_logger
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def _create_estimator(params):
""" Creates estimator.
:param params: TF params to build estimator from.
:returns: Built estimator.
"""
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params['model_dir'],
params=params,
config=tf.estimator.RunConfig(
save_checkpoints_steps=params['save_checkpoints_steps'],
tf_random_seed=params['random_seed'],
save_summary_steps=params['save_summary_steps'],
session_config=session_config,
log_step_count_steps=10,
keep_checkpoint_max=2))
return estimator
def _create_train_spec(params, audio_adapter, audio_path):
""" Creates train spec.
:param params: TF params to build spec from.
:returns: Built train spec.
"""
input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
train_spec = tf.estimator.TrainSpec(
input_fn=input_fn,
max_steps=params['train_max_steps'])
return train_spec
def _create_evaluation_spec(params, audio_adapter, audio_path):
""" Setup eval spec evaluating ever n seconds
:param params: TF params to build spec from.
:returns: Built evaluation spec.
"""
input_fn = partial(
get_validation_dataset,
params,
audio_adapter,
audio_path)
evaluation_spec = tf.estimator.EvalSpec(
input_fn=input_fn,
steps=None,
throttle_secs=params['throttle_secs'])
return evaluation_spec
def entrypoint(arguments, params):
""" Command entrypoint.
:param arguments: Command line parsed argument as argparse.Namespace.
:param params: Deserialized JSON configuration file provided in CLI args.
"""
audio_adapter = get_audio_adapter(arguments.audio_adapter)
audio_path = arguments.audio_path
estimator = _create_estimator(params)
train_spec = _create_train_spec(params, audio_adapter, audio_path)
evaluation_spec = _create_evaluation_spec(
params,
audio_adapter,
audio_path)
get_logger().info('Start model training')
tf.estimator.train_and_evaluate(
estimator,
train_spec,
evaluation_spec)
get_logger().info('Model training done')

464
spleeter/dataset.py Normal file
View File

@@ -0,0 +1,464 @@
#!/usr/bin/env python
# coding: utf8
"""
Module for building data preprocessing pipeline using the tensorflow data
API.
Data preprocessing such as audio loading, spectrogram computation, cropping,
feature caching or data augmentation is done using a tensorflow dataset object
that output a tuple (input_, output) where:
- input_ is a dictionary with a single key that contains the (batched) mix
spectrogram of audio samples
- output is a dictionary of spectrogram of the isolated tracks (ground truth)
"""
import time
import os
from os.path import exists, join, sep as SEPARATOR
# pylint: disable=import-error
import pandas as pd
import numpy as np
import tensorflow as tf
# pylint: enable=import-error
from .utils.audio.convertor import (
db_uint_spectrogram_to_gain,
spectrogram_to_db_uint)
from .utils.audio.spectrogram import (
compute_spectrogram_tf,
random_pitch_shift,
random_time_stretch)
from .utils.logging import get_logger
from .utils.tensor import (
check_tensor_shape,
dataset_from_csv,
set_tensor_shape,
sync_apply)
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# Default datasets path parameter to use.
DEFAULT_DATASETS_PATH = join(
'audio_database',
'separated_sources',
'experiments',
'karaoke_vocal_extraction',
'tensorflow_experiment'
)
# Default audio parameters to use.
DEFAULT_AUDIO_PARAMS = {
'instrument_list': ('vocals', 'accompaniment'),
'mix_name': 'mix',
'sample_rate': 44100,
'frame_length': 4096,
'frame_step': 1024,
'T': 512,
'F': 1024
}
def get_training_dataset(audio_params, audio_adapter, audio_path):
""" Builds training dataset.
:param audio_params: Audio parameters.
:param audio_adapter: Adapter to load audio from.
:param audio_path: Path of directory containing audio.
:returns: Built dataset.
"""
builder = DatasetBuilder(
audio_params,
audio_adapter,
audio_path,
chunk_duration=audio_params.get('chunk_duration', 20.0),
random_seed=audio_params.get('random_seed', 0))
return builder.build(
audio_params.get('train_csv'),
cache_directory=audio_params.get('training_cache'),
batch_size=audio_params.get('batch_size'),
n_chunks_per_song=audio_params.get('n_chunks_per_song', 2),
random_data_augmentation=False,
convert_to_uint=True,
wait_for_cache=False)
def get_validation_dataset(audio_params, audio_adapter, audio_path):
""" Builds validation dataset.
:param audio_params: Audio parameters.
:param audio_adapter: Adapter to load audio from.
:param audio_path: Path of directory containing audio.
:returns: Built dataset.
"""
builder = DatasetBuilder(
audio_params,
audio_adapter,
audio_path,
chunk_duration=12.0)
return builder.build(
audio_params.get('validation_csv'),
batch_size=audio_params.get('batch_size'),
cache_directory=audio_params.get('training_cache'),
convert_to_uint=True,
infinite_generator=False,
n_chunks_per_song=1,
# should not perform data augmentation for eval:
random_data_augmentation=False,
random_time_crop=False,
shuffle=False,
)
class InstrumentDatasetBuilder(object):
""" Instrument based filter and mapper provider. """
def __init__(self, parent, instrument):
""" Default constructor.
:param parent: Parent dataset builder.
:param instrument: Target instrument.
"""
self._parent = parent
self._instrument = instrument
self._spectrogram_key = f'{instrument}_spectrogram'
self._min_spectrogram_key = f'min_{instrument}_spectrogram'
self._max_spectrogram_key = f'max_{instrument}_spectrogram'
def load_waveform(self, sample):
""" Load waveform for given sample. """
return dict(sample, **self._parent._audio_adapter.load_tf_waveform(
sample[f'{self._instrument}_path'],
offset=sample['start'],
duration=self._parent._chunk_duration,
sample_rate=self._parent._sample_rate,
waveform_name='waveform'))
def compute_spectrogram(self, sample):
""" Compute spectrogram of the given sample. """
return dict(sample, **{
self._spectrogram_key: compute_spectrogram_tf(
sample['waveform'],
frame_length=self._parent._frame_length,
frame_step=self._parent._frame_step,
spec_exponent=1.,
window_exponent=1.)})
def filter_frequencies(self, sample):
""" """
return dict(sample, **{
self._spectrogram_key:
sample[self._spectrogram_key][:, :self._parent._F, :]})
def convert_to_uint(self, sample):
""" Convert given sample from float to unit. """
return dict(sample, **spectrogram_to_db_uint(
sample[self._spectrogram_key],
tensor_key=self._spectrogram_key,
min_key=self._min_spectrogram_key,
max_key=self._max_spectrogram_key))
def filter_infinity(self, sample):
""" Filter infinity sample. """
return tf.logical_not(
tf.math.is_inf(
sample[self._min_spectrogram_key]))
def convert_to_float32(self, sample):
""" Convert given sample from unit to float. """
return dict(sample, **{
self._spectrogram_key: db_uint_spectrogram_to_gain(
sample[self._spectrogram_key],
sample[self._min_spectrogram_key],
sample[self._max_spectrogram_key])})
def time_crop(self, sample):
""" """
def start(sample):
""" mid_segment_start """
return tf.cast(
tf.maximum(
tf.shape(sample[self._spectrogram_key])[0]
/ 2 - self._parent._T / 2, 0),
tf.int32)
return dict(sample, **{
self._spectrogram_key: sample[self._spectrogram_key][
start(sample):start(sample) + self._parent._T, :, :]})
def filter_shape(self, sample):
""" Filter badly shaped sample. """
return check_tensor_shape(
sample[self._spectrogram_key], (
self._parent._T, self._parent._F, 2))
def reshape_spectrogram(self, sample):
""" """
return dict(sample, **{
self._spectrogram_key: set_tensor_shape(
sample[self._spectrogram_key],
(self._parent._T, self._parent._F, 2))})
class DatasetBuilder(object):
"""
"""
# Margin at beginning and end of songs in seconds.
MARGIN = 0.5
# Wait period for cache (in seconds).
WAIT_PERIOD = 60
def __init__(
self,
audio_params, audio_adapter, audio_path,
random_seed=0, chunk_duration=20.0):
""" Default constructor.
NOTE: Probably need for AudioAdapter.
:param audio_params: Audio parameters to use.
:param audio_adapter: Audio adapter to use.
:param audio_path:
:param random_seed:
:param chunk_duration:
"""
# Length of segment in frames (if fs=22050 and
# frame_step=512, then T=512 corresponds to 11.89s)
self._T = audio_params['T']
# Number of frequency bins to be used (should
# be less than frame_length/2 + 1)
self._F = audio_params['F']
self._sample_rate = audio_params['sample_rate']
self._frame_length = audio_params['frame_length']
self._frame_step = audio_params['frame_step']
self._mix_name = audio_params['mix_name']
self._instruments = [self._mix_name] + audio_params['instrument_list']
self._instrument_builders = None
self._chunk_duration = chunk_duration
self._audio_adapter = audio_adapter
self._audio_params = audio_params
self._audio_path = audio_path
self._random_seed = random_seed
def expand_path(self, sample):
""" Expands audio paths for the given sample. """
return dict(sample, **{f'{instrument}_path': tf.string_join(
(self._audio_path, sample[f'{instrument}_path']), SEPARATOR)
for instrument in self._instruments})
def filter_error(self, sample):
""" Filter errored sample. """
return tf.logical_not(sample['waveform_error'])
def filter_waveform(self, sample):
""" Filter waveform from sample. """
return {k: v for k, v in sample.items() if not k == 'waveform'}
def harmonize_spectrogram(self, sample):
""" Ensure same size for vocals and mix spectrograms. """
def _reduce(sample):
return tf.reduce_min([
tf.shape(sample[f'{instrument}_spectrogram'])[0]
for instrument in self._instruments])
return dict(sample, **{
f'{instrument}_spectrogram':
sample[f'{instrument}_spectrogram'][:_reduce(sample), :, :]
for instrument in self._instruments})
def filter_short_segments(self, sample):
""" Filter out too short segment. """
return tf.reduce_any([
tf.shape(sample[f'{instrument}_spectrogram'])[0] >= self._T
for instrument in self._instruments])
def random_time_crop(self, sample):
""" Random time crop of 11.88s. """
return dict(sample, **sync_apply({
f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram']
for instrument in self._instruments},
lambda x: tf.image.random_crop(
x, (self._T, len(self._instruments) * self._F, 2),
seed=self._random_seed)))
def random_time_stretch(self, sample):
""" Randomly time stretch the given sample. """
return dict(sample, **sync_apply({
f'{instrument}_spectrogram':
sample[f'{instrument}_spectrogram']
for instrument in self._instruments},
lambda x: random_time_stretch(
x, factor_min=0.9, factor_max=1.1)))
def random_pitch_shift(self, sample):
""" Randomly pitch shift the given sample. """
return dict(sample, **sync_apply({
f'{instrument}_spectrogram':
sample[f'{instrument}_spectrogram']
for instrument in self._instruments},
lambda x: random_pitch_shift(
x, shift_min=-1.0, shift_max=1.0), concat_axis=0))
def map_features(self, sample):
""" Select features and annotation of the given sample. """
input_ = {
f'{self._mix_name}_spectrogram':
sample[f'{self._mix_name}_spectrogram']}
output = {
f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram']
for instrument in self._audio_params['instrument_list']}
return (input_, output)
def compute_segments(self, dataset, n_chunks_per_song):
""" Computes segments for each song of the dataset.
:param dataset: Dataset to compute segments for.
:param n_chunks_per_song: Number of segment per song to compute.
:returns: Segmented dataset.
"""
if n_chunks_per_song <= 0:
raise ValueError('n_chunks_per_song must be positif')
datasets = []
for k in range(n_chunks_per_song):
if n_chunks_per_song > 1:
datasets.append(
dataset.map(lambda sample: dict(sample, start=tf.maximum(
k * (
sample['duration'] - self._chunk_duration - 2
* self.MARGIN) / (n_chunks_per_song - 1)
+ self.MARGIN, 0))))
elif n_chunks_per_song == 1: # Take central segment.
datasets.append(
dataset.map(lambda sample: dict(sample, start=tf.maximum(
sample['duration'] / 2 - self._chunk_duration / 2,
0))))
dataset = datasets[-1]
for d in datasets[:-1]:
dataset = dataset.concatenate(d)
return dataset
@property
def instruments(self):
""" Instrument dataset builder generator.
:yield InstrumentBuilder instance.
"""
if self._instrument_builders is None:
self._instrument_builders = []
for instrument in self._instruments:
self._instrument_builders.append(
InstrumentDatasetBuilder(self, instrument))
for builder in self._instrument_builders:
yield builder
def cache(self, dataset, cache, wait):
""" Cache the given dataset if cache is enabled. Eventually waits for
cache to be available (useful if another process is already computing
cache) if provided wait flag is True.
:param dataset: Dataset to be cached if cache is required.
:param cache: Path of cache directory to be used, None if no cache.
:param wait: If caching is enabled, True is cache should be waited.
:returns: Cached dataset if needed, original dataset otherwise.
"""
if cache is not None:
if wait:
while not exists(f'{cache}.index'):
get_logger().info(
'Cache not available, wait %s',
self.WAIT_PERIOD)
time.sleep(self.WAIT_PERIOD)
cache_path = os.path.split(cache)[0]
os.makedirs(cache_path, exist_ok=True)
return dataset.cache(cache)
return dataset
def build(
self, csv_path,
batch_size=8, shuffle=True, convert_to_uint=True,
random_data_augmentation=False, random_time_crop=True,
infinite_generator=True, cache_directory=None,
wait_for_cache=False, num_parallel_calls=4, n_chunks_per_song=2,):
"""
TO BE DOCUMENTED.
"""
dataset = dataset_from_csv(csv_path)
dataset = self.compute_segments(dataset, n_chunks_per_song)
# Shuffle data
if shuffle:
dataset = dataset.shuffle(
buffer_size=200000,
seed=self._random_seed,
# useless since it is cached :
reshuffle_each_iteration=True)
# Expand audio path.
dataset = dataset.map(self.expand_path)
# Load waveform, compute spectrogram, and filtering error,
# K bins frequencies, and waveform.
N = num_parallel_calls
for instrument in self.instruments:
dataset = (
dataset
.map(instrument.load_waveform, num_parallel_calls=N)
.filter(self.filter_error)
.map(instrument.compute_spectrogram, num_parallel_calls=N)
.map(instrument.filter_frequencies))
dataset = dataset.map(self.filter_waveform)
# Convert to uint before caching in order to save space.
if convert_to_uint:
for instrument in self.instruments:
dataset = dataset.map(instrument.convert_to_uint)
dataset = self.cache(dataset, cache_directory, wait_for_cache)
# Check for INFINITY (should not happen)
for instrument in self.instruments:
dataset = dataset.filter(instrument.filter_infinity)
# Repeat indefinitly
if infinite_generator:
dataset = dataset.repeat(count=-1)
# Ensure same size for vocals and mix spectrograms.
# NOTE: could be done before caching ?
dataset = dataset.map(self.harmonize_spectrogram)
# Filter out too short segment.
# NOTE: could be done before caching ?
dataset = dataset.filter(self.filter_short_segments)
# Random time crop of 11.88s
if random_time_crop:
dataset = dataset.map(self.random_time_crop, num_parallel_calls=N)
else:
# frame_duration = 11.88/T
# take central segment (for validation)
for instrument in self.instruments:
dataset = dataset.map(instrument.time_crop)
# Post cache shuffling. Done where the data are the lightest:
# after croping but before converting back to float.
if shuffle:
dataset = dataset.shuffle(
buffer_size=256, seed=self._random_seed,
reshuffle_each_iteration=True)
# Convert back to float32
if convert_to_uint:
for instrument in self.instruments:
dataset = dataset.map(
instrument.convert_to_float32, num_parallel_calls=N)
M = 8 # Parallel call post caching.
# Must be applied with the same factor on mix and vocals.
if random_data_augmentation:
dataset = (
dataset
.map(self.random_time_stretch, num_parallel_calls=M)
.map(self.random_pitch_shift, num_parallel_calls=M))
# Filter by shape (remove badly shaped tensors).
for instrument in self.instruments:
dataset = (
dataset
.filter(instrument.filter_shape)
.map(instrument.reshape_spectrogram))
# Select features and annotation.
dataset = dataset.map(self.map_features)
# Make batch (done after selection to avoid
# error due to unprocessed instrument spectrogram batching).
dataset = dataset.batch(batch_size)
return dataset

397
spleeter/model/__init__.py Normal file
View File

@@ -0,0 +1,397 @@
#!/usr/bin/env python
# coding: utf8
""" This package provide an estimator builder as well as model functions. """
import importlib
# pylint: disable=import-error
import tensorflow as tf
from tensorflow.contrib.signal import stft, inverse_stft, hann_window
# pylint: enable=import-error
from ..utils.tensor import pad_and_partition, pad_and_reshape
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def get_model_function(model_type):
"""
Get tensorflow function of the model to be applied to the input tensor.
For instance "unet.softmax_unet" will return the softmax_unet function
in the "unet.py" submodule of the current module (spleeter.model).
Params:
- model_type: str
the relative module path to the model function.
Returns:
A tensorflow function to be applied to the input tensor to get the
multitrack output.
"""
relative_path_to_module = '.'.join(model_type.split('.')[:-1])
model_name = model_type.split('.')[-1]
main_module = '.'.join((__name__, 'functions'))
path_to_module = f'{main_module}.{relative_path_to_module}'
module = importlib.import_module(path_to_module)
model_function = getattr(module, model_name)
return model_function
class EstimatorSpecBuilder(object):
""" A builder class that allows to builds a multitrack unet model
estimator. The built model estimator has a different behaviour when
used in a train/eval mode and in predict mode.
* In train/eval mode: it takes as input and outputs magnitude spectrogram
* In predict mode: it takes as input and outputs waveform. The whole
separation process is then done in this function
for performance reason: it makes it possible to run
the whole spearation process (including STFT and
inverse STFT) on GPU.
:Example:
>>> from spleeter.model import EstimatorSpecBuilder
>>> builder = EstimatorSpecBuilder()
>>> builder.build_prediction_model()
>>> builder.build_evaluation_model()
>>> builder.build_training_model()
>>> from spleeter.model import model_fn
>>> estimator = tf.estimator.Estimator(model_fn=model_fn, ...)
"""
# Supported model functions.
DEFAULT_MODEL = 'unet.unet'
# Supported loss functions.
L1_MASK = 'L1_mask'
WEIGHTED_L1_MASK = 'weighted_L1_mask'
# Supported optimizers.
ADADELTA = 'Adadelta'
SGD = 'SGD'
# Math constants.
WINDOW_COMPENSATION_FACTOR = 2./3.
EPSILON = 1e-10
def __init__(self, features, params):
""" Default constructor. Depending on built model
usage, the provided features should be different:
* In train/eval mode: features is a dictionary with a
"mix_spectrogram" key, associated to the
mix magnitude spectrogram.
* In predict mode: features is a dictionary with a "waveform"
key, associated to the waveform of the sound
to be separated.
:param features: The input features for the estimator.
:param params: Some hyperparameters as a dictionary.
"""
self._features = features
self._params = params
# Get instrument name.
self._mix_name = params['mix_name']
self._instruments = params['instrument_list']
# Get STFT/signals parameters
self._n_channels = params['n_channels']
self._T = params['T']
self._F = params['F']
self._frame_length = params['frame_length']
self._frame_step = params['frame_step']
def _build_output_dict(self):
""" Created a batch_sizexTxFxn_channels input tensor containing
mix magnitude spectrogram, then an output dict from it according
to the selected model in internal parameters.
:returns: Build output dict.
:raise ValueError: If required model_type is not supported.
"""
input_tensor = self._features[f'{self._mix_name}_spectrogram']
model = self._params.get('model', None)
if model is not None:
model_type = model.get('type', self.DEFAULT_MODEL)
else:
model_type = self.DEFAULT_MODEL
try:
apply_model = get_model_function(model_type)
except ModuleNotFoundError:
raise ValueError(f'No model function {model_type} found')
return apply_model(
input_tensor,
self._instruments,
self._params['model']['params'])
def _build_loss(self, output_dict, labels):
""" Construct tensorflow loss and metrics
:param output_dict: dictionary of network outputs (key: instrument
name, value: estimated spectrogram of the instrument)
:param labels: dictionary of target outputs (key: instrument
name, value: ground truth spectrogram of the instrument)
:returns: tensorflow (loss, metrics) tuple.
"""
loss_type = self._params.get('loss_type', self.L1_MASK)
if loss_type == self.L1_MASK:
losses = {
name: tf.reduce_mean(tf.abs(output - labels[name]))
for name, output in output_dict.items()
}
elif loss_type == self.WEIGHTED_L1_MASK:
losses = {
name: tf.reduce_mean(
tf.reduce_mean(
labels[name],
axis=[1, 2, 3],
keep_dims=True) *
tf.abs(output - labels[name]))
for name, output in output_dict.items()
}
else:
raise ValueError(f"Unkwnown loss type: {loss_type}")
loss = tf.reduce_sum(list(losses.values()))
# Add metrics for monitoring each instrument.
metrics = {k: tf.compat.v1.metrics.mean(v) for k, v in losses.items()}
metrics['absolute_difference'] = tf.compat.v1.metrics.mean(loss)
return loss, metrics
def _build_optimizer(self):
""" Builds an optimizer instance from internal parameter values.
Default to AdamOptimizer if not specified.
:returns: Optimizer instance from internal configuration.
"""
name = self._params.get('optimizer')
if name == self.ADADELTA:
return tf.compat.v1.train.AdadeltaOptimizer()
rate = self._params['learning_rate']
if name == self.SGD:
return tf.compat.v1.train.GradientDescentOptimizer(rate)
return tf.compat.v1.train.AdamOptimizer(rate)
def _build_stft_feature(self):
""" Compute STFT of waveform and slice the STFT in segment
with the right length to feed the network.
"""
stft_feature = tf.transpose(
stft(
tf.transpose(self._features['waveform']),
self._frame_length,
self._frame_step,
window_fn=lambda frame_length, dtype: (
hann_window(frame_length, periodic=True, dtype=dtype)),
pad_end=True),
perm=[1, 2, 0])
self._features[f'{self._mix_name}_stft'] = stft_feature
self._features[f'{self._mix_name}_spectrogram'] = tf.abs(
pad_and_partition(stft_feature, self._T))[:, :, :self._F, :]
def _inverse_stft(self, stft):
""" Inverse and reshape the given STFT
:param stft: input STFT
:returns: inverse STFT (waveform)
"""
inversed = inverse_stft(
tf.transpose(stft, perm=[2, 0, 1]),
self._frame_length,
self._frame_step,
window_fn=lambda frame_length, dtype: (
hann_window(frame_length, periodic=True, dtype=dtype))
) * self.WINDOW_COMPENSATION_FACTOR
reshaped = tf.transpose(inversed)
return reshaped[:tf.shape(self._features['waveform'])[0], :]
def _build_mwf_output_waveform(self, output_dict):
""" Perform separation with multichannel Wiener Filtering using Norbert.
Note: multichannel Wiener Filtering is not coded in Tensorflow and thus
may be quite slow.
:param output_dict: dictionary of estimated spectrogram (key: instrument
name, value: estimated spectrogram of the instrument)
:returns: dictionary of separated waveforms (key: instrument name,
value: estimated waveform of the instrument)
"""
import norbert # pylint: disable=import-error
x = self._features[f'{self._mix_name}_stft']
v = tf.stack(
[
pad_and_reshape(
output_dict[f'{instrument}_spectrogram'],
self._frame_length,
self._F)[:tf.shape(x)[0], ...]
for instrument in self._instruments
],
axis=3)
input_args = [v, x]
stft_function = tf.py_function(
lambda v, x: norbert.wiener(v.numpy(), x.numpy()),
input_args,
tf.complex64),
return {
instrument: self._inverse_stft(stft_function[0][:, :, :, k])
for k, instrument in enumerate(self._instruments)
}
def _extend_mask(self, mask):
""" Extend mask, from reduced number of frequency bin to the number of
frequency bin in the STFT.
:param mask: restricted mask
:returns: extended mask
:raise ValueError: If invalid mask_extension parameter is set.
"""
extension = self._params['mask_extension']
# Extend with average
# (dispatch according to energy in the processed band)
if extension == "average":
extension_row = tf.reduce_mean(mask, axis=2, keepdims=True)
# Extend with 0
# (avoid extension artifacts but not conservative separation)
elif extension == "zeros":
mask_shape = tf.shape(mask)
extension_row = tf.zeros((
mask_shape[0],
mask_shape[1],
1,
mask_shape[-1]))
else:
raise ValueError(f'Invalid mask_extension parameter {extension}')
n_extra_row = (self._frame_length) // 2 + 1 - self._F
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
return tf.concat([mask, extension], axis=2)
def _build_manual_output_waveform(self, output_dict):
""" Perform ratio mask separation
:param output_dict: dictionary of estimated spectrogram (key: instrument
name, value: estimated spectrogram of the instrument)
:returns: dictionary of separated waveforms (key: instrument name,
value: estimated waveform of the instrument)
"""
separation_exponent = self._params['separation_exponent']
output_sum = tf.reduce_sum(
[e ** separation_exponent for e in output_dict.values()],
axis=0
) + self.EPSILON
output_waveform = {}
for instrument in self._instruments:
output = output_dict[f'{instrument}_spectrogram']
# Compute mask with the model.
instrument_mask = (
output ** separation_exponent
+ (self.EPSILON / len(output_dict))) / output_sum
# Extend mask;
instrument_mask = self._extend_mask(instrument_mask)
# Stack back mask.
old_shape = tf.shape(instrument_mask)
new_shape = tf.concat(
[[old_shape[0] * old_shape[1]], old_shape[2:]],
axis=0)
instrument_mask = tf.reshape(instrument_mask, new_shape)
# Remove padded part (for mask having the same size as STFT);
stft_feature = self._features[f'{self._mix_name}_stft']
instrument_mask = instrument_mask[
:tf.shape(stft_feature)[0], ...]
# Compute masked STFT and normalize it.
output_waveform[instrument] = self._inverse_stft(
tf.cast(instrument_mask, dtype=tf.complex64) * stft_feature)
return output_waveform
def _build_output_waveform(self, output_dict):
""" Build output waveform from given output dict in order to be used in
prediction context. Regarding of the configuration building method will
be using MWF.
:param output_dict: Output dict to build output waveform from.
:returns: Built output waveform.
"""
if self._params.get('MWF', False):
output_waveform = self._build_mwf_output_waveform(output_dict)
else:
output_waveform = self._build_manual_output_waveform(output_dict)
if 'audio_id' in self._features:
output_waveform['audio_id'] = self._features['audio_id']
return output_waveform
def build_predict_model(self):
""" Builder interface for creating model instance that aims to perform
prediction / inference over given track. The output of such estimator
will be a dictionary with a "<instrument>" key per separated instrument
, associated to the estimated separated waveform of the instrument.
:returns: An estimator for performing prediction.
"""
self._build_stft_feature()
output_dict = self._build_output_dict()
output_waveform = self._build_output_waveform(output_dict)
return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.PREDICT,
predictions=output_waveform)
def build_evaluation_model(self, labels):
""" Builder interface for creating model instance that aims to perform
model evaluation. The output of such estimator will be a dictionary
with a key "<instrument>_spectrogram" per separated instrument,
associated to the estimated separated instrument magnitude spectrogram.
:param labels: Model labels.
:returns: An estimator for performing model evaluation.
"""
output_dict = self._build_output_dict()
loss, metrics = self._build_loss(output_dict, labels)
return tf.estimator.EstimatorSpec(
tf.estimator.ModeKeys.EVAL,
loss=loss,
eval_metric_ops=metrics)
def build_train_model(self, labels):
""" Builder interface for creating model instance that aims to perform
model training. The output of such estimator will be a dictionary
with a key "<instrument>_spectrogram" per separated instrument,
associated to the estimated separated instrument magnitude spectrogram.
:param labels: Model labels.
:returns: An estimator for performing model training.
"""
output_dict = self._build_output_dict()
loss, metrics = self._build_loss(output_dict, labels)
optimizer = self._build_optimizer()
train_operation = optimizer.minimize(
loss=loss,
global_step=tf.compat.v1.train.get_global_step())
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN,
loss=loss,
train_op=train_operation,
eval_metric_ops=metrics,
)
def model_fn(features, labels, mode, params, config):
"""
:param features:
:param labels:
:param mode: Estimator mode.
:param params:
:param config: TF configuration (not used).
:returns: Built EstimatorSpec.
:raise ValueError: If estimator mode is not supported.
"""
builder = EstimatorSpecBuilder(features, params)
if mode == tf.estimator.ModeKeys.PREDICT:
return builder.build_predict_model()
elif mode == tf.estimator.ModeKeys.EVAL:
return builder.build_evaluation_model(labels)
elif mode == tf.estimator.ModeKeys.TRAIN:
return builder.build_train_model(labels)
raise ValueError(f'Unknown mode {mode}')

View File

@@ -0,0 +1,27 @@
#!/usr/bin/env python
# coding: utf8
""" This package provide model functions. """
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def apply(function, input_tensor, instruments, params={}):
""" Apply given function to the input tensor.
:param function: Function to be applied to tensor.
:param input_tensor: Tensor to apply blstm to.
:param instruments: Iterable that provides a collection of instruments.
:param params: (Optional) dict of BLSTM parameters.
:returns: Created output tensor dict.
"""
output_dict = {}
for instrument in instruments:
out_name = f'{instrument}_spectrogram'
output_dict[out_name] = function(
input_tensor,
output_name=out_name,
params=params)
return output_dict

View File

@@ -0,0 +1,76 @@
#!/usr/bin/env python
# coding: utf8
"""
This system (UHL1) uses a bi-directional LSTM network as described in :
`S. Uhlich, M. Porcu, F. Giron, M. Enenkl, T. Kemp, N. Takahashi and
Y. Mitsufuji.
"Improving music source separation based on deep neural networks through
data augmentation and network blending", Proc. ICASSP, 2017.`
It has three BLSTM layers, each having 500 cells. For each instrument,
a network is trained which predicts the target instrument amplitude from
the mixture amplitude in the STFT domain (frame size: 4096, hop size:
1024). The raw output of each network is then combined by a multichannel
Wiener filter. The network is trained on musdb where we split train into
train_train and train_valid with 86 and 14 songs, respectively. The
validation set is used to perform early stopping and hyperparameter
selection (LSTM layer dropout rate, regularization strength).
"""
# pylint: disable=import-error
from tensorflow.compat.v1.keras.initializers import he_uniform
from tensorflow.compat.v1.keras.layers import CuDNNLSTM
from tensorflow.keras.layers import (
Bidirectional,
Dense,
Flatten,
Reshape,
TimeDistributed)
# pylint: enable=import-error
from . import apply
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def apply_blstm(input_tensor, output_name='output', params={}):
""" Apply BLSTM to the given input_tensor.
:param input_tensor: Input of the model.
:param output_name: (Optional) name of the output, default to 'output'.
:param params: (Optional) dict of BLSTM parameters.
:returns: Output tensor.
"""
units = params.get('lstm_units', 250)
kernel_initializer = he_uniform(seed=50)
flatten_input = TimeDistributed(Flatten())((input_tensor))
def create_bidirectional():
return Bidirectional(
CuDNNLSTM(
units,
kernel_initializer=kernel_initializer,
return_sequences=True))
l1 = create_bidirectional()((flatten_input))
l2 = create_bidirectional()((l1))
l3 = create_bidirectional()((l2))
dense = TimeDistributed(
Dense(
int(flatten_input.shape[2]),
activation='relu',
kernel_initializer=kernel_initializer))((l3))
output = TimeDistributed(
Reshape(input_tensor.shape[2:]),
name=output_name)(dense)
return output
def blstm(input_tensor, output_name='output', params={}):
""" Model function applier. """
return apply(apply_blstm, input_tensor, output_name, params)

View File

@@ -0,0 +1,201 @@
#!/usr/bin/env python
# coding: utf8
"""
This module contains building functions for U-net source separation source
separation models.
Each instrument is modeled by a single U-net convolutional/deconvolutional
network that take a mix spectrogram as input and the estimated sound spectrogram
as output.
"""
from functools import partial
# pylint: disable=import-error
import tensorflow as tf
from tensorflow.keras.layers import (
BatchNormalization,
Concatenate,
Conv2D,
Conv2DTranspose,
Dropout,
ELU,
LeakyReLU,
Multiply,
ReLU,
Softmax)
from tensorflow.compat.v1 import logging
from tensorflow.compat.v1.keras.initializers import he_uniform
# pylint: enable=import-error
from . import apply
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def _get_conv_activation_layer(params):
"""
:param params:
:returns: Required Activation function.
"""
conv_activation = params.get('conv_activation')
if conv_activation == 'ReLU':
return ReLU()
elif conv_activation == 'ELU':
return ELU()
return LeakyReLU(0.2)
def _get_deconv_activation_layer(params):
"""
:param params:
:returns: Required Activation function.
"""
deconv_activation = params.get('deconv_activation')
if deconv_activation == 'LeakyReLU':
return LeakyReLU(0.2)
elif deconv_activation == 'ELU':
return ELU()
return ReLU()
def apply_unet(
input_tensor,
output_name='output',
params={},
output_mask_logit=False):
""" Apply a convolutionnal U-net to model a single instrument (one U-net
is used for each instrument).
:param input_tensor:
:param output_name: (Optional) , default to 'output'
:param params: (Optional) , default to empty dict.
:param output_mask_logit: (Optional) , default to False.
"""
logging.info(f'Apply unet for {output_name}')
conv_n_filters = params.get('conv_n_filters', [16, 32, 64, 128, 256, 512])
conv_activation_layer = _get_conv_activation_layer(params)
deconv_activation_layer = _get_deconv_activation_layer(params)
kernel_initializer = he_uniform(seed=50)
conv2d_factory = partial(
Conv2D,
strides=(2, 2),
padding='same',
kernel_initializer=kernel_initializer)
# First layer.
conv1 = conv2d_factory(conv_n_filters[0], (5, 5))(input_tensor)
batch1 = BatchNormalization(axis=-1)(conv1)
rel1 = conv_activation_layer(batch1)
# Second layer.
conv2 = conv2d_factory(conv_n_filters[1], (5, 5))(rel1)
batch2 = BatchNormalization(axis=-1)(conv2)
rel2 = conv_activation_layer(batch2)
# Third layer.
conv3 = conv2d_factory(conv_n_filters[2], (5, 5))(rel2)
batch3 = BatchNormalization(axis=-1)(conv3)
rel3 = conv_activation_layer(batch3)
# Fourth layer.
conv4 = conv2d_factory(conv_n_filters[3], (5, 5))(rel3)
batch4 = BatchNormalization(axis=-1)(conv4)
rel4 = conv_activation_layer(batch4)
# Fifth layer.
conv5 = conv2d_factory(conv_n_filters[4], (5, 5))(rel4)
batch5 = BatchNormalization(axis=-1)(conv5)
rel5 = conv_activation_layer(batch5)
# Sixth layer
conv6 = conv2d_factory(conv_n_filters[5], (5, 5))(rel5)
batch6 = BatchNormalization(axis=-1)(conv6)
_ = conv_activation_layer(batch6)
#
#
conv2d_transpose_factory = partial(
Conv2DTranspose,
strides=(2, 2),
padding='same',
kernel_initializer=kernel_initializer)
#
up1 = conv2d_transpose_factory(conv_n_filters[4], (5, 5))((conv6))
up1 = deconv_activation_layer(up1)
batch7 = BatchNormalization(axis=-1)(up1)
drop1 = Dropout(0.5)(batch7)
merge1 = Concatenate(axis=-1)([conv5, drop1])
#
up2 = conv2d_transpose_factory(conv_n_filters[3], (5, 5))((merge1))
up2 = deconv_activation_layer(up2)
batch8 = BatchNormalization(axis=-1)(up2)
drop2 = Dropout(0.5)(batch8)
merge2 = Concatenate(axis=-1)([conv4, drop2])
#
up3 = conv2d_transpose_factory(conv_n_filters[2], (5, 5))((merge2))
up3 = deconv_activation_layer(up3)
batch9 = BatchNormalization(axis=-1)(up3)
drop3 = Dropout(0.5)(batch9)
merge3 = Concatenate(axis=-1)([conv3, drop3])
#
up4 = conv2d_transpose_factory(conv_n_filters[1], (5, 5))((merge3))
up4 = deconv_activation_layer(up4)
batch10 = BatchNormalization(axis=-1)(up4)
merge4 = Concatenate(axis=-1)([conv2, batch10])
#
up5 = conv2d_transpose_factory(conv_n_filters[0], (5, 5))((merge4))
up5 = deconv_activation_layer(up5)
batch11 = BatchNormalization(axis=-1)(up5)
merge5 = Concatenate(axis=-1)([conv1, batch11])
#
up6 = conv2d_transpose_factory(1, (5, 5), strides=(2, 2))((merge5))
up6 = deconv_activation_layer(up6)
batch12 = BatchNormalization(axis=-1)(up6)
# Last layer to ensure initial shape reconstruction.
if not output_mask_logit:
up7 = Conv2D(
2,
(4, 4),
dilation_rate=(2, 2),
activation='sigmoid',
padding='same',
kernel_initializer=kernel_initializer)((batch12))
output = Multiply(name=output_name)([up7, input_tensor])
return output
return Conv2D(
2,
(4, 4),
dilation_rate=(2, 2),
padding='same',
kernel_initializer=kernel_initializer)((batch12))
def unet(input_tensor, instruments, params={}):
""" Model function applier. """
return apply(apply_unet, input_tensor, instruments, params)
def softmax_unet(input_tensor, instruments, params={}):
""" Apply softmax to multitrack unet in order to have mask suming to one.
:param input_tensor: Tensor to apply blstm to.
:param instruments: Iterable that provides a collection of instruments.
:param params: (Optional) dict of BLSTM parameters.
:returns: Created output tensor dict.
"""
logit_mask_list = []
for instrument in instruments:
out_name = f'{instrument}_spectrogram'
logit_mask_list.append(
apply_unet(
input_tensor,
output_name=out_name,
params=params,
output_mask_logit=True))
masks = Softmax(axis=4)(tf.stack(logit_mask_list, axis=4))
output_dict = {}
for i, instrument in enumerate(instruments):
out_name = f'{instrument}_spectrogram'
output_dict[out_name] = Multiply(name=out_name)([
masks[..., i],
input_tensor])
return output_dict

View File

@@ -0,0 +1,79 @@
#!/usr/bin/env python
# coding: utf8
"""
This package provides tools for downloading model from network
using remote storage abstraction.
:Example:
>>> provider = MyProviderImplementation()
>>> provider.get('/path/to/local/storage', params)
"""
from abc import ABC, abstractmethod
from os import environ, makedirs
from os.path import exists, isabs, join, sep
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
class ModelProvider(ABC):
"""
A ModelProvider manages model files on disk and
file download is not available.
"""
DEFAULT_MODEL_PATH = environ.get('MODEL_PATH', 'pretrained_models')
MODEL_PROBE_PATH = '.probe'
@abstractmethod
def download(self, name, path):
""" Download model denoted by the given name to disk.
:param name: Name of the model to download.
:param path: Path of the directory to save model into.
"""
pass
def writeProbe(self, directory):
""" Write a model probe file into the given directory.
:param directory: Directory to write probe into.
"""
with open(join(directory, self.MODEL_PROBE_PATH), 'w') as stream:
stream.write('OK')
def get(self, model_directory):
""" Ensures required model is available at given location.
:param model_directory: Expected model_directory to be available.
:raise IOError: If model can not be retrieved.
"""
# Expend model directory if needed.
if not isabs(model_directory):
model_directory = join(self.DEFAULT_MODEL_PATH, model_directory)
# Download it if not exists.
model_probe = join(model_directory, self.MODEL_PROBE_PATH)
if not exists(model_probe):
if not exists(model_directory):
makedirs(model_directory)
self.download(
model_directory.split(sep)[-1],
model_directory)
self.writeProbe(model_directory)
return model_directory
def get_default_model_provider():
""" Builds and returns a default model provider.
:returns: A default model provider instance to use.
"""
from .github import GithubModelProvider
host = environ.get('GITHUB_HOST', 'https://github.com')
repository = environ.get('GITHUB_REPOSITORY', 'deezer/spleeter')
release = environ.get('GITHUB_RELEASE', GithubModelProvider.LATEST_RELEASE)
return GithubModelProvider(host, repository, release)

View File

@@ -0,0 +1,73 @@
#!/usr/bin/env python
# coding: utf8
"""
A ModelProvider backed by Github Release feature.
:Example:
>>> from spleeter.model.provider import github
>>> provider = github.GithubModelProvider(
'github.com',
'Deezer/spleeter',
'latest')
>>> provider.download('2stems', '/path/to/local/storage')
"""
import tarfile
from os import environ
from tempfile import TemporaryFile
from shutil import copyfileobj
import requests
from . import ModelProvider
from ...utils.logging import get_logger
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
class GithubModelProvider(ModelProvider):
""" A ModelProvider implementation backed on Github for remote storage. """
LATEST_RELEASE = 'v1.4.0'
RELEASE_PATH = 'releases/download'
def __init__(self, host, repository, release):
""" Default constructor.
:param host: Host to the Github instance to reach.
:param repository: Repository path within target Github.
:param release: Release name to get models from.
"""
self._host = host
self._repository = repository
self._release = release
def download(self, name, path):
""" Download model denoted by the given name to disk.
:param name: Name of the model to download.
:param path: Path of the directory to save model into.
"""
url = '{}/{}/{}/{}/{}.tar.gz'.format(
self._host,
self._repository,
self.RELEASE_PATH,
self._release,
name)
get_logger().info('Downloading model archive %s', url)
response = requests.get(url, stream=True)
if response.status_code != 200:
raise IOError(f'Resource {url} not found')
with TemporaryFile() as stream:
copyfileobj(response.raw, stream)
get_logger().debug('Extracting downloaded archive')
stream.seek(0)
tar = tarfile.open(fileobj=stream)
tar.extractall(path=path)
tar.close()
get_logger().debug('Model file extracted')

View File

@@ -0,0 +1,28 @@
{
"train_csv": "path/to/train.csv",
"validation_csv": "path/to/test.csv",
"model_dir": "2stems",
"mix_name": "mix",
"instrument_list": ["vocals", "accompaniment"],
"sample_rate":44100,
"frame_length":4096,
"frame_step":1024,
"T":512,
"F":1024,
"n_channels":2,
"separation_exponent":2,
"mask_extension":"zeros",
"learning_rate": 1e-4,
"batch_size":4,
"training_cache":"training_cache",
"validation_cache":"validation_cache",
"train_max_steps": 1000000,
"throttle_secs":300,
"random_seed":0,
"save_checkpoints_steps":150,
"save_summary_steps":5,
"model":{
"type":"unet.unet",
"params":{}
}
}

View File

@@ -0,0 +1,31 @@
{
"train_csv": "path/to/train.csv",
"validation_csv": "path/to/val.csv",
"model_dir": "4stems",
"mix_name": "mix",
"instrument_list": ["vocals", "drums", "bass", "other"],
"sample_rate":44100,
"frame_length":4096,
"frame_step":1024,
"T":512,
"F":1024,
"n_channels":2,
"separation_exponent":2,
"mask_extension":"zeros",
"learning_rate": 1e-4,
"batch_size":4,
"training_cache":"training_cache",
"validation_cache":"validation_cache",
"train_max_steps": 1500000,
"throttle_secs":600,
"random_seed":3,
"save_checkpoints_steps":300,
"save_summary_steps":5,
"model":{
"type":"unet.unet",
"params":{
"conv_activation":"ELU",
"deconv_activation":"ELU"
}
}
}

View File

@@ -0,0 +1,31 @@
{
"train_csv": "path/to/train.csv",
"validation_csv": "path/to/test.csv",
"model_dir": "5stems",
"mix_name": "mix",
"instrument_list": ["vocals", "piano", "drums", "bass", "other"],
"sample_rate":44100,
"frame_length":4096,
"frame_step":1024,
"T":512,
"F":1024,
"n_channels":2,
"separation_exponent":2,
"mask_extension":"zeros",
"learning_rate": 1e-4,
"batch_size":4,
"training_cache":"training_cache",
"validation_cache":"validation_cache",
"train_max_steps": 2500000,
"throttle_secs":600,
"random_seed":8,
"save_checkpoints_steps":300,
"save_summary_steps":5,
"model":{
"type":"unet.softmax_unet",
"params":{
"conv_activation":"ELU",
"deconv_activation":"ELU"
}
}
}

View File

@@ -0,0 +1,8 @@
#!/usr/bin/env python
# coding: utf8
""" Packages that provides static resources file for the library. """
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'

View File

@@ -0,0 +1,32 @@
{
"train_csv": "configs/musdb_train.csv",
"validation_csv": "configs/musdb_validation.csv",
"model_dir": "musdb_model",
"mix_name": "mix",
"instrument_list": ["vocals", "drums", "bass", "other"],
"sample_rate":44100,
"frame_length":4096,
"frame_step":1024,
"T":512,
"F":1024,
"n_channels":2,
"n_chunks_per_song":1,
"separation_exponent":2,
"mask_extension":"zeros",
"learning_rate": 1e-4,
"batch_size":4,
"training_cache":"training_cache",
"validation_cache":"validation_cache",
"train_max_steps": 100000,
"throttle_secs":600,
"random_seed":3,
"save_checkpoints_steps":300,
"save_summary_steps":5,
"model":{
"type":"unet.unet",
"params":{
"conv_activation":"ELU",
"deconv_activation":"ELU"
}
}
}

127
spleeter/separator.py Normal file
View File

@@ -0,0 +1,127 @@
#!/usr/bin/env python
# coding: utf8
"""
Module that provides a class wrapper for source separation.
:Example:
>>> from spleeter.separator import Separator
>>> separator = Separator('spleeter:2stems')
>>> separator.separate(waveform, lambda instrument, data: ...)
>>> separator.separate_to_file(...)
"""
import os
import json
from functools import partial
from multiprocessing import Pool
from pathlib import Path
from os.path import join
from .model import model_fn
from .utils.audio.adapter import get_default_audio_adapter
from .utils.audio.convertor import to_stereo
from .utils.configuration import load_configuration
from .utils.estimator import create_estimator, to_predictor
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
class Separator(object):
""" A wrapper class for performing separation. """
def __init__(self, params_descriptor, MWF=False):
""" Default constructor.
:param params_descriptor: Descriptor for TF params to be used.
:param MWF: (Optional) True if MWF should be used, False otherwise.
"""
self._params = load_configuration(params_descriptor)
self._sample_rate = self._params['sample_rate']
self._MWF = MWF
self._predictor = None
self._pool = Pool()
self._tasks = []
def _get_predictor(self):
""" Lazy loading access method for internal predictor instance.
:returns: Predictor to use for source separation.
"""
if self._predictor is None:
estimator = create_estimator(self._params, self._MWF)
self._predictor = to_predictor(estimator)
return self._predictor
def join(self, timeout=20):
""" Wait for all pending tasks to be finished.
:param timeout: (Optional) task waiting timeout.
"""
while len(self._tasks) > 0:
task = self._tasks.pop()
task.get()
task.wait(timeout=timeout)
def separate(self, waveform):
""" Performs source separation over the given waveform.
The separation is performed synchronously but the result
processing is done asynchronously, allowing for instance
to export audio in parallel (through multiprocessing).
Given result is passed by to the given consumer, which will
be waited for task finishing if synchronous flag is True.
:param waveform: Waveform to apply separation on.
:returns: Separated waveforms.
"""
if not waveform.shape[-1] == 2:
waveform = to_stereo(waveform)
predictor = self._get_predictor()
prediction = predictor({
'waveform': waveform,
'audio_id': ''})
prediction.pop('audio_id')
return prediction
def separate_to_file(
self, audio_descriptor, destination,
audio_adapter=get_default_audio_adapter(),
offset=0, duration=600., codec='wav', bitrate='128k',
synchronous=True):
""" Performs source separation and export result to file using
given audio adapter.
:param audio_descriptor: Describe song to separate, used by audio
adapter to retrieve and load audio data,
in case of file based audio adapter, such
descriptor would be a file path.
:param destination: Target directory to write output to.
:param audio_adapter: (Optional) Audio adapter to use for I/O.
:param offset: (Optional) Offset of loaded song.
:param duration: (Optional) Duration of loaded song.
:param codec: (Optional) Export codec.
:param bitrate: (Optional) Export bitrate.
:param synchronous: (Optional) True is should by synchronous.
"""
waveform, _ = audio_adapter.load(
audio_descriptor,
offset=offset,
duration=duration,
sample_rate=self._sample_rate)
sources = self.separate(waveform)
for instrument, data in sources.items():
task = self._pool.apply_async(audio_adapter.save, (
join(destination, f'{instrument}.{codec}'),
data,
self._sample_rate,
codec,
bitrate))
self._tasks.append(task)
if synchronous:
self.join()

View File

@@ -0,0 +1,8 @@
#!/usr/bin/env python
# coding: utf8
""" This package provides utility function and classes. """
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'

View File

@@ -0,0 +1,15 @@
#!/usr/bin/env python
# coding: utf8
"""
`spleeter.utils.audio` package provides various
tools for manipulating audio content such as :
- Audio adapter class for abstract interaction with audio file.
- FFMPEG implementation for audio adapter.
- Waveform convertion and transforming functions.
"""
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'

View File

@@ -0,0 +1,144 @@
#!/usr/bin/env python
# coding: utf8
""" AudioAdapter class defintion. """
import subprocess
from abc import ABC, abstractmethod
from importlib import import_module
from os.path import exists
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
from tensorflow.contrib.signal import stft, hann_window
# pylint: enable=import-error
from ..logging import get_logger
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
class AudioAdapter(ABC):
""" An abstract class for manipulating audio signal. """
# Default audio adapter singleton instance.
DEFAULT = None
@abstractmethod
def load(
self, audio_descriptor, offset, duration,
sample_rate, dtype=np.float32):
""" Loads the audio file denoted by the given audio descriptor
and returns it data as a waveform. Aims to be implemented
by client.
:param audio_descriptor: Describe song to load, in case of file
based audio adapter, such descriptor would
be a file path.
:param offset: Start offset to load from in seconds.
:param duration: Duration to load in seconds.
:param sample_rate: Sample rate to load audio with.
:param dtype: Numpy data type to use, default to float32.
:returns: Loaded data as (wf, sample_rate) tuple.
"""
pass
def load_tf_waveform(
self, audio_descriptor,
offset=0.0, duration=1800., sample_rate=44100,
dtype=b'float32', waveform_name='waveform'):
""" Load the audio and convert it to a tensorflow waveform.
:param audio_descriptor: Describe song to load, in case of file
based audio adapter, such descriptor would
be a file path.
:param offset: Start offset to load from in seconds.
:param duration: Duration to load in seconds.
:param sample_rate: Sample rate to load audio with.
:param dtype: Numpy data type to use, default to float32.
:param waveform_name: (Optional) Name of the key in output dict.
:returns: TF output dict with waveform as
(T x chan numpy array) and a boolean that
tells whether there were an error while
trying to load the waveform.
"""
# Cast parameters to TF format.
offset = tf.cast(offset, tf.float64)
duration = tf.cast(duration, tf.float64)
# Defined safe loading function.
def safe_load(path, offset, duration, sample_rate, dtype):
get_logger().info(
f'Loading audio {path} from {offset} to {offset + duration}')
try:
(data, _) = self.load(
path.numpy(),
offset.numpy(),
duration.numpy(),
sample_rate.numpy(),
dtype=dtype.numpy())
return (data, False)
except Exception as e:
get_logger().warning(e)
return (np.float32(-1.0), True)
# Execute function and format results.
results = tf.py_function(
safe_load,
[audio_descriptor, offset, duration, sample_rate, dtype],
(tf.float32, tf.bool)),
waveform, error = results[0]
return {
waveform_name: waveform,
f'{waveform_name}_error': error
}
@abstractmethod
def save(
self, path, data, sample_rate,
codec=None, bitrate=None):
""" Save the given audio data to the file denoted by
the given path.
:param path: Path of the audio file to save data in.
:param data: Waveform data to write.
:param sample_rate: Sample rate to write file in.
:param codec: (Optional) Writing codec to use.
:param bitrate: (Optional) Bitrate of the written audio file.
"""
pass
def get_default_audio_adapter():
""" Builds and returns a default audio adapter instance.
:returns: An audio adapter instance.
"""
if AudioAdapter.DEFAULT is None:
from .ffmpeg import FFMPEGProcessAudioAdapter
AudioAdapter.DEFAULT = FFMPEGProcessAudioAdapter()
return AudioAdapter.DEFAULT
def get_audio_adapter(descriptor):
""" Load dynamically an AudioAdapter from given class descriptor.
:param descriptor: Adapter class descriptor (module.Class)
:returns: Created adapter instance.
"""
if descriptor is None:
return get_default_audio_adapter()
module_path = descriptor.split('.')
adapter_class_name = module_path[-1]
module_path = '.'.join(module_path[:-1])
adapter_module = import_module(module_path)
adapter_class = getattr(adapter_module, adapter_class_name)
if not isinstance(adapter_class, AudioAdapter):
raise ValueError(
f'{adapter_class_name} is not a valid AudioAdapter class')
return adapter_class()

View File

@@ -0,0 +1,88 @@
#!/usr/bin/env python
# coding: utf8
""" This module provides audio data convertion functions. """
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
# pylint: enable=import-error
from ..tensor import from_float32_to_uint8, from_uint8_to_float32
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def to_n_channels(waveform, n_channels):
""" Convert a waveform to n_channels by removing or
duplicating channels if needed (in tensorflow).
:param waveform: Waveform to transform.
:param n_channels: Number of channel to reshape waveform in.
:returns: Reshaped waveform.
"""
return tf.cond(
tf.shape(waveform)[1] >= n_channels,
true_fn=lambda: waveform[:, :n_channels],
false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels]
)
def to_stereo(waveform):
""" Convert a waveform to stereo by duplicating if mono,
or truncating if too many channels.
:param waveform: a (N, d) numpy array.
:returns: A stereo waveform as a (N, 1) numpy array.
"""
if waveform.shape[1] == 1:
return np.repeat(waveform, 2, axis=-1)
if waveform.shape[1] > 2:
return waveform[:, :2]
return waveform
def gain_to_db(tensor, espilon=10e-10):
""" Convert from gain to decibel in tensorflow.
:param tensor: Tensor to convert.
:param epsilon: Operation constant.
:returns: Converted tensor.
"""
return 20. / np.log(10) * tf.math.log(tf.maximum(tensor, espilon))
def db_to_gain(tensor):
""" Convert from decibel to gain in tensorflow.
:param tensor_db: Tensor to convert.
:returns: Converted tensor.
"""
return tf.pow(10., (tensor / 20.))
def spectrogram_to_db_uint(spectrogram, db_range=100., **kwargs):
""" Encodes given spectrogram into uint8 using decibel scale.
:param spectrogram: Spectrogram to be encoded as TF float tensor.
:param db_range: Range in decibel for encoding.
:returns: Encoded decibel spectrogram as uint8 tensor.
"""
db_spectrogram = gain_to_db(spectrogram)
max_db_spectrogram = tf.reduce_max(db_spectrogram)
db_spectrogram = tf.maximum(db_spectrogram, max_db_spectrogram - db_range)
return from_float32_to_uint8(db_spectrogram, **kwargs)
def db_uint_spectrogram_to_gain(db_uint_spectrogram, min_db, max_db):
""" Decode spectrogram from uint8 decibel scale.
:param db_uint_spectrogram: Decibel pectrogram to decode.
:param min_db: Lower bound limit for decoding.
:param max_db: Upper bound limit for decoding.
:returns: Decoded spectrogram as float2 tensor.
"""
db_spectrogram = from_uint8_to_float32(db_uint_spectrogram, min_db, max_db)
return db_to_gain(db_spectrogram)

View File

@@ -0,0 +1,263 @@
#!/usr/bin/env python
# coding: utf8
"""
This module provides an AudioAdapter implementation based on FFMPEG
process. Such implementation is POSIXish and depends on nothing except
standard Python libraries. Thus this implementation is the default one
used within this library.
"""
import os
import os.path
import platform
import re
import subprocess
import numpy as np # pylint: disable=import-error
from .adapter import AudioAdapter
from ..logging import get_logger
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# Default FFMPEG binary name.
_UNIX_BINARY = 'ffmpeg'
_WINDOWS_BINARY = 'ffmpeg.exe'
def _which(program):
""" A pure python implementation of `which`command
for retrieving absolute path from command name or path.
@see https://stackoverflow.com/a/377028/1211342
:param program: Program name or path to expend.
:returns: Absolute path of program if any, None otherwise.
"""
def is_exe(fpath):
return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
fpath, _ = os.path.split(program)
if fpath:
if is_exe(program):
return program
else:
for path in os.environ['PATH'].split(os.pathsep):
exe_file = os.path.join(path, program)
if is_exe(exe_file):
return exe_file
return None
def _get_ffmpeg_path():
""" Retrieves FFMPEG binary path using ENVVAR if defined
or default binary name (Windows or UNIX style).
:returns: Absolute path of FFMPEG binary.
:raise IOError: If FFMPEG binary cannot be found.
"""
ffmpeg_path = os.environ.get('FFMPEG_PATH', None)
if ffmpeg_path is None:
# Note: try to infer standard binary name regarding of platform.
if platform.system() == 'Windows':
ffmpeg_path = _WINDOWS_BINARY
else:
ffmpeg_path = _UNIX_BINARY
expended = _which(ffmpeg_path)
if expended is None:
raise IOError(f'FFMPEG binary ({ffmpeg_path}) not found')
return expended
def _to_ffmpeg_time(n):
""" Format number of seconds to time expected by FFMPEG.
:param n: Time in seconds to format.
:returns: Formatted time in FFMPEG format.
"""
m, s = divmod(n, 60)
h, m = divmod(m, 60)
return '%d:%02d:%09.6f' % (h, m, s)
def _parse_ffmpg_results(stderr):
""" Extract number of channels and sample rate from
the given FFMPEG STDERR output line.
:param stderr: STDERR output line to parse.
:returns: Parsed n_channels and sample_rate values.
"""
# Setup default value.
n_channels = 0
sample_rate = 0
# Find samplerate
match = re.search(r'(\d+) hz', stderr)
if match:
sample_rate = int(match.group(1))
# Channel count.
match = re.search(r'hz, ([^,]+),', stderr)
if match:
mode = match.group(1)
if mode == 'stereo':
n_channels = 2
else:
match = re.match(r'(\d+) ', mode)
n_channels = match and int(match.group(1)) or 1
return n_channels, sample_rate
class _CommandBuilder(object):
""" A simple builder pattern class for CLI string. """
def __init__(self, binary):
""" Default constructor. """
self._command = [binary]
def flag(self, flag):
""" Add flag or unlabelled opt. """
self._command.append(flag)
return self
def opt(self, short, value, formatter=str):
""" Add option if value not None. """
if value is not None:
self._command.append(short)
self._command.append(formatter(value))
return self
def command(self):
""" Build string command. """
return self._command
class FFMPEGProcessAudioAdapter(AudioAdapter):
""" An AudioAdapter implementation that use FFMPEG binary through
subprocess in order to perform I/O operation for audio processing.
When created, FFMPEG binary path will be checked and expended,
raising exception if not found. Such path could be infered using
FFMPEG_PATH environment variable.
"""
def __init__(self):
""" Default constructor. """
self._ffmpeg_path = _get_ffmpeg_path()
def _get_command_builder(self):
""" Creates and returns a command builder using FFMPEG path.
:returns: Built command builder.
"""
return _CommandBuilder(self._ffmpeg_path)
def load(
self, path, offset=None, duration=None,
sample_rate=None, dtype=np.float32):
""" Loads the audio file denoted by the given path
and returns it data as a waveform.
:param path: Path of the audio file to load data from.
:param offset: (Optional) Start offset to load from in seconds.
:param duration: (Optional) Duration to load in seconds.
:param sample_rate: (Optional) Sample rate to load audio with.
:param dtype: (Optional) Numpy data type to use, default to float32.
:returns: Loaded data a (waveform, sample_rate) tuple.
"""
if not isinstance(path, str):
path = path.decode()
command = (
self._get_command_builder()
.opt('-ss', offset, formatter=_to_ffmpeg_time)
.opt('-t', duration, formatter=_to_ffmpeg_time)
.opt('-i', path)
.opt('-ar', sample_rate)
.opt('-f', 'f32le')
.flag('-')
.command())
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
buffer = process.stdout.read(-1)
# Read STDERR until end of the process detected.
while True:
status = process.stderr.readline()
if not status:
raise OSError('Stream info not found')
if isinstance(status, bytes): # Note: Python 3 compatibility.
status = status.decode('utf8', 'ignore')
status = status.strip().lower()
if 'no such file' in status:
raise IOError(f'File {path} not found')
elif 'invalid data found' in status:
raise IOError(f'FFMPEG error : {status}')
elif 'audio:' in status:
n_channels, ffmpeg_sample_rate = _parse_ffmpg_results(status)
if sample_rate is None:
sample_rate = ffmpeg_sample_rate
break
# Load waveform and clean process.
waveform = np.frombuffer(buffer, dtype='<f4').reshape(-1, n_channels)
if not waveform.dtype == np.dtype(dtype):
waveform = waveform.astype(dtype)
process.stdout.close()
process.stderr.close()
del process
return (waveform, sample_rate)
def save(
self, path, data, sample_rate,
codec=None, bitrate=None):
""" Write waveform data to the file denoted by the given path
using FFMPEG process.
:param path: Path of the audio file to save data in.
:param data: Waveform data to write.
:param sample_rate: Sample rate to write file in.
:param codec: (Optional) Writing codec to use.
:param bitrate: (Optional) Bitrate of the written audio file.
:raise IOError: If any error occurs while using FFMPEG to write data.
"""
directory = os.path.split(path)[0]
if not os.path.exists(directory):
os.makedirs(directory)
get_logger().debug('Writing file %s', path)
# NOTE: Tweak.
if codec == 'wav':
codec = None
command = (
self._get_command_builder()
.flag('-y')
.opt('-loglevel', 'error')
.opt('-f', 'f32le')
.opt('-ar', sample_rate)
.opt('-ac', data.shape[1])
.opt('-i', '-')
.flag('-vn')
.opt('-acodec', codec)
.opt('-ar', sample_rate) # Note: why twice ?
.opt('-strict', '-2') # Note: For 'aac' codec support.
.opt('-ab', bitrate)
.flag(path)
.command())
process = subprocess.Popen(
command,
stdout=open(os.devnull, 'wb'),
stdin=subprocess.PIPE,
stderr=subprocess.PIPE)
# Write data to STDIN.
try:
process.stdin.write(
data.astype('<f4').tostring())
except IOError:
raise IOError(f'FFMPEG error: {process.stderr.read()}')
# Clean process.
process.stdin.close()
if process.stderr is not None:
process.stderr.close()
process.wait()
del process
get_logger().info('File %s written', path)

View File

@@ -0,0 +1,128 @@
#!/usr/bin/env python
# coding: utf8
""" Spectrogram specific data augmentation """
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
from tensorflow.contrib.signal import stft, hann_window
# pylint: enable=import-error
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def compute_spectrogram_tf(
waveform,
frame_length=2048, frame_step=512,
spec_exponent=1., window_exponent=1.):
""" Compute magnitude / power spectrogram from waveform as
a n_samples x n_channels tensor.
:param waveform: Input waveform as (times x number of channels)
tensor.
:param frame_length: Length of a STFT frame to use.
:param frame_step: HOP between successive frames.
:param spec_exponent: Exponent of the spectrogram (usually 1 for
magnitude spectrogram, or 2 for power spectrogram).
:param window_exponent: Exponent applied to the Hann windowing function
(may be useful for making perfect STFT/iSTFT
reconstruction).
:returns: Computed magnitude / power spectrogram as a
(T x F x n_channels) tensor.
"""
stft_tensor = tf.transpose(
stft(
tf.transpose(waveform),
frame_length,
frame_step,
window_fn=lambda f, dtype: hann_window(
f,
periodic=True,
dtype=waveform.dtype) ** window_exponent),
perm=[1, 2, 0])
return np.abs(stft_tensor) ** spec_exponent
def time_stretch(
spectrogram,
factor=1.0,
method=tf.image.ResizeMethod.BILINEAR):
""" Time stretch a spectrogram preserving shape in tensorflow. Note that
this is an approximation in the frequency domain.
:param spectrogram: Input spectrogram to be time stretched as tensor.
:param factor: (Optional) Time stretch factor, must be >0, default to 1.
:param mehtod: (Optional) Interpolation method, default to BILINEAR.
:returns: Time stretched spectrogram as tensor with same shape.
"""
T = tf.shape(spectrogram)[0]
T_ts = tf.cast(tf.cast(T, tf.float32) * factor, tf.int32)[0]
F = tf.shape(spectrogram)[1]
ts_spec = tf.image.resize_images(
spectrogram,
[T_ts, F],
method=method,
align_corners=True)
return tf.image.resize_image_with_crop_or_pad(ts_spec, T, F)
def random_time_stretch(spectrogram, factor_min=0.9, factor_max=1.1, **kwargs):
""" Time stretch a spectrogram preserving shape with random ratio in
tensorflow. Applies time_stretch to spectrogram with a random ratio drawn
uniformly in [factor_min, factor_max].
:param spectrogram: Input spectrogram to be time stretched as tensor.
:param factor_min: (Optional) Min time stretch factor, default to 0.9.
:param factor_max: (Optional) Max time stretch factor, default to 1.1.
:returns: Randomly time stretched spectrogram as tensor with same shape.
"""
factor = tf.random_uniform(
shape=(1,),
seed=0) * (factor_max - factor_min) + factor_min
return time_stretch(spectrogram, factor=factor, **kwargs)
def pitch_shift(
spectrogram,
semitone_shift=0.0,
method=tf.image.ResizeMethod.BILINEAR):
""" Pitch shift a spectrogram preserving shape in tensorflow. Note that
this is an approximation in the frequency domain.
:param spectrogram: Input spectrogram to be pitch shifted as tensor.
:param semitone_shift: (Optional) Pitch shift in semitone, default to 0.0.
:param mehtod: (Optional) Interpolation method, default to BILINEAR.
:returns: Pitch shifted spectrogram (same shape as spectrogram).
"""
factor = 2 ** (semitone_shift / 12.)
T = tf.shape(spectrogram)[0]
F = tf.shape(spectrogram)[1]
F_ps = tf.cast(tf.cast(F, tf.float32) * factor, tf.int32)[0]
ps_spec = tf.image.resize_images(
spectrogram,
[T, F_ps],
method=method,
align_corners=True)
paddings = [[0, 0], [0, tf.maximum(0, F - F_ps)], [0, 0]]
return tf.pad(ps_spec[:, :F, :], paddings, 'CONSTANT')
def random_pitch_shift(spectrogram, shift_min=-1., shift_max=1., **kwargs):
""" Pitch shift a spectrogram preserving shape with random ratio in
tensorflow. Applies pitch_shift to spectrogram with a random shift
amount (expressed in semitones) drawn uniformly in [shift_min, shift_max].
:param spectrogram: Input spectrogram to be pitch shifted as tensor.
:param shift_min: (Optional) Min pitch shift in semitone, default to -1.
:param shift_max: (Optional) Max pitch shift in semitone, default to 1.
:returns: Randomly pitch shifted spectrogram (same shape as spectrogram).
"""
semitone_shift = tf.random_uniform(
shape=(1,),
seed=0) * (shift_max - shift_min) + shift_min
return pitch_shift(spectrogram, semitone_shift=semitone_shift, **kwargs)

View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python
# coding: utf8
""" Module that provides configuration loading function. """
import json
try:
import importlib.resources as loader
except ImportError:
# Try backported to PY<37 `importlib_resources`.
import importlib_resources as loader
from os.path import exists
from .. import resources
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
_EMBEDDED_CONFIGURATION_PREFIX = 'spleeter:'
def load_configuration(descriptor):
""" Load configuration from the given descriptor. Could be
either a `spleeter:` prefixed embedded configuration name
or a file system path to read configuration from.
:param descriptor: Configuration descriptor to use for lookup.
:returns: Loaded description as dict.
:raise ValueError: If required embedded configuration does not exists.
:raise IOError: If required configuration file does not exists.
"""
# Embedded configuration reading.
if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX):
name = descriptor[len(_EMBEDDED_CONFIGURATION_PREFIX):]
if not loader.is_resource(resources, f'{name}.json'):
raise ValueError(f'No embedded configuration {name} found')
with loader.open_text(resources, f'{name}.json') as stream:
return json.load(stream)
# Standard file reading.
if not exists(descriptor):
raise IOError(f'Configuration file {descriptor} not found')
with open(descriptor, 'r') as stream:
return json.load(stream)

View File

@@ -0,0 +1,69 @@
#!/usr/bin/env python
# coding: utf8
""" Utility functions for creating estimator. """
from pathlib import Path
# pylint: disable=import-error
import tensorflow as tf
from tensorflow.contrib import predictor
# pylint: enable=import-error
from ..model import model_fn
from ..model.provider import get_default_model_provider
# Default exporting directory for predictor.
DEFAULT_EXPORT_DIRECTORY = '/tmp/serving'
def create_estimator(params, MWF):
"""
Initialize tensorflow estimator that will perform separation
Params:
- params: a dictionnary of parameters for building the model
Returns:
a tensorflow estimator
"""
# Load model.
model_directory = params['model_dir']
model_provider = get_default_model_provider()
params['model_dir'] = model_provider.get(model_directory)
params['MWF'] = MWF
# Setup config
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
config = tf.estimator.RunConfig(session_config=session_config)
# Setup estimator
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params['model_dir'],
params=params,
config=config
)
return estimator
def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
""" Exports given estimator as predictor into the given directory
and returns associated tf.predictor instance.
:param estimator: Estimator to export.
:param directory: (Optional) path to write exported model into.
"""
def receiver():
shape = (None, estimator.params['n_channels'])
features = {
'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape),
'audio_id': tf.compat.v1.placeholder(tf.string)}
return tf.estimator.export.ServingInputReceiver(features, features)
estimator.export_saved_model(directory, receiver)
versions = [
model for model in Path(directory).iterdir()
if model.is_dir() and 'temp' not in str(model)]
latest = str(sorted(versions)[-1])
return predictor.from_saved_model(latest)

45
spleeter/utils/logging.py Normal file
View File

@@ -0,0 +1,45 @@
#!/usr/bin/env python
# coding: utf8
""" Centralized logging facilities for Spleeter. """
from os import environ
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
class _LoggerHolder(object):
""" Logger singleton instance holder. """
INSTANCE = None
def get_logger():
""" Returns library scoped logger.
:returns: Library logger.
"""
if _LoggerHolder.INSTANCE is None:
# pylint: disable=import-error
from tensorflow.compat.v1 import logging
# pylint: enable=import-error
_LoggerHolder.INSTANCE = logging
_LoggerHolder.INSTANCE.set_verbosity(_LoggerHolder.INSTANCE.ERROR)
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
return _LoggerHolder.INSTANCE
def enable_logging():
""" Enable INFO level logging. """
environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
logger = get_logger()
logger.set_verbosity(logger.INFO)
def enable_verbose_logging():
""" Enable DEBUG level logging. """
environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
logger = get_logger()
logger.set_verbosity(logger.DEBUG)

191
spleeter/utils/tensor.py Normal file
View File

@@ -0,0 +1,191 @@
#!/usr/bin/env python
# coding: utf8
""" Utility function for tensorflow. """
# pylint: disable=import-error
import tensorflow as tf
import pandas as pd
# pylint: enable=import-error
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def sync_apply(tensor_dict, func, concat_axis=1):
""" Return a function that applies synchronously the provided func on the
provided dictionnary of tensor. This means that func is applied to the
concatenation of the tensors in tensor_dict. This is useful for performing
random operation that needs the same drawn value on multiple tensor, such
as a random time-crop on both input data and label (the same crop should be
applied to both input data and label, so random crop cannot be applied
separately on each of them).
IMPORTANT NOTE: all tensor are assumed to be the same shape.
Params:
- tensor_dict: dictionary (key: strings, values: tf.tensor)
a dictionary of tensor.
- func: function
function to be applied to the concatenation of the tensors in
tensor_dict
- concat_axis: int
The axis on which to perform the concatenation.
Returns:
processed tensors dictionary with the same name (keys) as input
tensor_dict.
"""
if concat_axis not in {0, 1}:
raise NotImplementedError(
'Function only implemented for concat_axis equal to 0 or 1')
tensor_list = list(tensor_dict.values())
concat_tensor = tf.concat(tensor_list, concat_axis)
processed_concat_tensor = func(concat_tensor)
tensor_shape = tf.shape(list(tensor_dict.values())[0])
D = tensor_shape[concat_axis]
if concat_axis == 0:
return {
name: processed_concat_tensor[index * D:(index + 1) * D, :, :]
for index, name in enumerate(tensor_dict)
}
return {
name: processed_concat_tensor[:, index * D:(index + 1) * D, :]
for index, name in enumerate(tensor_dict)
}
def from_float32_to_uint8(
tensor,
tensor_key='tensor',
min_key='min',
max_key='max'):
"""
:param tensor:
:param tensor_key:
:param min_key:
:param max_key:
:returns:
"""
tensor_min = tf.reduce_min(tensor)
tensor_max = tf.reduce_max(tensor)
return {
tensor_key: tf.cast(
(tensor - tensor_min) / (tensor_max - tensor_min + 1e-16)
* 255.9999, dtype=tf.uint8),
min_key: tensor_min,
max_key: tensor_max
}
def from_uint8_to_float32(tensor, tensor_min, tensor_max):
"""
:param tensor:
:param tensor_min:
:param tensor_max:
:returns:
"""
return (
tf.cast(tensor, tf.float32)
* (tensor_max - tensor_min)
/ 255.9999 + tensor_min)
def pad_and_partition(tensor, segment_len):
""" Pad and partition a tensor into segment of len segment_len
along the first dimension. The tensor is padded with 0 in order
to ensure that the first dimension is a multiple of segment_len.
Tensor must be of known fixed rank
:Example:
>>> tensor = [[1, 2, 3], [4, 5, 6]]
>>> segment_len = 2
>>> pad_and_partition(tensor, segment_len)
[[[1, 2], [4, 5]], [[3, 0], [6, 0]]]
:param tensor:
:param segment_len:
:returns:
"""
tensor_size = tf.math.floormod(tf.shape(tensor)[0], segment_len)
pad_size = tf.math.floormod(segment_len - tensor_size, segment_len)
padded = tf.pad(
tensor,
[[0, pad_size]] + [[0, 0]] * (len(tensor.shape)-1))
split = (tf.shape(padded)[0] + segment_len - 1) // segment_len
return tf.reshape(
padded,
tf.concat(
[[split, segment_len], tf.shape(padded)[1:]],
axis=0))
def pad_and_reshape(instr_spec, frame_length, F):
"""
:param instr_spec:
:param frame_length:
:param F:
:returns:
"""
spec_shape = tf.shape(instr_spec)
extension_row = tf.zeros((spec_shape[0], spec_shape[1], 1, spec_shape[-1]))
n_extra_row = (frame_length) // 2 + 1 - F
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
extended_spec = tf.concat([instr_spec, extension], axis=2)
old_shape = tf.shape(extended_spec)
new_shape = tf.concat([
[old_shape[0] * old_shape[1]],
old_shape[2:]],
axis=0)
processed_instr_spec = tf.reshape(extended_spec, new_shape)
return processed_instr_spec
def dataset_from_csv(csv_path, **kwargs):
""" Load dataset from a CSV file using Pandas. kwargs if any are
forwarded to the `pandas.read_csv` function.
:param csv_path: Path of the CSV file to load dataset from.
:returns: Loaded dataset.
"""
df = pd.read_csv(csv_path, **kwargs)
dataset = (
tf.data.Dataset.from_tensor_slices(
{key: df[key].values for key in df})
)
return dataset
def check_tensor_shape(tensor_tf, target_shape):
""" Return a Tensorflow boolean graph that indicates whether
sample[features_key] has the specified target shape. Only check
not None entries of target_shape.
:param tensor_tf: Tensor to check shape for.
:param target_shape: Target shape to compare tensor to.
:returns: True if shape is valid, False otherwise (as TF boolean).
"""
result = tf.constant(True)
for i, target_length in enumerate(target_shape):
if target_length:
result = tf.logical_and(
result,
tf.equal(tf.constant(target_length), tf.shape(tensor_tf)[i]))
return result
def set_tensor_shape(tensor, tensor_shape):
""" Set shape for a tensor (not in place, as opposed to tf.set_shape)
:param tensor: Tensor to reshape.
:param tensor_shape: Shape to apply to the tensor.
:returns: A reshaped tensor.
"""
# NOTE: That SOUND LIKE IN PLACE HERE ?
tensor.set_shape(tensor_shape)
return tensor