🐛 fix test

This commit is contained in:
Faylixe
2020-12-08 12:31:08 +01:00
parent b8277a0126
commit 76bb91f30c
4 changed files with 113 additions and 123 deletions

View File

@@ -7,82 +7,78 @@ __email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research' __author__ = 'Deezer Research'
__license__ = 'MIT License' __license__ = 'MIT License'
import filecmp
import itertools
from os import makedirs from os import makedirs
from os.path import splitext, basename, exists, join from os.path import join
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import pytest import pytest
import numpy as np import numpy as np
import tensorflow as tf from spleeter.__main__ import evaluate
from spleeter.audio.adapter import AudioAdapter
from spleeter.audio.adapter import get_default_audio_adapter BACKENDS = ['tensorflow', 'librosa']
from spleeter.commands import create_argument_parser TEST_CONFIGURATIONS = {el: el for el in BACKENDS}
from spleeter.commands import evaluate
from spleeter.utils.configuration import load_configuration
BACKENDS = ["tensorflow", "librosa"]
TEST_CONFIGURATIONS = {el:el for el in BACKENDS}
res_4stems = { res_4stems = {
"vocals": { 'vocals': {
"SDR": 3.25e-05, 'SDR': 3.25e-05,
"SAR": -11.153575, 'SAR': -11.153575,
"SIR": -1.3849, 'SIR': -1.3849,
"ISR": 2.75e-05 'ISR': 2.75e-05
}, },
"drums": { 'drums': {
"SDR": -0.079505, 'SDR': -0.079505,
"SAR": -15.7073575, 'SAR': -15.7073575,
"SIR": -4.972755, 'SIR': -4.972755,
"ISR": 0.0013575 'ISR': 0.0013575
}, },
"bass":{ 'bass': {
"SDR": 2.5e-06, 'SDR': 2.5e-06,
"SAR": -10.3520575, 'SAR': -10.3520575,
"SIR": -4.272325, 'SIR': -4.272325,
"ISR": 2.5e-06 'ISR': 2.5e-06
}, },
"other":{ 'other': {
"SDR": -1.359175, 'SDR': -1.359175,
"SAR": -14.7076775, 'SAR': -14.7076775,
"SIR": -4.761505, 'SIR': -4.761505,
"ISR": -0.01528 'ISR': -0.01528
} }
} }
def generate_fake_eval_dataset(path): def generate_fake_eval_dataset(path):
""" """
generate fake evaluation dataset generate fake evaluation dataset
""" """
aa = get_default_audio_adapter() aa = AudioAdapter.default()
n_songs = 2 n_songs = 2
fs = 44100 fs = 44100
duration = 3 duration = 3
n_channels = 2 n_channels = 2
rng = np.random.RandomState(seed=0) rng = np.random.RandomState(seed=0)
for song in range(n_songs): for song in range(n_songs):
song_path = join(path, "test", f"song{song}") song_path = join(path, 'test', f'song{song}')
makedirs(song_path, exist_ok=True) makedirs(song_path, exist_ok=True)
for instr in ["mixture", "vocals", "bass", "drums", "other"]: for instr in ['mixture', 'vocals', 'bass', 'drums', 'other']:
filename = join(song_path, f"{instr}.wav") filename = join(song_path, f'{instr}.wav')
data = rng.rand(duration*fs, n_channels)-0.5 data = rng.rand(duration*fs, n_channels)-0.5
aa.save(filename, data, fs) aa.save(filename, data, fs)
@pytest.mark.parametrize('backend', TEST_CONFIGURATIONS) @pytest.mark.parametrize('backend', TEST_CONFIGURATIONS)
def test_evaluate(backend): def test_evaluate(backend):
with TemporaryDirectory() as directory: with TemporaryDirectory() as directory:
generate_fake_eval_dataset(directory) generate_fake_eval_dataset(directory)
p = create_argument_parser() metrics = evaluate(
arguments = p.parse_args(["evaluate", "-p", "spleeter:4stems", "--mus_dir", directory, "-B", backend]) stft_backend=backend,
params = load_configuration(arguments.configuration) params_filename='spleeter:4stems',
metrics = evaluate.entrypoint(arguments, params) mus_dir=directory,
)
for instrument, metric in metrics.items(): for instrument, metric in metrics.items():
for m, value in metric.items(): for m, value in metric.items():
assert np.allclose(np.median(value), res_4stems[instrument][m], atol=1e-3) assert np.allclose(
np.median(value),
res_4stems[instrument][m],
atol=1e-3)

View File

@@ -10,6 +10,11 @@ __license__ = 'MIT License'
from os.path import join from os.path import join
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from spleeter import SpleeterError
from spleeter.audio.adapter import AudioAdapter
from spleeter.audio.ffmpeg import FFMPEGProcessAudioAdapter
# pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
from pytest import fixture, raises from pytest import fixture, raises
@@ -17,12 +22,6 @@ import numpy as np
import ffmpeg import ffmpeg
# pylint: enable=import-error # pylint: enable=import-error
from spleeter import SpleeterError
from spleeter.audio.adapter import AudioAdapter
from spleeter.audio.adapter import get_default_audio_adapter
from spleeter.audio.adapter import get_audio_adapter
from spleeter.audio.ffmpeg import FFMPEGProcessAudioAdapter
TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3' TEST_AUDIO_DESCRIPTOR = 'audio_example.mp3'
TEST_OFFSET = 0 TEST_OFFSET = 0
TEST_DURATION = 600. TEST_DURATION = 600.
@@ -32,7 +31,7 @@ TEST_SAMPLE_RATE = 44100
@fixture(scope='session') @fixture(scope='session')
def adapter(): def adapter():
""" Target test audio adapter fixture. """ """ Target test audio adapter fixture. """
return get_default_audio_adapter() return AudioAdapter.default()
@fixture(scope='session') @fixture(scope='session')

View File

@@ -5,12 +5,12 @@
from pytest import raises from pytest import raises
from spleeter.model.provider import get_default_model_provider from spleeter.model.provider import ModelProvider
def test_checksum(): def test_checksum():
""" Test archive checksum index retrieval. """ """ Test archive checksum index retrieval. """
provider = get_default_model_provider() provider = ModelProvider.default()
assert provider.checksum('2stems') == \ assert provider.checksum('2stems') == \
'f3a90b39dd2874269e8b05a48a86745df897b848c61f3958efc80a39152bd692' 'f3a90b39dd2874269e8b05a48a86745df897b848c61f3958efc80a39152bd692'
assert provider.checksum('4stems') == \ assert provider.checksum('4stems') == \

View File

@@ -7,107 +7,102 @@ __email__ = 'research@deezer.com'
__author__ = 'Deezer Research' __author__ = 'Deezer Research'
__license__ = 'MIT License' __license__ = 'MIT License'
import filecmp import json
import itertools
import os import os
from os import makedirs from os import makedirs
from os.path import splitext, basename, exists, join from os.path import join
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import json
import tensorflow as tf from spleeter.audio.adapter import AudioAdapter
from spleeter.__main__ import spleeter
from typer.testing import CliRunner
from spleeter.audio.adapter import get_default_audio_adapter
from spleeter.commands import create_argument_parser
from spleeter.commands import train
from spleeter.utils.configuration import load_configuration
TRAIN_CONFIG = { TRAIN_CONFIG = {
"mix_name": "mix", 'mix_name': 'mix',
"instrument_list": ["vocals", "other"], 'instrument_list': ['vocals', 'other'],
"sample_rate":44100, 'sample_rate': 44100,
"frame_length":4096, 'frame_length': 4096,
"frame_step":1024, 'frame_step': 1024,
"T":128, 'T': 128,
"F":128, 'F': 128,
"n_channels":2, 'n_channels': 2,
"chunk_duration":4, 'chunk_duration': 4,
"n_chunks_per_song":1, 'n_chunks_per_song': 1,
"separation_exponent":2, 'separation_exponent': 2,
"mask_extension":"zeros", 'mask_extension': 'zeros',
"learning_rate": 1e-4, 'learning_rate': 1e-4,
"batch_size":2, 'batch_size': 2,
"train_max_steps": 10, 'train_max_steps': 10,
"throttle_secs":20, 'throttle_secs': 20,
"save_checkpoints_steps":100, 'save_checkpoints_steps': 100,
"save_summary_steps":5, 'save_summary_steps': 5,
"random_seed":0, 'random_seed': 0,
"model":{ 'model': {
"type":"unet.unet", 'type': 'unet.unet',
"params":{ 'params': {
"conv_activation":"ELU", 'conv_activation': 'ELU',
"deconv_activation":"ELU" 'deconv_activation': 'ELU'
} }
} }
} }
def generate_fake_training_dataset(path, instrument_list=["vocals", "other"]): def generate_fake_training_dataset(path, instrument_list=['vocals', 'other']):
""" """
generates a fake training dataset in path: generates a fake training dataset in path:
- generates audio files - generates audio files
- generates a csv file describing the dataset - generates a csv file describing the dataset
""" """
aa = get_default_audio_adapter() aa = AudioAdapter.default()
n_songs = 2 n_songs = 2
fs = 44100 fs = 44100
duration = 6 duration = 6
n_channels = 2 n_channels = 2
rng = np.random.RandomState(seed=0) rng = np.random.RandomState(seed=0)
dataset_df = pd.DataFrame(columns=["mix_path"]+[f"{instr}_path" for instr in instrument_list]+["duration"]) dataset_df = pd.DataFrame(
columns=['mix_path'] + [
f'{instr}_path' for instr in instrument_list] + ['duration'])
for song in range(n_songs): for song in range(n_songs):
song_path = join(path, "train", f"song{song}") song_path = join(path, 'train', f'song{song}')
makedirs(song_path, exist_ok=True) makedirs(song_path, exist_ok=True)
dataset_df.loc[song, f"duration"] = duration dataset_df.loc[song, f'duration'] = duration
for instr in instrument_list+["mix"]: for instr in instrument_list+['mix']:
filename = join(song_path, f"{instr}.wav") filename = join(song_path, f'{instr}.wav')
data = rng.rand(duration*fs, n_channels)-0.5 data = rng.rand(duration*fs, n_channels)-0.5
aa.save(filename, data, fs) aa.save(filename, data, fs)
dataset_df.loc[song, f"{instr}_path"] = join("train", f"song{song}", f"{instr}.wav") dataset_df.loc[song, f'{instr}_path'] = join(
'train',
dataset_df.to_csv(join(path, "train", "train.csv"), index=False) f'song{song}',
f'{instr}.wav')
dataset_df.to_csv(join(path, 'train', 'train.csv'), index=False)
def test_train(): def test_train():
with TemporaryDirectory() as path: with TemporaryDirectory() as path:
# generate training dataset # generate training dataset
generate_fake_training_dataset(path) generate_fake_training_dataset(path)
# set training command aruments # set training command aruments
p = create_argument_parser() runner = CliRunner()
arguments = p.parse_args(["train", "-p", "useless_config.json", "-d", path]) TRAIN_CONFIG['train_csv'] = join(path, 'train', 'train.csv')
TRAIN_CONFIG["train_csv"] = join(path, "train", "train.csv") TRAIN_CONFIG['validation_csv'] = join(path, 'train', 'train.csv')
TRAIN_CONFIG["validation_csv"] = join(path, "train", "train.csv") TRAIN_CONFIG['model_dir'] = join(path, 'model')
TRAIN_CONFIG["model_dir"] = join(path, "model") TRAIN_CONFIG['training_cache'] = join(path, 'cache', 'training')
TRAIN_CONFIG["training_cache"] = join(path, "cache", "training") TRAIN_CONFIG['validation_cache'] = join(path, 'cache', 'validation')
TRAIN_CONFIG["validation_cache"] = join(path, "cache", "validation") with open('useless_config.json') as stream:
json.dump(TRAIN_CONFIG, stream)
# execute training # execute training
res = train.entrypoint(arguments, TRAIN_CONFIG) result = runner.invoke(spleeter, [
'train',
'-p', 'useless_config.json',
'-d', path
])
assert result.exit_code == 0
# assert that model checkpoint was created. # assert that model checkpoint was created.
assert os.path.exists(join(path,'model','model.ckpt-10.index')) assert os.path.exists(join(path, 'model', 'model.ckpt-10.index'))
assert os.path.exists(join(path,'model','checkpoint')) assert os.path.exists(join(path, 'model', 'checkpoint'))
assert os.path.exists(join(path,'model','model.ckpt-0.meta')) assert os.path.exists(join(path, 'model', 'model.ckpt-0.meta'))
if __name__=="__main__":
test_train()