Files
spleeter/tests/test_train.py

125 lines
3.9 KiB
Python
Raw Permalink Normal View History

2020-09-25 13:19:07 +02:00
#!/usr/bin/env python
# coding: utf8
""" Unit testing for Separator class. """
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
2020-12-08 12:31:08 +01:00
import json
2020-09-25 13:19:07 +02:00
import os
2020-12-08 12:31:08 +01:00
2020-09-25 13:19:07 +02:00
from os import makedirs
2020-12-08 12:31:08 +01:00
from os.path import join
2020-09-25 13:19:07 +02:00
from tempfile import TemporaryDirectory
import numpy as np
import pandas as pd
2020-12-08 12:31:08 +01:00
from spleeter.audio.adapter import AudioAdapter
from spleeter.__main__ import spleeter
from typer.testing import CliRunner
2020-09-25 13:19:07 +02:00
TRAIN_CONFIG = {
2020-12-08 12:31:08 +01:00
'mix_name': 'mix',
'instrument_list': ['vocals', 'other'],
'sample_rate': 44100,
'frame_length': 4096,
'frame_step': 1024,
'T': 128,
'F': 128,
'n_channels': 2,
'chunk_duration': 4,
'n_chunks_per_song': 1,
'separation_exponent': 2,
'mask_extension': 'zeros',
'learning_rate': 1e-4,
'batch_size': 2,
'train_max_steps': 10,
'throttle_secs': 20,
'save_checkpoints_steps': 100,
'save_summary_steps': 5,
'random_seed': 0,
'model': {
'type': 'unet.unet',
'params': {
'conv_activation': 'ELU',
'deconv_activation': 'ELU'
2020-09-25 13:19:07 +02:00
}
}
}
def generate_fake_training_dataset(path,
instrument_list=['vocals', 'other'],
n_channels=2,
n_songs = 2,
fs = 44100,
duration = 6,
):
2020-09-25 13:19:07 +02:00
"""
generates a fake training dataset in path:
- generates audio files
- generates a csv file describing the dataset
"""
2020-12-08 12:31:08 +01:00
aa = AudioAdapter.default()
2020-09-25 13:19:07 +02:00
rng = np.random.RandomState(seed=0)
2020-12-08 12:31:08 +01:00
dataset_df = pd.DataFrame(
columns=['mix_path'] + [
f'{instr}_path' for instr in instrument_list] + ['duration'])
2020-09-25 13:19:07 +02:00
for song in range(n_songs):
2020-12-08 12:31:08 +01:00
song_path = join(path, 'train', f'song{song}')
2020-09-25 13:19:07 +02:00
makedirs(song_path, exist_ok=True)
2020-12-08 12:31:08 +01:00
dataset_df.loc[song, f'duration'] = duration
for instr in instrument_list+['mix']:
filename = join(song_path, f'{instr}.wav')
2020-09-25 13:19:07 +02:00
data = rng.rand(duration*fs, n_channels)-0.5
aa.save(filename, data, fs)
2020-12-08 12:31:08 +01:00
dataset_df.loc[song, f'{instr}_path'] = join(
'train',
f'song{song}',
f'{instr}.wav')
dataset_df.to_csv(join(path, 'train', 'train.csv'), index=False)
2020-09-25 13:19:07 +02:00
def test_train():
2020-09-25 13:19:07 +02:00
with TemporaryDirectory() as path:
# generate training dataset
for n_channels in [1,2]:
TRAIN_CONFIG["n_channels"] = n_channels
generate_fake_training_dataset(path,
n_channels=n_channels,
fs=TRAIN_CONFIG["sample_rate"]
)
# set training command arguments
runner = CliRunner()
model_dir = join(path, f'model_{n_channels}')
train_dir = join(path, f'train')
cache_dir = join(path, f'cache_{n_channels}')
TRAIN_CONFIG['train_csv'] = join(train_dir, 'train.csv')
TRAIN_CONFIG['validation_csv'] = join(train_dir, 'train.csv')
TRAIN_CONFIG['model_dir'] = model_dir
TRAIN_CONFIG['training_cache'] = join(cache_dir, 'training')
TRAIN_CONFIG['validation_cache'] = join(cache_dir, 'validation')
with open('useless_config.json', 'w') as stream:
json.dump(TRAIN_CONFIG, stream)
# execute training
result = runner.invoke(spleeter, [
'train',
'-p', 'useless_config.json',
'-d', path,
"--verbose"
])
# assert that model checkpoint was created.
assert os.path.exists(join(model_dir, 'model.ckpt-10.index'))
assert os.path.exists(join(model_dir, 'checkpoint'))
assert os.path.exists(join(model_dir, 'model.ckpt-0.meta'))
assert result.exit_code == 0