diff --git a/tests/test_eval.py b/tests/test_eval.py index 6a3634b..13fb708 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -53,6 +53,9 @@ res_4stems = { "vocals": { def generate_fake_eval_dataset(path): + """ + generate fake evaluation dataset + """ aa = get_default_audio_adapter() n_songs = 2 fs = 44100 @@ -68,12 +71,26 @@ def generate_fake_eval_dataset(path): aa.save(filename, data, fs) -def test_evaluate(path="FAKE_MUSDB_DIR"): - generate_fake_eval_dataset(path) - p = create_argument_parser() - arguments = p.parse_args(["evaluate", "-p", "spleeter:4stems", "--mus_dir", path]) - params = load_configuration(arguments.configuration) - metrics = evaluate.entrypoint(arguments, params) - for instrument, metric in metrics.items(): - for metric, value in metric.items(): - assert np.allclose(np.median(value), res_4stems[instrument][metric], atol=1e-3) \ No newline at end of file +def test_evaluate(): + """ + test evaluate command + """ + + with TemporaryDirectory() as path: + + # generate fake dataset + generate_fake_eval_dataset(path) + + # set up arguments of command + p = create_argument_parser() + arguments = p.parse_args(["evaluate", "-p", "spleeter:4stems", "--mus_dir", path]) + params = load_configuration(arguments.configuration) + + # run evaluation + metrics = evaluate.entrypoint(arguments, params) + + # assert that the metric as not changed compared to reference value + # (Note that this fails with tensorflow backend) + for instrument, metric in metrics.items(): + for metric, value in metric.items(): + assert np.allclose(np.median(value), res_4stems[instrument][metric], atol=1e-3) \ No newline at end of file