diff --git a/tests/test_eval.py b/tests/test_eval.py index e9bf762..88c627a 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -99,14 +99,14 @@ def generate_fake_eval_dataset(path): aa.save(filename, data, fs) -def test_evaluate(path="FAKE_MUSDB_DIR"): - generate_fake_eval_dataset(path) - p = create_argument_parser() - arguments = p.parse_args(["evaluate", "-p", "spleeter:4stems", "--mus_dir", path]) - params = load_configuration(arguments.configuration) - metrics = evaluate.entrypoint(arguments, params) - for instrument, metric in metrics.items(): - print(instrument), print(metric) - for m, value in metric.items(): - print(np.median(value)), print(res_4stems[instrument][m]) - assert np.allclose(np.median(value), res_4stems[instrument][m], atol=1e-3) +@pytest.mark.parametrize('backend', TEST_CONFIGURATIONS) +def test_evaluate(backend): + with TemporaryDirectory() as directory: + generate_fake_eval_dataset(directory) + p = create_argument_parser() + arguments = p.parse_args(["evaluate", "-p", "spleeter:4stems", "--mus_dir", directory, "-B", backend]) + params = load_configuration(arguments.configuration) + metrics = evaluate.entrypoint(arguments, params) + for instrument, metric in metrics.items(): + for m, value in metric.items(): + assert np.allclose(np.median(value), res_4stems[backend][instrument][m], atol=1e-3)