🐛 fix test

This commit is contained in:
Faylixe
2020-12-08 12:31:08 +01:00
parent b8277a0126
commit 76bb91f30c
4 changed files with 113 additions and 123 deletions

View File

@@ -7,82 +7,78 @@ __email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
import filecmp
import itertools
from os import makedirs
from os.path import splitext, basename, exists, join
from os.path import join
from tempfile import TemporaryDirectory
import pytest
import numpy as np
import tensorflow as tf
from spleeter.__main__ import evaluate
from spleeter.audio.adapter import AudioAdapter
from spleeter.audio.adapter import get_default_audio_adapter
from spleeter.commands import create_argument_parser
from spleeter.commands import evaluate
from spleeter.utils.configuration import load_configuration
BACKENDS = ["tensorflow", "librosa"]
TEST_CONFIGURATIONS = {el:el for el in BACKENDS}
BACKENDS = ['tensorflow', 'librosa']
TEST_CONFIGURATIONS = {el: el for el in BACKENDS}
res_4stems = {
"vocals": {
"SDR": 3.25e-05,
"SAR": -11.153575,
"SIR": -1.3849,
"ISR": 2.75e-05
},
"drums": {
"SDR": -0.079505,
"SAR": -15.7073575,
"SIR": -4.972755,
"ISR": 0.0013575
},
"bass":{
"SDR": 2.5e-06,
"SAR": -10.3520575,
"SIR": -4.272325,
"ISR": 2.5e-06
},
"other":{
"SDR": -1.359175,
"SAR": -14.7076775,
"SIR": -4.761505,
"ISR": -0.01528
}
}
'vocals': {
'SDR': 3.25e-05,
'SAR': -11.153575,
'SIR': -1.3849,
'ISR': 2.75e-05
},
'drums': {
'SDR': -0.079505,
'SAR': -15.7073575,
'SIR': -4.972755,
'ISR': 0.0013575
},
'bass': {
'SDR': 2.5e-06,
'SAR': -10.3520575,
'SIR': -4.272325,
'ISR': 2.5e-06
},
'other': {
'SDR': -1.359175,
'SAR': -14.7076775,
'SIR': -4.761505,
'ISR': -0.01528
}
}
def generate_fake_eval_dataset(path):
"""
generate fake evaluation dataset
"""
aa = get_default_audio_adapter()
aa = AudioAdapter.default()
n_songs = 2
fs = 44100
duration = 3
n_channels = 2
rng = np.random.RandomState(seed=0)
for song in range(n_songs):
song_path = join(path, "test", f"song{song}")
song_path = join(path, 'test', f'song{song}')
makedirs(song_path, exist_ok=True)
for instr in ["mixture", "vocals", "bass", "drums", "other"]:
filename = join(song_path, f"{instr}.wav")
for instr in ['mixture', 'vocals', 'bass', 'drums', 'other']:
filename = join(song_path, f'{instr}.wav')
data = rng.rand(duration*fs, n_channels)-0.5
aa.save(filename, data, fs)
@pytest.mark.parametrize('backend', TEST_CONFIGURATIONS)
def test_evaluate(backend):
with TemporaryDirectory() as directory:
generate_fake_eval_dataset(directory)
p = create_argument_parser()
arguments = p.parse_args(["evaluate", "-p", "spleeter:4stems", "--mus_dir", directory, "-B", backend])
params = load_configuration(arguments.configuration)
metrics = evaluate.entrypoint(arguments, params)
metrics = evaluate(
stft_backend=backend,
params_filename='spleeter:4stems',
mus_dir=directory,
)
for instrument, metric in metrics.items():
for m, value in metric.items():
assert np.allclose(np.median(value), res_4stems[instrument][m], atol=1e-3)
assert np.allclose(
np.median(value),
res_4stems[instrument][m],
atol=1e-3)

View File

@@ -10,6 +10,11 @@ __license__ = 'MIT License'
from os.path import join
from tempfile import TemporaryDirectory
from spleeter import SpleeterError
from spleeter.audio.adapter import AudioAdapter
from spleeter.audio.ffmpeg import FFMPEGProcessAudioAdapter
# pyright: reportMissingImports=false
# pylint: disable=import-error
from pytest import fixture, raises
@@ -17,12 +22,6 @@ import numpy as np
import ffmpeg
# pylint: enable=import-error
from spleeter import SpleeterError
from spleeter.audio.adapter import AudioAdapter
from spleeter.audio.adapter import get_default_audio_adapter
from spleeter.audio.adapter import get_audio_adapter
from spleeter.audio.ffmpeg import FFMPEGProcessAudioAdapter
TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3'
TEST_OFFSET = 0
TEST_DURATION = 600.
@@ -32,7 +31,7 @@ TEST_SAMPLE_RATE = 44100
@fixture(scope='session')
def adapter():
""" Target test audio adapter fixture. """
return get_default_audio_adapter()
return AudioAdapter.default()
@fixture(scope='session')

View File

@@ -5,12 +5,12 @@
from pytest import raises
from spleeter.model.provider import get_default_model_provider
from spleeter.model.provider import ModelProvider
def test_checksum():
""" Test archive checksum index retrieval. """
provider = get_default_model_provider()
provider = ModelProvider.default()
assert provider.checksum('2stems') == \
'f3a90b39dd2874269e8b05a48a86745df897b848c61f3958efc80a39152bd692'
assert provider.checksum('4stems') == \

View File

@@ -7,107 +7,102 @@ __email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
import filecmp
import itertools
import json
import os
from os import makedirs
from os.path import splitext, basename, exists, join
from os.path import join
from tempfile import TemporaryDirectory
import numpy as np
import pandas as pd
import json
import tensorflow as tf
from spleeter.audio.adapter import AudioAdapter
from spleeter.__main__ import spleeter
from typer.testing import CliRunner
from spleeter.audio.adapter import get_default_audio_adapter
from spleeter.commands import create_argument_parser
from spleeter.commands import train
from spleeter.utils.configuration import load_configuration
TRAIN_CONFIG = {
"mix_name": "mix",
"instrument_list": ["vocals", "other"],
"sample_rate":44100,
"frame_length":4096,
"frame_step":1024,
"T":128,
"F":128,
"n_channels":2,
"chunk_duration":4,
"n_chunks_per_song":1,
"separation_exponent":2,
"mask_extension":"zeros",
"learning_rate": 1e-4,
"batch_size":2,
"train_max_steps": 10,
"throttle_secs":20,
"save_checkpoints_steps":100,
"save_summary_steps":5,
"random_seed":0,
"model":{
"type":"unet.unet",
"params":{
"conv_activation":"ELU",
"deconv_activation":"ELU"
'mix_name': 'mix',
'instrument_list': ['vocals', 'other'],
'sample_rate': 44100,
'frame_length': 4096,
'frame_step': 1024,
'T': 128,
'F': 128,
'n_channels': 2,
'chunk_duration': 4,
'n_chunks_per_song': 1,
'separation_exponent': 2,
'mask_extension': 'zeros',
'learning_rate': 1e-4,
'batch_size': 2,
'train_max_steps': 10,
'throttle_secs': 20,
'save_checkpoints_steps': 100,
'save_summary_steps': 5,
'random_seed': 0,
'model': {
'type': 'unet.unet',
'params': {
'conv_activation': 'ELU',
'deconv_activation': 'ELU'
}
}
}
def generate_fake_training_dataset(path, instrument_list=["vocals", "other"]):
def generate_fake_training_dataset(path, instrument_list=['vocals', 'other']):
"""
generates a fake training dataset in path:
- generates audio files
- generates a csv file describing the dataset
"""
aa = get_default_audio_adapter()
aa = AudioAdapter.default()
n_songs = 2
fs = 44100
duration = 6
n_channels = 2
rng = np.random.RandomState(seed=0)
dataset_df = pd.DataFrame(columns=["mix_path"]+[f"{instr}_path" for instr in instrument_list]+["duration"])
dataset_df = pd.DataFrame(
columns=['mix_path'] + [
f'{instr}_path' for instr in instrument_list] + ['duration'])
for song in range(n_songs):
song_path = join(path, "train", f"song{song}")
song_path = join(path, 'train', f'song{song}')
makedirs(song_path, exist_ok=True)
dataset_df.loc[song, f"duration"] = duration
for instr in instrument_list+["mix"]:
filename = join(song_path, f"{instr}.wav")
dataset_df.loc[song, f'duration'] = duration
for instr in instrument_list+['mix']:
filename = join(song_path, f'{instr}.wav')
data = rng.rand(duration*fs, n_channels)-0.5
aa.save(filename, data, fs)
dataset_df.loc[song, f"{instr}_path"] = join("train", f"song{song}", f"{instr}.wav")
dataset_df.to_csv(join(path, "train", "train.csv"), index=False)
dataset_df.loc[song, f'{instr}_path'] = join(
'train',
f'song{song}',
f'{instr}.wav')
dataset_df.to_csv(join(path, 'train', 'train.csv'), index=False)
def test_train():
with TemporaryDirectory() as path:
# generate training dataset
generate_fake_training_dataset(path)
# set training command aruments
p = create_argument_parser()
arguments = p.parse_args(["train", "-p", "useless_config.json", "-d", path])
TRAIN_CONFIG["train_csv"] = join(path, "train", "train.csv")
TRAIN_CONFIG["validation_csv"] = join(path, "train", "train.csv")
TRAIN_CONFIG["model_dir"] = join(path, "model")
TRAIN_CONFIG["training_cache"] = join(path, "cache", "training")
TRAIN_CONFIG["validation_cache"] = join(path, "cache", "validation")
runner = CliRunner()
TRAIN_CONFIG['train_csv'] = join(path, 'train', 'train.csv')
TRAIN_CONFIG['validation_csv'] = join(path, 'train', 'train.csv')
TRAIN_CONFIG['model_dir'] = join(path, 'model')
TRAIN_CONFIG['training_cache'] = join(path, 'cache', 'training')
TRAIN_CONFIG['validation_cache'] = join(path, 'cache', 'validation')
with open('useless_config.json') as stream:
json.dump(TRAIN_CONFIG, stream)
# execute training
res = train.entrypoint(arguments, TRAIN_CONFIG)
result = runner.invoke(spleeter, [
'train',
'-p', 'useless_config.json',
'-d', path
])
assert result.exit_code == 0
# assert that model checkpoint was created.
assert os.path.exists(join(path,'model','model.ckpt-10.index'))
assert os.path.exists(join(path,'model','checkpoint'))
assert os.path.exists(join(path,'model','model.ckpt-0.meta'))
if __name__=="__main__":
test_train()
assert os.path.exists(join(path, 'model', 'model.ckpt-10.index'))
assert os.path.exists(join(path, 'model', 'checkpoint'))
assert os.path.exists(join(path, 'model', 'model.ckpt-0.meta'))