diff --git a/tests/test_train.py b/tests/test_train.py new file mode 100644 index 0000000..8d9533a --- /dev/null +++ b/tests/test_train.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python +# coding: utf8 + +""" Unit testing for Separator class. """ + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + +import filecmp +import itertools +import os +from os import makedirs +from os.path import splitext, basename, exists, join +from tempfile import TemporaryDirectory + +import numpy as np +import pandas as pd +import json + +import tensorflow as tf + +from spleeter.audio.adapter import get_default_audio_adapter +from spleeter.commands import create_argument_parser + +from spleeter.commands import train + +from spleeter.utils.configuration import load_configuration + +TRAIN_CONFIG = { + "mix_name": "mix", + "instrument_list": ["vocals", "other"], + "sample_rate":44100, + "frame_length":4096, + "frame_step":1024, + "T":128, + "F":128, + "n_channels":2, + "chunk_duration":4, + "n_chunks_per_song":1, + "separation_exponent":2, + "mask_extension":"zeros", + "learning_rate": 1e-4, + "batch_size":2, + "train_max_steps": 10, + "throttle_secs":20, + "save_checkpoints_steps":100, + "save_summary_steps":5, + "random_seed":0, + "model":{ + "type":"unet.unet", + "params":{ + "conv_activation":"ELU", + "deconv_activation":"ELU" + } + } +} + + +def generate_fake_training_dataset(path, instrument_list=["vocals", "other"]): + """ + generates a fake training dataset in path: + - generates audio files + - generates a csv file describing the dataset + """ + aa = get_default_audio_adapter() + n_songs = 2 + fs = 44100 + duration = 6 + n_channels = 2 + rng = np.random.RandomState(seed=0) + dataset_df = pd.DataFrame(columns=["mix_path"]+[f"{instr}_path" for instr in instrument_list]+["duration"]) + for song in range(n_songs): + song_path = join(path, "train", f"song{song}") + makedirs(song_path, exist_ok=True) + dataset_df.loc[song, f"duration"] = duration + for instr in instrument_list+["mix"]: + filename = join(song_path, f"{instr}.wav") + data = rng.rand(duration*fs, n_channels)-0.5 + aa.save(filename, data, fs) + dataset_df.loc[song, f"{instr}_path"] = join("train", f"song{song}", f"{instr}.wav") + + dataset_df.to_csv(join(path, "train", "train.csv"), index=False) + + + +def test_train(): + + + with TemporaryDirectory() as path: + + # generate training dataset + generate_fake_training_dataset(path) + + # set training command aruments + p = create_argument_parser() + arguments = p.parse_args(["train", "-p", "useless_config.json", "-d", path]) + TRAIN_CONFIG["train_csv"] = join(path, "train", "train.csv") + TRAIN_CONFIG["validation_csv"] = join(path, "train", "train.csv") + TRAIN_CONFIG["model_dir"] = join(path, "model") + TRAIN_CONFIG["training_cache"] = join(path, "cache", "training") + TRAIN_CONFIG["validation_cache"] = join(path, "cache", "validation") + + # execute training + res = train.entrypoint(arguments, TRAIN_CONFIG) + + # assert that model checkpoint was created. + assert os.path.exists(join(path,'model','model.ckpt-10.index')) + assert os.path.exists(join(path,'model','checkpoint')) + assert os.path.exists(join(path,'model','model.ckpt-0.meta')) + +if __name__=="__main__": + test_train() \ No newline at end of file