From 194a50e7cfbd4656030aa0eef82f24757db65933 Mon Sep 17 00:00:00 2001 From: Faylixe Date: Mon, 7 Dec 2020 14:38:37 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=89=20=20start=20typing=20integration?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- spleeter/audio/__init__.py | 13 +++ spleeter/audio/adapter.py | 194 +++++++++++++++++++++------------- spleeter/audio/convertor.py | 136 +++++++++++++++++------- spleeter/audio/spectrogram.py | 1 + spleeter/commands/__init__.py | 60 ++++++----- spleeter/commands/separate.py | 4 + spleeter/types.py | 15 +++ 7 files changed, 284 insertions(+), 139 deletions(-) create mode 100644 spleeter/types.py diff --git a/spleeter/audio/__init__.py b/spleeter/audio/__init__.py index 3d973c5..8f1343f 100644 --- a/spleeter/audio/__init__.py +++ b/spleeter/audio/__init__.py @@ -10,6 +10,19 @@ - Waveform convertion and transforming functions. """ +from enum import Enum + __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' + + +class Codec(str, Enum): + """ Enumeration of supported audio codec. """ + + WAV: str = 'wav' + MP3: str = 'mp3' + OGG: str = 'ogg' + M4A: str = 'm4a' + WMA: str = 'wma' + FLAC: str = 'flac' diff --git a/spleeter/audio/adapter.py b/spleeter/audio/adapter.py index 994c8df..d75612b 100644 --- a/spleeter/audio/adapter.py +++ b/spleeter/audio/adapter.py @@ -3,21 +3,22 @@ """ AudioAdapter class defintion. """ -import subprocess - from abc import ABC, abstractmethod from importlib import import_module -from os.path import exists +from pathlib import Path +from spleeter.audio import Codec +from typing import Any, Dict, List, Union +from .. import SpleeterError +from ..types import AudioDescriptor, Signal +from ..utils.logging import get_logger + +# pyright: reportMissingImports=false # pylint: disable=import-error import numpy as np import tensorflow as tf - -from tensorflow.signal import stft, hann_window # pylint: enable=import-error -from .. import SpleeterError -from ..utils.logging import get_logger __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' @@ -27,46 +28,72 @@ __license__ = 'MIT License' class AudioAdapter(ABC): """ An abstract class for manipulating audio signal. """ - # Default audio adapter singleton instance. - DEFAULT = None + _DEFAULT: 'AudioAdapter' = None + """ Default audio adapter singleton instance. """ @abstractmethod def load( - self, audio_descriptor, offset, duration, - sample_rate, dtype=np.float32): - """ Loads the audio file denoted by the given audio descriptor - and returns it data as a waveform. Aims to be implemented - by client. + self, + audio_descriptor: AudioDescriptor, + offset: float, + duration: float, + sample_rate: float, + dtype: np.dtype = np.float32) -> Signal: + """ + Loads the audio file denoted by the given audio descriptor and + returns it data as a waveform. Aims to be implemented by client. - :param audio_descriptor: Describe song to load, in case of file - based audio adapter, such descriptor would - be a file path. - :param offset: Start offset to load from in seconds. - :param duration: Duration to load in seconds. - :param sample_rate: Sample rate to load audio with. - :param dtype: Numpy data type to use, default to float32. - :returns: Loaded data as (wf, sample_rate) tuple. + Parameters: + audio_descriptor (AudioDescriptor): + Describe song to load, in case of file based audio adapter, + such descriptor would be a file path. + offset (float): + Start offset to load from in seconds. + duration (float): + Duration to load in seconds. + sample_rate (float): + Sample rate to load audio with. + dtype (numpy.dtype): + (Optional) Numpy data type to use, default to `float32`. + + Returns: + Signal: + Loaded data as (wf, sample_rate) tuple. """ pass def load_tf_waveform( - self, audio_descriptor, - offset=0.0, duration=1800., sample_rate=44100, - dtype=b'float32', waveform_name='waveform'): - """ Load the audio and convert it to a tensorflow waveform. + self, + audio_descriptor, + offset: float = 0.0, + duration: float = 1800., + sample_rate: int = 44100, + dtype: bytes = b'float32', + waveform_name: str = 'waveform') -> Dict[str, Any]: + """ + Load the audio and convert it to a tensorflow waveform. - :param audio_descriptor: Describe song to load, in case of file - based audio adapter, such descriptor would - be a file path. - :param offset: Start offset to load from in seconds. - :param duration: Duration to load in seconds. - :param sample_rate: Sample rate to load audio with. - :param dtype: Numpy data type to use, default to float32. - :param waveform_name: (Optional) Name of the key in output dict. - :returns: TF output dict with waveform as - (T x chan numpy array) and a boolean that - tells whether there were an error while - trying to load the waveform. + Parameters: + audio_descriptor (): + Describe song to load, in case of file based audio adapter, + such descriptor would be a file path. + offset (float): + Start offset to load from in seconds. + duration (float): + Duration to load in seconds. + sample_rate (float): + Sample rate to load audio with. + dtype (bytes): + (Optional)data type to use, default to `b'float32'`. + waveform_name (str): + (Optional) Name of the key in output dict, default to + `'waveform'`. + + Returns: + Dict[str, Any]: + TF output dict with waveform as `(T x chan numpy array)` + and a boolean that tells whether there were an error while + trying to load the waveform. """ # Cast parameters to TF format. offset = tf.cast(offset, tf.float64) @@ -100,50 +127,69 @@ class AudioAdapter(ABC): waveform, error = results[0] return { waveform_name: waveform, - f'{waveform_name}_error': error - } + f'{waveform_name}_error': error} @abstractmethod def save( - self, path, data, sample_rate, - codec=None, bitrate=None): - """ Save the given audio data to the file denoted by - the given path. + self, + path: Union[Path, str], + data: np.ndarray, + sample_rate: float, + codec: Codec = None, + bitrate: str = None): + """ + Save the given audio data to the file denoted by the given path. - :param path: Path of the audio file to save data in. - :param data: Waveform data to write. - :param sample_rate: Sample rate to write file in. - :param codec: (Optional) Writing codec to use. - :param bitrate: (Optional) Bitrate of the written audio file. + Parameters: + path (Union[Path, str]): + Path like of the audio file to save data in. + data (numpy.ndarray): + Waveform data to write. + sample_rate (float): + Sample rate to write file in. + codec (): + (Optional) Writing codec to use, default to `None`. + bitrate (str): + (Optional) Bitrate of the written audio file, default to + `None`. """ pass + @classmethod + def default(cls: type) -> 'AudioAdapter': + """ + Builds and returns a default audio adapter instance. -def get_default_audio_adapter(): - """ Builds and returns a default audio adapter instance. + Returns: + AudioAdapter: + Default adapter instance to use. + """ + if cls._DEFAULT is None: + from .ffmpeg import FFMPEGProcessAudioAdapter + cls._DEFAULT = FFMPEGProcessAudioAdapter() + return cls._DEFAULT - :returns: An audio adapter instance. - """ - if AudioAdapter.DEFAULT is None: - from .ffmpeg import FFMPEGProcessAudioAdapter - AudioAdapter.DEFAULT = FFMPEGProcessAudioAdapter() - return AudioAdapter.DEFAULT + @classmethod + def get(cls: type, descriptor: str) -> 'AudioAdapter': + """ + Load dynamically an AudioAdapter from given class descriptor. + Parameters: + descriptor (str): + Adapter class descriptor (module.Class) -def get_audio_adapter(descriptor): - """ Load dynamically an AudioAdapter from given class descriptor. - - :param descriptor: Adapter class descriptor (module.Class) - :returns: Created adapter instance. - """ - if descriptor is None: - return get_default_audio_adapter() - module_path = descriptor.split('.') - adapter_class_name = module_path[-1] - module_path = '.'.join(module_path[:-1]) - adapter_module = import_module(module_path) - adapter_class = getattr(adapter_module, adapter_class_name) - if not isinstance(adapter_class, AudioAdapter): - raise SpleeterError( - f'{adapter_class_name} is not a valid AudioAdapter class') - return adapter_class() + Returns: + AudioAdapter: + Created adapter instance. + """ + if not descriptor: + return cls.default() + module_path: List[str] = descriptor.split('.') + adapter_class_name: str = module_path[-1] + module_path: str = '.'.join(module_path[:-1]) + adapter_module = import_module(module_path) + adapter_class = getattr(adapter_module, adapter_class_name) + if not isinstance(adapter_class, AudioAdapter): + raise SpleeterError( + f'{adapter_class_name} is not a valid AudioAdapter class') + return adapter_class() diff --git a/spleeter/audio/convertor.py b/spleeter/audio/convertor.py index 0751b03..6f8b135 100644 --- a/spleeter/audio/convertor.py +++ b/spleeter/audio/convertor.py @@ -3,39 +3,54 @@ """ This module provides audio data convertion functions. """ +from ..utils.tensor import from_float32_to_uint8, from_uint8_to_float32 + +# pyright: reportMissingImports=false # pylint: disable=import-error import numpy as np import tensorflow as tf # pylint: enable=import-error -from ..utils.tensor import from_float32_to_uint8, from_uint8_to_float32 - __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' -def to_n_channels(waveform, n_channels): - """ Convert a waveform to n_channels by removing or - duplicating channels if needed (in tensorflow). +def to_n_channels( + waveform: tf.Tensor, + n_channels: int) -> tf.Tensor: + """ + Convert a waveform to n_channels by removing or duplicating channels if + needed (in tensorflow). - :param waveform: Waveform to transform. - :param n_channels: Number of channel to reshape waveform in. - :returns: Reshaped waveform. + Parameters: + waveform (tensorflow.Tensor): + Waveform to transform. + n_channels (int): + Number of channel to reshape waveform in. + + Returns: + tensorflow.Tensor: + Reshaped waveform. """ return tf.cond( tf.shape(waveform)[1] >= n_channels, true_fn=lambda: waveform[:, :n_channels], - false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels] - ) + false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels]) -def to_stereo(waveform): - """ Convert a waveform to stereo by duplicating if mono, - or truncating if too many channels. +def to_stereo(waveform: np.ndarray) -> np.ndarray: + """ + Convert a waveform to stereo by duplicating if mono, or truncating + if too many channels. - :param waveform: a (N, d) numpy array. - :returns: A stereo waveform as a (N, 1) numpy array. + Parameters: + waveform (numpy.ndarray): + a `(N, d)` numpy array. + + Returns: + numpy.ndarray: + A stereo waveform as a `(N, 1)` numpy array. """ if waveform.shape[1] == 1: return np.repeat(waveform, 2, axis=-1) @@ -44,45 +59,84 @@ def to_stereo(waveform): return waveform -def gain_to_db(tensor, espilon=10e-10): - """ Convert from gain to decibel in tensorflow. +def gain_to_db(tensor: tf.Tensor, espilon: float = 10e-10) -> tf.Tensor: + """ + Convert from gain to decibel in tensorflow. - :param tensor: Tensor to convert. - :param epsilon: Operation constant. - :returns: Converted tensor. + Parameters: + tensor (tensorflow.Tensor): + Tensor to convert + epsilon (float): + Operation constant. + + Returns: + tensorflow.Tensor: + Converted tensor. """ return 20. / np.log(10) * tf.math.log(tf.maximum(tensor, espilon)) -def db_to_gain(tensor): - """ Convert from decibel to gain in tensorflow. +def db_to_gain(tensor: tf.Tensor) -> tf.Tensor: + """ + Convert from decibel to gain in tensorflow. - :param tensor_db: Tensor to convert. - :returns: Converted tensor. + Parameters: + tensor (tensorflow.Tensor): + Tensor to convert + + Returns: + tensorflow.Tensor: + Converted tensor. """ return tf.pow(10., (tensor / 20.)) -def spectrogram_to_db_uint(spectrogram, db_range=100., **kwargs): - """ Encodes given spectrogram into uint8 using decibel scale. - - :param spectrogram: Spectrogram to be encoded as TF float tensor. - :param db_range: Range in decibel for encoding. - :returns: Encoded decibel spectrogram as uint8 tensor. +def spectrogram_to_db_uint( + spectrogram: tf.Tensor, + db_range: float = 100., + **kwargs) -> tf.Tensor: """ - db_spectrogram = gain_to_db(spectrogram) - max_db_spectrogram = tf.reduce_max(db_spectrogram) - db_spectrogram = tf.maximum(db_spectrogram, max_db_spectrogram - db_range) + Encodes given spectrogram into uint8 using decibel scale. + + Parameters: + spectrogram (tensorflow.Tensor): + Spectrogram to be encoded as TF float tensor. + db_range (float): + Range in decibel for encoding. + + Returns: + tensorflow.Tensor: + Encoded decibel spectrogram as `uint8` tensor. + """ + db_spectrogram: tf.Tensor = gain_to_db(spectrogram) + max_db_spectrogram: tf.Tensor = tf.reduce_max(db_spectrogram) + db_spectrogram: tf.Tensor = tf.maximum( + db_spectrogram, + max_db_spectrogram - db_range) return from_float32_to_uint8(db_spectrogram, **kwargs) -def db_uint_spectrogram_to_gain(db_uint_spectrogram, min_db, max_db): - """ Decode spectrogram from uint8 decibel scale. - - :param db_uint_spectrogram: Decibel pectrogram to decode. - :param min_db: Lower bound limit for decoding. - :param max_db: Upper bound limit for decoding. - :returns: Decoded spectrogram as float2 tensor. +def db_uint_spectrogram_to_gain( + db_uint_spectrogram: tf.Tensor, + min_db: tf.Tensor, + max_db: tf.Tensor) -> tf.Tensor: """ - db_spectrogram = from_uint8_to_float32(db_uint_spectrogram, min_db, max_db) + Decode spectrogram from uint8 decibel scale. + + Paramters: + db_uint_spectrogram (tensorflow.Tensor): + Decibel spectrogram to decode. + min_db (tensorflow.Tensor): + Lower bound limit for decoding. + max_db (tensorflow.Tensor): + Upper bound limit for decoding. + + Returns: + tensorflow.Tensor: + Decoded spectrogram as `float32` tensor. + """ + db_spectrogram: tf.Tensor = from_uint8_to_float32( + db_uint_spectrogram, + min_db, + max_db) return db_to_gain(db_spectrogram) diff --git a/spleeter/audio/spectrogram.py b/spleeter/audio/spectrogram.py index a1a79b3..a70e4fe 100644 --- a/spleeter/audio/spectrogram.py +++ b/spleeter/audio/spectrogram.py @@ -3,6 +3,7 @@ """ Spectrogram specific data augmentation """ +# pyright: reportMissingImports=false # pylint: disable=import-error import numpy as np import tensorflow as tf diff --git a/spleeter/commands/__init__.py b/spleeter/commands/__init__.py index a54e4c1..241a9bc 100644 --- a/spleeter/commands/__init__.py +++ b/spleeter/commands/__init__.py @@ -51,34 +51,38 @@ OPT_PARAMS = { 'help': 'JSON filename that contains params' } -# -s opt specification (separate). -OPT_OFFSET = { - 'dest': 'offset', - 'type': float, - 'default': 0., - 'help': 'Set the starting offset to separate audio from.' -} +Offset: OptionInfo = Option( + 0., + '--offset', + '-s', + help='Set the starting offset to separate audio from') -# -d opt specification (separate). -OPT_DURATION = { - 'dest': 'duration', - 'type': float, - 'default': 600., - 'help': ( +Duration: OptionInfo = Option( + 600., + '--duration', + '-d', + help=( 'Set a maximum duration for processing audio ' '(only separate offset + duration first seconds of ' - 'the input file)') -} + 'the input file)')) -# -w opt specification (separate) -OPT_STFT_BACKEND = { - 'dest': 'stft_backend', - 'type': str, - 'choices' : ["tensorflow", "librosa", "auto"], - 'default': "auto", - 'help': 'Who should be in charge of computing the stfts. Librosa is faster than tensorflow on CPU and uses' - ' less memory. "auto" will use tensorflow when GPU acceleration is available and librosa when not.' -} + +class STFTBackendEnum(Enum, str): + + AUTO: str + TENSORFLOW: str + LIBROSA: str + + +STFTBackend: OptionInfo = Option( + STFTBackendEnum.AUTO, + '--stft-backend', + '-B', + case_sensitive=False, + help=( + 'Who should be in charge of computing the stfts. Librosa is faster ' + 'than tensorflow on CPU and uses less memory. "auto" will use ' + 'tensorflow when GPU acceleration is available and librosa when not')) # -c opt specification (separate). @@ -128,6 +132,14 @@ OPT_ADAPTER = { 'help': 'Name of the audio adapter to use for audio I/O' } + + +AudioAdapter: OptionInfo = Option( + 'spleeter.audio.ffmpeg.FFMPEGProcessAudioAdapter', + '--adapter', + help='Name of the audio adapter to use for audio I/O') + + # -a opt specification (train, evaluate and separate). OPT_VERBOSE = { 'action': 'store_true', diff --git a/spleeter/commands/separate.py b/spleeter/commands/separate.py index 193d8f6..740deb4 100644 --- a/spleeter/commands/separate.py +++ b/spleeter/commands/separate.py @@ -19,6 +19,10 @@ __author__ = 'Deezer Research' __license__ = 'MIT License' +from typer import Option + +AudioAdapter = Option() + def entrypoint(arguments, params): """ Command entrypoint. diff --git a/spleeter/types.py b/spleeter/types.py new file mode 100644 index 0000000..cb97577 --- /dev/null +++ b/spleeter/types.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +# coding: utf8 + +""" TO DOCUMENT """ + +from typing import Any, Tuple + +# pyright: reportMissingImports=false +# pylint: disable=import-error +import numpy as np +# pylint: enable=import-error + + +AudioDescriptor: type = Any +Signal: type = Tuple[np.ndarray, float]