diff --git a/spleeter/audio/adapter.py b/spleeter/audio/adapter.py index d75612b..e7d0c6b 100644 --- a/spleeter/audio/adapter.py +++ b/spleeter/audio/adapter.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from importlib import import_module from pathlib import Path from spleeter.audio import Codec -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from .. import SpleeterError from ..types import AudioDescriptor, Signal @@ -35,9 +35,9 @@ class AudioAdapter(ABC): def load( self, audio_descriptor: AudioDescriptor, - offset: float, - duration: float, - sample_rate: float, + offset: Optional[float] = None, + duration: Optional[float] = None, + sample_rate: Optional[float] = None, dtype: np.dtype = np.float32) -> Signal: """ Loads the audio file denoted by the given audio descriptor and @@ -47,11 +47,11 @@ class AudioAdapter(ABC): audio_descriptor (AudioDescriptor): Describe song to load, in case of file based audio adapter, such descriptor would be a file path. - offset (float): + offset (Optional[float]): Start offset to load from in seconds. - duration (float): + duration (Optional[float]): Duration to load in seconds. - sample_rate (float): + sample_rate (Optional[float]): Sample rate to load audio with. dtype (numpy.dtype): (Optional) Numpy data type to use, default to `float32`. @@ -136,7 +136,7 @@ class AudioAdapter(ABC): data: np.ndarray, sample_rate: float, codec: Codec = None, - bitrate: str = None): + bitrate: str = None) -> None: """ Save the given audio data to the file denoted by the given path. diff --git a/spleeter/audio/ffmpeg.py b/spleeter/audio/ffmpeg.py index 890e02e..6f5ce37 100644 --- a/spleeter/audio/ffmpeg.py +++ b/spleeter/audio/ffmpeg.py @@ -8,76 +8,92 @@ used within this library. """ +import datetime as dt import os import shutil +from pathlib import Path +from typing import Dict, Optional, Union + +from . import Codec +from .adapter import AudioAdapter +from .. import SpleeterError +from ..types import Signal +from ..utils.logging import get_logger + +# pyright: reportMissingImports=false # pylint: disable=import-error import ffmpeg import numpy as np # pylint: enable=import-error -from .adapter import AudioAdapter -from .. import SpleeterError -from ..utils.logging import get_logger - __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' -def _check_ffmpeg_install(): - """ Ensure FFMPEG binaries are available. - - :raise SpleeterError: If ffmpeg or ffprobe is not found. - """ - for binary in ('ffmpeg', 'ffprobe'): - if shutil.which(binary) is None: - raise SpleeterError('{} binary not found'.format(binary)) - - -def _to_ffmpeg_time(n): - """ Format number of seconds to time expected by FFMPEG. - :param n: Time in seconds to format. - :returns: Formatted time in FFMPEG format. - """ - m, s = divmod(n, 60) - h, m = divmod(m, 60) - return '%d:%02d:%09.6f' % (h, m, s) - - -def _to_ffmpeg_codec(codec): - ffmpeg_codecs = { - 'm4a': 'aac', - 'ogg': 'libvorbis', - 'wma': 'wmav2', - } - return ffmpeg_codecs.get(codec) or codec - - class FFMPEGProcessAudioAdapter(AudioAdapter): - """ An AudioAdapter implementation that use FFMPEG binary through - subprocess in order to perform I/O operation for audio processing. - - When created, FFMPEG binary path will be checked and expended, - raising exception if not found. Such path could be infered using - FFMPEG_PATH environment variable. """ + An AudioAdapter implementation that use FFMPEG binary through + subprocess in order to perform I/O operation for audio processing. + + When created, FFMPEG binary path will be checked and expended, + raising exception if not found. Such path could be infered using + `FFMPEG_PATH` environment variable. + """ + + SUPPORTED_CODECS: Dict[Codec, str] = { + Codec.M4A: 'aac', + Codec.OGG: 'libvorbis', + Codec.WMA: 'wmav2' + } + """ FFMPEG codec name mapping. """ + + def __init__(_) -> None: + """ + Default constructor, ensure FFMPEG binaries are available. + + Raises: + SpleeterError: + If ffmpeg or ffprobe is not found. + """ + for binary in ('ffmpeg', 'ffprobe'): + if shutil.which(binary) is None: + raise SpleeterError('{} binary not found'.format(binary)) def load( - self, path, offset=None, duration=None, - sample_rate=None, dtype=np.float32): - """ Loads the audio file denoted by the given path - and returns it data as a waveform. - - :param path: Path of the audio file to load data from. - :param offset: (Optional) Start offset to load from in seconds. - :param duration: (Optional) Duration to load in seconds. - :param sample_rate: (Optional) Sample rate to load audio with. - :param dtype: (Optional) Numpy data type to use, default to float32. - :returns: Loaded data a (waveform, sample_rate) tuple. - :raise SpleeterError: If any error occurs while loading audio. + _, + path: Union[Path, str], + offset: Optional[float] = None, + duration: Optional[float] = None, + sample_rate: Optional[float] = None, + dtype: np.dtype = np.float32) -> Signal: """ - _check_ffmpeg_install() + Loads the audio file denoted by the given path + and returns it data as a waveform. + + Parameters: + path (Union[Path, str]: + Path of the audio file to load data from. + offset (Optional[float]): + Start offset to load from in seconds. + duration (Optional[float]): + Duration to load in seconds. + sample_rate (Optional[float]): + Sample rate to load audio with. + dtype (numpy.dtype): + (Optional) Numpy data type to use, default to `float32`. + + Returns: + Signal: + Loaded data a (waveform, sample_rate) tuple. + + Raises: + SpleeterError: + If any error occurs while loading audio. + """ + if isinstance(path, Path): + path = str(path) if not isinstance(path, str): path = path.decode() try: @@ -97,9 +113,9 @@ class FFMPEGProcessAudioAdapter(AudioAdapter): sample_rate = metadata['sample_rate'] output_kwargs = {'format': 'f32le', 'ar': sample_rate} if duration is not None: - output_kwargs['t'] = _to_ffmpeg_time(duration) + output_kwargs['t'] = str(dt.timedelta(seconds=duration)) if offset is not None: - output_kwargs['ss'] = _to_ffmpeg_time(offset) + output_kwargs['ss'] = str(dt.timedelta(seconds=offset)) process = ( ffmpeg .input(path) @@ -112,29 +128,46 @@ class FFMPEGProcessAudioAdapter(AudioAdapter): return (waveform, sample_rate) def save( - self, path, data, sample_rate, - codec=None, bitrate=None): - """ Write waveform data to the file denoted by the given path - using FFMPEG process. - - :param path: Path of the audio file to save data in. - :param data: Waveform data to write. - :param sample_rate: Sample rate to write file in. - :param codec: (Optional) Writing codec to use. - :param bitrate: (Optional) Bitrate of the written audio file. - :raise IOError: If any error occurs while using FFMPEG to write data. + self, + path: Union[Path, str], + data: np.ndarray, + sample_rate: float, + codec: Codec = None, + bitrate: str = None) -> None: """ - _check_ffmpeg_install() + Write waveform data to the file denoted by the given path using + FFMPEG process. + + Parameters: + path (Union[Path, str]): + Path like of the audio file to save data in. + data (numpy.ndarray): + Waveform data to write. + sample_rate (float): + Sample rate to write file in. + codec (): + (Optional) Writing codec to use, default to `None`. + bitrate (str): + (Optional) Bitrate of the written audio file, default to + `None`. + + Raises: + IOError: + If any error occurs while using FFMPEG to write data. + """ + if isinstance(path, Path): + path = str(path) directory = os.path.dirname(path) if not os.path.exists(directory): - raise SpleeterError(f'output directory does not exists: {directory}') - get_logger().debug('Writing file %s', path) + raise SpleeterError( + f'output directory does not exists: {directory}') + get_logger().debug(f'Writing file {path}') input_kwargs = {'ar': sample_rate, 'ac': data.shape[1]} output_kwargs = {'ar': sample_rate, 'strict': '-2'} if bitrate: output_kwargs['audio_bitrate'] = bitrate if codec is not None and codec != 'wav': - output_kwargs['codec'] = _to_ffmpeg_codec(codec) + output_kwargs['codec'] = self.SUPPORTED_CODECS.get(codec, codec) process = ( ffmpeg .input('pipe:', format='f32le', **input_kwargs) @@ -147,4 +180,4 @@ class FFMPEGProcessAudioAdapter(AudioAdapter): process.wait() except IOError: raise SpleeterError(f'FFMPEG error: {process.stderr.read()}') - get_logger().info('File %s written succesfully', path) + get_logger().info(f'File {path} written succesfully') diff --git a/spleeter/audio/spectrogram.py b/spleeter/audio/spectrogram.py index a70e4fe..3eb8d19 100644 --- a/spleeter/audio/spectrogram.py +++ b/spleeter/audio/spectrogram.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # coding: utf8 -""" Spectrogram specific data augmentation """ +""" Spectrogram specific data augmentation. """ # pyright: reportMissingImports=false # pylint: disable=import-error @@ -17,25 +17,35 @@ __license__ = 'MIT License' def compute_spectrogram_tf( - waveform, - frame_length=2048, frame_step=512, - spec_exponent=1., window_exponent=1.): - """ Compute magnitude / power spectrogram from waveform as - a n_samples x n_channels tensor. - - :param waveform: Input waveform as (times x number of channels) - tensor. - :param frame_length: Length of a STFT frame to use. - :param frame_step: HOP between successive frames. - :param spec_exponent: Exponent of the spectrogram (usually 1 for - magnitude spectrogram, or 2 for power spectrogram). - :param window_exponent: Exponent applied to the Hann windowing function - (may be useful for making perfect STFT/iSTFT - reconstruction). - :returns: Computed magnitude / power spectrogram as a - (T x F x n_channels) tensor. + waveform: tf.Tensor, + frame_length: int = 2048, + frame_step: int = 512, + spec_exponent: float = 1., + window_exponent: float = 1.) -> tf.Tensor: """ - stft_tensor = tf.transpose( + Compute magnitude / power spectrogram from waveform as a + `n_samples x n_channels` tensor. + + Parameters: + waveform (tensorflow.Tensor): + Input waveform as `(times x number of channels)` tensor. + frame_length (int): + Length of a STFT frame to use. + frame_step (int): + HOP between successive frames. + spec_exponent (float): + Exponent of the spectrogram (usually 1 for magnitude + spectrogram, or 2 for power spectrogram). + window_exponent (float): + Exponent applied to the Hann windowing function (may be + useful for making perfect STFT/iSTFT reconstruction). + + Returns: + tensorflow.Tensor: + Computed magnitude / power spectrogram as a + `(T x F x n_channels)` tensor. + """ + stft_tensor: tf.Tensor = tf.transpose( stft( tf.transpose(waveform), frame_length, @@ -49,16 +59,25 @@ def compute_spectrogram_tf( def time_stretch( - spectrogram, - factor=1.0, - method=tf.image.ResizeMethod.BILINEAR): - """ Time stretch a spectrogram preserving shape in tensorflow. Note that - this is an approximation in the frequency domain. + spectrogram: tf.Tensor, + factor: float = 1.0, + method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR + ) -> tf.Tensor: + """ + Time stretch a spectrogram preserving shape in tensorflow. Note that + this is an approximation in the frequency domain. - :param spectrogram: Input spectrogram to be time stretched as tensor. - :param factor: (Optional) Time stretch factor, must be >0, default to 1. - :param mehtod: (Optional) Interpolation method, default to BILINEAR. - :returns: Time stretched spectrogram as tensor with same shape. + Parameters: + spectrogram (tensorflow.Tensor): + Input spectrogram to be time stretched as tensor. + factor (float): + (Optional) Time stretch factor, must be > 0, default to `1`. + method (tensorflow.image.ResizeMethod): + (Optional) Interpolation method, default to `BILINEAR`. + + Returns: + tensorflow.Tensor: + Time stretched spectrogram as tensor with same shape. """ T = tf.shape(spectrogram)[0] T_ts = tf.cast(tf.cast(T, tf.float32) * factor, tf.int32)[0] @@ -71,15 +90,27 @@ def time_stretch( return tf.image.resize_image_with_crop_or_pad(ts_spec, T, F) -def random_time_stretch(spectrogram, factor_min=0.9, factor_max=1.1, **kwargs): - """ Time stretch a spectrogram preserving shape with random ratio in - tensorflow. Applies time_stretch to spectrogram with a random ratio drawn - uniformly in [factor_min, factor_max]. +def random_time_stretch( + spectrogram: tf.Tensor, + factor_min: float = 0.9, + factor_max: float = 1.1, + **kwargs) -> tf.Tensor: + """ + Time stretch a spectrogram preserving shape with random ratio in + tensorflow. Applies time_stretch to spectrogram with a random ratio + drawn uniformly in `[factor_min, factor_max]`. - :param spectrogram: Input spectrogram to be time stretched as tensor. - :param factor_min: (Optional) Min time stretch factor, default to 0.9. - :param factor_max: (Optional) Max time stretch factor, default to 1.1. - :returns: Randomly time stretched spectrogram as tensor with same shape. + Parameters: + spectrogram (tensorflow.Tensor): + Input spectrogram to be time stretched as tensor. + factor_min (float): + (Optional) Min time stretch factor, default to `0.9`. + factor_max (float): + (Optional) Max time stretch factor, default to `1.1`. + + Returns: + tensorflow.Tensor: + Randomly time stretched spectrogram as tensor with same shape. """ factor = tf.random_uniform( shape=(1,), @@ -88,16 +119,25 @@ def random_time_stretch(spectrogram, factor_min=0.9, factor_max=1.1, **kwargs): def pitch_shift( - spectrogram, - semitone_shift=0.0, - method=tf.image.ResizeMethod.BILINEAR): - """ Pitch shift a spectrogram preserving shape in tensorflow. Note that - this is an approximation in the frequency domain. + spectrogram: tf.Tensor, + semitone_shift: float = 0.0, + method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR + ) -> tf.Tensor: + """ + Pitch shift a spectrogram preserving shape in tensorflow. Note that + this is an approximation in the frequency domain. - :param spectrogram: Input spectrogram to be pitch shifted as tensor. - :param semitone_shift: (Optional) Pitch shift in semitone, default to 0.0. - :param mehtod: (Optional) Interpolation method, default to BILINEAR. - :returns: Pitch shifted spectrogram (same shape as spectrogram). + Parameters: + spectrogram (tensorflow.Tensor): + Input spectrogram to be pitch shifted as tensor. + semitone_shift (float): + (Optional) Pitch shift in semitone, default to `0.0`. + method (tensorflow.image.ResizeMethod): + (Optional) Interpolation method, default to `BILINEAR`. + + Returns: + tensorflow.Tensor: + Pitch shifted spectrogram (same shape as spectrogram). """ factor = 2 ** (semitone_shift / 12.) T = tf.shape(spectrogram)[0] @@ -112,16 +152,28 @@ def pitch_shift( return tf.pad(ps_spec[:, :F, :], paddings, 'CONSTANT') -def random_pitch_shift(spectrogram, shift_min=-1., shift_max=1., **kwargs): - """ Pitch shift a spectrogram preserving shape with random ratio in - tensorflow. Applies pitch_shift to spectrogram with a random shift - amount (expressed in semitones) drawn uniformly in [shift_min, shift_max]. +def random_pitch_shift( + spectrogram: tf.Tensor, + shift_min: float = -1., + shift_max: float = 1., + **kwargs) -> tf.Tensor: + """ + Pitch shift a spectrogram preserving shape with random ratio in + tensorflow. Applies pitch_shift to spectrogram with a random shift + amount (expressed in semitones) drawn uniformly in + `[shift_min, shift_max]`. - :param spectrogram: Input spectrogram to be pitch shifted as tensor. + Parameters: + spectrogram (tensorflow.Tensor): + Input spectrogram to be pitch shifted as tensor. + shift_min (float): + (Optional) Min pitch shift in semitone, default to -1. + shift_max (float): + (Optional) Max pitch shift in semitone, default to 1. - :param shift_min: (Optional) Min pitch shift in semitone, default to -1. - :param shift_max: (Optional) Max pitch shift in semitone, default to 1. - :returns: Randomly pitch shifted spectrogram (same shape as spectrogram). + Returns: + tensorflow.Tensor: + Randomly pitch shifted spectrogram (same shape as spectrogram). """ semitone_shift = tf.random_uniform( shape=(1,),