🐛 fix separator test

This commit is contained in:
Faylixe
2020-12-08 12:38:18 +01:00
parent 7afe146358
commit 1886dde38b
2 changed files with 16 additions and 8 deletions

View File

@@ -85,7 +85,7 @@ def train(
logger.info('Model training done')
@spleeter.commmand()
@spleeter.command()
def separate(
files: List[Path] = AudioInputArgument,
adapter: str = AudioAdapterOption,

View File

@@ -17,7 +17,7 @@ import numpy as np
import tensorflow as tf
from spleeter import SpleeterError
from spleeter.audio.adapter import get_default_audio_adapter
from spleeter.audio.adapter import AudioAdapter
from spleeter.separator import Separator
TEST_AUDIO_DESCRIPTORS = ['audio_example.mp3', 'audio_example_mono.mp3']
@@ -41,7 +41,7 @@ print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__))
@pytest.mark.parametrize('test_file', TEST_AUDIO_DESCRIPTORS)
def test_separator_backends(test_file):
adapter = get_default_audio_adapter()
adapter = AudioAdapter.default()
waveform, _ = adapter.load(test_file)
separator_lib = Separator(
@@ -64,11 +64,13 @@ def test_separator_backends(test_file):
assert np.allclose(out_tf[instrument], out_lib[instrument], atol=1e-5)
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
@pytest.mark.parametrize(
'test_file, configuration, backend',
TEST_CONFIGURATIONS)
def test_separate(test_file, configuration, backend):
""" Test separation from raw data. """
instruments = MODEL_TO_INST[configuration]
adapter = get_default_audio_adapter()
adapter = AudioAdapter.default()
waveform, _ = adapter.load(test_file)
separator = Separator(
configuration, stft_backend=backend, multiprocess=False)
@@ -85,7 +87,9 @@ def test_separate(test_file, configuration, backend):
assert not np.allclose(track, prediction[compared])
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
@pytest.mark.parametrize(
'test_file, configuration, backend',
TEST_CONFIGURATIONS)
def test_separate_to_file(test_file, configuration, backend):
""" Test file based separation. """
instruments = MODEL_TO_INST[configuration]
@@ -102,7 +106,9 @@ def test_separate_to_file(test_file, configuration, backend):
'{}/{}.wav'.format(name, instrument)))
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
@pytest.mark.parametrize(
'test_file, configuration, backend',
TEST_CONFIGURATIONS)
def test_filename_format(test_file, configuration, backend):
""" Test custom filename format. """
instruments = MODEL_TO_INST[configuration]
@@ -120,7 +126,9 @@ def test_filename_format(test_file, configuration, backend):
'export/{}/{}.wav'.format(name, instrument)))
@pytest.mark.parametrize('test_file, configuration', MODELS_AND_TEST_FILES)
@pytest.mark.parametrize(
'test_file, configuration',
MODELS_AND_TEST_FILES)
def test_filename_conflict(test_file, configuration):
""" Test error handling with static pattern. """
separator = Separator(configuration, multiprocess=False)