mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-06 04:32:43 +00:00
🐛 fix separator test
This commit is contained in:
@@ -85,7 +85,7 @@ def train(
|
||||
logger.info('Model training done')
|
||||
|
||||
|
||||
@spleeter.commmand()
|
||||
@spleeter.command()
|
||||
def separate(
|
||||
files: List[Path] = AudioInputArgument,
|
||||
adapter: str = AudioAdapterOption,
|
||||
|
||||
@@ -17,7 +17,7 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from spleeter import SpleeterError
|
||||
from spleeter.audio.adapter import get_default_audio_adapter
|
||||
from spleeter.audio.adapter import AudioAdapter
|
||||
from spleeter.separator import Separator
|
||||
|
||||
TEST_AUDIO_DESCRIPTORS = ['audio_example.mp3', 'audio_example_mono.mp3']
|
||||
@@ -41,7 +41,7 @@ print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__))
|
||||
|
||||
@pytest.mark.parametrize('test_file', TEST_AUDIO_DESCRIPTORS)
|
||||
def test_separator_backends(test_file):
|
||||
adapter = get_default_audio_adapter()
|
||||
adapter = AudioAdapter.default()
|
||||
waveform, _ = adapter.load(test_file)
|
||||
|
||||
separator_lib = Separator(
|
||||
@@ -64,11 +64,13 @@ def test_separator_backends(test_file):
|
||||
assert np.allclose(out_tf[instrument], out_lib[instrument], atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
|
||||
@pytest.mark.parametrize(
|
||||
'test_file, configuration, backend',
|
||||
TEST_CONFIGURATIONS)
|
||||
def test_separate(test_file, configuration, backend):
|
||||
""" Test separation from raw data. """
|
||||
instruments = MODEL_TO_INST[configuration]
|
||||
adapter = get_default_audio_adapter()
|
||||
adapter = AudioAdapter.default()
|
||||
waveform, _ = adapter.load(test_file)
|
||||
separator = Separator(
|
||||
configuration, stft_backend=backend, multiprocess=False)
|
||||
@@ -85,7 +87,9 @@ def test_separate(test_file, configuration, backend):
|
||||
assert not np.allclose(track, prediction[compared])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
|
||||
@pytest.mark.parametrize(
|
||||
'test_file, configuration, backend',
|
||||
TEST_CONFIGURATIONS)
|
||||
def test_separate_to_file(test_file, configuration, backend):
|
||||
""" Test file based separation. """
|
||||
instruments = MODEL_TO_INST[configuration]
|
||||
@@ -102,7 +106,9 @@ def test_separate_to_file(test_file, configuration, backend):
|
||||
'{}/{}.wav'.format(name, instrument)))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
|
||||
@pytest.mark.parametrize(
|
||||
'test_file, configuration, backend',
|
||||
TEST_CONFIGURATIONS)
|
||||
def test_filename_format(test_file, configuration, backend):
|
||||
""" Test custom filename format. """
|
||||
instruments = MODEL_TO_INST[configuration]
|
||||
@@ -120,7 +126,9 @@ def test_filename_format(test_file, configuration, backend):
|
||||
'export/{}/{}.wav'.format(name, instrument)))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('test_file, configuration', MODELS_AND_TEST_FILES)
|
||||
@pytest.mark.parametrize(
|
||||
'test_file, configuration',
|
||||
MODELS_AND_TEST_FILES)
|
||||
def test_filename_conflict(test_file, configuration):
|
||||
""" Test error handling with static pattern. """
|
||||
separator = Separator(configuration, multiprocess=False)
|
||||
|
||||
Reference in New Issue
Block a user