Merge pull request #498 from deezer/tf2

Tensorflow 2 compatible version
This commit is contained in:
Romain Hennequin
2020-10-09 16:09:50 +02:00
committed by GitHub
12 changed files with 310 additions and 124 deletions

View File

@@ -7,7 +7,7 @@ jobs:
strategy: strategy:
matrix: matrix:
platform: [cpu, gpu] platform: [cpu, gpu]
distribution: [3.6, 3.7, conda] distribution: [3.6, 3.7, 3.8, conda]
model: [modelless, 2stems, 4stems, 5stems] model: [modelless, 2stems, 4stems, 5stems]
fail-fast: true fail-fast: true
steps: steps:
@@ -69,13 +69,13 @@ jobs:
run: echo ${{ secrets.DOCKERHUB_PASSWORD }} | docker login -u ${{ secrets.DOCKERHUB_USERNAME }} --password-stdin run: echo ${{ secrets.DOCKERHUB_PASSWORD }} | docker login -u ${{ secrets.DOCKERHUB_USERNAME }} --password-stdin
- name: Push deezer/spleeter:${{ env.tag }} image - name: Push deezer/spleeter:${{ env.tag }} image
run: docker push deezer/spleeter:${{ env.tag }} run: docker push deezer/spleeter:${{ env.tag }}
- if: ${{ env.tag == 'spleeter:3.7' }} - if: ${{ env.tag == 'spleeter:3.8' }}
name: Push deezer/spleeter:latest image name: Push deezer/spleeter:latest image
run: | run: |
docker tag deezer/spleeter:3.7 deezer/spleeter:latest docker tag deezer/spleeter:3.8 deezer/spleeter:latest
docker push deezer/spleeter:latest docker push deezer/spleeter:latest
- if: ${{ env.tag == 'spleeter:3.7-gpu' }} - if: ${{ env.tag == 'spleeter:3.8-gpu' }}
name: Push deezer/spleeter:gpu image name: Push deezer/spleeter:gpu image
run: | run: |
docker tag deezer/spleeter:3.7-gpu deezer/spleeter:gpu docker tag deezer/spleeter:3.8-gpu deezer/spleeter:gpu
docker push deezer/spleeter:gpu docker push deezer/spleeter:gpu

View File

@@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: [3.6, 3.7] python-version: [3.6, 3.7, 3.8]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}

View File

@@ -14,9 +14,9 @@ __license__ = 'MIT License'
# Default project values. # Default project values.
project_name = 'spleeter' project_name = 'spleeter'
project_version = '1.5.4' project_version = '2.0'
tensorflow_dependency = 'tensorflow' tensorflow_dependency = 'tensorflow'
tensorflow_version = '1.15.2' tensorflow_version = '2.3.0'
here = path.abspath(path.dirname(__file__)) here = path.abspath(path.dirname(__file__))
readme_path = path.join(here, 'README.md') readme_path = path.join(here, 'README.md')
with open(readme_path, 'r') as stream: with open(readme_path, 'r') as stream:
@@ -47,17 +47,16 @@ setup(
'spleeter.utils', 'spleeter.utils',
], ],
package_data={'spleeter.resources': ['*.json']}, package_data={'spleeter.resources': ['*.json']},
python_requires='>=3.6, <3.8', python_requires='>=3.6, <3.9',
include_package_data=True, include_package_data=True,
install_requires=[ install_requires=[
'ffmpeg-python', 'ffmpeg-python',
'importlib_resources ; python_version<"3.7"', 'importlib_resources ; python_version<"3.7"',
'norbert==0.2.1', 'norbert==0.2.1',
'pandas==0.25.1', 'pandas==1.1.2',
'requests', 'requests',
'setuptools>=41.0.0', 'setuptools>=41.0.0',
'librosa==0.7.2', 'librosa==0.8.0',
'numba==0.48.0',
'{}=={}'.format(tensorflow_dependency, tensorflow_version), '{}=={}'.format(tensorflow_dependency, tensorflow_version),
], ],
extras_require={ extras_require={

View File

@@ -13,7 +13,7 @@ from os.path import exists
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.signal import stft, hann_window from tensorflow.signal import stft, hann_window
# pylint: enable=import-error # pylint: enable=import-error
from .. import SpleeterError from .. import SpleeterError

View File

@@ -7,7 +7,7 @@
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.signal import stft, hann_window from tensorflow.signal import stft, hann_window
# pylint: enable=import-error # pylint: enable=import-error
__email__ = 'spleeter@deezer.com' __email__ = 'spleeter@deezer.com'

View File

@@ -238,7 +238,7 @@ class DatasetBuilder(object):
def expand_path(self, sample): def expand_path(self, sample):
""" Expands audio paths for the given sample. """ """ Expands audio paths for the given sample. """
return dict(sample, **{f'{instrument}_path': tf.string_join( return dict(sample, **{f'{instrument}_path': tf.strings.join(
(self._audio_path, sample[f'{instrument}_path']), SEPARATOR) (self._audio_path, sample[f'{instrument}_path']), SEPARATOR)
for instrument in self._instruments}) for instrument in self._instruments})

View File

@@ -8,7 +8,7 @@ import importlib
# pylint: disable=import-error # pylint: disable=import-error
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.signal import stft, inverse_stft, hann_window from tensorflow.signal import stft, inverse_stft, hann_window
# pylint: enable=import-error # pylint: enable=import-error
from ..utils.tensor import pad_and_partition, pad_and_reshape from ..utils.tensor import pad_and_partition, pad_and_reshape

View File

@@ -12,14 +12,18 @@
>>> separator.separate_to_file(...) >>> separator.separate_to_file(...)
""" """
import atexit
import os import os
import logging import logging
from time import time
from multiprocessing import Pool from multiprocessing import Pool
from os.path import basename, join, splitext, dirname from os.path import basename, join, splitext, dirname
from time import time
from typing import Container, NoReturn
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from librosa.core import stft, istft from librosa.core import stft, istft
from scipy.signal.windows import hann from scipy.signal.windows import hann
@@ -27,64 +31,114 @@ from . import SpleeterError
from .audio.adapter import get_default_audio_adapter from .audio.adapter import get_default_audio_adapter
from .audio.convertor import to_stereo from .audio.convertor import to_stereo
from .utils.configuration import load_configuration from .utils.configuration import load_configuration
from .utils.estimator import create_estimator, to_predictor, get_default_model_dir from .utils.estimator import create_estimator, get_default_model_dir
from .model import EstimatorSpecBuilder, InputProviderFactory from .model import EstimatorSpecBuilder, InputProviderFactory
__email__ = 'spleeter@deezer.com' __email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research' __author__ = 'Deezer Research'
__license__ = 'MIT License' __license__ = 'MIT License'
SUPPORTED_BACKEND: Container[str] = ('auto', 'tensorflow', 'librosa')
logger = logging.getLogger("spleeter") """ """
class DataGenerator():
"""
Generator object that store a sample and generate it once while called.
Used to feed a tensorflow estimator without knowing the whole data at
build time.
"""
def get_backend(backend): def __init__(self):
assert backend in ["auto", "tensorflow", "librosa"] """ Default constructor. """
if backend == "auto": self._current_data = None
return "tensorflow" if tf.test.is_gpu_available() else "librosa"
def update_data(self, data):
""" Replace internal data. """
self._current_data = data
def __call__(self):
""" Generation process. """
buffer = self._current_data
while buffer:
yield buffer
buffer = self._current_data
def get_backend(backend: str) -> str:
"""
"""
if backend not in SUPPORTED_BACKEND:
raise ValueError(f'Unsupported backend {backend}')
if backend == 'auto':
if len(tf.config.list_physical_devices('GPU')):
return 'tensorflow'
return 'librosa'
return backend return backend
class Separator(object): class Separator(object):
""" A wrapper class for performing separation. """ """ A wrapper class for performing separation. """
def __init__(self, params_descriptor, MWF=False, stft_backend="auto", multiprocess=True): def __init__(
self,
params_descriptor,
MWF: bool = False,
stft_backend: str = 'auto',
multiprocess: bool = True):
""" Default constructor. """ Default constructor.
:param params_descriptor: Descriptor for TF params to be used. :param params_descriptor: Descriptor for TF params to be used.
:param MWF: (Optional) True if MWF should be used, False otherwise. :param MWF: (Optional) True if MWF should be used, False otherwise.
""" """
self._params = load_configuration(params_descriptor) self._params = load_configuration(params_descriptor)
self._sample_rate = self._params['sample_rate'] self._sample_rate = self._params['sample_rate']
self._MWF = MWF self._MWF = MWF
self._tf_graph = tf.Graph() self._tf_graph = tf.Graph()
self._predictor = None self._prediction_generator = None
self._input_provider = None self._input_provider = None
self._builder = None self._builder = None
self._features = None self._features = None
self._session = None self._session = None
self._pool = Pool() if multiprocess else None if multiprocess:
self._pool = Pool()
atexit.register(self._pool.close)
else:
self._pool = None
self._tasks = [] self._tasks = []
self._params["stft_backend"] = get_backend(stft_backend) self._params['stft_backend'] = get_backend(stft_backend)
self._data_generator = DataGenerator()
def __del__(self): def __del__(self):
""" """
if self._session: if self._session:
self._session.close() self._session.close()
def _get_predictor(self): def _get_prediction_generator(self):
""" Lazy loading access method for internal predictor instance. """ Lazy loading access method for internal prediction generator
returned by the predict method of a tensorflow estimator.
:returns: Predictor to use for source separation. :returns: generator of prediction.
""" """
if self._predictor is None: if self._prediction_generator is None:
estimator = create_estimator(self._params, self._MWF) estimator = create_estimator(self._params, self._MWF)
self._predictor = to_predictor(estimator)
return self._predictor
def join(self, timeout=200): def get_dataset():
return tf.data.Dataset.from_generator(
self._data_generator,
output_types={
'waveform': tf.float32,
'audio_id': tf.string},
output_shapes={
'waveform': (None, 2),
'audio_id': ()})
self._prediction_generator = estimator.predict(
get_dataset,
yield_single_examples=False)
return self._prediction_generator
def join(self, timeout: int = 200) -> NoReturn:
""" Wait for all pending tasks to be finished. """ Wait for all pending tasks to be finished.
:param timeout: (Optional) task waiting timeout. :param timeout: (Optional) task waiting timeout.
@@ -94,44 +148,52 @@ class Separator(object):
task.get() task.get()
task.wait(timeout=timeout) task.wait(timeout=timeout)
def _separate_tensorflow(self, waveform, audio_descriptor): def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
""" """ Performs source separation over the given waveform with tensorflow
Performs source separation over the given waveform with tensorflow backend. backend.
:param waveform: Waveform to apply separation on. :param waveform: Waveform to apply separation on.
:returns: Separated waveforms. :returns: Separated waveforms.
""" """
if not waveform.shape[-1] == 2: if not waveform.shape[-1] == 2:
waveform = to_stereo(waveform) waveform = to_stereo(waveform)
predictor = self._get_predictor() prediction_generator = self._get_prediction_generator()
prediction = predictor({ # NOTE: update data in generator before performing separation.
self._data_generator.update_data({
'waveform': waveform, 'waveform': waveform,
'audio_id': audio_descriptor}) 'audio_id': np.array(audio_descriptor)})
# NOTE: perform separation.
prediction = next(prediction_generator)
prediction.pop('audio_id') prediction.pop('audio_id')
return prediction return prediction
def _stft(self, data, inverse=False, length=None): def _stft(self, data, inverse: bool = False, length=None):
""" """ Single entrypoint for both stft and istft. This computes stft and
Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two istft with librosa on stereo data. The two channels are processed
channels are processed separately and are concatenated together in the result. The expected input formats are: separately and are concatenated together in the result. The expected
(n_samples, 2) for stft and (T, F, 2) for istft. input formats are: (n_samples, 2) for stft and (T, F, 2) for istft.
:param data: np.array with either the waveform or the complex spectrogram depending on the parameter inverse
:param data: np.array with either the waveform or the complex
spectrogram depending on the parameter inverse
:param inverse: should a stft or an istft be computed. :param inverse: should a stft or an istft be computed.
:return: Stereo data as numpy array for the transform. The channels are stored in the last dimension :returns: Stereo data as numpy array for the transform.
The channels are stored in the last dimension.
""" """
assert not (inverse and length is None) assert not (inverse and length is None)
data = np.asfortranarray(data) data = np.asfortranarray(data)
N = self._params["frame_length"] N = self._params['frame_length']
H = self._params["frame_step"] H = self._params['frame_step']
win = hann(N, sym=False) win = hann(N, sym=False)
fstft = istft if inverse else stft fstft = istft if inverse else stft
win_len_arg = {"win_length": None, win_len_arg = {
"length": None} if inverse else {"n_fft": N} 'win_length': None,
'length': None} if inverse else {'n_fft': N}
n_channels = data.shape[-1] n_channels = data.shape[-1]
out = [] out = []
for c in range(n_channels): for c in range(n_channels):
d = np.concatenate((np.zeros((N, )), data[:, c], np.zeros((N, )))) if not inverse else data[:, :, c].T d = np.concatenate(
(np.zeros((N, )), data[:, c], np.zeros((N, )))
) if not inverse else data[:, :, c].T
s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg) s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg)
if inverse: if inverse:
s = s[N:N+length] s = s[N:N+length]
@@ -141,7 +203,6 @@ class Separator(object):
return out[0] return out[0]
return np.concatenate(out, axis=2-inverse) return np.concatenate(out, axis=2-inverse)
def _get_input_provider(self): def _get_input_provider(self):
if self._input_provider is None: if self._input_provider is None:
self._input_provider = InputProviderFactory.get(self._params) self._input_provider = InputProviderFactory.get(self._params)
@@ -149,66 +210,83 @@ class Separator(object):
def _get_features(self): def _get_features(self):
if self._features is None: if self._features is None:
self._features = self._get_input_provider().get_input_dict_placeholders() provider = self._get_input_provider()
self._features = provider.get_input_dict_placeholders()
return self._features return self._features
def _get_builder(self): def _get_builder(self):
if self._builder is None: if self._builder is None:
self._builder = EstimatorSpecBuilder(self._get_features(), self._params) self._builder = EstimatorSpecBuilder(
self._get_features(),
self._params)
return self._builder return self._builder
def _get_session(self): def _get_session(self):
if self._session is None: if self._session is None:
saver = tf.train.Saver() saver = tf.compat.v1.train.Saver()
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir'])) latest_checkpoint = tf.train.latest_checkpoint(
self._session = tf.Session() get_default_model_dir(self._params['model_dir']))
self._session = tf.compat.v1.Session()
saver.restore(self._session, latest_checkpoint) saver.restore(self._session, latest_checkpoint)
return self._session return self._session
def _separate_librosa(self, waveform, audio_id): def _separate_librosa(self, waveform: np.ndarray, audio_id):
""" """ Performs separation with librosa backend for STFT.
Performs separation with librosa backend for STFT.
""" """
with self._tf_graph.as_default(): with self._tf_graph.as_default():
out = {} out = {}
features = self._get_features() features = self._get_features()
# TODO: fix the logic, build sometimes return,
# TODO: fix the logic, build sometimes return, sometimes set attribute # sometimes set attribute.
outputs = self._get_builder().outputs outputs = self._get_builder().outputs
stft = self._stft(waveform) stft = self._stft(waveform)
if stft.shape[-1] == 1: if stft.shape[-1] == 1:
stft = np.concatenate([stft, stft], axis=-1) stft = np.concatenate([stft, stft], axis=-1)
elif stft.shape[-1] > 2: elif stft.shape[-1] > 2:
stft = stft[:, :2] stft = stft[:, :2]
sess = self._get_session() sess = self._get_session()
outputs = sess.run(outputs, feed_dict=self._get_input_provider().get_feed_dict(features, stft, audio_id)) outputs = sess.run(
outputs,
feed_dict=self._get_input_provider().get_feed_dict(
features,
stft,
audio_id))
for inst in self._get_builder().instruments: for inst in self._get_builder().instruments:
out[inst] = self._stft(outputs[inst], inverse=True, length=waveform.shape[0]) out[inst] = self._stft(
outputs[inst],
inverse=True,
length=waveform.shape[0])
return out return out
def separate(self, waveform, audio_descriptor=""): def separate(self, waveform: np.ndarray, audio_descriptor=''):
""" Performs separation on a waveform. """ Performs separation on a waveform.
:param waveform: Waveform to be separated (as a numpy array) :param waveform: Waveform to be separated (as a numpy array)
:param audio_descriptor: (Optional) string describing the waveform (e.g. filename). :param audio_descriptor: (Optional) string describing the waveform
(e.g. filename).
""" """
if self._params["stft_backend"] == "tensorflow": if self._params['stft_backend'] == 'tensorflow':
return self._separate_tensorflow(waveform, audio_descriptor) return self._separate_tensorflow(waveform, audio_descriptor)
else: else:
return self._separate_librosa(waveform, audio_descriptor) return self._separate_librosa(waveform, audio_descriptor)
def separate_to_file( def separate_to_file(
self, audio_descriptor, destination, self,
audio_descriptor,
destination,
audio_adapter=get_default_audio_adapter(), audio_adapter=get_default_audio_adapter(),
offset=0, duration=600., codec='wav', bitrate='128k', offset=0,
duration=600.,
codec='wav',
bitrate='128k',
filename_format='{filename}/{instrument}.{codec}', filename_format='{filename}/{instrument}.{codec}',
synchronous=True): synchronous=True):
""" Performs source separation and export result to file using """ Performs source separation and export result to file using
given audio adapter. given audio adapter.
Filename format should be a Python formattable string that could use Filename format should be a Python formattable string that could use
following parameters : {instrument}, {filename}, {foldername} and {codec}. following parameters : {instrument}, {filename}, {foldername} and
{codec}.
:param audio_descriptor: Describe song to separate, used by audio :param audio_descriptor: Describe song to separate, used by audio
adapter to retrieve and load audio data, adapter to retrieve and load audio data,
@@ -217,8 +295,8 @@ class Separator(object):
:param destination: Target directory to write output to. :param destination: Target directory to write output to.
:param audio_adapter: (Optional) Audio adapter to use for I/O. :param audio_adapter: (Optional) Audio adapter to use for I/O.
:param offset: (Optional) Offset of loaded song. :param offset: (Optional) Offset of loaded song.
:param duration: (Optional) Duration of loaded song (default: :param duration: (Optional) Duration of loaded song
600s). (default: 600s).
:param codec: (Optional) Export codec. :param codec: (Optional) Export codec.
:param bitrate: (Optional) Export bitrate. :param bitrate: (Optional) Export bitrate.
:param filename_format: (Optional) Filename format. :param filename_format: (Optional) Filename format.
@@ -230,16 +308,27 @@ class Separator(object):
duration=duration, duration=duration,
sample_rate=self._sample_rate) sample_rate=self._sample_rate)
sources = self.separate(waveform, audio_descriptor) sources = self.separate(waveform, audio_descriptor)
self.save_to_file( sources, audio_descriptor, destination, self.save_to_file(
filename_format, codec, audio_adapter, sources,
bitrate, synchronous) audio_descriptor,
destination,
filename_format,
codec,
audio_adapter,
bitrate,
synchronous)
def save_to_file( def save_to_file(
self, sources, audio_descriptor, destination, self,
sources,
audio_descriptor,
destination,
filename_format='{filename}/{instrument}.{codec}', filename_format='{filename}/{instrument}.{codec}',
codec='wav', audio_adapter=get_default_audio_adapter(), codec='wav',
bitrate='128k', synchronous=True): audio_adapter=get_default_audio_adapter(),
""" export dictionary of sources to files. bitrate='128k',
synchronous=True):
""" Export dictionary of sources to files.
:param sources: Dictionary of sources to be exported. The :param sources: Dictionary of sources to be exported. The
keys are the name of the instruments, and keys are the name of the instruments, and
@@ -258,7 +347,6 @@ class Separator(object):
:param synchronous: (Optional) True is should by synchronous. :param synchronous: (Optional) True is should by synchronous.
""" """
foldername = basename(dirname(audio_descriptor)) foldername = basename(dirname(audio_descriptor))
filename = splitext(basename(audio_descriptor))[0] filename = splitext(basename(audio_descriptor))[0]
generated = [] generated = []
@@ -286,6 +374,11 @@ class Separator(object):
bitrate)) bitrate))
self._tasks.append(task) self._tasks.append(task)
else: else:
audio_adapter.save(path, data, self._sample_rate, codec, bitrate) audio_adapter.save(
path,
data,
self._sample_rate,
codec,
bitrate)
if synchronous and self._pool: if synchronous and self._pool:
self.join() self.join()

View File

@@ -5,20 +5,14 @@
from pathlib import Path from pathlib import Path
from os.path import join from os.path import join
from tempfile import gettempdir
# pylint: disable=import-error # pylint: disable=import-error
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib import predictor
# pylint: enable=import-error
from ..model import model_fn, InputProviderFactory from ..model import model_fn
from ..model.provider import get_default_model_provider from ..model.provider import get_default_model_provider
# Default exporting directory for predictor.
DEFAULT_EXPORT_DIRECTORY = join(gettempdir(), 'serving')
def get_default_model_dir(model_dir): def get_default_model_dir(model_dir):
@@ -57,24 +51,3 @@ def create_estimator(params, MWF):
config=config config=config
) )
return estimator return estimator
def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
""" Exports given estimator as predictor into the given directory
and returns associated tf.predictor instance.
:param estimator: Estimator to export.
:param directory: (Optional) path to write exported model into.
"""
input_provider = InputProviderFactory.get(estimator.params)
def receiver():
features = input_provider.get_input_dict_placeholders()
return tf.estimator.export.ServingInputReceiver(features, features)
estimator.export_saved_model(directory, receiver)
versions = [
model for model in Path(directory).iterdir()
if model.is_dir() and 'temp' not in str(model)]
latest = str(sorted(versions)[-1])
return predictor.from_saved_model(latest)

View File

@@ -56,6 +56,9 @@ res_4stems = {
} }
def generate_fake_eval_dataset(path): def generate_fake_eval_dataset(path):
"""
generate fake evaluation dataset
"""
aa = get_default_audio_adapter() aa = get_default_audio_adapter()
n_songs = 2 n_songs = 2
fs = 44100 fs = 44100
@@ -71,6 +74,7 @@ def generate_fake_eval_dataset(path):
aa.save(filename, data, fs) aa.save(filename, data, fs)
@pytest.mark.parametrize('backend', TEST_CONFIGURATIONS) @pytest.mark.parametrize('backend', TEST_CONFIGURATIONS)
def test_evaluate(backend): def test_evaluate(backend):
with TemporaryDirectory() as directory: with TemporaryDirectory() as directory:

View File

@@ -7,7 +7,6 @@ __email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research' __author__ = 'Deezer Research'
__license__ = 'MIT License' __license__ = 'MIT License'
import filecmp
import itertools import itertools
from os.path import splitext, basename, exists, join from os.path import splitext, basename, exists, join
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
@@ -33,7 +32,8 @@ MODEL_TO_INST = {
MODELS_AND_TEST_FILES = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS)) MODELS_AND_TEST_FILES = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS))
TEST_CONFIGURATIONS = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS, BACKENDS)) TEST_CONFIGURATIONS = list(itertools.product(
TEST_AUDIO_DESCRIPTORS, MODELS, BACKENDS))
print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__)) print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__))
@@ -44,8 +44,10 @@ def test_separator_backends(test_file):
adapter = get_default_audio_adapter() adapter = get_default_audio_adapter()
waveform, _ = adapter.load(test_file) waveform, _ = adapter.load(test_file)
separator_lib = Separator("spleeter:2stems", stft_backend="librosa") separator_lib = Separator(
separator_tf = Separator("spleeter:2stems", stft_backend="tensorflow") "spleeter:2stems", stft_backend="librosa", multiprocess=False)
separator_tf = Separator(
"spleeter:2stems", stft_backend="tensorflow", multiprocess=False)
# Test the stft and inverse stft provides exact reconstruction # Test the stft and inverse stft provides exact reconstruction
stft_matrix = separator_lib._stft(waveform) stft_matrix = separator_lib._stft(waveform)
@@ -68,7 +70,8 @@ def test_separate(test_file, configuration, backend):
instruments = MODEL_TO_INST[configuration] instruments = MODEL_TO_INST[configuration]
adapter = get_default_audio_adapter() adapter = get_default_audio_adapter()
waveform, _ = adapter.load(test_file) waveform, _ = adapter.load(test_file)
separator = Separator(configuration, stft_backend=backend, multiprocess=False) separator = Separator(
configuration, stft_backend=backend, multiprocess=False)
prediction = separator.separate(waveform, test_file) prediction = separator.separate(waveform, test_file)
assert len(prediction) == len(instruments) assert len(prediction) == len(instruments)
for instrument in instruments: for instrument in instruments:
@@ -82,12 +85,12 @@ def test_separate(test_file, configuration, backend):
assert not np.allclose(track, prediction[compared]) assert not np.allclose(track, prediction[compared])
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) @pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
def test_separate_to_file(test_file, configuration, backend): def test_separate_to_file(test_file, configuration, backend):
""" Test file based separation. """ """ Test file based separation. """
instruments = MODEL_TO_INST[configuration] instruments = MODEL_TO_INST[configuration]
separator = Separator(configuration, stft_backend=backend, multiprocess=False) separator = Separator(
configuration, stft_backend=backend, multiprocess=False)
name = splitext(basename(test_file))[0] name = splitext(basename(test_file))[0]
with TemporaryDirectory() as directory: with TemporaryDirectory() as directory:
separator.separate_to_file( separator.separate_to_file(
@@ -103,7 +106,8 @@ def test_separate_to_file(test_file, configuration, backend):
def test_filename_format(test_file, configuration, backend): def test_filename_format(test_file, configuration, backend):
""" Test custom filename format. """ """ Test custom filename format. """
instruments = MODEL_TO_INST[configuration] instruments = MODEL_TO_INST[configuration]
separator = Separator(configuration, stft_backend=backend, multiprocess=False) separator = Separator(
configuration, stft_backend=backend, multiprocess=False)
name = splitext(basename(test_file))[0] name = splitext(basename(test_file))[0]
with TemporaryDirectory() as directory: with TemporaryDirectory() as directory:
separator.separate_to_file( separator.separate_to_file(

113
tests/test_train.py Normal file
View File

@@ -0,0 +1,113 @@
#!/usr/bin/env python
# coding: utf8
""" Unit testing for Separator class. """
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
import filecmp
import itertools
import os
from os import makedirs
from os.path import splitext, basename, exists, join
from tempfile import TemporaryDirectory
import numpy as np
import pandas as pd
import json
import tensorflow as tf
from spleeter.audio.adapter import get_default_audio_adapter
from spleeter.commands import create_argument_parser
from spleeter.commands import train
from spleeter.utils.configuration import load_configuration
TRAIN_CONFIG = {
"mix_name": "mix",
"instrument_list": ["vocals", "other"],
"sample_rate":44100,
"frame_length":4096,
"frame_step":1024,
"T":128,
"F":128,
"n_channels":2,
"chunk_duration":4,
"n_chunks_per_song":1,
"separation_exponent":2,
"mask_extension":"zeros",
"learning_rate": 1e-4,
"batch_size":2,
"train_max_steps": 10,
"throttle_secs":20,
"save_checkpoints_steps":100,
"save_summary_steps":5,
"random_seed":0,
"model":{
"type":"unet.unet",
"params":{
"conv_activation":"ELU",
"deconv_activation":"ELU"
}
}
}
def generate_fake_training_dataset(path, instrument_list=["vocals", "other"]):
"""
generates a fake training dataset in path:
- generates audio files
- generates a csv file describing the dataset
"""
aa = get_default_audio_adapter()
n_songs = 2
fs = 44100
duration = 6
n_channels = 2
rng = np.random.RandomState(seed=0)
dataset_df = pd.DataFrame(columns=["mix_path"]+[f"{instr}_path" for instr in instrument_list]+["duration"])
for song in range(n_songs):
song_path = join(path, "train", f"song{song}")
makedirs(song_path, exist_ok=True)
dataset_df.loc[song, f"duration"] = duration
for instr in instrument_list+["mix"]:
filename = join(song_path, f"{instr}.wav")
data = rng.rand(duration*fs, n_channels)-0.5
aa.save(filename, data, fs)
dataset_df.loc[song, f"{instr}_path"] = join("train", f"song{song}", f"{instr}.wav")
dataset_df.to_csv(join(path, "train", "train.csv"), index=False)
def test_train():
with TemporaryDirectory() as path:
# generate training dataset
generate_fake_training_dataset(path)
# set training command aruments
p = create_argument_parser()
arguments = p.parse_args(["train", "-p", "useless_config.json", "-d", path])
TRAIN_CONFIG["train_csv"] = join(path, "train", "train.csv")
TRAIN_CONFIG["validation_csv"] = join(path, "train", "train.csv")
TRAIN_CONFIG["model_dir"] = join(path, "model")
TRAIN_CONFIG["training_cache"] = join(path, "cache", "training")
TRAIN_CONFIG["validation_cache"] = join(path, "cache", "validation")
# execute training
res = train.entrypoint(arguments, TRAIN_CONFIG)
# assert that model checkpoint was created.
assert os.path.exists(join(path,'model','model.ckpt-10.index'))
assert os.path.exists(join(path,'model','checkpoint'))
assert os.path.exists(join(path,'model','model.ckpt-0.meta'))
if __name__=="__main__":
test_train()