mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
🐛 fix separator test
This commit is contained in:
@@ -85,7 +85,7 @@ def train(
|
|||||||
logger.info('Model training done')
|
logger.info('Model training done')
|
||||||
|
|
||||||
|
|
||||||
@spleeter.commmand()
|
@spleeter.command()
|
||||||
def separate(
|
def separate(
|
||||||
files: List[Path] = AudioInputArgument,
|
files: List[Path] = AudioInputArgument,
|
||||||
adapter: str = AudioAdapterOption,
|
adapter: str = AudioAdapterOption,
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import numpy as np
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from spleeter import SpleeterError
|
from spleeter import SpleeterError
|
||||||
from spleeter.audio.adapter import get_default_audio_adapter
|
from spleeter.audio.adapter import AudioAdapter
|
||||||
from spleeter.separator import Separator
|
from spleeter.separator import Separator
|
||||||
|
|
||||||
TEST_AUDIO_DESCRIPTORS = ['audio_example.mp3', 'audio_example_mono.mp3']
|
TEST_AUDIO_DESCRIPTORS = ['audio_example.mp3', 'audio_example_mono.mp3']
|
||||||
@@ -41,7 +41,7 @@ print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__))
|
|||||||
|
|
||||||
@pytest.mark.parametrize('test_file', TEST_AUDIO_DESCRIPTORS)
|
@pytest.mark.parametrize('test_file', TEST_AUDIO_DESCRIPTORS)
|
||||||
def test_separator_backends(test_file):
|
def test_separator_backends(test_file):
|
||||||
adapter = get_default_audio_adapter()
|
adapter = AudioAdapter.default()
|
||||||
waveform, _ = adapter.load(test_file)
|
waveform, _ = adapter.load(test_file)
|
||||||
|
|
||||||
separator_lib = Separator(
|
separator_lib = Separator(
|
||||||
@@ -64,11 +64,13 @@ def test_separator_backends(test_file):
|
|||||||
assert np.allclose(out_tf[instrument], out_lib[instrument], atol=1e-5)
|
assert np.allclose(out_tf[instrument], out_lib[instrument], atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
|
@pytest.mark.parametrize(
|
||||||
|
'test_file, configuration, backend',
|
||||||
|
TEST_CONFIGURATIONS)
|
||||||
def test_separate(test_file, configuration, backend):
|
def test_separate(test_file, configuration, backend):
|
||||||
""" Test separation from raw data. """
|
""" Test separation from raw data. """
|
||||||
instruments = MODEL_TO_INST[configuration]
|
instruments = MODEL_TO_INST[configuration]
|
||||||
adapter = get_default_audio_adapter()
|
adapter = AudioAdapter.default()
|
||||||
waveform, _ = adapter.load(test_file)
|
waveform, _ = adapter.load(test_file)
|
||||||
separator = Separator(
|
separator = Separator(
|
||||||
configuration, stft_backend=backend, multiprocess=False)
|
configuration, stft_backend=backend, multiprocess=False)
|
||||||
@@ -85,7 +87,9 @@ def test_separate(test_file, configuration, backend):
|
|||||||
assert not np.allclose(track, prediction[compared])
|
assert not np.allclose(track, prediction[compared])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
|
@pytest.mark.parametrize(
|
||||||
|
'test_file, configuration, backend',
|
||||||
|
TEST_CONFIGURATIONS)
|
||||||
def test_separate_to_file(test_file, configuration, backend):
|
def test_separate_to_file(test_file, configuration, backend):
|
||||||
""" Test file based separation. """
|
""" Test file based separation. """
|
||||||
instruments = MODEL_TO_INST[configuration]
|
instruments = MODEL_TO_INST[configuration]
|
||||||
@@ -102,7 +106,9 @@ def test_separate_to_file(test_file, configuration, backend):
|
|||||||
'{}/{}.wav'.format(name, instrument)))
|
'{}/{}.wav'.format(name, instrument)))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
|
@pytest.mark.parametrize(
|
||||||
|
'test_file, configuration, backend',
|
||||||
|
TEST_CONFIGURATIONS)
|
||||||
def test_filename_format(test_file, configuration, backend):
|
def test_filename_format(test_file, configuration, backend):
|
||||||
""" Test custom filename format. """
|
""" Test custom filename format. """
|
||||||
instruments = MODEL_TO_INST[configuration]
|
instruments = MODEL_TO_INST[configuration]
|
||||||
@@ -120,7 +126,9 @@ def test_filename_format(test_file, configuration, backend):
|
|||||||
'export/{}/{}.wav'.format(name, instrument)))
|
'export/{}/{}.wav'.format(name, instrument)))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('test_file, configuration', MODELS_AND_TEST_FILES)
|
@pytest.mark.parametrize(
|
||||||
|
'test_file, configuration',
|
||||||
|
MODELS_AND_TEST_FILES)
|
||||||
def test_filename_conflict(test_file, configuration):
|
def test_filename_conflict(test_file, configuration):
|
||||||
""" Test error handling with static pattern. """
|
""" Test error handling with static pattern. """
|
||||||
separator = Separator(configuration, multiprocess=False)
|
separator = Separator(configuration, multiprocess=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user