mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-06 04:32:43 +00:00
457 lines
17 KiB
Python
457 lines
17 KiB
Python
#!/usr/bin/env python
|
|
# coding: utf8
|
|
|
|
"""
|
|
Module for building data preprocessing pipeline using the tensorflow
|
|
data API. Data preprocessing such as audio loading, spectrogram
|
|
computation, cropping, feature caching or data augmentation is done
|
|
using a tensorflow dataset object that output a tuple (input_, output)
|
|
where:
|
|
|
|
- input is a dictionary with a single key that contains the (batched)
|
|
mix spectrogram of audio samples
|
|
- output is a dictionary of spectrogram of the isolated tracks
|
|
(ground truth)
|
|
"""
|
|
|
|
import time
|
|
import os
|
|
from os.path import exists, join, sep as SEPARATOR
|
|
|
|
# pylint: disable=import-error
|
|
import pandas as pd
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
# pylint: enable=import-error
|
|
|
|
from .audio.convertor import (
|
|
db_uint_spectrogram_to_gain,
|
|
spectrogram_to_db_uint)
|
|
from .audio.spectrogram import (
|
|
compute_spectrogram_tf,
|
|
random_pitch_shift,
|
|
random_time_stretch)
|
|
from .utils.logging import get_logger
|
|
from .utils.tensor import (
|
|
check_tensor_shape,
|
|
dataset_from_csv,
|
|
set_tensor_shape,
|
|
sync_apply)
|
|
|
|
__email__ = 'research@deezer.com'
|
|
__author__ = 'Deezer Research'
|
|
__license__ = 'MIT License'
|
|
|
|
# Default audio parameters to use.
|
|
DEFAULT_AUDIO_PARAMS = {
|
|
'instrument_list': ('vocals', 'accompaniment'),
|
|
'mix_name': 'mix',
|
|
'sample_rate': 44100,
|
|
'frame_length': 4096,
|
|
'frame_step': 1024,
|
|
'T': 512,
|
|
'F': 1024
|
|
}
|
|
|
|
|
|
def get_training_dataset(audio_params, audio_adapter, audio_path):
|
|
""" Builds training dataset.
|
|
|
|
:param audio_params: Audio parameters.
|
|
:param audio_adapter: Adapter to load audio from.
|
|
:param audio_path: Path of directory containing audio.
|
|
:returns: Built dataset.
|
|
"""
|
|
builder = DatasetBuilder(
|
|
audio_params,
|
|
audio_adapter,
|
|
audio_path,
|
|
chunk_duration=audio_params.get('chunk_duration', 20.0),
|
|
random_seed=audio_params.get('random_seed', 0))
|
|
return builder.build(
|
|
audio_params.get('train_csv'),
|
|
cache_directory=audio_params.get('training_cache'),
|
|
batch_size=audio_params.get('batch_size'),
|
|
n_chunks_per_song=audio_params.get('n_chunks_per_song', 2),
|
|
random_data_augmentation=False,
|
|
convert_to_uint=True,
|
|
wait_for_cache=False)
|
|
|
|
|
|
def get_validation_dataset(audio_params, audio_adapter, audio_path):
|
|
""" Builds validation dataset.
|
|
|
|
:param audio_params: Audio parameters.
|
|
:param audio_adapter: Adapter to load audio from.
|
|
:param audio_path: Path of directory containing audio.
|
|
:returns: Built dataset.
|
|
"""
|
|
builder = DatasetBuilder(
|
|
audio_params,
|
|
audio_adapter,
|
|
audio_path,
|
|
chunk_duration=12.0)
|
|
return builder.build(
|
|
audio_params.get('validation_csv'),
|
|
batch_size=audio_params.get('batch_size'),
|
|
cache_directory=audio_params.get('validation_cache'),
|
|
convert_to_uint=True,
|
|
infinite_generator=False,
|
|
n_chunks_per_song=1,
|
|
# should not perform data augmentation for eval:
|
|
random_data_augmentation=False,
|
|
random_time_crop=False,
|
|
shuffle=False,
|
|
)
|
|
|
|
|
|
class InstrumentDatasetBuilder(object):
|
|
""" Instrument based filter and mapper provider. """
|
|
|
|
def __init__(self, parent, instrument):
|
|
""" Default constructor.
|
|
|
|
:param parent: Parent dataset builder.
|
|
:param instrument: Target instrument.
|
|
"""
|
|
self._parent = parent
|
|
self._instrument = instrument
|
|
self._spectrogram_key = f'{instrument}_spectrogram'
|
|
self._min_spectrogram_key = f'min_{instrument}_spectrogram'
|
|
self._max_spectrogram_key = f'max_{instrument}_spectrogram'
|
|
|
|
def load_waveform(self, sample):
|
|
""" Load waveform for given sample. """
|
|
return dict(sample, **self._parent._audio_adapter.load_tf_waveform(
|
|
sample[f'{self._instrument}_path'],
|
|
offset=sample['start'],
|
|
duration=self._parent._chunk_duration,
|
|
sample_rate=self._parent._sample_rate,
|
|
waveform_name='waveform'))
|
|
|
|
def compute_spectrogram(self, sample):
|
|
""" Compute spectrogram of the given sample. """
|
|
return dict(sample, **{
|
|
self._spectrogram_key: compute_spectrogram_tf(
|
|
sample['waveform'],
|
|
frame_length=self._parent._frame_length,
|
|
frame_step=self._parent._frame_step,
|
|
spec_exponent=1.,
|
|
window_exponent=1.)})
|
|
|
|
def filter_frequencies(self, sample):
|
|
""" """
|
|
return dict(sample, **{
|
|
self._spectrogram_key:
|
|
sample[self._spectrogram_key][:, :self._parent._F, :]})
|
|
|
|
def convert_to_uint(self, sample):
|
|
""" Convert given sample from float to unit. """
|
|
return dict(sample, **spectrogram_to_db_uint(
|
|
sample[self._spectrogram_key],
|
|
tensor_key=self._spectrogram_key,
|
|
min_key=self._min_spectrogram_key,
|
|
max_key=self._max_spectrogram_key))
|
|
|
|
def filter_infinity(self, sample):
|
|
""" Filter infinity sample. """
|
|
return tf.logical_not(
|
|
tf.math.is_inf(
|
|
sample[self._min_spectrogram_key]))
|
|
|
|
def convert_to_float32(self, sample):
|
|
""" Convert given sample from unit to float. """
|
|
return dict(sample, **{
|
|
self._spectrogram_key: db_uint_spectrogram_to_gain(
|
|
sample[self._spectrogram_key],
|
|
sample[self._min_spectrogram_key],
|
|
sample[self._max_spectrogram_key])})
|
|
|
|
def time_crop(self, sample):
|
|
""" """
|
|
def start(sample):
|
|
""" mid_segment_start """
|
|
return tf.cast(
|
|
tf.maximum(
|
|
tf.shape(sample[self._spectrogram_key])[0]
|
|
/ 2 - self._parent._T / 2, 0),
|
|
tf.int32)
|
|
return dict(sample, **{
|
|
self._spectrogram_key: sample[self._spectrogram_key][
|
|
start(sample):start(sample) + self._parent._T, :, :]})
|
|
|
|
def filter_shape(self, sample):
|
|
""" Filter badly shaped sample. """
|
|
return check_tensor_shape(
|
|
sample[self._spectrogram_key], (
|
|
self._parent._T, self._parent._F, 2))
|
|
|
|
def reshape_spectrogram(self, sample):
|
|
""" """
|
|
return dict(sample, **{
|
|
self._spectrogram_key: set_tensor_shape(
|
|
sample[self._spectrogram_key],
|
|
(self._parent._T, self._parent._F, 2))})
|
|
|
|
|
|
class DatasetBuilder(object):
|
|
"""
|
|
"""
|
|
|
|
# Margin at beginning and end of songs in seconds.
|
|
MARGIN = 0.5
|
|
|
|
# Wait period for cache (in seconds).
|
|
WAIT_PERIOD = 60
|
|
|
|
def __init__(
|
|
self,
|
|
audio_params, audio_adapter, audio_path,
|
|
random_seed=0, chunk_duration=20.0):
|
|
""" Default constructor.
|
|
|
|
NOTE: Probably need for AudioAdapter.
|
|
|
|
:param audio_params: Audio parameters to use.
|
|
:param audio_adapter: Audio adapter to use.
|
|
:param audio_path:
|
|
:param random_seed:
|
|
:param chunk_duration:
|
|
"""
|
|
# Length of segment in frames (if fs=22050 and
|
|
# frame_step=512, then T=512 corresponds to 11.89s)
|
|
self._T = audio_params['T']
|
|
# Number of frequency bins to be used (should
|
|
# be less than frame_length/2 + 1)
|
|
self._F = audio_params['F']
|
|
self._sample_rate = audio_params['sample_rate']
|
|
self._frame_length = audio_params['frame_length']
|
|
self._frame_step = audio_params['frame_step']
|
|
self._mix_name = audio_params['mix_name']
|
|
self._instruments = [self._mix_name] + audio_params['instrument_list']
|
|
self._instrument_builders = None
|
|
self._chunk_duration = chunk_duration
|
|
self._audio_adapter = audio_adapter
|
|
self._audio_params = audio_params
|
|
self._audio_path = audio_path
|
|
self._random_seed = random_seed
|
|
|
|
def expand_path(self, sample):
|
|
""" Expands audio paths for the given sample. """
|
|
return dict(sample, **{f'{instrument}_path': tf.string_join(
|
|
(self._audio_path, sample[f'{instrument}_path']), SEPARATOR)
|
|
for instrument in self._instruments})
|
|
|
|
def filter_error(self, sample):
|
|
""" Filter errored sample. """
|
|
return tf.logical_not(sample['waveform_error'])
|
|
|
|
def filter_waveform(self, sample):
|
|
""" Filter waveform from sample. """
|
|
return {k: v for k, v in sample.items() if not k == 'waveform'}
|
|
|
|
def harmonize_spectrogram(self, sample):
|
|
""" Ensure same size for vocals and mix spectrograms. """
|
|
def _reduce(sample):
|
|
return tf.reduce_min([
|
|
tf.shape(sample[f'{instrument}_spectrogram'])[0]
|
|
for instrument in self._instruments])
|
|
return dict(sample, **{
|
|
f'{instrument}_spectrogram':
|
|
sample[f'{instrument}_spectrogram'][:_reduce(sample), :, :]
|
|
for instrument in self._instruments})
|
|
|
|
def filter_short_segments(self, sample):
|
|
""" Filter out too short segment. """
|
|
return tf.reduce_any([
|
|
tf.shape(sample[f'{instrument}_spectrogram'])[0] >= self._T
|
|
for instrument in self._instruments])
|
|
|
|
def random_time_crop(self, sample):
|
|
""" Random time crop of 11.88s. """
|
|
return dict(sample, **sync_apply({
|
|
f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram']
|
|
for instrument in self._instruments},
|
|
lambda x: tf.image.random_crop(
|
|
x, (self._T, len(self._instruments) * self._F, 2),
|
|
seed=self._random_seed)))
|
|
|
|
def random_time_stretch(self, sample):
|
|
""" Randomly time stretch the given sample. """
|
|
return dict(sample, **sync_apply({
|
|
f'{instrument}_spectrogram':
|
|
sample[f'{instrument}_spectrogram']
|
|
for instrument in self._instruments},
|
|
lambda x: random_time_stretch(
|
|
x, factor_min=0.9, factor_max=1.1)))
|
|
|
|
def random_pitch_shift(self, sample):
|
|
""" Randomly pitch shift the given sample. """
|
|
return dict(sample, **sync_apply({
|
|
f'{instrument}_spectrogram':
|
|
sample[f'{instrument}_spectrogram']
|
|
for instrument in self._instruments},
|
|
lambda x: random_pitch_shift(
|
|
x, shift_min=-1.0, shift_max=1.0), concat_axis=0))
|
|
|
|
def map_features(self, sample):
|
|
""" Select features and annotation of the given sample. """
|
|
input_ = {
|
|
f'{self._mix_name}_spectrogram':
|
|
sample[f'{self._mix_name}_spectrogram']}
|
|
output = {
|
|
f'{instrument}_spectrogram': sample[f'{instrument}_spectrogram']
|
|
for instrument in self._audio_params['instrument_list']}
|
|
return (input_, output)
|
|
|
|
def compute_segments(self, dataset, n_chunks_per_song):
|
|
""" Computes segments for each song of the dataset.
|
|
|
|
:param dataset: Dataset to compute segments for.
|
|
:param n_chunks_per_song: Number of segment per song to compute.
|
|
:returns: Segmented dataset.
|
|
"""
|
|
if n_chunks_per_song <= 0:
|
|
raise ValueError('n_chunks_per_song must be positif')
|
|
datasets = []
|
|
for k in range(n_chunks_per_song):
|
|
if n_chunks_per_song > 1:
|
|
datasets.append(
|
|
dataset.map(lambda sample: dict(sample, start=tf.maximum(
|
|
k * (
|
|
sample['duration'] - self._chunk_duration - 2
|
|
* self.MARGIN) / (n_chunks_per_song - 1)
|
|
+ self.MARGIN, 0))))
|
|
elif n_chunks_per_song == 1: # Take central segment.
|
|
datasets.append(
|
|
dataset.map(lambda sample: dict(sample, start=tf.maximum(
|
|
sample['duration'] / 2 - self._chunk_duration / 2,
|
|
0))))
|
|
dataset = datasets[-1]
|
|
for d in datasets[:-1]:
|
|
dataset = dataset.concatenate(d)
|
|
return dataset
|
|
|
|
@property
|
|
def instruments(self):
|
|
""" Instrument dataset builder generator.
|
|
|
|
:yield InstrumentBuilder instance.
|
|
"""
|
|
if self._instrument_builders is None:
|
|
self._instrument_builders = []
|
|
for instrument in self._instruments:
|
|
self._instrument_builders.append(
|
|
InstrumentDatasetBuilder(self, instrument))
|
|
for builder in self._instrument_builders:
|
|
yield builder
|
|
|
|
def cache(self, dataset, cache, wait):
|
|
""" Cache the given dataset if cache is enabled. Eventually waits for
|
|
cache to be available (useful if another process is already computing
|
|
cache) if provided wait flag is True.
|
|
|
|
:param dataset: Dataset to be cached if cache is required.
|
|
:param cache: Path of cache directory to be used, None if no cache.
|
|
:param wait: If caching is enabled, True is cache should be waited.
|
|
:returns: Cached dataset if needed, original dataset otherwise.
|
|
"""
|
|
if cache is not None:
|
|
if wait:
|
|
while not exists(f'{cache}.index'):
|
|
get_logger().info(
|
|
'Cache not available, wait %s',
|
|
self.WAIT_PERIOD)
|
|
time.sleep(self.WAIT_PERIOD)
|
|
cache_path = os.path.split(cache)[0]
|
|
os.makedirs(cache_path, exist_ok=True)
|
|
return dataset.cache(cache)
|
|
return dataset
|
|
|
|
def build(
|
|
self, csv_path,
|
|
batch_size=8, shuffle=True, convert_to_uint=True,
|
|
random_data_augmentation=False, random_time_crop=True,
|
|
infinite_generator=True, cache_directory=None,
|
|
wait_for_cache=False, num_parallel_calls=4, n_chunks_per_song=2,):
|
|
"""
|
|
TO BE DOCUMENTED.
|
|
"""
|
|
dataset = dataset_from_csv(csv_path)
|
|
dataset = self.compute_segments(dataset, n_chunks_per_song)
|
|
# Shuffle data
|
|
if shuffle:
|
|
dataset = dataset.shuffle(
|
|
buffer_size=200000,
|
|
seed=self._random_seed,
|
|
# useless since it is cached :
|
|
reshuffle_each_iteration=True)
|
|
# Expand audio path.
|
|
dataset = dataset.map(self.expand_path)
|
|
# Load waveform, compute spectrogram, and filtering error,
|
|
# K bins frequencies, and waveform.
|
|
N = num_parallel_calls
|
|
for instrument in self.instruments:
|
|
dataset = (
|
|
dataset
|
|
.map(instrument.load_waveform, num_parallel_calls=N)
|
|
.filter(self.filter_error)
|
|
.map(instrument.compute_spectrogram, num_parallel_calls=N)
|
|
.map(instrument.filter_frequencies))
|
|
dataset = dataset.map(self.filter_waveform)
|
|
# Convert to uint before caching in order to save space.
|
|
if convert_to_uint:
|
|
for instrument in self.instruments:
|
|
dataset = dataset.map(instrument.convert_to_uint)
|
|
dataset = self.cache(dataset, cache_directory, wait_for_cache)
|
|
# Check for INFINITY (should not happen)
|
|
for instrument in self.instruments:
|
|
dataset = dataset.filter(instrument.filter_infinity)
|
|
# Repeat indefinitly
|
|
if infinite_generator:
|
|
dataset = dataset.repeat(count=-1)
|
|
# Ensure same size for vocals and mix spectrograms.
|
|
# NOTE: could be done before caching ?
|
|
dataset = dataset.map(self.harmonize_spectrogram)
|
|
# Filter out too short segment.
|
|
# NOTE: could be done before caching ?
|
|
dataset = dataset.filter(self.filter_short_segments)
|
|
# Random time crop of 11.88s
|
|
if random_time_crop:
|
|
dataset = dataset.map(self.random_time_crop, num_parallel_calls=N)
|
|
else:
|
|
# frame_duration = 11.88/T
|
|
# take central segment (for validation)
|
|
for instrument in self.instruments:
|
|
dataset = dataset.map(instrument.time_crop)
|
|
# Post cache shuffling. Done where the data are the lightest:
|
|
# after croping but before converting back to float.
|
|
if shuffle:
|
|
dataset = dataset.shuffle(
|
|
buffer_size=256, seed=self._random_seed,
|
|
reshuffle_each_iteration=True)
|
|
# Convert back to float32
|
|
if convert_to_uint:
|
|
for instrument in self.instruments:
|
|
dataset = dataset.map(
|
|
instrument.convert_to_float32, num_parallel_calls=N)
|
|
M = 8 # Parallel call post caching.
|
|
# Must be applied with the same factor on mix and vocals.
|
|
if random_data_augmentation:
|
|
dataset = (
|
|
dataset
|
|
.map(self.random_time_stretch, num_parallel_calls=M)
|
|
.map(self.random_pitch_shift, num_parallel_calls=M))
|
|
# Filter by shape (remove badly shaped tensors).
|
|
for instrument in self.instruments:
|
|
dataset = (
|
|
dataset
|
|
.filter(instrument.filter_shape)
|
|
.map(instrument.reshape_spectrogram))
|
|
# Select features and annotation.
|
|
dataset = dataset.map(self.map_features)
|
|
# Make batch (done after selection to avoid
|
|
# error due to unprocessed instrument spectrogram batching).
|
|
dataset = dataset.batch(batch_size)
|
|
return dataset
|