mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
Initial commit from private spleeter
This commit is contained in:
8
spleeter/utils/__init__.py
Normal file
8
spleeter/utils/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf8
|
||||
|
||||
""" This package provides utility function and classes. """
|
||||
|
||||
__email__ = 'research@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
15
spleeter/utils/audio/__init__.py
Normal file
15
spleeter/utils/audio/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf8
|
||||
|
||||
"""
|
||||
`spleeter.utils.audio` package provides various
|
||||
tools for manipulating audio content such as :
|
||||
|
||||
- Audio adapter class for abstract interaction with audio file.
|
||||
- FFMPEG implementation for audio adapter.
|
||||
- Waveform convertion and transforming functions.
|
||||
"""
|
||||
|
||||
__email__ = 'research@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
144
spleeter/utils/audio/adapter.py
Normal file
144
spleeter/utils/audio/adapter.py
Normal file
@@ -0,0 +1,144 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf8
|
||||
|
||||
""" AudioAdapter class defintion. """
|
||||
|
||||
import subprocess
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from importlib import import_module
|
||||
from os.path import exists
|
||||
|
||||
# pylint: disable=import-error
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.signal import stft, hann_window
|
||||
# pylint: enable=import-error
|
||||
|
||||
from ..logging import get_logger
|
||||
|
||||
__email__ = 'research@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
|
||||
class AudioAdapter(ABC):
|
||||
""" An abstract class for manipulating audio signal. """
|
||||
|
||||
# Default audio adapter singleton instance.
|
||||
DEFAULT = None
|
||||
|
||||
@abstractmethod
|
||||
def load(
|
||||
self, audio_descriptor, offset, duration,
|
||||
sample_rate, dtype=np.float32):
|
||||
""" Loads the audio file denoted by the given audio descriptor
|
||||
and returns it data as a waveform. Aims to be implemented
|
||||
by client.
|
||||
|
||||
:param audio_descriptor: Describe song to load, in case of file
|
||||
based audio adapter, such descriptor would
|
||||
be a file path.
|
||||
:param offset: Start offset to load from in seconds.
|
||||
:param duration: Duration to load in seconds.
|
||||
:param sample_rate: Sample rate to load audio with.
|
||||
:param dtype: Numpy data type to use, default to float32.
|
||||
:returns: Loaded data as (wf, sample_rate) tuple.
|
||||
"""
|
||||
pass
|
||||
|
||||
def load_tf_waveform(
|
||||
self, audio_descriptor,
|
||||
offset=0.0, duration=1800., sample_rate=44100,
|
||||
dtype=b'float32', waveform_name='waveform'):
|
||||
""" Load the audio and convert it to a tensorflow waveform.
|
||||
|
||||
:param audio_descriptor: Describe song to load, in case of file
|
||||
based audio adapter, such descriptor would
|
||||
be a file path.
|
||||
:param offset: Start offset to load from in seconds.
|
||||
:param duration: Duration to load in seconds.
|
||||
:param sample_rate: Sample rate to load audio with.
|
||||
:param dtype: Numpy data type to use, default to float32.
|
||||
:param waveform_name: (Optional) Name of the key in output dict.
|
||||
:returns: TF output dict with waveform as
|
||||
(T x chan numpy array) and a boolean that
|
||||
tells whether there were an error while
|
||||
trying to load the waveform.
|
||||
"""
|
||||
# Cast parameters to TF format.
|
||||
offset = tf.cast(offset, tf.float64)
|
||||
duration = tf.cast(duration, tf.float64)
|
||||
|
||||
# Defined safe loading function.
|
||||
def safe_load(path, offset, duration, sample_rate, dtype):
|
||||
get_logger().info(
|
||||
f'Loading audio {path} from {offset} to {offset + duration}')
|
||||
try:
|
||||
(data, _) = self.load(
|
||||
path.numpy(),
|
||||
offset.numpy(),
|
||||
duration.numpy(),
|
||||
sample_rate.numpy(),
|
||||
dtype=dtype.numpy())
|
||||
return (data, False)
|
||||
except Exception as e:
|
||||
get_logger().warning(e)
|
||||
return (np.float32(-1.0), True)
|
||||
|
||||
# Execute function and format results.
|
||||
results = tf.py_function(
|
||||
safe_load,
|
||||
[audio_descriptor, offset, duration, sample_rate, dtype],
|
||||
(tf.float32, tf.bool)),
|
||||
waveform, error = results[0]
|
||||
return {
|
||||
waveform_name: waveform,
|
||||
f'{waveform_name}_error': error
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self, path, data, sample_rate,
|
||||
codec=None, bitrate=None):
|
||||
""" Save the given audio data to the file denoted by
|
||||
the given path.
|
||||
|
||||
:param path: Path of the audio file to save data in.
|
||||
:param data: Waveform data to write.
|
||||
:param sample_rate: Sample rate to write file in.
|
||||
:param codec: (Optional) Writing codec to use.
|
||||
:param bitrate: (Optional) Bitrate of the written audio file.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def get_default_audio_adapter():
|
||||
""" Builds and returns a default audio adapter instance.
|
||||
|
||||
:returns: An audio adapter instance.
|
||||
"""
|
||||
if AudioAdapter.DEFAULT is None:
|
||||
from .ffmpeg import FFMPEGProcessAudioAdapter
|
||||
AudioAdapter.DEFAULT = FFMPEGProcessAudioAdapter()
|
||||
return AudioAdapter.DEFAULT
|
||||
|
||||
|
||||
def get_audio_adapter(descriptor):
|
||||
""" Load dynamically an AudioAdapter from given class descriptor.
|
||||
|
||||
:param descriptor: Adapter class descriptor (module.Class)
|
||||
:returns: Created adapter instance.
|
||||
"""
|
||||
if descriptor is None:
|
||||
return get_default_audio_adapter()
|
||||
module_path = descriptor.split('.')
|
||||
adapter_class_name = module_path[-1]
|
||||
module_path = '.'.join(module_path[:-1])
|
||||
adapter_module = import_module(module_path)
|
||||
adapter_class = getattr(adapter_module, adapter_class_name)
|
||||
if not isinstance(adapter_class, AudioAdapter):
|
||||
raise ValueError(
|
||||
f'{adapter_class_name} is not a valid AudioAdapter class')
|
||||
return adapter_class()
|
||||
88
spleeter/utils/audio/convertor.py
Normal file
88
spleeter/utils/audio/convertor.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf8
|
||||
|
||||
""" This module provides audio data convertion functions. """
|
||||
|
||||
# pylint: disable=import-error
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
# pylint: enable=import-error
|
||||
|
||||
from ..tensor import from_float32_to_uint8, from_uint8_to_float32
|
||||
|
||||
__email__ = 'research@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
|
||||
def to_n_channels(waveform, n_channels):
|
||||
""" Convert a waveform to n_channels by removing or
|
||||
duplicating channels if needed (in tensorflow).
|
||||
|
||||
:param waveform: Waveform to transform.
|
||||
:param n_channels: Number of channel to reshape waveform in.
|
||||
:returns: Reshaped waveform.
|
||||
"""
|
||||
return tf.cond(
|
||||
tf.shape(waveform)[1] >= n_channels,
|
||||
true_fn=lambda: waveform[:, :n_channels],
|
||||
false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels]
|
||||
)
|
||||
|
||||
|
||||
def to_stereo(waveform):
|
||||
""" Convert a waveform to stereo by duplicating if mono,
|
||||
or truncating if too many channels.
|
||||
|
||||
:param waveform: a (N, d) numpy array.
|
||||
:returns: A stereo waveform as a (N, 1) numpy array.
|
||||
"""
|
||||
if waveform.shape[1] == 1:
|
||||
return np.repeat(waveform, 2, axis=-1)
|
||||
if waveform.shape[1] > 2:
|
||||
return waveform[:, :2]
|
||||
return waveform
|
||||
|
||||
|
||||
def gain_to_db(tensor, espilon=10e-10):
|
||||
""" Convert from gain to decibel in tensorflow.
|
||||
|
||||
:param tensor: Tensor to convert.
|
||||
:param epsilon: Operation constant.
|
||||
:returns: Converted tensor.
|
||||
"""
|
||||
return 20. / np.log(10) * tf.math.log(tf.maximum(tensor, espilon))
|
||||
|
||||
|
||||
def db_to_gain(tensor):
|
||||
""" Convert from decibel to gain in tensorflow.
|
||||
|
||||
:param tensor_db: Tensor to convert.
|
||||
:returns: Converted tensor.
|
||||
"""
|
||||
return tf.pow(10., (tensor / 20.))
|
||||
|
||||
|
||||
def spectrogram_to_db_uint(spectrogram, db_range=100., **kwargs):
|
||||
""" Encodes given spectrogram into uint8 using decibel scale.
|
||||
|
||||
:param spectrogram: Spectrogram to be encoded as TF float tensor.
|
||||
:param db_range: Range in decibel for encoding.
|
||||
:returns: Encoded decibel spectrogram as uint8 tensor.
|
||||
"""
|
||||
db_spectrogram = gain_to_db(spectrogram)
|
||||
max_db_spectrogram = tf.reduce_max(db_spectrogram)
|
||||
db_spectrogram = tf.maximum(db_spectrogram, max_db_spectrogram - db_range)
|
||||
return from_float32_to_uint8(db_spectrogram, **kwargs)
|
||||
|
||||
|
||||
def db_uint_spectrogram_to_gain(db_uint_spectrogram, min_db, max_db):
|
||||
""" Decode spectrogram from uint8 decibel scale.
|
||||
|
||||
:param db_uint_spectrogram: Decibel pectrogram to decode.
|
||||
:param min_db: Lower bound limit for decoding.
|
||||
:param max_db: Upper bound limit for decoding.
|
||||
:returns: Decoded spectrogram as float2 tensor.
|
||||
"""
|
||||
db_spectrogram = from_uint8_to_float32(db_uint_spectrogram, min_db, max_db)
|
||||
return db_to_gain(db_spectrogram)
|
||||
263
spleeter/utils/audio/ffmpeg.py
Normal file
263
spleeter/utils/audio/ffmpeg.py
Normal file
@@ -0,0 +1,263 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf8
|
||||
|
||||
"""
|
||||
This module provides an AudioAdapter implementation based on FFMPEG
|
||||
process. Such implementation is POSIXish and depends on nothing except
|
||||
standard Python libraries. Thus this implementation is the default one
|
||||
used within this library.
|
||||
"""
|
||||
|
||||
import os
|
||||
import os.path
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
import numpy as np # pylint: disable=import-error
|
||||
|
||||
from .adapter import AudioAdapter
|
||||
from ..logging import get_logger
|
||||
|
||||
__email__ = 'research@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
# Default FFMPEG binary name.
|
||||
_UNIX_BINARY = 'ffmpeg'
|
||||
_WINDOWS_BINARY = 'ffmpeg.exe'
|
||||
|
||||
|
||||
def _which(program):
|
||||
""" A pure python implementation of `which`command
|
||||
for retrieving absolute path from command name or path.
|
||||
|
||||
@see https://stackoverflow.com/a/377028/1211342
|
||||
|
||||
:param program: Program name or path to expend.
|
||||
:returns: Absolute path of program if any, None otherwise.
|
||||
"""
|
||||
def is_exe(fpath):
|
||||
return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
|
||||
|
||||
fpath, _ = os.path.split(program)
|
||||
if fpath:
|
||||
if is_exe(program):
|
||||
return program
|
||||
else:
|
||||
for path in os.environ['PATH'].split(os.pathsep):
|
||||
exe_file = os.path.join(path, program)
|
||||
if is_exe(exe_file):
|
||||
return exe_file
|
||||
return None
|
||||
|
||||
|
||||
def _get_ffmpeg_path():
|
||||
""" Retrieves FFMPEG binary path using ENVVAR if defined
|
||||
or default binary name (Windows or UNIX style).
|
||||
|
||||
:returns: Absolute path of FFMPEG binary.
|
||||
:raise IOError: If FFMPEG binary cannot be found.
|
||||
"""
|
||||
ffmpeg_path = os.environ.get('FFMPEG_PATH', None)
|
||||
if ffmpeg_path is None:
|
||||
# Note: try to infer standard binary name regarding of platform.
|
||||
if platform.system() == 'Windows':
|
||||
ffmpeg_path = _WINDOWS_BINARY
|
||||
else:
|
||||
ffmpeg_path = _UNIX_BINARY
|
||||
expended = _which(ffmpeg_path)
|
||||
if expended is None:
|
||||
raise IOError(f'FFMPEG binary ({ffmpeg_path}) not found')
|
||||
return expended
|
||||
|
||||
|
||||
def _to_ffmpeg_time(n):
|
||||
""" Format number of seconds to time expected by FFMPEG.
|
||||
|
||||
:param n: Time in seconds to format.
|
||||
:returns: Formatted time in FFMPEG format.
|
||||
"""
|
||||
m, s = divmod(n, 60)
|
||||
h, m = divmod(m, 60)
|
||||
return '%d:%02d:%09.6f' % (h, m, s)
|
||||
|
||||
|
||||
def _parse_ffmpg_results(stderr):
|
||||
""" Extract number of channels and sample rate from
|
||||
the given FFMPEG STDERR output line.
|
||||
|
||||
:param stderr: STDERR output line to parse.
|
||||
:returns: Parsed n_channels and sample_rate values.
|
||||
"""
|
||||
# Setup default value.
|
||||
n_channels = 0
|
||||
sample_rate = 0
|
||||
# Find samplerate
|
||||
match = re.search(r'(\d+) hz', stderr)
|
||||
if match:
|
||||
sample_rate = int(match.group(1))
|
||||
# Channel count.
|
||||
match = re.search(r'hz, ([^,]+),', stderr)
|
||||
if match:
|
||||
mode = match.group(1)
|
||||
if mode == 'stereo':
|
||||
n_channels = 2
|
||||
else:
|
||||
match = re.match(r'(\d+) ', mode)
|
||||
n_channels = match and int(match.group(1)) or 1
|
||||
return n_channels, sample_rate
|
||||
|
||||
|
||||
class _CommandBuilder(object):
|
||||
""" A simple builder pattern class for CLI string. """
|
||||
|
||||
def __init__(self, binary):
|
||||
""" Default constructor. """
|
||||
self._command = [binary]
|
||||
|
||||
def flag(self, flag):
|
||||
""" Add flag or unlabelled opt. """
|
||||
self._command.append(flag)
|
||||
return self
|
||||
|
||||
def opt(self, short, value, formatter=str):
|
||||
""" Add option if value not None. """
|
||||
if value is not None:
|
||||
self._command.append(short)
|
||||
self._command.append(formatter(value))
|
||||
return self
|
||||
|
||||
def command(self):
|
||||
""" Build string command. """
|
||||
return self._command
|
||||
|
||||
|
||||
class FFMPEGProcessAudioAdapter(AudioAdapter):
|
||||
""" An AudioAdapter implementation that use FFMPEG binary through
|
||||
subprocess in order to perform I/O operation for audio processing.
|
||||
|
||||
When created, FFMPEG binary path will be checked and expended,
|
||||
raising exception if not found. Such path could be infered using
|
||||
FFMPEG_PATH environment variable.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
""" Default constructor. """
|
||||
self._ffmpeg_path = _get_ffmpeg_path()
|
||||
|
||||
def _get_command_builder(self):
|
||||
""" Creates and returns a command builder using FFMPEG path.
|
||||
|
||||
:returns: Built command builder.
|
||||
"""
|
||||
return _CommandBuilder(self._ffmpeg_path)
|
||||
|
||||
def load(
|
||||
self, path, offset=None, duration=None,
|
||||
sample_rate=None, dtype=np.float32):
|
||||
""" Loads the audio file denoted by the given path
|
||||
and returns it data as a waveform.
|
||||
|
||||
:param path: Path of the audio file to load data from.
|
||||
:param offset: (Optional) Start offset to load from in seconds.
|
||||
:param duration: (Optional) Duration to load in seconds.
|
||||
:param sample_rate: (Optional) Sample rate to load audio with.
|
||||
:param dtype: (Optional) Numpy data type to use, default to float32.
|
||||
:returns: Loaded data a (waveform, sample_rate) tuple.
|
||||
"""
|
||||
if not isinstance(path, str):
|
||||
path = path.decode()
|
||||
command = (
|
||||
self._get_command_builder()
|
||||
.opt('-ss', offset, formatter=_to_ffmpeg_time)
|
||||
.opt('-t', duration, formatter=_to_ffmpeg_time)
|
||||
.opt('-i', path)
|
||||
.opt('-ar', sample_rate)
|
||||
.opt('-f', 'f32le')
|
||||
.flag('-')
|
||||
.command())
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
buffer = process.stdout.read(-1)
|
||||
# Read STDERR until end of the process detected.
|
||||
while True:
|
||||
status = process.stderr.readline()
|
||||
if not status:
|
||||
raise OSError('Stream info not found')
|
||||
if isinstance(status, bytes): # Note: Python 3 compatibility.
|
||||
status = status.decode('utf8', 'ignore')
|
||||
status = status.strip().lower()
|
||||
if 'no such file' in status:
|
||||
raise IOError(f'File {path} not found')
|
||||
elif 'invalid data found' in status:
|
||||
raise IOError(f'FFMPEG error : {status}')
|
||||
elif 'audio:' in status:
|
||||
n_channels, ffmpeg_sample_rate = _parse_ffmpg_results(status)
|
||||
if sample_rate is None:
|
||||
sample_rate = ffmpeg_sample_rate
|
||||
break
|
||||
# Load waveform and clean process.
|
||||
waveform = np.frombuffer(buffer, dtype='<f4').reshape(-1, n_channels)
|
||||
if not waveform.dtype == np.dtype(dtype):
|
||||
waveform = waveform.astype(dtype)
|
||||
process.stdout.close()
|
||||
process.stderr.close()
|
||||
del process
|
||||
return (waveform, sample_rate)
|
||||
|
||||
def save(
|
||||
self, path, data, sample_rate,
|
||||
codec=None, bitrate=None):
|
||||
""" Write waveform data to the file denoted by the given path
|
||||
using FFMPEG process.
|
||||
|
||||
:param path: Path of the audio file to save data in.
|
||||
:param data: Waveform data to write.
|
||||
:param sample_rate: Sample rate to write file in.
|
||||
:param codec: (Optional) Writing codec to use.
|
||||
:param bitrate: (Optional) Bitrate of the written audio file.
|
||||
:raise IOError: If any error occurs while using FFMPEG to write data.
|
||||
"""
|
||||
directory = os.path.split(path)[0]
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
get_logger().debug('Writing file %s', path)
|
||||
# NOTE: Tweak.
|
||||
if codec == 'wav':
|
||||
codec = None
|
||||
command = (
|
||||
self._get_command_builder()
|
||||
.flag('-y')
|
||||
.opt('-loglevel', 'error')
|
||||
.opt('-f', 'f32le')
|
||||
.opt('-ar', sample_rate)
|
||||
.opt('-ac', data.shape[1])
|
||||
.opt('-i', '-')
|
||||
.flag('-vn')
|
||||
.opt('-acodec', codec)
|
||||
.opt('-ar', sample_rate) # Note: why twice ?
|
||||
.opt('-strict', '-2') # Note: For 'aac' codec support.
|
||||
.opt('-ab', bitrate)
|
||||
.flag(path)
|
||||
.command())
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
stdout=open(os.devnull, 'wb'),
|
||||
stdin=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
# Write data to STDIN.
|
||||
try:
|
||||
process.stdin.write(
|
||||
data.astype('<f4').tostring())
|
||||
except IOError:
|
||||
raise IOError(f'FFMPEG error: {process.stderr.read()}')
|
||||
# Clean process.
|
||||
process.stdin.close()
|
||||
if process.stderr is not None:
|
||||
process.stderr.close()
|
||||
process.wait()
|
||||
del process
|
||||
get_logger().info('File %s written', path)
|
||||
128
spleeter/utils/audio/spectrogram.py
Normal file
128
spleeter/utils/audio/spectrogram.py
Normal file
@@ -0,0 +1,128 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf8
|
||||
|
||||
""" Spectrogram specific data augmentation """
|
||||
|
||||
# pylint: disable=import-error
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.signal import stft, hann_window
|
||||
# pylint: enable=import-error
|
||||
|
||||
__email__ = 'research@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
|
||||
def compute_spectrogram_tf(
|
||||
waveform,
|
||||
frame_length=2048, frame_step=512,
|
||||
spec_exponent=1., window_exponent=1.):
|
||||
""" Compute magnitude / power spectrogram from waveform as
|
||||
a n_samples x n_channels tensor.
|
||||
|
||||
:param waveform: Input waveform as (times x number of channels)
|
||||
tensor.
|
||||
:param frame_length: Length of a STFT frame to use.
|
||||
:param frame_step: HOP between successive frames.
|
||||
:param spec_exponent: Exponent of the spectrogram (usually 1 for
|
||||
magnitude spectrogram, or 2 for power spectrogram).
|
||||
:param window_exponent: Exponent applied to the Hann windowing function
|
||||
(may be useful for making perfect STFT/iSTFT
|
||||
reconstruction).
|
||||
:returns: Computed magnitude / power spectrogram as a
|
||||
(T x F x n_channels) tensor.
|
||||
"""
|
||||
stft_tensor = tf.transpose(
|
||||
stft(
|
||||
tf.transpose(waveform),
|
||||
frame_length,
|
||||
frame_step,
|
||||
window_fn=lambda f, dtype: hann_window(
|
||||
f,
|
||||
periodic=True,
|
||||
dtype=waveform.dtype) ** window_exponent),
|
||||
perm=[1, 2, 0])
|
||||
return np.abs(stft_tensor) ** spec_exponent
|
||||
|
||||
|
||||
def time_stretch(
|
||||
spectrogram,
|
||||
factor=1.0,
|
||||
method=tf.image.ResizeMethod.BILINEAR):
|
||||
""" Time stretch a spectrogram preserving shape in tensorflow. Note that
|
||||
this is an approximation in the frequency domain.
|
||||
|
||||
:param spectrogram: Input spectrogram to be time stretched as tensor.
|
||||
:param factor: (Optional) Time stretch factor, must be >0, default to 1.
|
||||
:param mehtod: (Optional) Interpolation method, default to BILINEAR.
|
||||
:returns: Time stretched spectrogram as tensor with same shape.
|
||||
"""
|
||||
T = tf.shape(spectrogram)[0]
|
||||
T_ts = tf.cast(tf.cast(T, tf.float32) * factor, tf.int32)[0]
|
||||
F = tf.shape(spectrogram)[1]
|
||||
ts_spec = tf.image.resize_images(
|
||||
spectrogram,
|
||||
[T_ts, F],
|
||||
method=method,
|
||||
align_corners=True)
|
||||
return tf.image.resize_image_with_crop_or_pad(ts_spec, T, F)
|
||||
|
||||
|
||||
def random_time_stretch(spectrogram, factor_min=0.9, factor_max=1.1, **kwargs):
|
||||
""" Time stretch a spectrogram preserving shape with random ratio in
|
||||
tensorflow. Applies time_stretch to spectrogram with a random ratio drawn
|
||||
uniformly in [factor_min, factor_max].
|
||||
|
||||
:param spectrogram: Input spectrogram to be time stretched as tensor.
|
||||
:param factor_min: (Optional) Min time stretch factor, default to 0.9.
|
||||
:param factor_max: (Optional) Max time stretch factor, default to 1.1.
|
||||
:returns: Randomly time stretched spectrogram as tensor with same shape.
|
||||
"""
|
||||
factor = tf.random_uniform(
|
||||
shape=(1,),
|
||||
seed=0) * (factor_max - factor_min) + factor_min
|
||||
return time_stretch(spectrogram, factor=factor, **kwargs)
|
||||
|
||||
|
||||
def pitch_shift(
|
||||
spectrogram,
|
||||
semitone_shift=0.0,
|
||||
method=tf.image.ResizeMethod.BILINEAR):
|
||||
""" Pitch shift a spectrogram preserving shape in tensorflow. Note that
|
||||
this is an approximation in the frequency domain.
|
||||
|
||||
:param spectrogram: Input spectrogram to be pitch shifted as tensor.
|
||||
:param semitone_shift: (Optional) Pitch shift in semitone, default to 0.0.
|
||||
:param mehtod: (Optional) Interpolation method, default to BILINEAR.
|
||||
:returns: Pitch shifted spectrogram (same shape as spectrogram).
|
||||
"""
|
||||
factor = 2 ** (semitone_shift / 12.)
|
||||
T = tf.shape(spectrogram)[0]
|
||||
F = tf.shape(spectrogram)[1]
|
||||
F_ps = tf.cast(tf.cast(F, tf.float32) * factor, tf.int32)[0]
|
||||
ps_spec = tf.image.resize_images(
|
||||
spectrogram,
|
||||
[T, F_ps],
|
||||
method=method,
|
||||
align_corners=True)
|
||||
paddings = [[0, 0], [0, tf.maximum(0, F - F_ps)], [0, 0]]
|
||||
return tf.pad(ps_spec[:, :F, :], paddings, 'CONSTANT')
|
||||
|
||||
|
||||
def random_pitch_shift(spectrogram, shift_min=-1., shift_max=1., **kwargs):
|
||||
""" Pitch shift a spectrogram preserving shape with random ratio in
|
||||
tensorflow. Applies pitch_shift to spectrogram with a random shift
|
||||
amount (expressed in semitones) drawn uniformly in [shift_min, shift_max].
|
||||
|
||||
:param spectrogram: Input spectrogram to be pitch shifted as tensor.
|
||||
|
||||
:param shift_min: (Optional) Min pitch shift in semitone, default to -1.
|
||||
:param shift_max: (Optional) Max pitch shift in semitone, default to 1.
|
||||
:returns: Randomly pitch shifted spectrogram (same shape as spectrogram).
|
||||
"""
|
||||
semitone_shift = tf.random_uniform(
|
||||
shape=(1,),
|
||||
seed=0) * (shift_max - shift_min) + shift_min
|
||||
return pitch_shift(spectrogram, semitone_shift=semitone_shift, **kwargs)
|
||||
47
spleeter/utils/configuration.py
Normal file
47
spleeter/utils/configuration.py
Normal file
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf8
|
||||
|
||||
""" Module that provides configuration loading function. """
|
||||
|
||||
import json
|
||||
|
||||
try:
|
||||
import importlib.resources as loader
|
||||
except ImportError:
|
||||
# Try backported to PY<37 `importlib_resources`.
|
||||
import importlib_resources as loader
|
||||
|
||||
from os.path import exists
|
||||
|
||||
from .. import resources
|
||||
|
||||
|
||||
__email__ = 'research@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
_EMBEDDED_CONFIGURATION_PREFIX = 'spleeter:'
|
||||
|
||||
|
||||
def load_configuration(descriptor):
|
||||
""" Load configuration from the given descriptor. Could be
|
||||
either a `spleeter:` prefixed embedded configuration name
|
||||
or a file system path to read configuration from.
|
||||
|
||||
:param descriptor: Configuration descriptor to use for lookup.
|
||||
:returns: Loaded description as dict.
|
||||
:raise ValueError: If required embedded configuration does not exists.
|
||||
:raise IOError: If required configuration file does not exists.
|
||||
"""
|
||||
# Embedded configuration reading.
|
||||
if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX):
|
||||
name = descriptor[len(_EMBEDDED_CONFIGURATION_PREFIX):]
|
||||
if not loader.is_resource(resources, f'{name}.json'):
|
||||
raise ValueError(f'No embedded configuration {name} found')
|
||||
with loader.open_text(resources, f'{name}.json') as stream:
|
||||
return json.load(stream)
|
||||
# Standard file reading.
|
||||
if not exists(descriptor):
|
||||
raise IOError(f'Configuration file {descriptor} not found')
|
||||
with open(descriptor, 'r') as stream:
|
||||
return json.load(stream)
|
||||
69
spleeter/utils/estimator.py
Normal file
69
spleeter/utils/estimator.py
Normal file
@@ -0,0 +1,69 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf8
|
||||
|
||||
""" Utility functions for creating estimator. """
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
# pylint: disable=import-error
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib import predictor
|
||||
# pylint: enable=import-error
|
||||
|
||||
from ..model import model_fn
|
||||
from ..model.provider import get_default_model_provider
|
||||
|
||||
# Default exporting directory for predictor.
|
||||
DEFAULT_EXPORT_DIRECTORY = '/tmp/serving'
|
||||
|
||||
|
||||
def create_estimator(params, MWF):
|
||||
"""
|
||||
Initialize tensorflow estimator that will perform separation
|
||||
|
||||
Params:
|
||||
- params: a dictionnary of parameters for building the model
|
||||
|
||||
Returns:
|
||||
a tensorflow estimator
|
||||
"""
|
||||
# Load model.
|
||||
model_directory = params['model_dir']
|
||||
model_provider = get_default_model_provider()
|
||||
params['model_dir'] = model_provider.get(model_directory)
|
||||
params['MWF'] = MWF
|
||||
# Setup config
|
||||
session_config = tf.compat.v1.ConfigProto()
|
||||
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
|
||||
config = tf.estimator.RunConfig(session_config=session_config)
|
||||
# Setup estimator
|
||||
estimator = tf.estimator.Estimator(
|
||||
model_fn=model_fn,
|
||||
model_dir=params['model_dir'],
|
||||
params=params,
|
||||
config=config
|
||||
)
|
||||
return estimator
|
||||
|
||||
|
||||
def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
|
||||
""" Exports given estimator as predictor into the given directory
|
||||
and returns associated tf.predictor instance.
|
||||
|
||||
:param estimator: Estimator to export.
|
||||
:param directory: (Optional) path to write exported model into.
|
||||
"""
|
||||
def receiver():
|
||||
shape = (None, estimator.params['n_channels'])
|
||||
features = {
|
||||
'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape),
|
||||
'audio_id': tf.compat.v1.placeholder(tf.string)}
|
||||
return tf.estimator.export.ServingInputReceiver(features, features)
|
||||
|
||||
estimator.export_saved_model(directory, receiver)
|
||||
versions = [
|
||||
model for model in Path(directory).iterdir()
|
||||
if model.is_dir() and 'temp' not in str(model)]
|
||||
latest = str(sorted(versions)[-1])
|
||||
return predictor.from_saved_model(latest)
|
||||
45
spleeter/utils/logging.py
Normal file
45
spleeter/utils/logging.py
Normal file
@@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf8
|
||||
|
||||
""" Centralized logging facilities for Spleeter. """
|
||||
|
||||
from os import environ
|
||||
|
||||
__email__ = 'research@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
|
||||
class _LoggerHolder(object):
|
||||
""" Logger singleton instance holder. """
|
||||
|
||||
INSTANCE = None
|
||||
|
||||
|
||||
def get_logger():
|
||||
""" Returns library scoped logger.
|
||||
|
||||
:returns: Library logger.
|
||||
"""
|
||||
if _LoggerHolder.INSTANCE is None:
|
||||
# pylint: disable=import-error
|
||||
from tensorflow.compat.v1 import logging
|
||||
# pylint: enable=import-error
|
||||
_LoggerHolder.INSTANCE = logging
|
||||
_LoggerHolder.INSTANCE.set_verbosity(_LoggerHolder.INSTANCE.ERROR)
|
||||
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
return _LoggerHolder.INSTANCE
|
||||
|
||||
|
||||
def enable_logging():
|
||||
""" Enable INFO level logging. """
|
||||
environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||
logger = get_logger()
|
||||
logger.set_verbosity(logger.INFO)
|
||||
|
||||
|
||||
def enable_verbose_logging():
|
||||
""" Enable DEBUG level logging. """
|
||||
environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
|
||||
logger = get_logger()
|
||||
logger.set_verbosity(logger.DEBUG)
|
||||
191
spleeter/utils/tensor.py
Normal file
191
spleeter/utils/tensor.py
Normal file
@@ -0,0 +1,191 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf8
|
||||
|
||||
""" Utility function for tensorflow. """
|
||||
|
||||
# pylint: disable=import-error
|
||||
import tensorflow as tf
|
||||
import pandas as pd
|
||||
# pylint: enable=import-error
|
||||
|
||||
__email__ = 'research@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
|
||||
def sync_apply(tensor_dict, func, concat_axis=1):
|
||||
""" Return a function that applies synchronously the provided func on the
|
||||
provided dictionnary of tensor. This means that func is applied to the
|
||||
concatenation of the tensors in tensor_dict. This is useful for performing
|
||||
random operation that needs the same drawn value on multiple tensor, such
|
||||
as a random time-crop on both input data and label (the same crop should be
|
||||
applied to both input data and label, so random crop cannot be applied
|
||||
separately on each of them).
|
||||
|
||||
IMPORTANT NOTE: all tensor are assumed to be the same shape.
|
||||
|
||||
Params:
|
||||
- tensor_dict: dictionary (key: strings, values: tf.tensor)
|
||||
a dictionary of tensor.
|
||||
- func: function
|
||||
function to be applied to the concatenation of the tensors in
|
||||
tensor_dict
|
||||
- concat_axis: int
|
||||
The axis on which to perform the concatenation.
|
||||
|
||||
Returns:
|
||||
processed tensors dictionary with the same name (keys) as input
|
||||
tensor_dict.
|
||||
"""
|
||||
if concat_axis not in {0, 1}:
|
||||
raise NotImplementedError(
|
||||
'Function only implemented for concat_axis equal to 0 or 1')
|
||||
tensor_list = list(tensor_dict.values())
|
||||
concat_tensor = tf.concat(tensor_list, concat_axis)
|
||||
processed_concat_tensor = func(concat_tensor)
|
||||
tensor_shape = tf.shape(list(tensor_dict.values())[0])
|
||||
D = tensor_shape[concat_axis]
|
||||
if concat_axis == 0:
|
||||
return {
|
||||
name: processed_concat_tensor[index * D:(index + 1) * D, :, :]
|
||||
for index, name in enumerate(tensor_dict)
|
||||
}
|
||||
return {
|
||||
name: processed_concat_tensor[:, index * D:(index + 1) * D, :]
|
||||
for index, name in enumerate(tensor_dict)
|
||||
}
|
||||
|
||||
|
||||
def from_float32_to_uint8(
|
||||
tensor,
|
||||
tensor_key='tensor',
|
||||
min_key='min',
|
||||
max_key='max'):
|
||||
"""
|
||||
|
||||
:param tensor:
|
||||
:param tensor_key:
|
||||
:param min_key:
|
||||
:param max_key:
|
||||
:returns:
|
||||
"""
|
||||
tensor_min = tf.reduce_min(tensor)
|
||||
tensor_max = tf.reduce_max(tensor)
|
||||
return {
|
||||
tensor_key: tf.cast(
|
||||
(tensor - tensor_min) / (tensor_max - tensor_min + 1e-16)
|
||||
* 255.9999, dtype=tf.uint8),
|
||||
min_key: tensor_min,
|
||||
max_key: tensor_max
|
||||
}
|
||||
|
||||
|
||||
def from_uint8_to_float32(tensor, tensor_min, tensor_max):
|
||||
"""
|
||||
|
||||
:param tensor:
|
||||
:param tensor_min:
|
||||
:param tensor_max:
|
||||
:returns:
|
||||
"""
|
||||
return (
|
||||
tf.cast(tensor, tf.float32)
|
||||
* (tensor_max - tensor_min)
|
||||
/ 255.9999 + tensor_min)
|
||||
|
||||
|
||||
def pad_and_partition(tensor, segment_len):
|
||||
""" Pad and partition a tensor into segment of len segment_len
|
||||
along the first dimension. The tensor is padded with 0 in order
|
||||
to ensure that the first dimension is a multiple of segment_len.
|
||||
|
||||
Tensor must be of known fixed rank
|
||||
|
||||
:Example:
|
||||
|
||||
>>> tensor = [[1, 2, 3], [4, 5, 6]]
|
||||
>>> segment_len = 2
|
||||
>>> pad_and_partition(tensor, segment_len)
|
||||
[[[1, 2], [4, 5]], [[3, 0], [6, 0]]]
|
||||
|
||||
:param tensor:
|
||||
:param segment_len:
|
||||
:returns:
|
||||
"""
|
||||
tensor_size = tf.math.floormod(tf.shape(tensor)[0], segment_len)
|
||||
pad_size = tf.math.floormod(segment_len - tensor_size, segment_len)
|
||||
padded = tf.pad(
|
||||
tensor,
|
||||
[[0, pad_size]] + [[0, 0]] * (len(tensor.shape)-1))
|
||||
split = (tf.shape(padded)[0] + segment_len - 1) // segment_len
|
||||
return tf.reshape(
|
||||
padded,
|
||||
tf.concat(
|
||||
[[split, segment_len], tf.shape(padded)[1:]],
|
||||
axis=0))
|
||||
|
||||
|
||||
def pad_and_reshape(instr_spec, frame_length, F):
|
||||
"""
|
||||
:param instr_spec:
|
||||
:param frame_length:
|
||||
:param F:
|
||||
:returns:
|
||||
"""
|
||||
spec_shape = tf.shape(instr_spec)
|
||||
extension_row = tf.zeros((spec_shape[0], spec_shape[1], 1, spec_shape[-1]))
|
||||
n_extra_row = (frame_length) // 2 + 1 - F
|
||||
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
|
||||
extended_spec = tf.concat([instr_spec, extension], axis=2)
|
||||
old_shape = tf.shape(extended_spec)
|
||||
new_shape = tf.concat([
|
||||
[old_shape[0] * old_shape[1]],
|
||||
old_shape[2:]],
|
||||
axis=0)
|
||||
processed_instr_spec = tf.reshape(extended_spec, new_shape)
|
||||
return processed_instr_spec
|
||||
|
||||
|
||||
def dataset_from_csv(csv_path, **kwargs):
|
||||
""" Load dataset from a CSV file using Pandas. kwargs if any are
|
||||
forwarded to the `pandas.read_csv` function.
|
||||
|
||||
:param csv_path: Path of the CSV file to load dataset from.
|
||||
:returns: Loaded dataset.
|
||||
"""
|
||||
df = pd.read_csv(csv_path, **kwargs)
|
||||
dataset = (
|
||||
tf.data.Dataset.from_tensor_slices(
|
||||
{key: df[key].values for key in df})
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
def check_tensor_shape(tensor_tf, target_shape):
|
||||
""" Return a Tensorflow boolean graph that indicates whether
|
||||
sample[features_key] has the specified target shape. Only check
|
||||
not None entries of target_shape.
|
||||
|
||||
:param tensor_tf: Tensor to check shape for.
|
||||
:param target_shape: Target shape to compare tensor to.
|
||||
:returns: True if shape is valid, False otherwise (as TF boolean).
|
||||
"""
|
||||
result = tf.constant(True)
|
||||
for i, target_length in enumerate(target_shape):
|
||||
if target_length:
|
||||
result = tf.logical_and(
|
||||
result,
|
||||
tf.equal(tf.constant(target_length), tf.shape(tensor_tf)[i]))
|
||||
return result
|
||||
|
||||
|
||||
def set_tensor_shape(tensor, tensor_shape):
|
||||
""" Set shape for a tensor (not in place, as opposed to tf.set_shape)
|
||||
|
||||
:param tensor: Tensor to reshape.
|
||||
:param tensor_shape: Shape to apply to the tensor.
|
||||
:returns: A reshaped tensor.
|
||||
"""
|
||||
# NOTE: That SOUND LIKE IN PLACE HERE ?
|
||||
tensor.set_shape(tensor_shape)
|
||||
return tensor
|
||||
Reference in New Issue
Block a user