Files
spleeter/spleeter/dataset.py
2019-11-19 14:07:27 +01:00

457 lines
17 KiB
Python

#!/usr/bin/env python
# coding: utf8
"""
Module for building data preprocessing pipeline using the tensorflow
data API. Data preprocessing such as audio loading, spectrogram
computation, cropping, feature caching or data augmentation is done
using a tensorflow dataset object that output a tuple (input_, output)
where:
- input is a dictionary with a single key that contains the (batched)
mix spectrogram of audio samples
- output is a dictionary of spectrogram of the isolated tracks
(ground truth)
"""
import time
import os
from os.path import exists, join, sep as SEPARATOR
# pylint: disable=import-error
import pandas as pd
import numpy as np
import tensorflow as tf
# pylint: enable=import-error
from .audio.convertor import (
db_uint_spectrogram_to_gain,
spectrogram_to_db_uint)
from .audio.spectrogram import (
compute_spectrogram_tf,
random_pitch_shift,
random_time_stretch)
from .utils.logging import get_logger
from .utils.tensor import (
check_tensor_shape,
dataset_from_csv,
set_tensor_shape,
sync_apply)
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# Default audio parameters to use.
DEFAULT_AUDIO_PARAMS = {
'instrument_list': ('vocals', 'accompaniment'),
'mix_name': 'mix',
'sample_rate': 44100,
'frame_length': 4096,
'frame_step': 1024,
'T': 512,
'F': 1024
}
def get_training_dataset(audio_params, audio_adapter, audio_path):
""" Builds training dataset.
:param audio_params: Audio parameters.
:param audio_adapter: Adapter to load audio from.
:param audio_path: Path of directory containing audio.
:returns: Built dataset.
"""
builder = DatasetBuilder(
audio_params,
audio_adapter,
audio_path,
chunk_duration=audio_params.get('chunk_duration', 20.0),
random_seed=audio_params.get('random_seed', 0))
return builder.build(
audio_params.get('train_csv'),
cache_directory=audio_params.get('training_cache'),
batch_size=audio_params.get('batch_size'),
n_chunks_per_song=audio_params.get('n_chunks_per_song', 2),
random_data_augmentation=False,
convert_to_uint=True,
wait_for_cache=False)
def get_validation_dataset(audio_params, audio_adapter, audio_path):
""" Builds validation dataset.
:param audio_params: Audio parameters.
:param audio_adapter: Adapter to load audio from.
:param audio_path: Path of directory containing audio.
:returns: Built dataset.
"""
builder = DatasetBuilder(
audio_params,
audio_adapter,
audio_path,
chunk_duration=12.0)
return builder.build(
audio_params.get('validation_csv'),
batch_size=audio_params.get('batch_size'),
cache_directory=audio_params.get('validation_cache'),
convert_to_uint=True,
infinite_generator=False,
n_chunks_per_song=1,
# should not perform data augmentation for eval:
random_data_augmentation=False,
random_time_crop=False,
shuffle=False,
)
class InstrumentDatasetBuilder(object):
""" Instrument based filter and mapper provider. """
def __init__(self, parent, instrument):
""" Default constructor.
:param parent: Parent dataset builder.
:param instrument: Target instrument.
"""
self._parent = parent
self._instrument = instrument
self._spectrogram_key = f'{instrument}_spectrogram'
self._min_spectrogram_key = f'min_{instrument}_spectrogram'
self._max_spectrogram_key = f'max_{instrument}_spectrogram'
def load_waveform(self, sample):
""" Load waveform for given sample. """
return dict(sample, **self._parent._audio_adapter.load_tf_waveform(
sample[f'{self._instrument}_path'],
offset=sample['start'],
duration=self._parent._chunk_duration,
sample_rate=self._parent._sample_rate,
waveform_name='waveform'))
def compute_spectrogram(self, sample):
""" Compute spectrogram of the given sample. """
return dict(sample, **{
self._spectrogram_key: compute_spectrogram_tf(
sample['waveform'],
frame_length=self._parent._frame_length,
frame_step=self._parent._frame_step,
spec_exponent=1.,
window_exponent=1.)})
def filter_frequencies(self, sample):
""" """
return dict(sample, **{
self._spectrogram_key:
sample[self._spectrogram_key][:, :self._parent._F, :]})
def convert_to_uint(self, sample):
""" Convert given sample from float to unit. """
return dict(sample, **spectrogram_to_db_uint(
sample[self._spectrogram_key],
tensor_key=self._spectrogram_key,
min_key=self._min_spectrogram_key,
max_key=self._max_spectrogram_key))
def filter_infinity(self, sample):
""" Filter infinity sample. """
return tf.logical_not(
tf.math.is_inf(
sample[self._min_spectrogram_key]))
def convert_to_float32(self, sample):
""" Convert given sample from unit to float. """
return dict(sample, **{
self._spectrogram_key: db_uint_spectrogram_to_gain(
sample[self._spectrogram_key],
sample[self._min_spectrogram_key],
sample[self._max_spectrogram_key])})
def time_crop(self, sample):
""" """
def start(sample):
""" mid_segment_start """
return tf.cast(
tf.maximum(
tf.shape(sample[self._spectrogram_key])[0]
/ 2 - self._parent._T / 2, 0),
tf.int32)
return dict(sample, **{
self._spectrogram_key: sample[self._spectrogram_key][
start(sample):start(sample) + self._parent._T, :, :]})
def filter_shape(self, sample):
""" Filter badly shaped sample. """
return check_tensor_shape(
sample[self._spectrogram_key], (
self._parent._T, self._parent._F, 2))
def reshape_spectrogram(self, sample):
""" """
return dict(sample, **{
self._spectrogram_key: set_tensor_shape(
sample[self._spectrogram_key],
(self._parent._T, self._parent._F, 2))})
class DatasetBuilder(object):
"""
"""
# Margin at beginning and end of songs in seconds.
MARGIN = 0.5
# Wait period for cache (in seconds).
WAIT_PERIOD = 60
def __init__(
self,
audio_params, audio_adapter, audio_path,
random_seed=0, chunk_duration=20.0):
""" Default constructor.
NOTE: Probably need for AudioAdapter.
:param audio_params: Audio parameters to use.
:param audio_adapter: Audio adapter to use.
:param audio_path:
:param random_seed:
:param chunk_duration:
"""
# Length of segment in frames (if fs=22050 and
# frame_step=512, then T=512 corresponds to 11.89s)
self._T = audio_params['T']
# Number of frequency bins to be used (should
# be less than frame_length/2 + 1)
self._F = audio_params['F']
self._sample_rate = audio_params['sample_rate']
self._frame_length = audio_params['frame_length']
self._frame_step = audio_params['frame_step']
self._mix_name = audio_params['mix_name']
self._instruments = [self._mix_name] + audio_params['instrument_list']
self._instrument_builders = None
self._chunk_duration = chunk_duration
self._audio_adapter = audio_adapter
self._audio_params = audio_params
self._audio_path = audio_path
self._random_seed = random_seed
def expand_path(self, sample):
""" Expands audio paths for the given sample. """
return dict(sample, **{f'{instrument}_path': tf.string_join(
(self._audio_path, sample[f'{instrument}_path']), SEPARATOR)
for instrument in self._instruments})
def filter_error(self, sample):
""" Filter errored sample. """
return tf.logical_not(sample['waveform_error'])
def filter_waveform(self, sample):
""" Filter waveform from sample. """
return {k: v for k, v in sample.items() if not k == 'waveform'}
def harmonize_spectrogram(self, sample):
""" Ensure same size for vocals and mix spectrograms. """
def _reduce(sample):
return tf.reduce_min([
tf.shape(sample[f'{instrument}_spectrogram'])[0]
for instrument in self._instruments])
return dict(sample, **{
f'{instrument}_spectrogram':
sample[f'{instrument}_spectrogram'][:_reduce(sample), :, :]
for instrument in self._instruments})
def filter_short_segments(self, sample):
""" Filter out too short segment. """
return tf.reduce_any([
tf.shape(sample[f'{instrument}_spectrogram'])[0] >= self._T
for instrument in self._instruments])
def random_time_crop(self, sample):
""" Random time crop of 11.88s. """
return dict(sample, **sync_apply({
f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram']
for instrument in self._instruments},
lambda x: tf.image.random_crop(
x, (self._T, len(self._instruments) * self._F, 2),
seed=self._random_seed)))
def random_time_stretch(self, sample):
""" Randomly time stretch the given sample. """
return dict(sample, **sync_apply({
f'{instrument}_spectrogram':
sample[f'{instrument}_spectrogram']
for instrument in self._instruments},
lambda x: random_time_stretch(
x, factor_min=0.9, factor_max=1.1)))
def random_pitch_shift(self, sample):
""" Randomly pitch shift the given sample. """
return dict(sample, **sync_apply({
f'{instrument}_spectrogram':
sample[f'{instrument}_spectrogram']
for instrument in self._instruments},
lambda x: random_pitch_shift(
x, shift_min=-1.0, shift_max=1.0), concat_axis=0))
def map_features(self, sample):
""" Select features and annotation of the given sample. """
input_ = {
f'{self._mix_name}_spectrogram':
sample[f'{self._mix_name}_spectrogram']}
output = {
f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram']
for instrument in self._audio_params['instrument_list']}
return (input_, output)
def compute_segments(self, dataset, n_chunks_per_song):
""" Computes segments for each song of the dataset.
:param dataset: Dataset to compute segments for.
:param n_chunks_per_song: Number of segment per song to compute.
:returns: Segmented dataset.
"""
if n_chunks_per_song <= 0:
raise ValueError('n_chunks_per_song must be positif')
datasets = []
for k in range(n_chunks_per_song):
if n_chunks_per_song > 1:
datasets.append(
dataset.map(lambda sample: dict(sample, start=tf.maximum(
k * (
sample['duration'] - self._chunk_duration - 2
* self.MARGIN) / (n_chunks_per_song - 1)
+ self.MARGIN, 0))))
elif n_chunks_per_song == 1: # Take central segment.
datasets.append(
dataset.map(lambda sample: dict(sample, start=tf.maximum(
sample['duration'] / 2 - self._chunk_duration / 2,
0))))
dataset = datasets[-1]
for d in datasets[:-1]:
dataset = dataset.concatenate(d)
return dataset
@property
def instruments(self):
""" Instrument dataset builder generator.
:yield InstrumentBuilder instance.
"""
if self._instrument_builders is None:
self._instrument_builders = []
for instrument in self._instruments:
self._instrument_builders.append(
InstrumentDatasetBuilder(self, instrument))
for builder in self._instrument_builders:
yield builder
def cache(self, dataset, cache, wait):
""" Cache the given dataset if cache is enabled. Eventually waits for
cache to be available (useful if another process is already computing
cache) if provided wait flag is True.
:param dataset: Dataset to be cached if cache is required.
:param cache: Path of cache directory to be used, None if no cache.
:param wait: If caching is enabled, True is cache should be waited.
:returns: Cached dataset if needed, original dataset otherwise.
"""
if cache is not None:
if wait:
while not exists(f'{cache}.index'):
get_logger().info(
'Cache not available, wait %s',
self.WAIT_PERIOD)
time.sleep(self.WAIT_PERIOD)
cache_path = os.path.split(cache)[0]
os.makedirs(cache_path, exist_ok=True)
return dataset.cache(cache)
return dataset
def build(
self, csv_path,
batch_size=8, shuffle=True, convert_to_uint=True,
random_data_augmentation=False, random_time_crop=True,
infinite_generator=True, cache_directory=None,
wait_for_cache=False, num_parallel_calls=4, n_chunks_per_song=2,):
"""
TO BE DOCUMENTED.
"""
dataset = dataset_from_csv(csv_path)
dataset = self.compute_segments(dataset, n_chunks_per_song)
# Shuffle data
if shuffle:
dataset = dataset.shuffle(
buffer_size=200000,
seed=self._random_seed,
# useless since it is cached :
reshuffle_each_iteration=True)
# Expand audio path.
dataset = dataset.map(self.expand_path)
# Load waveform, compute spectrogram, and filtering error,
# K bins frequencies, and waveform.
N = num_parallel_calls
for instrument in self.instruments:
dataset = (
dataset
.map(instrument.load_waveform, num_parallel_calls=N)
.filter(self.filter_error)
.map(instrument.compute_spectrogram, num_parallel_calls=N)
.map(instrument.filter_frequencies))
dataset = dataset.map(self.filter_waveform)
# Convert to uint before caching in order to save space.
if convert_to_uint:
for instrument in self.instruments:
dataset = dataset.map(instrument.convert_to_uint)
dataset = self.cache(dataset, cache_directory, wait_for_cache)
# Check for INFINITY (should not happen)
for instrument in self.instruments:
dataset = dataset.filter(instrument.filter_infinity)
# Repeat indefinitly
if infinite_generator:
dataset = dataset.repeat(count=-1)
# Ensure same size for vocals and mix spectrograms.
# NOTE: could be done before caching ?
dataset = dataset.map(self.harmonize_spectrogram)
# Filter out too short segment.
# NOTE: could be done before caching ?
dataset = dataset.filter(self.filter_short_segments)
# Random time crop of 11.88s
if random_time_crop:
dataset = dataset.map(self.random_time_crop, num_parallel_calls=N)
else:
# frame_duration = 11.88/T
# take central segment (for validation)
for instrument in self.instruments:
dataset = dataset.map(instrument.time_crop)
# Post cache shuffling. Done where the data are the lightest:
# after croping but before converting back to float.
if shuffle:
dataset = dataset.shuffle(
buffer_size=256, seed=self._random_seed,
reshuffle_each_iteration=True)
# Convert back to float32
if convert_to_uint:
for instrument in self.instruments:
dataset = dataset.map(
instrument.convert_to_float32, num_parallel_calls=N)
M = 8 # Parallel call post caching.
# Must be applied with the same factor on mix and vocals.
if random_data_augmentation:
dataset = (
dataset
.map(self.random_time_stretch, num_parallel_calls=M)
.map(self.random_pitch_shift, num_parallel_calls=M))
# Filter by shape (remove badly shaped tensors).
for instrument in self.instruments:
dataset = (
dataset
.filter(instrument.filter_shape)
.map(instrument.reshape_spectrogram))
# Select features and annotation.
dataset = dataset.map(self.map_features)
# Make batch (done after selection to avoid
# error due to unprocessed instrument spectrogram batching).
dataset = dataset.batch(batch_size)
return dataset