From ccde870a1aa09802f8f7e0eed811aa2ce30171ac Mon Sep 17 00:00:00 2001 From: mmoussallam Date: Fri, 2 Oct 2020 17:57:02 +0200 Subject: [PATCH] avoid weird tf issue with multiple Sep instances + Pep8 --- tests/test_separator.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/test_separator.py b/tests/test_separator.py index b83477e..e757abf 100644 --- a/tests/test_separator.py +++ b/tests/test_separator.py @@ -7,7 +7,6 @@ __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' -import filecmp import itertools from os.path import splitext, basename, exists, join from tempfile import TemporaryDirectory @@ -33,7 +32,8 @@ MODEL_TO_INST = { MODELS_AND_TEST_FILES = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS)) -TEST_CONFIGURATIONS = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS, BACKENDS)) +TEST_CONFIGURATIONS = list(itertools.product( + TEST_AUDIO_DESCRIPTORS, MODELS, BACKENDS)) print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__)) @@ -44,8 +44,10 @@ def test_separator_backends(test_file): adapter = get_default_audio_adapter() waveform, _ = adapter.load(test_file) - separator_lib = Separator("spleeter:2stems", stft_backend="librosa") - separator_tf = Separator("spleeter:2stems", stft_backend="tensorflow") + separator_lib = Separator( + "spleeter:2stems", stft_backend="librosa", multiprocess=False) + separator_tf = Separator( + "spleeter:2stems", stft_backend="tensorflow", multiprocess=False) # Test the stft and inverse stft provides exact reconstruction stft_matrix = separator_lib._stft(waveform) @@ -68,7 +70,8 @@ def test_separate(test_file, configuration, backend): instruments = MODEL_TO_INST[configuration] adapter = get_default_audio_adapter() waveform, _ = adapter.load(test_file) - separator = Separator(configuration, stft_backend=backend, multiprocess=False) + separator = Separator( + configuration, stft_backend=backend, multiprocess=False) prediction = separator.separate(waveform, test_file) assert len(prediction) == len(instruments) for instrument in instruments: @@ -86,7 +89,8 @@ def test_separate(test_file, configuration, backend): def test_separate_to_file(test_file, configuration, backend): """ Test file based separation. """ instruments = MODEL_TO_INST[configuration] - separator = Separator(configuration, stft_backend=backend, multiprocess=False) + separator = Separator( + configuration, stft_backend=backend, multiprocess=False) name = splitext(basename(test_file))[0] with TemporaryDirectory() as directory: separator.separate_to_file( @@ -102,7 +106,8 @@ def test_separate_to_file(test_file, configuration, backend): def test_filename_format(test_file, configuration, backend): """ Test custom filename format. """ instruments = MODEL_TO_INST[configuration] - separator = Separator(configuration, stft_backend=backend, multiprocess=False) + separator = Separator( + configuration, stft_backend=backend, multiprocess=False) name = splitext(basename(test_file))[0] with TemporaryDirectory() as directory: separator.separate_to_file(