From ed7bd4b945b976229b64c2f0936891bd2f245150 Mon Sep 17 00:00:00 2001 From: Faylixe Date: Tue, 8 Dec 2020 12:10:45 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20=20finalize=20refactoring?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- spleeter/audio/adapter.py | 3 +- spleeter/audio/ffmpeg.py | 6 +- spleeter/dataset.py | 184 ++++++++++++++++-------- spleeter/model/__init__.py | 1 + spleeter/model/provider/github.py | 10 +- spleeter/separator.py | 227 ++++++++++++++++++------------ 6 files changed, 269 insertions(+), 162 deletions(-) diff --git a/spleeter/audio/adapter.py b/spleeter/audio/adapter.py index e7d0c6b..b6fe1c6 100644 --- a/spleeter/audio/adapter.py +++ b/spleeter/audio/adapter.py @@ -11,7 +11,7 @@ from typing import Any, Dict, List, Optional, Union from .. import SpleeterError from ..types import AudioDescriptor, Signal -from ..utils.logging import get_logger +from ..utils.logging import logger # pyright: reportMissingImports=false # pylint: disable=import-error @@ -101,7 +101,6 @@ class AudioAdapter(ABC): # Defined safe loading function. def safe_load(path, offset, duration, sample_rate, dtype): - logger = get_logger() logger.info( f'Loading audio {path} from {offset} to {offset + duration}') try: diff --git a/spleeter/audio/ffmpeg.py b/spleeter/audio/ffmpeg.py index 6f5ce37..81f943a 100644 --- a/spleeter/audio/ffmpeg.py +++ b/spleeter/audio/ffmpeg.py @@ -19,7 +19,7 @@ from . import Codec from .adapter import AudioAdapter from .. import SpleeterError from ..types import Signal -from ..utils.logging import get_logger +from ..utils.logging import logger # pyright: reportMissingImports=false # pylint: disable=import-error @@ -161,7 +161,7 @@ class FFMPEGProcessAudioAdapter(AudioAdapter): if not os.path.exists(directory): raise SpleeterError( f'output directory does not exists: {directory}') - get_logger().debug(f'Writing file {path}') + logger.debug(f'Writing file {path}') input_kwargs = {'ar': sample_rate, 'ac': data.shape[1]} output_kwargs = {'ar': sample_rate, 'strict': '-2'} if bitrate: @@ -180,4 +180,4 @@ class FFMPEGProcessAudioAdapter(AudioAdapter): process.wait() except IOError: raise SpleeterError(f'FFMPEG error: {process.stderr.read()}') - get_logger().info(f'File {path} written succesfully') + logger.info(f'File {path} written succesfully') diff --git a/spleeter/dataset.py b/spleeter/dataset.py index 84a55d0..4c82c43 100644 --- a/spleeter/dataset.py +++ b/spleeter/dataset.py @@ -18,12 +18,14 @@ import time import os from os.path import exists, sep as SEPARATOR +from typing import Any, Dict, Optional +from .audio.adapter import AudioAdapter from .audio.convertor import db_uint_spectrogram_to_gain from .audio.convertor import spectrogram_to_db_uint from .audio.spectrogram import compute_spectrogram_tf from .audio.spectrogram import random_pitch_shift, random_time_stretch -from .utils.logging import get_logger +from .utils.logging import logger from .utils.tensor import check_tensor_shape, dataset_from_csv from .utils.tensor import set_tensor_shape, sync_apply @@ -37,24 +39,34 @@ __author__ = 'Deezer Research' __license__ = 'MIT License' # Default audio parameters to use. -DEFAULT_AUDIO_PARAMS = { +DEFAULT_AUDIO_PARAMS: Dict = { 'instrument_list': ('vocals', 'accompaniment'), 'mix_name': 'mix', 'sample_rate': 44100, 'frame_length': 4096, 'frame_step': 1024, 'T': 512, - 'F': 1024 -} + 'F': 1024} -def get_training_dataset(audio_params, audio_adapter, audio_path): - """ Builds training dataset. +def get_training_dataset( + audio_params: Dict, + audio_adapter: AudioAdapter, + audio_path: str) -> Any: + """ + Builds training dataset. - :param audio_params: Audio parameters. - :param audio_adapter: Adapter to load audio from. - :param audio_path: Path of directory containing audio. - :returns: Built dataset. + Parameters: + audio_params (Dict): + Audio parameters. + audio_adapter (AudioAdapter): + Adapter to load audio from. + audio_path (str): + Path of directory containing audio. + + Returns: + Any: + Built dataset. """ builder = DatasetBuilder( audio_params, @@ -72,13 +84,24 @@ def get_training_dataset(audio_params, audio_adapter, audio_path): wait_for_cache=False) -def get_validation_dataset(audio_params, audio_adapter, audio_path): - """ Builds validation dataset. +def get_validation_dataset( + audio_params: Dict, + audio_adapter: AudioAdapter, + audio_path: str) -> Any: + """ + Builds validation dataset. - :param audio_params: Audio parameters. - :param audio_adapter: Adapter to load audio from. - :param audio_path: Path of directory containing audio. - :returns: Built dataset. + Parameters: + audio_params (Dict): + Audio parameters. + audio_adapter (AudioAdapter): + Adapter to load audio from. + audio_path (str): + Path of directory containing audio. + + Returns: + Any: + Built dataset. """ builder = DatasetBuilder( audio_params, @@ -102,11 +125,15 @@ def get_validation_dataset(audio_params, audio_adapter, audio_path): class InstrumentDatasetBuilder(object): """ Instrument based filter and mapper provider. """ - def __init__(self, parent, instrument): - """ Default constructor. + def __init__(self, parent, instrument) -> None: + """ + Default constructor. - :param parent: Parent dataset builder. - :param instrument: Target instrument. + Parameters: + parent: + Parent dataset builder. + instrument: + Target instrument. """ self._parent = parent self._instrument = instrument @@ -181,7 +208,7 @@ class InstrumentDatasetBuilder(object): self._parent._T, self._parent._F, 2)) def reshape_spectrogram(self, sample): - """ """ + """ Reshape given sample. """ return dict(sample, **{ self._spectrogram_key: set_tensor_shape( sample[self._spectrogram_key], @@ -190,27 +217,35 @@ class InstrumentDatasetBuilder(object): class DatasetBuilder(object): """ + TO BE DOCUMENTED. """ - # Margin at beginning and end of songs in seconds. - MARGIN = 0.5 + MARGIN: float = 0.5 + """ Margin at beginning and end of songs in seconds. """ - # Wait period for cache (in seconds). - WAIT_PERIOD = 60 + WAIT_PERIOD: int = 60 + """ Wait period for cache (in seconds). """ def __init__( self, - audio_params, audio_adapter, audio_path, - random_seed=0, chunk_duration=20.0): - """ Default constructor. + audio_params: Dict, + audio_adapter: AudioAdapter, + audio_path: str, + random_seed: int = 0, + chunk_duration: float = 20.0) -> None: + """ + Default constructor. - NOTE: Probably need for AudioAdapter. + NOTE: Probably need for AudioAdapter. - :param audio_params: Audio parameters to use. - :param audio_adapter: Audio adapter to use. - :param audio_path: - :param random_seed: - :param chunk_duration: + Parameters: + audio_params (Dict): + Audio parameters to use. + audio_adapter (AudioAdapter): + Audio adapter to use. + audio_path (str): + random_seed (int): + chunk_duration (float): """ # Length of segment in frames (if fs=22050 and # frame_step=512, then T=512 corresponds to 11.89s) @@ -298,12 +333,22 @@ class DatasetBuilder(object): for instrument in self._audio_params['instrument_list']} return (input_, output) - def compute_segments(self, dataset, n_chunks_per_song): - """ Computes segments for each song of the dataset. + def compute_segments( + self, + dataset: Any, + n_chunks_per_song: int) -> Any: + """ + Computes segments for each song of the dataset. - :param dataset: Dataset to compute segments for. - :param n_chunks_per_song: Number of segment per song to compute. - :returns: Segmented dataset. + Parameters: + dataset (Any): + Dataset to compute segments for. + n_chunks_per_song (int): + Number of segment per song to compute. + + Returns: + Any: + Segmented dataset. """ if n_chunks_per_song <= 0: raise ValueError('n_chunks_per_song must be positif') @@ -327,10 +372,13 @@ class DatasetBuilder(object): return dataset @property - def instruments(self): - """ Instrument dataset builder generator. + def instruments(self) -> Any: + """ + Instrument dataset builder generator. - :yield InstrumentBuilder instance. + Yields: + Any: + InstrumentBuilder instance. """ if self._instrument_builders is None: self._instrument_builders = [] @@ -340,22 +388,33 @@ class DatasetBuilder(object): for builder in self._instrument_builders: yield builder - def cache(self, dataset, cache, wait): - """ Cache the given dataset if cache is enabled. Eventually waits for - cache to be available (useful if another process is already computing - cache) if provided wait flag is True. + def cache( + self, + dataset: Any, + cache: str, + wait: bool) -> Any: + """ + Cache the given dataset if cache is enabled. Eventually waits for + cache to be available (useful if another process is already + computing cache) if provided wait flag is `True`. - :param dataset: Dataset to be cached if cache is required. - :param cache: Path of cache directory to be used, None if no cache. - :param wait: If caching is enabled, True is cache should be waited. - :returns: Cached dataset if needed, original dataset otherwise. + Parameters: + dataset (Any): + Dataset to be cached if cache is required. + cache (str): + Path of cache directory to be used, None if no cache. + wait (bool): + If caching is enabled, True is cache should be waited. + + Returns: + Any: + Cached dataset if needed, original dataset otherwise. """ if cache is not None: if wait: while not exists(f'{cache}.index'): - get_logger().info( - 'Cache not available, wait %s', - self.WAIT_PERIOD) + logger.info( + f'Cache not available, wait {self.WAIT_PERIOD}') time.sleep(self.WAIT_PERIOD) cache_path = os.path.split(cache)[0] os.makedirs(cache_path, exist_ok=True) @@ -363,13 +422,20 @@ class DatasetBuilder(object): return dataset def build( - self, csv_path, - batch_size=8, shuffle=True, convert_to_uint=True, - random_data_augmentation=False, random_time_crop=True, - infinite_generator=True, cache_directory=None, - wait_for_cache=False, num_parallel_calls=4, n_chunks_per_song=2,): + self, + csv_path: str, + batch_size: int = 8, + shuffle: bool = True, + convert_to_uint: bool = True, + random_data_augmentation: bool = False, + random_time_crop: bool = True, + infinite_generator: bool = True, + cache_directory: Optional[str] = None, + wait_for_cache: bool = False, + num_parallel_calls: int = 4, + n_chunks_per_song: float = 2,) -> Any: """ - TO BE DOCUMENTED. + TO BE DOCUMENTED. """ dataset = dataset_from_csv(csv_path) dataset = self.compute_segments(dataset, n_chunks_per_song) diff --git a/spleeter/model/__init__.py b/spleeter/model/__init__.py index 8b8f511..c949d9a 100644 --- a/spleeter/model/__init__.py +++ b/spleeter/model/__init__.py @@ -5,6 +5,7 @@ import importlib +# pyright: reportMissingImports=false # pylint: disable=import-error import tensorflow as tf diff --git a/spleeter/model/provider/github.py b/spleeter/model/provider/github.py index f423d9c..2210c41 100644 --- a/spleeter/model/provider/github.py +++ b/spleeter/model/provider/github.py @@ -25,7 +25,7 @@ from tempfile import NamedTemporaryFile from typing import Dict from . import ModelProvider -from ...utils.logging import get_logger +from ...utils.logging import logger # pyright: reportMissingImports=false # pylint: disable=import-error @@ -138,7 +138,7 @@ class GithubModelProvider(ModelProvider): self._release, name)) url = f'{url}.tar.gz' - get_logger().info(f'Downloading model archive {url}') + logger.info(f'Downloading model archive {url}') with httpx.Client(http2=True) as client: with client.strema('GET', url) as response: response.raise_for_status() @@ -147,14 +147,14 @@ class GithubModelProvider(ModelProvider): with archive as stream: for chunk in response.iter_raw(): stream.write(chunk) - get_logger().info('Validating archive checksum') + logger.info('Validating archive checksum') checksum: str = compute_file_checksum(archive.name) if checksum != self.checksum(name): raise IOError( 'Downloaded file is corrupted, please retry') - get_logger().info(f'Extracting downloaded {name} archive') + logger.info(f'Extracting downloaded {name} archive') with tarfile.open(name=archive.name) as tar: tar.extractall(path=path) finally: os.unlink(archive.name) - get_logger().info(f'{name} model file(s) extracted') + logger.info(f'{name} model file(s) extracted') diff --git a/spleeter/separator.py b/spleeter/separator.py index ad7dc3f..d5a7ab0 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -19,13 +19,15 @@ import os from multiprocessing import Pool from os.path import basename, join, splitext, dirname -from typing import Generator, Optional +from spleeter.model.provider import ModelProvider +from typing import Dict, Generator, Optional from . import SpleeterError -from .audio import STFTBackend -from .audio.adapter import get_default_audio_adapter +from .audio import Codec, STFTBackend +from .audio.adapter import AudioAdapter from .audio.convertor import to_stereo from .model import EstimatorSpecBuilder, InputProviderFactory +from .model import model_fn from .utils.configuration import load_configuration # pyright: reportMissingImports=false @@ -65,18 +67,6 @@ class DataGenerator(object): buffer = self._current_data -def get_backend(backend: str) -> str: - """ - """ - if backend not in SUPPORTED_BACKEND: - raise ValueError(f'Unsupported backend {backend}') - if backend == 'auto': - if len(tf.config.list_physical_devices('GPU')): - return 'tensorflow' - return 'librosa' - return backend - - def create_estimator(params, MWF): """ Initialize tensorflow estimator that will perform separation @@ -137,18 +127,21 @@ class Separator(object): else: self._pool = None self._tasks = [] - self._params['stft_backend'] = get_backend(stft_backend) + self._params['stft_backend'] = stft_backend self._data_generator = DataGenerator() def __del__(self) -> None: if self._session: self._session.close() - def _get_prediction_generator(self): - """ Lazy loading access method for internal prediction generator - returned by the predict method of a tensorflow estimator. + def _get_prediction_generator(self) -> Generator: + """ + Lazy loading access method for internal prediction generator + returned by the predict method of a tensorflow estimator. - :returns: generator of prediction. + Returns: + Generator: + Generator of prediction. """ if self._prediction_generator is None: estimator = create_estimator(self._params, self._MWF) @@ -181,17 +174,30 @@ class Separator(object): task.get() task.wait(timeout=timeout) - def _stft(self, data, inverse: bool = False, length=None): - """ Single entrypoint for both stft and istft. This computes stft and - istft with librosa on stereo data. The two channels are processed - separately and are concatenated together in the result. The expected - input formats are: (n_samples, 2) for stft and (T, F, 2) for istft. + def _stft( + self, + data: np.ndarray, + inverse: bool = False, + length: Optional[int] = None) -> np.ndarray: + """ + Single entrypoint for both stft and istft. This computes stft and + istft with librosa on stereo data. The two channels are processed + separately and are concatenated together in the result. The + expected input formats are: (n_samples, 2) for stft and (T, F, 2) + for istft. - :param data: np.array with either the waveform or the complex - spectrogram depending on the parameter inverse - :param inverse: should a stft or an istft be computed. - :returns: Stereo data as numpy array for the transform. - The channels are stored in the last dimension. + Parameters: + data (numpy.array): + Array with either the waveform or the complex spectrogram + depending on the parameter inverse + inverse (bool): + (Optional) Should a stft or an istft be computed. + length (Optional[int]): + + Returns: + numpy.ndarray: + Stereo data as numpy array for the transform. The channels + are stored in the last dimension. """ assert not (inverse and length is None) data = np.asfortranarray(data) @@ -238,19 +244,24 @@ class Separator(object): def _get_session(self): if self._session is None: saver = tf.compat.v1.train.Saver() - latest_checkpoint = tf.train.latest_checkpoint( - get_default_model_dir(self._params['model_dir'])) + provider = ModelProvider.default() + model_directory: str = provider.get(self._params['model_dir']) + latest_checkpoint = tf.train.latest_checkpoint(model_directory) self._session = tf.compat.v1.Session() saver.restore(self._session, latest_checkpoint) return self._session - def _separate_librosa(self, waveform: np.ndarray, audio_id): + def _separate_librosa( + self, + waveform: np.ndarray, + audio_descriptor: str) -> Dict: """ Performs separation with librosa backend for STFT. Parameters: waveform (numpy.ndarray): Waveform to be separated (as a numpy array) + audio_descriptor (str): """ with self._tf_graph.as_default(): out = {} @@ -269,7 +280,7 @@ class Separator(object): feed_dict=self._get_input_provider().get_feed_dict( features, stft, - audio_id)) + audio_descriptor)) for inst in self._get_builder().instruments: out[inst] = self._stft( outputs[inst], @@ -277,7 +288,10 @@ class Separator(object): length=waveform.shape[0]) return out - def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor): + def _separate_tensorflow( + self, + waveform: np.ndarray, + audio_descriptor: str) -> Dict: """ Performs source separation over the given waveform with tensorflow backend. @@ -285,6 +299,7 @@ class Separator(object): Parameters: waveform (numpy.ndarray): Waveform to be separated (as a numpy array) + audio_descriptor (str): Returns: Separated waveforms. @@ -314,44 +329,61 @@ class Separator(object): audio_descriptor (str): (Optional) string describing the waveform (e.g. filename). """ - if self._params['stft_backend'] == 'tensorflow': + backend: str = self._params['stft_backend'] + if backend == STFTBackend.TENSORFLOW: return self._separate_tensorflow(waveform, audio_descriptor) - else: + elif backend == STFTBackend.LIBROSA: return self._separate_librosa(waveform, audio_descriptor) + raise ValueError(f'Unsupported STFT backend {backend}') def separate_to_file( self, - audio_descriptor, - destination, - audio_adapter=get_default_audio_adapter(), - offset=0, - duration=600., - codec='wav', - bitrate='128k', - filename_format='{filename}/{instrument}.{codec}', - synchronous=True): - """ Performs source separation and export result to file using - given audio adapter. - - Filename format should be a Python formattable string that could use - following parameters : {instrument}, {filename}, {foldername} and - {codec}. - - :param audio_descriptor: Describe song to separate, used by audio - adapter to retrieve and load audio data, - in case of file based audio adapter, such - descriptor would be a file path. - :param destination: Target directory to write output to. - :param audio_adapter: (Optional) Audio adapter to use for I/O. - :param offset: (Optional) Offset of loaded song. - :param duration: (Optional) Duration of loaded song - (default: 600s). - :param codec: (Optional) Export codec. - :param bitrate: (Optional) Export bitrate. - :param filename_format: (Optional) Filename format. - :param synchronous: (Optional) True is should by synchronous. + audio_descriptor: str, + destination: str, + audio_adapter: Optional[AudioAdapter] = None, + offset: int = 0, + duration: float = 600., + codec: Codec = Codec.WAV, + bitrate: str = '128k', + filename_format: str = '{filename}/{instrument}.{codec}', + synchronous: bool = True) -> None: """ - waveform, sample_rate = audio_adapter.load( + Performs source separation and export result to file using + given audio adapter. + + Filename format should be a Python formattable string that could + use following parameters : + + - {instrument} + - {filename} + - {foldername} + - {codec}. + + Parameters: + audio_descriptor (str): + Describe song to separate, used by audio adapter to + retrieve and load audio data, in case of file based + audio adapter, such descriptor would be a file path. + destination (str): + Target directory to write output to. + audio_adapter (Optional[AudioAdapter]): + (Optional) Audio adapter to use for I/O. + offset (int): + (Optional) Offset of loaded song. + duration (float): + (Optional) Duration of loaded song (default: 600s). + codec (Codec): + (Optional) Export codec. + bitrate (str): + (Optional) Export bitrate. + filename_format (str): + (Optional) Filename format. + synchronous (bool): + (Optional) True is should by synchronous. + """ + if audio_adapter is None: + audio_adapter = AudioAdapter.default() + waveform, _ = audio_adapter.load( audio_descriptor, offset=offset, duration=duration, @@ -369,33 +401,42 @@ class Separator(object): def save_to_file( self, - sources, - audio_descriptor, - destination, - filename_format='{filename}/{instrument}.{codec}', - codec='wav', - audio_adapter=get_default_audio_adapter(), - bitrate='128k', - synchronous=True): - """ Export dictionary of sources to files. - - :param sources: Dictionary of sources to be exported. The - keys are the name of the instruments, and - the values are Nx2 numpy arrays containing - the corresponding intrument waveform, as - returned by the separate method - :param audio_descriptor: Describe song to separate, used by audio - adapter to retrieve and load audio data, - in case of file based audio adapter, such - descriptor would be a file path. - :param destination: Target directory to write output to. - :param filename_format: (Optional) Filename format. - :param codec: (Optional) Export codec. - :param audio_adapter: (Optional) Audio adapter to use for I/O. - :param bitrate: (Optional) Export bitrate. - :param synchronous: (Optional) True is should by synchronous. - + sources: Dict, + audio_descriptor: str, + destination: str, + filename_format: str = '{filename}/{instrument}.{codec}', + codec: Codec = Codec.WAV, + audio_adapter: Optional[AudioAdapter] = None, + bitrate: str = '128k', + synchronous: bool = True) -> None: """ + Export dictionary of sources to files. + + Parameters: + sources (Dict): + Dictionary of sources to be exported. The keys are the name + of the instruments, and the values are `N x 2` numpy arrays + containing the corresponding intrument waveform, as + returned by the separate method + audio_descriptor (str): + Describe song to separate, used by audio adapter to + retrieve and load audio data, in case of file based audio + adapter, such descriptor would be a file path. + destination (str): + Target directory to write output to. + filename_format (str): + (Optional) Filename format. + codec (Codec): + (Optional) Export codec. + audio_adapter (Optional[AudioAdapter]): + (Optional) Audio adapter to use for I/O. + bitrate (str): + (Optional) Export bitrate. + synchronous (bool): + (Optional) True is should by synchronous. + """ + if audio_adapter is None: + audio_adapter = AudioAdapter.default() foldername = basename(dirname(audio_descriptor)) filename = splitext(basename(audio_descriptor))[0] generated = []