diff --git a/tests/test_eval.py b/tests/test_eval.py index f3764b6..06bfb88 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -7,82 +7,78 @@ __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' -import filecmp -import itertools from os import makedirs -from os.path import splitext, basename, exists, join +from os.path import join from tempfile import TemporaryDirectory import pytest import numpy as np -import tensorflow as tf +from spleeter.__main__ import evaluate +from spleeter.audio.adapter import AudioAdapter -from spleeter.audio.adapter import get_default_audio_adapter -from spleeter.commands import create_argument_parser - -from spleeter.commands import evaluate - -from spleeter.utils.configuration import load_configuration - -BACKENDS = ["tensorflow", "librosa"] -TEST_CONFIGURATIONS = {el:el for el in BACKENDS} +BACKENDS = ['tensorflow', 'librosa'] +TEST_CONFIGURATIONS = {el: el for el in BACKENDS} res_4stems = { - "vocals": { - "SDR": 3.25e-05, - "SAR": -11.153575, - "SIR": -1.3849, - "ISR": 2.75e-05 - }, - "drums": { - "SDR": -0.079505, - "SAR": -15.7073575, - "SIR": -4.972755, - "ISR": 0.0013575 - }, - "bass":{ - "SDR": 2.5e-06, - "SAR": -10.3520575, - "SIR": -4.272325, - "ISR": 2.5e-06 - }, - "other":{ - "SDR": -1.359175, - "SAR": -14.7076775, - "SIR": -4.761505, - "ISR": -0.01528 - } - } + 'vocals': { + 'SDR': 3.25e-05, + 'SAR': -11.153575, + 'SIR': -1.3849, + 'ISR': 2.75e-05 + }, + 'drums': { + 'SDR': -0.079505, + 'SAR': -15.7073575, + 'SIR': -4.972755, + 'ISR': 0.0013575 + }, + 'bass': { + 'SDR': 2.5e-06, + 'SAR': -10.3520575, + 'SIR': -4.272325, + 'ISR': 2.5e-06 + }, + 'other': { + 'SDR': -1.359175, + 'SAR': -14.7076775, + 'SIR': -4.761505, + 'ISR': -0.01528 + } +} + def generate_fake_eval_dataset(path): """ generate fake evaluation dataset """ - aa = get_default_audio_adapter() + aa = AudioAdapter.default() n_songs = 2 fs = 44100 duration = 3 n_channels = 2 rng = np.random.RandomState(seed=0) for song in range(n_songs): - song_path = join(path, "test", f"song{song}") + song_path = join(path, 'test', f'song{song}') makedirs(song_path, exist_ok=True) - for instr in ["mixture", "vocals", "bass", "drums", "other"]: - filename = join(song_path, f"{instr}.wav") + for instr in ['mixture', 'vocals', 'bass', 'drums', 'other']: + filename = join(song_path, f'{instr}.wav') data = rng.rand(duration*fs, n_channels)-0.5 aa.save(filename, data, fs) - @pytest.mark.parametrize('backend', TEST_CONFIGURATIONS) def test_evaluate(backend): with TemporaryDirectory() as directory: generate_fake_eval_dataset(directory) - p = create_argument_parser() - arguments = p.parse_args(["evaluate", "-p", "spleeter:4stems", "--mus_dir", directory, "-B", backend]) - params = load_configuration(arguments.configuration) - metrics = evaluate.entrypoint(arguments, params) + metrics = evaluate( + stft_backend=backend, + params_filename='spleeter:4stems', + mus_dir=directory, + ) for instrument, metric in metrics.items(): for m, value in metric.items(): - assert np.allclose(np.median(value), res_4stems[instrument][m], atol=1e-3) \ No newline at end of file + assert np.allclose( + np.median(value), + res_4stems[instrument][m], + atol=1e-3) diff --git a/tests/test_ffmpeg_adapter.py b/tests/test_ffmpeg_adapter.py index 8eb284a..dc4335f 100644 --- a/tests/test_ffmpeg_adapter.py +++ b/tests/test_ffmpeg_adapter.py @@ -10,6 +10,11 @@ __license__ = 'MIT License' from os.path import join from tempfile import TemporaryDirectory +from spleeter import SpleeterError +from spleeter.audio.adapter import AudioAdapter +from spleeter.audio.ffmpeg import FFMPEGProcessAudioAdapter + +# pyright: reportMissingImports=false # pylint: disable=import-error from pytest import fixture, raises @@ -17,12 +22,6 @@ import numpy as np import ffmpeg # pylint: enable=import-error -from spleeter import SpleeterError -from spleeter.audio.adapter import AudioAdapter -from spleeter.audio.adapter import get_default_audio_adapter -from spleeter.audio.adapter import get_audio_adapter -from spleeter.audio.ffmpeg import FFMPEGProcessAudioAdapter - TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3' TEST_OFFSET = 0 TEST_DURATION = 600. @@ -32,7 +31,7 @@ TEST_SAMPLE_RATE = 44100 @fixture(scope='session') def adapter(): """ Target test audio adapter fixture. """ - return get_default_audio_adapter() + return AudioAdapter.default() @fixture(scope='session') diff --git a/tests/test_github_model_provider.py b/tests/test_github_model_provider.py index 248b1d5..6313999 100644 --- a/tests/test_github_model_provider.py +++ b/tests/test_github_model_provider.py @@ -5,12 +5,12 @@ from pytest import raises -from spleeter.model.provider import get_default_model_provider +from spleeter.model.provider import ModelProvider def test_checksum(): """ Test archive checksum index retrieval. """ - provider = get_default_model_provider() + provider = ModelProvider.default() assert provider.checksum('2stems') == \ 'f3a90b39dd2874269e8b05a48a86745df897b848c61f3958efc80a39152bd692' assert provider.checksum('4stems') == \ diff --git a/tests/test_train.py b/tests/test_train.py index 8d9533a..3c63b40 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -7,107 +7,102 @@ __email__ = 'research@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' -import filecmp -import itertools +import json import os + from os import makedirs -from os.path import splitext, basename, exists, join +from os.path import join from tempfile import TemporaryDirectory import numpy as np import pandas as pd -import json -import tensorflow as tf +from spleeter.audio.adapter import AudioAdapter +from spleeter.__main__ import spleeter +from typer.testing import CliRunner -from spleeter.audio.adapter import get_default_audio_adapter -from spleeter.commands import create_argument_parser - -from spleeter.commands import train - -from spleeter.utils.configuration import load_configuration TRAIN_CONFIG = { - "mix_name": "mix", - "instrument_list": ["vocals", "other"], - "sample_rate":44100, - "frame_length":4096, - "frame_step":1024, - "T":128, - "F":128, - "n_channels":2, - "chunk_duration":4, - "n_chunks_per_song":1, - "separation_exponent":2, - "mask_extension":"zeros", - "learning_rate": 1e-4, - "batch_size":2, - "train_max_steps": 10, - "throttle_secs":20, - "save_checkpoints_steps":100, - "save_summary_steps":5, - "random_seed":0, - "model":{ - "type":"unet.unet", - "params":{ - "conv_activation":"ELU", - "deconv_activation":"ELU" + 'mix_name': 'mix', + 'instrument_list': ['vocals', 'other'], + 'sample_rate': 44100, + 'frame_length': 4096, + 'frame_step': 1024, + 'T': 128, + 'F': 128, + 'n_channels': 2, + 'chunk_duration': 4, + 'n_chunks_per_song': 1, + 'separation_exponent': 2, + 'mask_extension': 'zeros', + 'learning_rate': 1e-4, + 'batch_size': 2, + 'train_max_steps': 10, + 'throttle_secs': 20, + 'save_checkpoints_steps': 100, + 'save_summary_steps': 5, + 'random_seed': 0, + 'model': { + 'type': 'unet.unet', + 'params': { + 'conv_activation': 'ELU', + 'deconv_activation': 'ELU' } } } -def generate_fake_training_dataset(path, instrument_list=["vocals", "other"]): +def generate_fake_training_dataset(path, instrument_list=['vocals', 'other']): """ generates a fake training dataset in path: - generates audio files - generates a csv file describing the dataset """ - aa = get_default_audio_adapter() + aa = AudioAdapter.default() n_songs = 2 fs = 44100 duration = 6 n_channels = 2 rng = np.random.RandomState(seed=0) - dataset_df = pd.DataFrame(columns=["mix_path"]+[f"{instr}_path" for instr in instrument_list]+["duration"]) + dataset_df = pd.DataFrame( + columns=['mix_path'] + [ + f'{instr}_path' for instr in instrument_list] + ['duration']) for song in range(n_songs): - song_path = join(path, "train", f"song{song}") + song_path = join(path, 'train', f'song{song}') makedirs(song_path, exist_ok=True) - dataset_df.loc[song, f"duration"] = duration - for instr in instrument_list+["mix"]: - filename = join(song_path, f"{instr}.wav") + dataset_df.loc[song, f'duration'] = duration + for instr in instrument_list+['mix']: + filename = join(song_path, f'{instr}.wav') data = rng.rand(duration*fs, n_channels)-0.5 aa.save(filename, data, fs) - dataset_df.loc[song, f"{instr}_path"] = join("train", f"song{song}", f"{instr}.wav") - - dataset_df.to_csv(join(path, "train", "train.csv"), index=False) - + dataset_df.loc[song, f'{instr}_path'] = join( + 'train', + f'song{song}', + f'{instr}.wav') + dataset_df.to_csv(join(path, 'train', 'train.csv'), index=False) def test_train(): - - with TemporaryDirectory() as path: - # generate training dataset generate_fake_training_dataset(path) - # set training command aruments - p = create_argument_parser() - arguments = p.parse_args(["train", "-p", "useless_config.json", "-d", path]) - TRAIN_CONFIG["train_csv"] = join(path, "train", "train.csv") - TRAIN_CONFIG["validation_csv"] = join(path, "train", "train.csv") - TRAIN_CONFIG["model_dir"] = join(path, "model") - TRAIN_CONFIG["training_cache"] = join(path, "cache", "training") - TRAIN_CONFIG["validation_cache"] = join(path, "cache", "validation") - + runner = CliRunner() + TRAIN_CONFIG['train_csv'] = join(path, 'train', 'train.csv') + TRAIN_CONFIG['validation_csv'] = join(path, 'train', 'train.csv') + TRAIN_CONFIG['model_dir'] = join(path, 'model') + TRAIN_CONFIG['training_cache'] = join(path, 'cache', 'training') + TRAIN_CONFIG['validation_cache'] = join(path, 'cache', 'validation') + with open('useless_config.json') as stream: + json.dump(TRAIN_CONFIG, stream) # execute training - res = train.entrypoint(arguments, TRAIN_CONFIG) - + result = runner.invoke(spleeter, [ + 'train', + '-p', 'useless_config.json', + '-d', path + ]) + assert result.exit_code == 0 # assert that model checkpoint was created. - assert os.path.exists(join(path,'model','model.ckpt-10.index')) - assert os.path.exists(join(path,'model','checkpoint')) - assert os.path.exists(join(path,'model','model.ckpt-0.meta')) - -if __name__=="__main__": - test_train() \ No newline at end of file + assert os.path.exists(join(path, 'model', 'model.ckpt-10.index')) + assert os.path.exists(join(path, 'model', 'checkpoint')) + assert os.path.exists(join(path, 'model', 'model.ckpt-0.meta'))