mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
Merge pull request #498 from deezer/tf2
Tensorflow 2 compatible version
This commit is contained in:
10
.github/workflows/docker.yml
vendored
10
.github/workflows/docker.yml
vendored
@@ -7,7 +7,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
platform: [cpu, gpu]
|
platform: [cpu, gpu]
|
||||||
distribution: [3.6, 3.7, conda]
|
distribution: [3.6, 3.7, 3.8, conda]
|
||||||
model: [modelless, 2stems, 4stems, 5stems]
|
model: [modelless, 2stems, 4stems, 5stems]
|
||||||
fail-fast: true
|
fail-fast: true
|
||||||
steps:
|
steps:
|
||||||
@@ -69,13 +69,13 @@ jobs:
|
|||||||
run: echo ${{ secrets.DOCKERHUB_PASSWORD }} | docker login -u ${{ secrets.DOCKERHUB_USERNAME }} --password-stdin
|
run: echo ${{ secrets.DOCKERHUB_PASSWORD }} | docker login -u ${{ secrets.DOCKERHUB_USERNAME }} --password-stdin
|
||||||
- name: Push deezer/spleeter:${{ env.tag }} image
|
- name: Push deezer/spleeter:${{ env.tag }} image
|
||||||
run: docker push deezer/spleeter:${{ env.tag }}
|
run: docker push deezer/spleeter:${{ env.tag }}
|
||||||
- if: ${{ env.tag == 'spleeter:3.7' }}
|
- if: ${{ env.tag == 'spleeter:3.8' }}
|
||||||
name: Push deezer/spleeter:latest image
|
name: Push deezer/spleeter:latest image
|
||||||
run: |
|
run: |
|
||||||
docker tag deezer/spleeter:3.7 deezer/spleeter:latest
|
docker tag deezer/spleeter:3.8 deezer/spleeter:latest
|
||||||
docker push deezer/spleeter:latest
|
docker push deezer/spleeter:latest
|
||||||
- if: ${{ env.tag == 'spleeter:3.7-gpu' }}
|
- if: ${{ env.tag == 'spleeter:3.8-gpu' }}
|
||||||
name: Push deezer/spleeter:gpu image
|
name: Push deezer/spleeter:gpu image
|
||||||
run: |
|
run: |
|
||||||
docker tag deezer/spleeter:3.7-gpu deezer/spleeter:gpu
|
docker tag deezer/spleeter:3.8-gpu deezer/spleeter:gpu
|
||||||
docker push deezer/spleeter:gpu
|
docker push deezer/spleeter:gpu
|
||||||
|
|||||||
2
.github/workflows/pytest.yml
vendored
2
.github/workflows/pytest.yml
vendored
@@ -8,7 +8,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: [3.6, 3.7]
|
python-version: [3.6, 3.7, 3.8]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
|||||||
11
setup.py
11
setup.py
@@ -14,9 +14,9 @@ __license__ = 'MIT License'
|
|||||||
|
|
||||||
# Default project values.
|
# Default project values.
|
||||||
project_name = 'spleeter'
|
project_name = 'spleeter'
|
||||||
project_version = '1.5.4'
|
project_version = '2.0'
|
||||||
tensorflow_dependency = 'tensorflow'
|
tensorflow_dependency = 'tensorflow'
|
||||||
tensorflow_version = '1.15.2'
|
tensorflow_version = '2.3.0'
|
||||||
here = path.abspath(path.dirname(__file__))
|
here = path.abspath(path.dirname(__file__))
|
||||||
readme_path = path.join(here, 'README.md')
|
readme_path = path.join(here, 'README.md')
|
||||||
with open(readme_path, 'r') as stream:
|
with open(readme_path, 'r') as stream:
|
||||||
@@ -47,17 +47,16 @@ setup(
|
|||||||
'spleeter.utils',
|
'spleeter.utils',
|
||||||
],
|
],
|
||||||
package_data={'spleeter.resources': ['*.json']},
|
package_data={'spleeter.resources': ['*.json']},
|
||||||
python_requires='>=3.6, <3.8',
|
python_requires='>=3.6, <3.9',
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'ffmpeg-python',
|
'ffmpeg-python',
|
||||||
'importlib_resources ; python_version<"3.7"',
|
'importlib_resources ; python_version<"3.7"',
|
||||||
'norbert==0.2.1',
|
'norbert==0.2.1',
|
||||||
'pandas==0.25.1',
|
'pandas==1.1.2',
|
||||||
'requests',
|
'requests',
|
||||||
'setuptools>=41.0.0',
|
'setuptools>=41.0.0',
|
||||||
'librosa==0.7.2',
|
'librosa==0.8.0',
|
||||||
'numba==0.48.0',
|
|
||||||
'{}=={}'.format(tensorflow_dependency, tensorflow_version),
|
'{}=={}'.format(tensorflow_dependency, tensorflow_version),
|
||||||
],
|
],
|
||||||
extras_require={
|
extras_require={
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from os.path import exists
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.contrib.signal import stft, hann_window
|
from tensorflow.signal import stft, hann_window
|
||||||
# pylint: enable=import-error
|
# pylint: enable=import-error
|
||||||
|
|
||||||
from .. import SpleeterError
|
from .. import SpleeterError
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.contrib.signal import stft, hann_window
|
from tensorflow.signal import stft, hann_window
|
||||||
# pylint: enable=import-error
|
# pylint: enable=import-error
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
__email__ = 'spleeter@deezer.com'
|
||||||
|
|||||||
@@ -238,7 +238,7 @@ class DatasetBuilder(object):
|
|||||||
|
|
||||||
def expand_path(self, sample):
|
def expand_path(self, sample):
|
||||||
""" Expands audio paths for the given sample. """
|
""" Expands audio paths for the given sample. """
|
||||||
return dict(sample, **{f'{instrument}_path': tf.string_join(
|
return dict(sample, **{f'{instrument}_path': tf.strings.join(
|
||||||
(self._audio_path, sample[f'{instrument}_path']), SEPARATOR)
|
(self._audio_path, sample[f'{instrument}_path']), SEPARATOR)
|
||||||
for instrument in self._instruments})
|
for instrument in self._instruments})
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import importlib
|
|||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.contrib.signal import stft, inverse_stft, hann_window
|
from tensorflow.signal import stft, inverse_stft, hann_window
|
||||||
# pylint: enable=import-error
|
# pylint: enable=import-error
|
||||||
|
|
||||||
from ..utils.tensor import pad_and_partition, pad_and_reshape
|
from ..utils.tensor import pad_and_partition, pad_and_reshape
|
||||||
|
|||||||
@@ -12,14 +12,18 @@
|
|||||||
>>> separator.separate_to_file(...)
|
>>> separator.separate_to_file(...)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import atexit
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from time import time
|
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from os.path import basename, join, splitext, dirname
|
from os.path import basename, join, splitext, dirname
|
||||||
|
from time import time
|
||||||
|
from typing import Container, NoReturn
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from librosa.core import stft, istft
|
from librosa.core import stft, istft
|
||||||
from scipy.signal.windows import hann
|
from scipy.signal.windows import hann
|
||||||
|
|
||||||
@@ -27,64 +31,114 @@ from . import SpleeterError
|
|||||||
from .audio.adapter import get_default_audio_adapter
|
from .audio.adapter import get_default_audio_adapter
|
||||||
from .audio.convertor import to_stereo
|
from .audio.convertor import to_stereo
|
||||||
from .utils.configuration import load_configuration
|
from .utils.configuration import load_configuration
|
||||||
from .utils.estimator import create_estimator, to_predictor, get_default_model_dir
|
from .utils.estimator import create_estimator, get_default_model_dir
|
||||||
from .model import EstimatorSpecBuilder, InputProviderFactory
|
from .model import EstimatorSpecBuilder, InputProviderFactory
|
||||||
|
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
__email__ = 'spleeter@deezer.com'
|
||||||
__author__ = 'Deezer Research'
|
__author__ = 'Deezer Research'
|
||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
SUPPORTED_BACKEND: Container[str] = ('auto', 'tensorflow', 'librosa')
|
||||||
logger = logging.getLogger("spleeter")
|
""" """
|
||||||
|
|
||||||
|
|
||||||
|
class DataGenerator():
|
||||||
|
"""
|
||||||
|
Generator object that store a sample and generate it once while called.
|
||||||
|
Used to feed a tensorflow estimator without knowing the whole data at
|
||||||
|
build time.
|
||||||
|
"""
|
||||||
|
|
||||||
def get_backend(backend):
|
def __init__(self):
|
||||||
assert backend in ["auto", "tensorflow", "librosa"]
|
""" Default constructor. """
|
||||||
if backend == "auto":
|
self._current_data = None
|
||||||
return "tensorflow" if tf.test.is_gpu_available() else "librosa"
|
|
||||||
|
def update_data(self, data):
|
||||||
|
""" Replace internal data. """
|
||||||
|
self._current_data = data
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
""" Generation process. """
|
||||||
|
buffer = self._current_data
|
||||||
|
while buffer:
|
||||||
|
yield buffer
|
||||||
|
buffer = self._current_data
|
||||||
|
|
||||||
|
|
||||||
|
def get_backend(backend: str) -> str:
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
if backend not in SUPPORTED_BACKEND:
|
||||||
|
raise ValueError(f'Unsupported backend {backend}')
|
||||||
|
if backend == 'auto':
|
||||||
|
if len(tf.config.list_physical_devices('GPU')):
|
||||||
|
return 'tensorflow'
|
||||||
|
return 'librosa'
|
||||||
return backend
|
return backend
|
||||||
|
|
||||||
|
|
||||||
class Separator(object):
|
class Separator(object):
|
||||||
""" A wrapper class for performing separation. """
|
""" A wrapper class for performing separation. """
|
||||||
|
|
||||||
def __init__(self, params_descriptor, MWF=False, stft_backend="auto", multiprocess=True):
|
def __init__(
|
||||||
|
self,
|
||||||
|
params_descriptor,
|
||||||
|
MWF: bool = False,
|
||||||
|
stft_backend: str = 'auto',
|
||||||
|
multiprocess: bool = True):
|
||||||
""" Default constructor.
|
""" Default constructor.
|
||||||
|
|
||||||
:param params_descriptor: Descriptor for TF params to be used.
|
:param params_descriptor: Descriptor for TF params to be used.
|
||||||
:param MWF: (Optional) True if MWF should be used, False otherwise.
|
:param MWF: (Optional) True if MWF should be used, False otherwise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._params = load_configuration(params_descriptor)
|
self._params = load_configuration(params_descriptor)
|
||||||
self._sample_rate = self._params['sample_rate']
|
self._sample_rate = self._params['sample_rate']
|
||||||
self._MWF = MWF
|
self._MWF = MWF
|
||||||
self._tf_graph = tf.Graph()
|
self._tf_graph = tf.Graph()
|
||||||
self._predictor = None
|
self._prediction_generator = None
|
||||||
self._input_provider = None
|
self._input_provider = None
|
||||||
self._builder = None
|
self._builder = None
|
||||||
self._features = None
|
self._features = None
|
||||||
self._session = None
|
self._session = None
|
||||||
self._pool = Pool() if multiprocess else None
|
if multiprocess:
|
||||||
|
self._pool = Pool()
|
||||||
|
atexit.register(self._pool.close)
|
||||||
|
else:
|
||||||
|
self._pool = None
|
||||||
self._tasks = []
|
self._tasks = []
|
||||||
self._params["stft_backend"] = get_backend(stft_backend)
|
self._params['stft_backend'] = get_backend(stft_backend)
|
||||||
|
self._data_generator = DataGenerator()
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
""" """
|
||||||
if self._session:
|
if self._session:
|
||||||
self._session.close()
|
self._session.close()
|
||||||
|
|
||||||
def _get_predictor(self):
|
def _get_prediction_generator(self):
|
||||||
""" Lazy loading access method for internal predictor instance.
|
""" Lazy loading access method for internal prediction generator
|
||||||
|
returned by the predict method of a tensorflow estimator.
|
||||||
|
|
||||||
:returns: Predictor to use for source separation.
|
:returns: generator of prediction.
|
||||||
"""
|
"""
|
||||||
if self._predictor is None:
|
if self._prediction_generator is None:
|
||||||
estimator = create_estimator(self._params, self._MWF)
|
estimator = create_estimator(self._params, self._MWF)
|
||||||
self._predictor = to_predictor(estimator)
|
|
||||||
return self._predictor
|
|
||||||
|
|
||||||
def join(self, timeout=200):
|
def get_dataset():
|
||||||
|
return tf.data.Dataset.from_generator(
|
||||||
|
self._data_generator,
|
||||||
|
output_types={
|
||||||
|
'waveform': tf.float32,
|
||||||
|
'audio_id': tf.string},
|
||||||
|
output_shapes={
|
||||||
|
'waveform': (None, 2),
|
||||||
|
'audio_id': ()})
|
||||||
|
|
||||||
|
self._prediction_generator = estimator.predict(
|
||||||
|
get_dataset,
|
||||||
|
yield_single_examples=False)
|
||||||
|
return self._prediction_generator
|
||||||
|
|
||||||
|
def join(self, timeout: int = 200) -> NoReturn:
|
||||||
""" Wait for all pending tasks to be finished.
|
""" Wait for all pending tasks to be finished.
|
||||||
|
|
||||||
:param timeout: (Optional) task waiting timeout.
|
:param timeout: (Optional) task waiting timeout.
|
||||||
@@ -94,44 +148,52 @@ class Separator(object):
|
|||||||
task.get()
|
task.get()
|
||||||
task.wait(timeout=timeout)
|
task.wait(timeout=timeout)
|
||||||
|
|
||||||
def _separate_tensorflow(self, waveform, audio_descriptor):
|
def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
|
||||||
"""
|
""" Performs source separation over the given waveform with tensorflow
|
||||||
Performs source separation over the given waveform with tensorflow backend.
|
backend.
|
||||||
|
|
||||||
:param waveform: Waveform to apply separation on.
|
:param waveform: Waveform to apply separation on.
|
||||||
:returns: Separated waveforms.
|
:returns: Separated waveforms.
|
||||||
"""
|
"""
|
||||||
if not waveform.shape[-1] == 2:
|
if not waveform.shape[-1] == 2:
|
||||||
waveform = to_stereo(waveform)
|
waveform = to_stereo(waveform)
|
||||||
predictor = self._get_predictor()
|
prediction_generator = self._get_prediction_generator()
|
||||||
prediction = predictor({
|
# NOTE: update data in generator before performing separation.
|
||||||
|
self._data_generator.update_data({
|
||||||
'waveform': waveform,
|
'waveform': waveform,
|
||||||
'audio_id': audio_descriptor})
|
'audio_id': np.array(audio_descriptor)})
|
||||||
|
# NOTE: perform separation.
|
||||||
|
prediction = next(prediction_generator)
|
||||||
prediction.pop('audio_id')
|
prediction.pop('audio_id')
|
||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
def _stft(self, data, inverse=False, length=None):
|
def _stft(self, data, inverse: bool = False, length=None):
|
||||||
"""
|
""" Single entrypoint for both stft and istft. This computes stft and
|
||||||
Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two
|
istft with librosa on stereo data. The two channels are processed
|
||||||
channels are processed separately and are concatenated together in the result. The expected input formats are:
|
separately and are concatenated together in the result. The expected
|
||||||
(n_samples, 2) for stft and (T, F, 2) for istft.
|
input formats are: (n_samples, 2) for stft and (T, F, 2) for istft.
|
||||||
:param data: np.array with either the waveform or the complex spectrogram depending on the parameter inverse
|
|
||||||
|
:param data: np.array with either the waveform or the complex
|
||||||
|
spectrogram depending on the parameter inverse
|
||||||
:param inverse: should a stft or an istft be computed.
|
:param inverse: should a stft or an istft be computed.
|
||||||
:return: Stereo data as numpy array for the transform. The channels are stored in the last dimension
|
:returns: Stereo data as numpy array for the transform.
|
||||||
|
The channels are stored in the last dimension.
|
||||||
"""
|
"""
|
||||||
assert not (inverse and length is None)
|
assert not (inverse and length is None)
|
||||||
data = np.asfortranarray(data)
|
data = np.asfortranarray(data)
|
||||||
N = self._params["frame_length"]
|
N = self._params['frame_length']
|
||||||
H = self._params["frame_step"]
|
H = self._params['frame_step']
|
||||||
|
|
||||||
win = hann(N, sym=False)
|
win = hann(N, sym=False)
|
||||||
fstft = istft if inverse else stft
|
fstft = istft if inverse else stft
|
||||||
win_len_arg = {"win_length": None,
|
win_len_arg = {
|
||||||
"length": None} if inverse else {"n_fft": N}
|
'win_length': None,
|
||||||
|
'length': None} if inverse else {'n_fft': N}
|
||||||
n_channels = data.shape[-1]
|
n_channels = data.shape[-1]
|
||||||
out = []
|
out = []
|
||||||
for c in range(n_channels):
|
for c in range(n_channels):
|
||||||
d = np.concatenate((np.zeros((N, )), data[:, c], np.zeros((N, )))) if not inverse else data[:, :, c].T
|
d = np.concatenate(
|
||||||
|
(np.zeros((N, )), data[:, c], np.zeros((N, )))
|
||||||
|
) if not inverse else data[:, :, c].T
|
||||||
s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg)
|
s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg)
|
||||||
if inverse:
|
if inverse:
|
||||||
s = s[N:N+length]
|
s = s[N:N+length]
|
||||||
@@ -141,7 +203,6 @@ class Separator(object):
|
|||||||
return out[0]
|
return out[0]
|
||||||
return np.concatenate(out, axis=2-inverse)
|
return np.concatenate(out, axis=2-inverse)
|
||||||
|
|
||||||
|
|
||||||
def _get_input_provider(self):
|
def _get_input_provider(self):
|
||||||
if self._input_provider is None:
|
if self._input_provider is None:
|
||||||
self._input_provider = InputProviderFactory.get(self._params)
|
self._input_provider = InputProviderFactory.get(self._params)
|
||||||
@@ -149,66 +210,83 @@ class Separator(object):
|
|||||||
|
|
||||||
def _get_features(self):
|
def _get_features(self):
|
||||||
if self._features is None:
|
if self._features is None:
|
||||||
self._features = self._get_input_provider().get_input_dict_placeholders()
|
provider = self._get_input_provider()
|
||||||
|
self._features = provider.get_input_dict_placeholders()
|
||||||
return self._features
|
return self._features
|
||||||
|
|
||||||
def _get_builder(self):
|
def _get_builder(self):
|
||||||
if self._builder is None:
|
if self._builder is None:
|
||||||
self._builder = EstimatorSpecBuilder(self._get_features(), self._params)
|
self._builder = EstimatorSpecBuilder(
|
||||||
|
self._get_features(),
|
||||||
|
self._params)
|
||||||
return self._builder
|
return self._builder
|
||||||
|
|
||||||
def _get_session(self):
|
def _get_session(self):
|
||||||
if self._session is None:
|
if self._session is None:
|
||||||
saver = tf.train.Saver()
|
saver = tf.compat.v1.train.Saver()
|
||||||
latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir']))
|
latest_checkpoint = tf.train.latest_checkpoint(
|
||||||
self._session = tf.Session()
|
get_default_model_dir(self._params['model_dir']))
|
||||||
|
self._session = tf.compat.v1.Session()
|
||||||
saver.restore(self._session, latest_checkpoint)
|
saver.restore(self._session, latest_checkpoint)
|
||||||
return self._session
|
return self._session
|
||||||
|
|
||||||
def _separate_librosa(self, waveform, audio_id):
|
def _separate_librosa(self, waveform: np.ndarray, audio_id):
|
||||||
"""
|
""" Performs separation with librosa backend for STFT.
|
||||||
Performs separation with librosa backend for STFT.
|
|
||||||
"""
|
"""
|
||||||
with self._tf_graph.as_default():
|
with self._tf_graph.as_default():
|
||||||
out = {}
|
out = {}
|
||||||
features = self._get_features()
|
features = self._get_features()
|
||||||
|
# TODO: fix the logic, build sometimes return,
|
||||||
# TODO: fix the logic, build sometimes return, sometimes set attribute
|
# sometimes set attribute.
|
||||||
outputs = self._get_builder().outputs
|
outputs = self._get_builder().outputs
|
||||||
stft = self._stft(waveform)
|
stft = self._stft(waveform)
|
||||||
if stft.shape[-1] == 1:
|
if stft.shape[-1] == 1:
|
||||||
stft = np.concatenate([stft, stft], axis=-1)
|
stft = np.concatenate([stft, stft], axis=-1)
|
||||||
elif stft.shape[-1] > 2:
|
elif stft.shape[-1] > 2:
|
||||||
stft = stft[:, :2]
|
stft = stft[:, :2]
|
||||||
|
|
||||||
sess = self._get_session()
|
sess = self._get_session()
|
||||||
outputs = sess.run(outputs, feed_dict=self._get_input_provider().get_feed_dict(features, stft, audio_id))
|
outputs = sess.run(
|
||||||
|
outputs,
|
||||||
|
feed_dict=self._get_input_provider().get_feed_dict(
|
||||||
|
features,
|
||||||
|
stft,
|
||||||
|
audio_id))
|
||||||
for inst in self._get_builder().instruments:
|
for inst in self._get_builder().instruments:
|
||||||
out[inst] = self._stft(outputs[inst], inverse=True, length=waveform.shape[0])
|
out[inst] = self._stft(
|
||||||
|
outputs[inst],
|
||||||
|
inverse=True,
|
||||||
|
length=waveform.shape[0])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def separate(self, waveform, audio_descriptor=""):
|
def separate(self, waveform: np.ndarray, audio_descriptor=''):
|
||||||
""" Performs separation on a waveform.
|
""" Performs separation on a waveform.
|
||||||
|
|
||||||
:param waveform: Waveform to be separated (as a numpy array)
|
:param waveform: Waveform to be separated (as a numpy array)
|
||||||
:param audio_descriptor: (Optional) string describing the waveform (e.g. filename).
|
:param audio_descriptor: (Optional) string describing the waveform
|
||||||
|
(e.g. filename).
|
||||||
"""
|
"""
|
||||||
if self._params["stft_backend"] == "tensorflow":
|
if self._params['stft_backend'] == 'tensorflow':
|
||||||
return self._separate_tensorflow(waveform, audio_descriptor)
|
return self._separate_tensorflow(waveform, audio_descriptor)
|
||||||
else:
|
else:
|
||||||
return self._separate_librosa(waveform, audio_descriptor)
|
return self._separate_librosa(waveform, audio_descriptor)
|
||||||
|
|
||||||
def separate_to_file(
|
def separate_to_file(
|
||||||
self, audio_descriptor, destination,
|
self,
|
||||||
|
audio_descriptor,
|
||||||
|
destination,
|
||||||
audio_adapter=get_default_audio_adapter(),
|
audio_adapter=get_default_audio_adapter(),
|
||||||
offset=0, duration=600., codec='wav', bitrate='128k',
|
offset=0,
|
||||||
|
duration=600.,
|
||||||
|
codec='wav',
|
||||||
|
bitrate='128k',
|
||||||
filename_format='{filename}/{instrument}.{codec}',
|
filename_format='{filename}/{instrument}.{codec}',
|
||||||
synchronous=True):
|
synchronous=True):
|
||||||
""" Performs source separation and export result to file using
|
""" Performs source separation and export result to file using
|
||||||
given audio adapter.
|
given audio adapter.
|
||||||
|
|
||||||
Filename format should be a Python formattable string that could use
|
Filename format should be a Python formattable string that could use
|
||||||
following parameters : {instrument}, {filename}, {foldername} and {codec}.
|
following parameters : {instrument}, {filename}, {foldername} and
|
||||||
|
{codec}.
|
||||||
|
|
||||||
:param audio_descriptor: Describe song to separate, used by audio
|
:param audio_descriptor: Describe song to separate, used by audio
|
||||||
adapter to retrieve and load audio data,
|
adapter to retrieve and load audio data,
|
||||||
@@ -217,8 +295,8 @@ class Separator(object):
|
|||||||
:param destination: Target directory to write output to.
|
:param destination: Target directory to write output to.
|
||||||
:param audio_adapter: (Optional) Audio adapter to use for I/O.
|
:param audio_adapter: (Optional) Audio adapter to use for I/O.
|
||||||
:param offset: (Optional) Offset of loaded song.
|
:param offset: (Optional) Offset of loaded song.
|
||||||
:param duration: (Optional) Duration of loaded song (default:
|
:param duration: (Optional) Duration of loaded song
|
||||||
600s).
|
(default: 600s).
|
||||||
:param codec: (Optional) Export codec.
|
:param codec: (Optional) Export codec.
|
||||||
:param bitrate: (Optional) Export bitrate.
|
:param bitrate: (Optional) Export bitrate.
|
||||||
:param filename_format: (Optional) Filename format.
|
:param filename_format: (Optional) Filename format.
|
||||||
@@ -230,16 +308,27 @@ class Separator(object):
|
|||||||
duration=duration,
|
duration=duration,
|
||||||
sample_rate=self._sample_rate)
|
sample_rate=self._sample_rate)
|
||||||
sources = self.separate(waveform, audio_descriptor)
|
sources = self.separate(waveform, audio_descriptor)
|
||||||
self.save_to_file( sources, audio_descriptor, destination,
|
self.save_to_file(
|
||||||
filename_format, codec, audio_adapter,
|
sources,
|
||||||
bitrate, synchronous)
|
audio_descriptor,
|
||||||
|
destination,
|
||||||
|
filename_format,
|
||||||
|
codec,
|
||||||
|
audio_adapter,
|
||||||
|
bitrate,
|
||||||
|
synchronous)
|
||||||
|
|
||||||
def save_to_file(
|
def save_to_file(
|
||||||
self, sources, audio_descriptor, destination,
|
self,
|
||||||
|
sources,
|
||||||
|
audio_descriptor,
|
||||||
|
destination,
|
||||||
filename_format='{filename}/{instrument}.{codec}',
|
filename_format='{filename}/{instrument}.{codec}',
|
||||||
codec='wav', audio_adapter=get_default_audio_adapter(),
|
codec='wav',
|
||||||
bitrate='128k', synchronous=True):
|
audio_adapter=get_default_audio_adapter(),
|
||||||
""" export dictionary of sources to files.
|
bitrate='128k',
|
||||||
|
synchronous=True):
|
||||||
|
""" Export dictionary of sources to files.
|
||||||
|
|
||||||
:param sources: Dictionary of sources to be exported. The
|
:param sources: Dictionary of sources to be exported. The
|
||||||
keys are the name of the instruments, and
|
keys are the name of the instruments, and
|
||||||
@@ -258,7 +347,6 @@ class Separator(object):
|
|||||||
:param synchronous: (Optional) True is should by synchronous.
|
:param synchronous: (Optional) True is should by synchronous.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
foldername = basename(dirname(audio_descriptor))
|
foldername = basename(dirname(audio_descriptor))
|
||||||
filename = splitext(basename(audio_descriptor))[0]
|
filename = splitext(basename(audio_descriptor))[0]
|
||||||
generated = []
|
generated = []
|
||||||
@@ -286,6 +374,11 @@ class Separator(object):
|
|||||||
bitrate))
|
bitrate))
|
||||||
self._tasks.append(task)
|
self._tasks.append(task)
|
||||||
else:
|
else:
|
||||||
audio_adapter.save(path, data, self._sample_rate, codec, bitrate)
|
audio_adapter.save(
|
||||||
|
path,
|
||||||
|
data,
|
||||||
|
self._sample_rate,
|
||||||
|
codec,
|
||||||
|
bitrate)
|
||||||
if synchronous and self._pool:
|
if synchronous and self._pool:
|
||||||
self.join()
|
self.join()
|
||||||
|
|||||||
@@ -5,20 +5,14 @@
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from os.path import join
|
from os.path import join
|
||||||
from tempfile import gettempdir
|
|
||||||
|
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.contrib import predictor
|
|
||||||
# pylint: enable=import-error
|
|
||||||
|
|
||||||
from ..model import model_fn, InputProviderFactory
|
from ..model import model_fn
|
||||||
from ..model.provider import get_default_model_provider
|
from ..model.provider import get_default_model_provider
|
||||||
|
|
||||||
# Default exporting directory for predictor.
|
|
||||||
DEFAULT_EXPORT_DIRECTORY = join(gettempdir(), 'serving')
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_model_dir(model_dir):
|
def get_default_model_dir(model_dir):
|
||||||
@@ -57,24 +51,3 @@ def create_estimator(params, MWF):
|
|||||||
config=config
|
config=config
|
||||||
)
|
)
|
||||||
return estimator
|
return estimator
|
||||||
|
|
||||||
|
|
||||||
def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
|
|
||||||
""" Exports given estimator as predictor into the given directory
|
|
||||||
and returns associated tf.predictor instance.
|
|
||||||
|
|
||||||
:param estimator: Estimator to export.
|
|
||||||
:param directory: (Optional) path to write exported model into.
|
|
||||||
"""
|
|
||||||
|
|
||||||
input_provider = InputProviderFactory.get(estimator.params)
|
|
||||||
def receiver():
|
|
||||||
features = input_provider.get_input_dict_placeholders()
|
|
||||||
return tf.estimator.export.ServingInputReceiver(features, features)
|
|
||||||
|
|
||||||
estimator.export_saved_model(directory, receiver)
|
|
||||||
versions = [
|
|
||||||
model for model in Path(directory).iterdir()
|
|
||||||
if model.is_dir() and 'temp' not in str(model)]
|
|
||||||
latest = str(sorted(versions)[-1])
|
|
||||||
return predictor.from_saved_model(latest)
|
|
||||||
|
|||||||
@@ -56,6 +56,9 @@ res_4stems = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def generate_fake_eval_dataset(path):
|
def generate_fake_eval_dataset(path):
|
||||||
|
"""
|
||||||
|
generate fake evaluation dataset
|
||||||
|
"""
|
||||||
aa = get_default_audio_adapter()
|
aa = get_default_audio_adapter()
|
||||||
n_songs = 2
|
n_songs = 2
|
||||||
fs = 44100
|
fs = 44100
|
||||||
@@ -71,6 +74,7 @@ def generate_fake_eval_dataset(path):
|
|||||||
aa.save(filename, data, fs)
|
aa.save(filename, data, fs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('backend', TEST_CONFIGURATIONS)
|
@pytest.mark.parametrize('backend', TEST_CONFIGURATIONS)
|
||||||
def test_evaluate(backend):
|
def test_evaluate(backend):
|
||||||
with TemporaryDirectory() as directory:
|
with TemporaryDirectory() as directory:
|
||||||
@@ -81,4 +85,4 @@ def test_evaluate(backend):
|
|||||||
metrics = evaluate.entrypoint(arguments, params)
|
metrics = evaluate.entrypoint(arguments, params)
|
||||||
for instrument, metric in metrics.items():
|
for instrument, metric in metrics.items():
|
||||||
for m, value in metric.items():
|
for m, value in metric.items():
|
||||||
assert np.allclose(np.median(value), res_4stems[instrument][m], atol=1e-3)
|
assert np.allclose(np.median(value), res_4stems[instrument][m], atol=1e-3)
|
||||||
@@ -7,7 +7,6 @@ __email__ = 'spleeter@deezer.com'
|
|||||||
__author__ = 'Deezer Research'
|
__author__ = 'Deezer Research'
|
||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
import filecmp
|
|
||||||
import itertools
|
import itertools
|
||||||
from os.path import splitext, basename, exists, join
|
from os.path import splitext, basename, exists, join
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
@@ -33,7 +32,8 @@ MODEL_TO_INST = {
|
|||||||
|
|
||||||
|
|
||||||
MODELS_AND_TEST_FILES = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS))
|
MODELS_AND_TEST_FILES = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS))
|
||||||
TEST_CONFIGURATIONS = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS, BACKENDS))
|
TEST_CONFIGURATIONS = list(itertools.product(
|
||||||
|
TEST_AUDIO_DESCRIPTORS, MODELS, BACKENDS))
|
||||||
|
|
||||||
|
|
||||||
print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__))
|
print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__))
|
||||||
@@ -44,8 +44,10 @@ def test_separator_backends(test_file):
|
|||||||
adapter = get_default_audio_adapter()
|
adapter = get_default_audio_adapter()
|
||||||
waveform, _ = adapter.load(test_file)
|
waveform, _ = adapter.load(test_file)
|
||||||
|
|
||||||
separator_lib = Separator("spleeter:2stems", stft_backend="librosa")
|
separator_lib = Separator(
|
||||||
separator_tf = Separator("spleeter:2stems", stft_backend="tensorflow")
|
"spleeter:2stems", stft_backend="librosa", multiprocess=False)
|
||||||
|
separator_tf = Separator(
|
||||||
|
"spleeter:2stems", stft_backend="tensorflow", multiprocess=False)
|
||||||
|
|
||||||
# Test the stft and inverse stft provides exact reconstruction
|
# Test the stft and inverse stft provides exact reconstruction
|
||||||
stft_matrix = separator_lib._stft(waveform)
|
stft_matrix = separator_lib._stft(waveform)
|
||||||
@@ -68,7 +70,8 @@ def test_separate(test_file, configuration, backend):
|
|||||||
instruments = MODEL_TO_INST[configuration]
|
instruments = MODEL_TO_INST[configuration]
|
||||||
adapter = get_default_audio_adapter()
|
adapter = get_default_audio_adapter()
|
||||||
waveform, _ = adapter.load(test_file)
|
waveform, _ = adapter.load(test_file)
|
||||||
separator = Separator(configuration, stft_backend=backend, multiprocess=False)
|
separator = Separator(
|
||||||
|
configuration, stft_backend=backend, multiprocess=False)
|
||||||
prediction = separator.separate(waveform, test_file)
|
prediction = separator.separate(waveform, test_file)
|
||||||
assert len(prediction) == len(instruments)
|
assert len(prediction) == len(instruments)
|
||||||
for instrument in instruments:
|
for instrument in instruments:
|
||||||
@@ -80,14 +83,14 @@ def test_separate(test_file, configuration, backend):
|
|||||||
for compared in instruments:
|
for compared in instruments:
|
||||||
if instrument != compared:
|
if instrument != compared:
|
||||||
assert not np.allclose(track, prediction[compared])
|
assert not np.allclose(track, prediction[compared])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
|
@pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS)
|
||||||
def test_separate_to_file(test_file, configuration, backend):
|
def test_separate_to_file(test_file, configuration, backend):
|
||||||
""" Test file based separation. """
|
""" Test file based separation. """
|
||||||
instruments = MODEL_TO_INST[configuration]
|
instruments = MODEL_TO_INST[configuration]
|
||||||
separator = Separator(configuration, stft_backend=backend, multiprocess=False)
|
separator = Separator(
|
||||||
|
configuration, stft_backend=backend, multiprocess=False)
|
||||||
name = splitext(basename(test_file))[0]
|
name = splitext(basename(test_file))[0]
|
||||||
with TemporaryDirectory() as directory:
|
with TemporaryDirectory() as directory:
|
||||||
separator.separate_to_file(
|
separator.separate_to_file(
|
||||||
@@ -103,7 +106,8 @@ def test_separate_to_file(test_file, configuration, backend):
|
|||||||
def test_filename_format(test_file, configuration, backend):
|
def test_filename_format(test_file, configuration, backend):
|
||||||
""" Test custom filename format. """
|
""" Test custom filename format. """
|
||||||
instruments = MODEL_TO_INST[configuration]
|
instruments = MODEL_TO_INST[configuration]
|
||||||
separator = Separator(configuration, stft_backend=backend, multiprocess=False)
|
separator = Separator(
|
||||||
|
configuration, stft_backend=backend, multiprocess=False)
|
||||||
name = splitext(basename(test_file))[0]
|
name = splitext(basename(test_file))[0]
|
||||||
with TemporaryDirectory() as directory:
|
with TemporaryDirectory() as directory:
|
||||||
separator.separate_to_file(
|
separator.separate_to_file(
|
||||||
|
|||||||
113
tests/test_train.py
Normal file
113
tests/test_train.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf8
|
||||||
|
|
||||||
|
""" Unit testing for Separator class. """
|
||||||
|
|
||||||
|
__email__ = 'research@deezer.com'
|
||||||
|
__author__ = 'Deezer Research'
|
||||||
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
import filecmp
|
||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
from os import makedirs
|
||||||
|
from os.path import splitext, basename, exists, join
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import json
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from spleeter.audio.adapter import get_default_audio_adapter
|
||||||
|
from spleeter.commands import create_argument_parser
|
||||||
|
|
||||||
|
from spleeter.commands import train
|
||||||
|
|
||||||
|
from spleeter.utils.configuration import load_configuration
|
||||||
|
|
||||||
|
TRAIN_CONFIG = {
|
||||||
|
"mix_name": "mix",
|
||||||
|
"instrument_list": ["vocals", "other"],
|
||||||
|
"sample_rate":44100,
|
||||||
|
"frame_length":4096,
|
||||||
|
"frame_step":1024,
|
||||||
|
"T":128,
|
||||||
|
"F":128,
|
||||||
|
"n_channels":2,
|
||||||
|
"chunk_duration":4,
|
||||||
|
"n_chunks_per_song":1,
|
||||||
|
"separation_exponent":2,
|
||||||
|
"mask_extension":"zeros",
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"batch_size":2,
|
||||||
|
"train_max_steps": 10,
|
||||||
|
"throttle_secs":20,
|
||||||
|
"save_checkpoints_steps":100,
|
||||||
|
"save_summary_steps":5,
|
||||||
|
"random_seed":0,
|
||||||
|
"model":{
|
||||||
|
"type":"unet.unet",
|
||||||
|
"params":{
|
||||||
|
"conv_activation":"ELU",
|
||||||
|
"deconv_activation":"ELU"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def generate_fake_training_dataset(path, instrument_list=["vocals", "other"]):
|
||||||
|
"""
|
||||||
|
generates a fake training dataset in path:
|
||||||
|
- generates audio files
|
||||||
|
- generates a csv file describing the dataset
|
||||||
|
"""
|
||||||
|
aa = get_default_audio_adapter()
|
||||||
|
n_songs = 2
|
||||||
|
fs = 44100
|
||||||
|
duration = 6
|
||||||
|
n_channels = 2
|
||||||
|
rng = np.random.RandomState(seed=0)
|
||||||
|
dataset_df = pd.DataFrame(columns=["mix_path"]+[f"{instr}_path" for instr in instrument_list]+["duration"])
|
||||||
|
for song in range(n_songs):
|
||||||
|
song_path = join(path, "train", f"song{song}")
|
||||||
|
makedirs(song_path, exist_ok=True)
|
||||||
|
dataset_df.loc[song, f"duration"] = duration
|
||||||
|
for instr in instrument_list+["mix"]:
|
||||||
|
filename = join(song_path, f"{instr}.wav")
|
||||||
|
data = rng.rand(duration*fs, n_channels)-0.5
|
||||||
|
aa.save(filename, data, fs)
|
||||||
|
dataset_df.loc[song, f"{instr}_path"] = join("train", f"song{song}", f"{instr}.wav")
|
||||||
|
|
||||||
|
dataset_df.to_csv(join(path, "train", "train.csv"), index=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_train():
|
||||||
|
|
||||||
|
|
||||||
|
with TemporaryDirectory() as path:
|
||||||
|
|
||||||
|
# generate training dataset
|
||||||
|
generate_fake_training_dataset(path)
|
||||||
|
|
||||||
|
# set training command aruments
|
||||||
|
p = create_argument_parser()
|
||||||
|
arguments = p.parse_args(["train", "-p", "useless_config.json", "-d", path])
|
||||||
|
TRAIN_CONFIG["train_csv"] = join(path, "train", "train.csv")
|
||||||
|
TRAIN_CONFIG["validation_csv"] = join(path, "train", "train.csv")
|
||||||
|
TRAIN_CONFIG["model_dir"] = join(path, "model")
|
||||||
|
TRAIN_CONFIG["training_cache"] = join(path, "cache", "training")
|
||||||
|
TRAIN_CONFIG["validation_cache"] = join(path, "cache", "validation")
|
||||||
|
|
||||||
|
# execute training
|
||||||
|
res = train.entrypoint(arguments, TRAIN_CONFIG)
|
||||||
|
|
||||||
|
# assert that model checkpoint was created.
|
||||||
|
assert os.path.exists(join(path,'model','model.ckpt-10.index'))
|
||||||
|
assert os.path.exists(join(path,'model','checkpoint'))
|
||||||
|
assert os.path.exists(join(path,'model','model.ckpt-0.meta'))
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
test_train()
|
||||||
Reference in New Issue
Block a user