mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
RAM is not properly released by tf.reset_default_graph
This commit is contained in:
4
Makefile
4
Makefile
@@ -8,7 +8,7 @@
|
|||||||
FEEDSTOCK = spleeter-feedstock
|
FEEDSTOCK = spleeter-feedstock
|
||||||
FEEDSTOCK_REPOSITORY = https://github.com/deezer/$(FEEDSTOCK)
|
FEEDSTOCK_REPOSITORY = https://github.com/deezer/$(FEEDSTOCK)
|
||||||
FEEDSTOCK_RECIPE = $(FEEDSTOCK)/recipe/spleeter/meta.yaml
|
FEEDSTOCK_RECIPE = $(FEEDSTOCK)/recipe/spleeter/meta.yaml
|
||||||
PYTEST_CMD = pytest -W ignore::FutureWarning -W ignore::DeprecationWarning -vv
|
PYTEST_CMD = pytest -W ignore::FutureWarning -W ignore::DeprecationWarning -vv --forked
|
||||||
|
|
||||||
all: clean build test deploy
|
all: clean build test deploy
|
||||||
|
|
||||||
@@ -27,7 +27,7 @@ build-gpu: clean
|
|||||||
python3 setup.py sdist
|
python3 setup.py sdist
|
||||||
|
|
||||||
test:
|
test:
|
||||||
$(foreach file, $(wildcard tests/test_*.py), $(PYTEST_CMD) $(file);)
|
$(PYTEST_CMD)
|
||||||
|
|
||||||
|
|
||||||
deploy:
|
deploy:
|
||||||
|
|||||||
@@ -42,67 +42,67 @@ print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__))
|
|||||||
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
|
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
|
||||||
def test_separate(test_file, configuration, backend):
|
def test_separate(test_file, configuration, backend):
|
||||||
""" Test separation from raw data. """
|
""" Test separation from raw data. """
|
||||||
tf.reset_default_graph()
|
with tf.Session() as sess:
|
||||||
instruments = MODEL_TO_INST[configuration]
|
instruments = MODEL_TO_INST[configuration]
|
||||||
adapter = get_default_audio_adapter()
|
adapter = get_default_audio_adapter()
|
||||||
waveform, _ = adapter.load(test_file)
|
waveform, _ = adapter.load(test_file)
|
||||||
separator = Separator(configuration, stft_backend=backend)
|
separator = Separator(configuration, stft_backend=backend)
|
||||||
prediction = separator.separate(waveform, test_file)
|
prediction = separator.separate(waveform, test_file)
|
||||||
assert len(prediction) == len(instruments)
|
assert len(prediction) == len(instruments)
|
||||||
for instrument in instruments:
|
for instrument in instruments:
|
||||||
assert instrument in prediction
|
assert instrument in prediction
|
||||||
for instrument in instruments:
|
for instrument in instruments:
|
||||||
track = prediction[instrument]
|
track = prediction[instrument]
|
||||||
assert waveform.shape[:-1] == track.shape[:-1]
|
assert waveform.shape[:-1] == track.shape[:-1]
|
||||||
assert not np.allclose(waveform, track)
|
assert not np.allclose(waveform, track)
|
||||||
for compared in instruments:
|
for compared in instruments:
|
||||||
if instrument != compared:
|
if instrument != compared:
|
||||||
assert not np.allclose(track, prediction[compared])
|
assert not np.allclose(track, prediction[compared])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
|
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
|
||||||
def test_separate_to_file(test_file, configuration, backend):
|
def test_separate_to_file(test_file, configuration, backend):
|
||||||
""" Test file based separation. """
|
""" Test file based separation. """
|
||||||
tf.reset_default_graph()
|
with tf.Session() as sess:
|
||||||
instruments = MODEL_TO_INST[configuration]
|
instruments = MODEL_TO_INST[configuration]
|
||||||
separator = Separator(configuration, stft_backend=backend)
|
separator = Separator(configuration, stft_backend=backend)
|
||||||
name = splitext(basename(test_file))[0]
|
name = splitext(basename(test_file))[0]
|
||||||
with TemporaryDirectory() as directory:
|
with TemporaryDirectory() as directory:
|
||||||
separator.separate_to_file(
|
separator.separate_to_file(
|
||||||
test_file,
|
test_file,
|
||||||
directory)
|
directory)
|
||||||
for instrument in instruments:
|
for instrument in instruments:
|
||||||
assert exists(join(
|
assert exists(join(
|
||||||
directory,
|
directory,
|
||||||
'{}/{}.wav'.format(name, instrument)))
|
'{}/{}.wav'.format(name, instrument)))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
|
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
|
||||||
def test_filename_format(test_file, configuration, backend):
|
def test_filename_format(test_file, configuration, backend):
|
||||||
""" Test custom filename format. """
|
""" Test custom filename format. """
|
||||||
tf.reset_default_graph()
|
with tf.Session() as sess:
|
||||||
instruments = MODEL_TO_INST[configuration]
|
instruments = MODEL_TO_INST[configuration]
|
||||||
separator = Separator(configuration, stft_backend=backend)
|
separator = Separator(configuration, stft_backend=backend)
|
||||||
name = splitext(basename(test_file))[0]
|
name = splitext(basename(test_file))[0]
|
||||||
with TemporaryDirectory() as directory:
|
with TemporaryDirectory() as directory:
|
||||||
separator.separate_to_file(
|
separator.separate_to_file(
|
||||||
test_file,
|
test_file,
|
||||||
directory,
|
|
||||||
filename_format='export/{filename}/{instrument}.{codec}')
|
|
||||||
for instrument in instruments:
|
|
||||||
assert exists(join(
|
|
||||||
directory,
|
directory,
|
||||||
'export/{}/{}.wav'.format(name, instrument)))
|
filename_format='export/{filename}/{instrument}.{codec}')
|
||||||
|
for instrument in instruments:
|
||||||
|
assert exists(join(
|
||||||
|
directory,
|
||||||
|
'export/{}/{}.wav'.format(name, instrument)))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('test_file, configuration', MODELS_AND_TEST_FILES)
|
@pytest.mark.parametrize('test_file, configuration', MODELS_AND_TEST_FILES)
|
||||||
def test_filename_conflict(test_file, configuration):
|
def test_filename_conflict(test_file, configuration):
|
||||||
""" Test error handling with static pattern. """
|
""" Test error handling with static pattern. """
|
||||||
tf.reset_default_graph()
|
with tf.Session() as sess:
|
||||||
separator = Separator(configuration)
|
separator = Separator(configuration)
|
||||||
with TemporaryDirectory() as directory:
|
with TemporaryDirectory() as directory:
|
||||||
with pytest.raises(SpleeterError):
|
with pytest.raises(SpleeterError):
|
||||||
separator.separate_to_file(
|
separator.separate_to_file(
|
||||||
test_file,
|
test_file,
|
||||||
directory,
|
directory,
|
||||||
filename_format='I wanna be your lover')
|
filename_format='I wanna be your lover')
|
||||||
|
|||||||
Reference in New Issue
Block a user