🎨 finalize refactoring

This commit is contained in:
Faylixe
2020-12-08 12:10:45 +01:00
parent 075bb97f82
commit ed7bd4b945
6 changed files with 269 additions and 162 deletions

View File

@@ -11,7 +11,7 @@ from typing import Any, Dict, List, Optional, Union
from .. import SpleeterError
from ..types import AudioDescriptor, Signal
from ..utils.logging import get_logger
from ..utils.logging import logger
# pyright: reportMissingImports=false
# pylint: disable=import-error
@@ -101,7 +101,6 @@ class AudioAdapter(ABC):
# Defined safe loading function.
def safe_load(path, offset, duration, sample_rate, dtype):
logger = get_logger()
logger.info(
f'Loading audio {path} from {offset} to {offset + duration}')
try:

View File

@@ -19,7 +19,7 @@ from . import Codec
from .adapter import AudioAdapter
from .. import SpleeterError
from ..types import Signal
from ..utils.logging import get_logger
from ..utils.logging import logger
# pyright: reportMissingImports=false
# pylint: disable=import-error
@@ -161,7 +161,7 @@ class FFMPEGProcessAudioAdapter(AudioAdapter):
if not os.path.exists(directory):
raise SpleeterError(
f'output directory does not exists: {directory}')
get_logger().debug(f'Writing file {path}')
logger.debug(f'Writing file {path}')
input_kwargs = {'ar': sample_rate, 'ac': data.shape[1]}
output_kwargs = {'ar': sample_rate, 'strict': '-2'}
if bitrate:
@@ -180,4 +180,4 @@ class FFMPEGProcessAudioAdapter(AudioAdapter):
process.wait()
except IOError:
raise SpleeterError(f'FFMPEG error: {process.stderr.read()}')
get_logger().info(f'File {path} written succesfully')
logger.info(f'File {path} written succesfully')

View File

@@ -18,12 +18,14 @@ import time
import os
from os.path import exists, sep as SEPARATOR
from typing import Any, Dict, Optional
from .audio.adapter import AudioAdapter
from .audio.convertor import db_uint_spectrogram_to_gain
from .audio.convertor import spectrogram_to_db_uint
from .audio.spectrogram import compute_spectrogram_tf
from .audio.spectrogram import random_pitch_shift, random_time_stretch
from .utils.logging import get_logger
from .utils.logging import logger
from .utils.tensor import check_tensor_shape, dataset_from_csv
from .utils.tensor import set_tensor_shape, sync_apply
@@ -37,24 +39,34 @@ __author__ = 'Deezer Research'
__license__ = 'MIT License'
# Default audio parameters to use.
DEFAULT_AUDIO_PARAMS = {
DEFAULT_AUDIO_PARAMS: Dict = {
'instrument_list': ('vocals', 'accompaniment'),
'mix_name': 'mix',
'sample_rate': 44100,
'frame_length': 4096,
'frame_step': 1024,
'T': 512,
'F': 1024
}
'F': 1024}
def get_training_dataset(audio_params, audio_adapter, audio_path):
""" Builds training dataset.
def get_training_dataset(
audio_params: Dict,
audio_adapter: AudioAdapter,
audio_path: str) -> Any:
"""
Builds training dataset.
:param audio_params: Audio parameters.
:param audio_adapter: Adapter to load audio from.
:param audio_path: Path of directory containing audio.
:returns: Built dataset.
Parameters:
audio_params (Dict):
Audio parameters.
audio_adapter (AudioAdapter):
Adapter to load audio from.
audio_path (str):
Path of directory containing audio.
Returns:
Any:
Built dataset.
"""
builder = DatasetBuilder(
audio_params,
@@ -72,13 +84,24 @@ def get_training_dataset(audio_params, audio_adapter, audio_path):
wait_for_cache=False)
def get_validation_dataset(audio_params, audio_adapter, audio_path):
""" Builds validation dataset.
def get_validation_dataset(
audio_params: Dict,
audio_adapter: AudioAdapter,
audio_path: str) -> Any:
"""
Builds validation dataset.
:param audio_params: Audio parameters.
:param audio_adapter: Adapter to load audio from.
:param audio_path: Path of directory containing audio.
:returns: Built dataset.
Parameters:
audio_params (Dict):
Audio parameters.
audio_adapter (AudioAdapter):
Adapter to load audio from.
audio_path (str):
Path of directory containing audio.
Returns:
Any:
Built dataset.
"""
builder = DatasetBuilder(
audio_params,
@@ -102,11 +125,15 @@ def get_validation_dataset(audio_params, audio_adapter, audio_path):
class InstrumentDatasetBuilder(object):
""" Instrument based filter and mapper provider. """
def __init__(self, parent, instrument):
""" Default constructor.
def __init__(self, parent, instrument) -> None:
"""
Default constructor.
:param parent: Parent dataset builder.
:param instrument: Target instrument.
Parameters:
parent:
Parent dataset builder.
instrument:
Target instrument.
"""
self._parent = parent
self._instrument = instrument
@@ -181,7 +208,7 @@ class InstrumentDatasetBuilder(object):
self._parent._T, self._parent._F, 2))
def reshape_spectrogram(self, sample):
""" """
""" Reshape given sample. """
return dict(sample, **{
self._spectrogram_key: set_tensor_shape(
sample[self._spectrogram_key],
@@ -190,27 +217,35 @@ class InstrumentDatasetBuilder(object):
class DatasetBuilder(object):
"""
TO BE DOCUMENTED.
"""
# Margin at beginning and end of songs in seconds.
MARGIN = 0.5
MARGIN: float = 0.5
""" Margin at beginning and end of songs in seconds. """
# Wait period for cache (in seconds).
WAIT_PERIOD = 60
WAIT_PERIOD: int = 60
""" Wait period for cache (in seconds). """
def __init__(
self,
audio_params, audio_adapter, audio_path,
random_seed=0, chunk_duration=20.0):
""" Default constructor.
audio_params: Dict,
audio_adapter: AudioAdapter,
audio_path: str,
random_seed: int = 0,
chunk_duration: float = 20.0) -> None:
"""
Default constructor.
NOTE: Probably need for AudioAdapter.
NOTE: Probably need for AudioAdapter.
:param audio_params: Audio parameters to use.
:param audio_adapter: Audio adapter to use.
:param audio_path:
:param random_seed:
:param chunk_duration:
Parameters:
audio_params (Dict):
Audio parameters to use.
audio_adapter (AudioAdapter):
Audio adapter to use.
audio_path (str):
random_seed (int):
chunk_duration (float):
"""
# Length of segment in frames (if fs=22050 and
# frame_step=512, then T=512 corresponds to 11.89s)
@@ -298,12 +333,22 @@ class DatasetBuilder(object):
for instrument in self._audio_params['instrument_list']}
return (input_, output)
def compute_segments(self, dataset, n_chunks_per_song):
""" Computes segments for each song of the dataset.
def compute_segments(
self,
dataset: Any,
n_chunks_per_song: int) -> Any:
"""
Computes segments for each song of the dataset.
:param dataset: Dataset to compute segments for.
:param n_chunks_per_song: Number of segment per song to compute.
:returns: Segmented dataset.
Parameters:
dataset (Any):
Dataset to compute segments for.
n_chunks_per_song (int):
Number of segment per song to compute.
Returns:
Any:
Segmented dataset.
"""
if n_chunks_per_song <= 0:
raise ValueError('n_chunks_per_song must be positif')
@@ -327,10 +372,13 @@ class DatasetBuilder(object):
return dataset
@property
def instruments(self):
""" Instrument dataset builder generator.
def instruments(self) -> Any:
"""
Instrument dataset builder generator.
:yield InstrumentBuilder instance.
Yields:
Any:
InstrumentBuilder instance.
"""
if self._instrument_builders is None:
self._instrument_builders = []
@@ -340,22 +388,33 @@ class DatasetBuilder(object):
for builder in self._instrument_builders:
yield builder
def cache(self, dataset, cache, wait):
""" Cache the given dataset if cache is enabled. Eventually waits for
cache to be available (useful if another process is already computing
cache) if provided wait flag is True.
def cache(
self,
dataset: Any,
cache: str,
wait: bool) -> Any:
"""
Cache the given dataset if cache is enabled. Eventually waits for
cache to be available (useful if another process is already
computing cache) if provided wait flag is `True`.
:param dataset: Dataset to be cached if cache is required.
:param cache: Path of cache directory to be used, None if no cache.
:param wait: If caching is enabled, True is cache should be waited.
:returns: Cached dataset if needed, original dataset otherwise.
Parameters:
dataset (Any):
Dataset to be cached if cache is required.
cache (str):
Path of cache directory to be used, None if no cache.
wait (bool):
If caching is enabled, True is cache should be waited.
Returns:
Any:
Cached dataset if needed, original dataset otherwise.
"""
if cache is not None:
if wait:
while not exists(f'{cache}.index'):
get_logger().info(
'Cache not available, wait %s',
self.WAIT_PERIOD)
logger.info(
f'Cache not available, wait {self.WAIT_PERIOD}')
time.sleep(self.WAIT_PERIOD)
cache_path = os.path.split(cache)[0]
os.makedirs(cache_path, exist_ok=True)
@@ -363,13 +422,20 @@ class DatasetBuilder(object):
return dataset
def build(
self, csv_path,
batch_size=8, shuffle=True, convert_to_uint=True,
random_data_augmentation=False, random_time_crop=True,
infinite_generator=True, cache_directory=None,
wait_for_cache=False, num_parallel_calls=4, n_chunks_per_song=2,):
self,
csv_path: str,
batch_size: int = 8,
shuffle: bool = True,
convert_to_uint: bool = True,
random_data_augmentation: bool = False,
random_time_crop: bool = True,
infinite_generator: bool = True,
cache_directory: Optional[str] = None,
wait_for_cache: bool = False,
num_parallel_calls: int = 4,
n_chunks_per_song: float = 2,) -> Any:
"""
TO BE DOCUMENTED.
TO BE DOCUMENTED.
"""
dataset = dataset_from_csv(csv_path)
dataset = self.compute_segments(dataset, n_chunks_per_song)

View File

@@ -5,6 +5,7 @@
import importlib
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf

View File

@@ -25,7 +25,7 @@ from tempfile import NamedTemporaryFile
from typing import Dict
from . import ModelProvider
from ...utils.logging import get_logger
from ...utils.logging import logger
# pyright: reportMissingImports=false
# pylint: disable=import-error
@@ -138,7 +138,7 @@ class GithubModelProvider(ModelProvider):
self._release,
name))
url = f'{url}.tar.gz'
get_logger().info(f'Downloading model archive {url}')
logger.info(f'Downloading model archive {url}')
with httpx.Client(http2=True) as client:
with client.strema('GET', url) as response:
response.raise_for_status()
@@ -147,14 +147,14 @@ class GithubModelProvider(ModelProvider):
with archive as stream:
for chunk in response.iter_raw():
stream.write(chunk)
get_logger().info('Validating archive checksum')
logger.info('Validating archive checksum')
checksum: str = compute_file_checksum(archive.name)
if checksum != self.checksum(name):
raise IOError(
'Downloaded file is corrupted, please retry')
get_logger().info(f'Extracting downloaded {name} archive')
logger.info(f'Extracting downloaded {name} archive')
with tarfile.open(name=archive.name) as tar:
tar.extractall(path=path)
finally:
os.unlink(archive.name)
get_logger().info(f'{name} model file(s) extracted')
logger.info(f'{name} model file(s) extracted')

View File

@@ -19,13 +19,15 @@ import os
from multiprocessing import Pool
from os.path import basename, join, splitext, dirname
from typing import Generator, Optional
from spleeter.model.provider import ModelProvider
from typing import Dict, Generator, Optional
from . import SpleeterError
from .audio import STFTBackend
from .audio.adapter import get_default_audio_adapter
from .audio import Codec, STFTBackend
from .audio.adapter import AudioAdapter
from .audio.convertor import to_stereo
from .model import EstimatorSpecBuilder, InputProviderFactory
from .model import model_fn
from .utils.configuration import load_configuration
# pyright: reportMissingImports=false
@@ -65,18 +67,6 @@ class DataGenerator(object):
buffer = self._current_data
def get_backend(backend: str) -> str:
"""
"""
if backend not in SUPPORTED_BACKEND:
raise ValueError(f'Unsupported backend {backend}')
if backend == 'auto':
if len(tf.config.list_physical_devices('GPU')):
return 'tensorflow'
return 'librosa'
return backend
def create_estimator(params, MWF):
"""
Initialize tensorflow estimator that will perform separation
@@ -137,18 +127,21 @@ class Separator(object):
else:
self._pool = None
self._tasks = []
self._params['stft_backend'] = get_backend(stft_backend)
self._params['stft_backend'] = stft_backend
self._data_generator = DataGenerator()
def __del__(self) -> None:
if self._session:
self._session.close()
def _get_prediction_generator(self):
""" Lazy loading access method for internal prediction generator
returned by the predict method of a tensorflow estimator.
def _get_prediction_generator(self) -> Generator:
"""
Lazy loading access method for internal prediction generator
returned by the predict method of a tensorflow estimator.
:returns: generator of prediction.
Returns:
Generator:
Generator of prediction.
"""
if self._prediction_generator is None:
estimator = create_estimator(self._params, self._MWF)
@@ -181,17 +174,30 @@ class Separator(object):
task.get()
task.wait(timeout=timeout)
def _stft(self, data, inverse: bool = False, length=None):
""" Single entrypoint for both stft and istft. This computes stft and
istft with librosa on stereo data. The two channels are processed
separately and are concatenated together in the result. The expected
input formats are: (n_samples, 2) for stft and (T, F, 2) for istft.
def _stft(
self,
data: np.ndarray,
inverse: bool = False,
length: Optional[int] = None) -> np.ndarray:
"""
Single entrypoint for both stft and istft. This computes stft and
istft with librosa on stereo data. The two channels are processed
separately and are concatenated together in the result. The
expected input formats are: (n_samples, 2) for stft and (T, F, 2)
for istft.
:param data: np.array with either the waveform or the complex
spectrogram depending on the parameter inverse
:param inverse: should a stft or an istft be computed.
:returns: Stereo data as numpy array for the transform.
The channels are stored in the last dimension.
Parameters:
data (numpy.array):
Array with either the waveform or the complex spectrogram
depending on the parameter inverse
inverse (bool):
(Optional) Should a stft or an istft be computed.
length (Optional[int]):
Returns:
numpy.ndarray:
Stereo data as numpy array for the transform. The channels
are stored in the last dimension.
"""
assert not (inverse and length is None)
data = np.asfortranarray(data)
@@ -238,19 +244,24 @@ class Separator(object):
def _get_session(self):
if self._session is None:
saver = tf.compat.v1.train.Saver()
latest_checkpoint = tf.train.latest_checkpoint(
get_default_model_dir(self._params['model_dir']))
provider = ModelProvider.default()
model_directory: str = provider.get(self._params['model_dir'])
latest_checkpoint = tf.train.latest_checkpoint(model_directory)
self._session = tf.compat.v1.Session()
saver.restore(self._session, latest_checkpoint)
return self._session
def _separate_librosa(self, waveform: np.ndarray, audio_id):
def _separate_librosa(
self,
waveform: np.ndarray,
audio_descriptor: str) -> Dict:
"""
Performs separation with librosa backend for STFT.
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
audio_descriptor (str):
"""
with self._tf_graph.as_default():
out = {}
@@ -269,7 +280,7 @@ class Separator(object):
feed_dict=self._get_input_provider().get_feed_dict(
features,
stft,
audio_id))
audio_descriptor))
for inst in self._get_builder().instruments:
out[inst] = self._stft(
outputs[inst],
@@ -277,7 +288,10 @@ class Separator(object):
length=waveform.shape[0])
return out
def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
def _separate_tensorflow(
self,
waveform: np.ndarray,
audio_descriptor: str) -> Dict:
"""
Performs source separation over the given waveform with tensorflow
backend.
@@ -285,6 +299,7 @@ class Separator(object):
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
audio_descriptor (str):
Returns:
Separated waveforms.
@@ -314,44 +329,61 @@ class Separator(object):
audio_descriptor (str):
(Optional) string describing the waveform (e.g. filename).
"""
if self._params['stft_backend'] == 'tensorflow':
backend: str = self._params['stft_backend']
if backend == STFTBackend.TENSORFLOW:
return self._separate_tensorflow(waveform, audio_descriptor)
else:
elif backend == STFTBackend.LIBROSA:
return self._separate_librosa(waveform, audio_descriptor)
raise ValueError(f'Unsupported STFT backend {backend}')
def separate_to_file(
self,
audio_descriptor,
destination,
audio_adapter=get_default_audio_adapter(),
offset=0,
duration=600.,
codec='wav',
bitrate='128k',
filename_format='{filename}/{instrument}.{codec}',
synchronous=True):
""" Performs source separation and export result to file using
given audio adapter.
Filename format should be a Python formattable string that could use
following parameters : {instrument}, {filename}, {foldername} and
{codec}.
:param audio_descriptor: Describe song to separate, used by audio
adapter to retrieve and load audio data,
in case of file based audio adapter, such
descriptor would be a file path.
:param destination: Target directory to write output to.
:param audio_adapter: (Optional) Audio adapter to use for I/O.
:param offset: (Optional) Offset of loaded song.
:param duration: (Optional) Duration of loaded song
(default: 600s).
:param codec: (Optional) Export codec.
:param bitrate: (Optional) Export bitrate.
:param filename_format: (Optional) Filename format.
:param synchronous: (Optional) True is should by synchronous.
audio_descriptor: str,
destination: str,
audio_adapter: Optional[AudioAdapter] = None,
offset: int = 0,
duration: float = 600.,
codec: Codec = Codec.WAV,
bitrate: str = '128k',
filename_format: str = '{filename}/{instrument}.{codec}',
synchronous: bool = True) -> None:
"""
waveform, sample_rate = audio_adapter.load(
Performs source separation and export result to file using
given audio adapter.
Filename format should be a Python formattable string that could
use following parameters :
- {instrument}
- {filename}
- {foldername}
- {codec}.
Parameters:
audio_descriptor (str):
Describe song to separate, used by audio adapter to
retrieve and load audio data, in case of file based
audio adapter, such descriptor would be a file path.
destination (str):
Target directory to write output to.
audio_adapter (Optional[AudioAdapter]):
(Optional) Audio adapter to use for I/O.
offset (int):
(Optional) Offset of loaded song.
duration (float):
(Optional) Duration of loaded song (default: 600s).
codec (Codec):
(Optional) Export codec.
bitrate (str):
(Optional) Export bitrate.
filename_format (str):
(Optional) Filename format.
synchronous (bool):
(Optional) True is should by synchronous.
"""
if audio_adapter is None:
audio_adapter = AudioAdapter.default()
waveform, _ = audio_adapter.load(
audio_descriptor,
offset=offset,
duration=duration,
@@ -369,33 +401,42 @@ class Separator(object):
def save_to_file(
self,
sources,
audio_descriptor,
destination,
filename_format='{filename}/{instrument}.{codec}',
codec='wav',
audio_adapter=get_default_audio_adapter(),
bitrate='128k',
synchronous=True):
""" Export dictionary of sources to files.
:param sources: Dictionary of sources to be exported. The
keys are the name of the instruments, and
the values are Nx2 numpy arrays containing
the corresponding intrument waveform, as
returned by the separate method
:param audio_descriptor: Describe song to separate, used by audio
adapter to retrieve and load audio data,
in case of file based audio adapter, such
descriptor would be a file path.
:param destination: Target directory to write output to.
:param filename_format: (Optional) Filename format.
:param codec: (Optional) Export codec.
:param audio_adapter: (Optional) Audio adapter to use for I/O.
:param bitrate: (Optional) Export bitrate.
:param synchronous: (Optional) True is should by synchronous.
sources: Dict,
audio_descriptor: str,
destination: str,
filename_format: str = '{filename}/{instrument}.{codec}',
codec: Codec = Codec.WAV,
audio_adapter: Optional[AudioAdapter] = None,
bitrate: str = '128k',
synchronous: bool = True) -> None:
"""
Export dictionary of sources to files.
Parameters:
sources (Dict):
Dictionary of sources to be exported. The keys are the name
of the instruments, and the values are `N x 2` numpy arrays
containing the corresponding intrument waveform, as
returned by the separate method
audio_descriptor (str):
Describe song to separate, used by audio adapter to
retrieve and load audio data, in case of file based audio
adapter, such descriptor would be a file path.
destination (str):
Target directory to write output to.
filename_format (str):
(Optional) Filename format.
codec (Codec):
(Optional) Export codec.
audio_adapter (Optional[AudioAdapter]):
(Optional) Audio adapter to use for I/O.
bitrate (str):
(Optional) Export bitrate.
synchronous (bool):
(Optional) True is should by synchronous.
"""
if audio_adapter is None:
audio_adapter = AudioAdapter.default()
foldername = basename(dirname(audio_descriptor))
filename = splitext(basename(audio_descriptor))[0]
generated = []