Initial commit from private spleeter

This commit is contained in:
Romain
2019-10-28 14:12:13 +01:00
parent dc39414ee9
commit 556ef21214
47 changed files with 3924 additions and 3 deletions

View File

@@ -0,0 +1,8 @@
#!/usr/bin/env python
# coding: utf8
""" This package provides utility function and classes. """
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'

View File

@@ -0,0 +1,15 @@
#!/usr/bin/env python
# coding: utf8
"""
`spleeter.utils.audio` package provides various
tools for manipulating audio content such as :
- Audio adapter class for abstract interaction with audio file.
- FFMPEG implementation for audio adapter.
- Waveform convertion and transforming functions.
"""
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'

View File

@@ -0,0 +1,144 @@
#!/usr/bin/env python
# coding: utf8
""" AudioAdapter class defintion. """
import subprocess
from abc import ABC, abstractmethod
from importlib import import_module
from os.path import exists
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
from tensorflow.contrib.signal import stft, hann_window
# pylint: enable=import-error
from ..logging import get_logger
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
class AudioAdapter(ABC):
""" An abstract class for manipulating audio signal. """
# Default audio adapter singleton instance.
DEFAULT = None
@abstractmethod
def load(
self, audio_descriptor, offset, duration,
sample_rate, dtype=np.float32):
""" Loads the audio file denoted by the given audio descriptor
and returns it data as a waveform. Aims to be implemented
by client.
:param audio_descriptor: Describe song to load, in case of file
based audio adapter, such descriptor would
be a file path.
:param offset: Start offset to load from in seconds.
:param duration: Duration to load in seconds.
:param sample_rate: Sample rate to load audio with.
:param dtype: Numpy data type to use, default to float32.
:returns: Loaded data as (wf, sample_rate) tuple.
"""
pass
def load_tf_waveform(
self, audio_descriptor,
offset=0.0, duration=1800., sample_rate=44100,
dtype=b'float32', waveform_name='waveform'):
""" Load the audio and convert it to a tensorflow waveform.
:param audio_descriptor: Describe song to load, in case of file
based audio adapter, such descriptor would
be a file path.
:param offset: Start offset to load from in seconds.
:param duration: Duration to load in seconds.
:param sample_rate: Sample rate to load audio with.
:param dtype: Numpy data type to use, default to float32.
:param waveform_name: (Optional) Name of the key in output dict.
:returns: TF output dict with waveform as
(T x chan numpy array) and a boolean that
tells whether there were an error while
trying to load the waveform.
"""
# Cast parameters to TF format.
offset = tf.cast(offset, tf.float64)
duration = tf.cast(duration, tf.float64)
# Defined safe loading function.
def safe_load(path, offset, duration, sample_rate, dtype):
get_logger().info(
f'Loading audio {path} from {offset} to {offset + duration}')
try:
(data, _) = self.load(
path.numpy(),
offset.numpy(),
duration.numpy(),
sample_rate.numpy(),
dtype=dtype.numpy())
return (data, False)
except Exception as e:
get_logger().warning(e)
return (np.float32(-1.0), True)
# Execute function and format results.
results = tf.py_function(
safe_load,
[audio_descriptor, offset, duration, sample_rate, dtype],
(tf.float32, tf.bool)),
waveform, error = results[0]
return {
waveform_name: waveform,
f'{waveform_name}_error': error
}
@abstractmethod
def save(
self, path, data, sample_rate,
codec=None, bitrate=None):
""" Save the given audio data to the file denoted by
the given path.
:param path: Path of the audio file to save data in.
:param data: Waveform data to write.
:param sample_rate: Sample rate to write file in.
:param codec: (Optional) Writing codec to use.
:param bitrate: (Optional) Bitrate of the written audio file.
"""
pass
def get_default_audio_adapter():
""" Builds and returns a default audio adapter instance.
:returns: An audio adapter instance.
"""
if AudioAdapter.DEFAULT is None:
from .ffmpeg import FFMPEGProcessAudioAdapter
AudioAdapter.DEFAULT = FFMPEGProcessAudioAdapter()
return AudioAdapter.DEFAULT
def get_audio_adapter(descriptor):
""" Load dynamically an AudioAdapter from given class descriptor.
:param descriptor: Adapter class descriptor (module.Class)
:returns: Created adapter instance.
"""
if descriptor is None:
return get_default_audio_adapter()
module_path = descriptor.split('.')
adapter_class_name = module_path[-1]
module_path = '.'.join(module_path[:-1])
adapter_module = import_module(module_path)
adapter_class = getattr(adapter_module, adapter_class_name)
if not isinstance(adapter_class, AudioAdapter):
raise ValueError(
f'{adapter_class_name} is not a valid AudioAdapter class')
return adapter_class()

View File

@@ -0,0 +1,88 @@
#!/usr/bin/env python
# coding: utf8
""" This module provides audio data convertion functions. """
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
# pylint: enable=import-error
from ..tensor import from_float32_to_uint8, from_uint8_to_float32
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def to_n_channels(waveform, n_channels):
""" Convert a waveform to n_channels by removing or
duplicating channels if needed (in tensorflow).
:param waveform: Waveform to transform.
:param n_channels: Number of channel to reshape waveform in.
:returns: Reshaped waveform.
"""
return tf.cond(
tf.shape(waveform)[1] >= n_channels,
true_fn=lambda: waveform[:, :n_channels],
false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels]
)
def to_stereo(waveform):
""" Convert a waveform to stereo by duplicating if mono,
or truncating if too many channels.
:param waveform: a (N, d) numpy array.
:returns: A stereo waveform as a (N, 1) numpy array.
"""
if waveform.shape[1] == 1:
return np.repeat(waveform, 2, axis=-1)
if waveform.shape[1] > 2:
return waveform[:, :2]
return waveform
def gain_to_db(tensor, espilon=10e-10):
""" Convert from gain to decibel in tensorflow.
:param tensor: Tensor to convert.
:param epsilon: Operation constant.
:returns: Converted tensor.
"""
return 20. / np.log(10) * tf.math.log(tf.maximum(tensor, espilon))
def db_to_gain(tensor):
""" Convert from decibel to gain in tensorflow.
:param tensor_db: Tensor to convert.
:returns: Converted tensor.
"""
return tf.pow(10., (tensor / 20.))
def spectrogram_to_db_uint(spectrogram, db_range=100., **kwargs):
""" Encodes given spectrogram into uint8 using decibel scale.
:param spectrogram: Spectrogram to be encoded as TF float tensor.
:param db_range: Range in decibel for encoding.
:returns: Encoded decibel spectrogram as uint8 tensor.
"""
db_spectrogram = gain_to_db(spectrogram)
max_db_spectrogram = tf.reduce_max(db_spectrogram)
db_spectrogram = tf.maximum(db_spectrogram, max_db_spectrogram - db_range)
return from_float32_to_uint8(db_spectrogram, **kwargs)
def db_uint_spectrogram_to_gain(db_uint_spectrogram, min_db, max_db):
""" Decode spectrogram from uint8 decibel scale.
:param db_uint_spectrogram: Decibel pectrogram to decode.
:param min_db: Lower bound limit for decoding.
:param max_db: Upper bound limit for decoding.
:returns: Decoded spectrogram as float2 tensor.
"""
db_spectrogram = from_uint8_to_float32(db_uint_spectrogram, min_db, max_db)
return db_to_gain(db_spectrogram)

View File

@@ -0,0 +1,263 @@
#!/usr/bin/env python
# coding: utf8
"""
This module provides an AudioAdapter implementation based on FFMPEG
process. Such implementation is POSIXish and depends on nothing except
standard Python libraries. Thus this implementation is the default one
used within this library.
"""
import os
import os.path
import platform
import re
import subprocess
import numpy as np # pylint: disable=import-error
from .adapter import AudioAdapter
from ..logging import get_logger
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
# Default FFMPEG binary name.
_UNIX_BINARY = 'ffmpeg'
_WINDOWS_BINARY = 'ffmpeg.exe'
def _which(program):
""" A pure python implementation of `which`command
for retrieving absolute path from command name or path.
@see https://stackoverflow.com/a/377028/1211342
:param program: Program name or path to expend.
:returns: Absolute path of program if any, None otherwise.
"""
def is_exe(fpath):
return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
fpath, _ = os.path.split(program)
if fpath:
if is_exe(program):
return program
else:
for path in os.environ['PATH'].split(os.pathsep):
exe_file = os.path.join(path, program)
if is_exe(exe_file):
return exe_file
return None
def _get_ffmpeg_path():
""" Retrieves FFMPEG binary path using ENVVAR if defined
or default binary name (Windows or UNIX style).
:returns: Absolute path of FFMPEG binary.
:raise IOError: If FFMPEG binary cannot be found.
"""
ffmpeg_path = os.environ.get('FFMPEG_PATH', None)
if ffmpeg_path is None:
# Note: try to infer standard binary name regarding of platform.
if platform.system() == 'Windows':
ffmpeg_path = _WINDOWS_BINARY
else:
ffmpeg_path = _UNIX_BINARY
expended = _which(ffmpeg_path)
if expended is None:
raise IOError(f'FFMPEG binary ({ffmpeg_path}) not found')
return expended
def _to_ffmpeg_time(n):
""" Format number of seconds to time expected by FFMPEG.
:param n: Time in seconds to format.
:returns: Formatted time in FFMPEG format.
"""
m, s = divmod(n, 60)
h, m = divmod(m, 60)
return '%d:%02d:%09.6f' % (h, m, s)
def _parse_ffmpg_results(stderr):
""" Extract number of channels and sample rate from
the given FFMPEG STDERR output line.
:param stderr: STDERR output line to parse.
:returns: Parsed n_channels and sample_rate values.
"""
# Setup default value.
n_channels = 0
sample_rate = 0
# Find samplerate
match = re.search(r'(\d+) hz', stderr)
if match:
sample_rate = int(match.group(1))
# Channel count.
match = re.search(r'hz, ([^,]+),', stderr)
if match:
mode = match.group(1)
if mode == 'stereo':
n_channels = 2
else:
match = re.match(r'(\d+) ', mode)
n_channels = match and int(match.group(1)) or 1
return n_channels, sample_rate
class _CommandBuilder(object):
""" A simple builder pattern class for CLI string. """
def __init__(self, binary):
""" Default constructor. """
self._command = [binary]
def flag(self, flag):
""" Add flag or unlabelled opt. """
self._command.append(flag)
return self
def opt(self, short, value, formatter=str):
""" Add option if value not None. """
if value is not None:
self._command.append(short)
self._command.append(formatter(value))
return self
def command(self):
""" Build string command. """
return self._command
class FFMPEGProcessAudioAdapter(AudioAdapter):
""" An AudioAdapter implementation that use FFMPEG binary through
subprocess in order to perform I/O operation for audio processing.
When created, FFMPEG binary path will be checked and expended,
raising exception if not found. Such path could be infered using
FFMPEG_PATH environment variable.
"""
def __init__(self):
""" Default constructor. """
self._ffmpeg_path = _get_ffmpeg_path()
def _get_command_builder(self):
""" Creates and returns a command builder using FFMPEG path.
:returns: Built command builder.
"""
return _CommandBuilder(self._ffmpeg_path)
def load(
self, path, offset=None, duration=None,
sample_rate=None, dtype=np.float32):
""" Loads the audio file denoted by the given path
and returns it data as a waveform.
:param path: Path of the audio file to load data from.
:param offset: (Optional) Start offset to load from in seconds.
:param duration: (Optional) Duration to load in seconds.
:param sample_rate: (Optional) Sample rate to load audio with.
:param dtype: (Optional) Numpy data type to use, default to float32.
:returns: Loaded data a (waveform, sample_rate) tuple.
"""
if not isinstance(path, str):
path = path.decode()
command = (
self._get_command_builder()
.opt('-ss', offset, formatter=_to_ffmpeg_time)
.opt('-t', duration, formatter=_to_ffmpeg_time)
.opt('-i', path)
.opt('-ar', sample_rate)
.opt('-f', 'f32le')
.flag('-')
.command())
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
buffer = process.stdout.read(-1)
# Read STDERR until end of the process detected.
while True:
status = process.stderr.readline()
if not status:
raise OSError('Stream info not found')
if isinstance(status, bytes): # Note: Python 3 compatibility.
status = status.decode('utf8', 'ignore')
status = status.strip().lower()
if 'no such file' in status:
raise IOError(f'File {path} not found')
elif 'invalid data found' in status:
raise IOError(f'FFMPEG error : {status}')
elif 'audio:' in status:
n_channels, ffmpeg_sample_rate = _parse_ffmpg_results(status)
if sample_rate is None:
sample_rate = ffmpeg_sample_rate
break
# Load waveform and clean process.
waveform = np.frombuffer(buffer, dtype='<f4').reshape(-1, n_channels)
if not waveform.dtype == np.dtype(dtype):
waveform = waveform.astype(dtype)
process.stdout.close()
process.stderr.close()
del process
return (waveform, sample_rate)
def save(
self, path, data, sample_rate,
codec=None, bitrate=None):
""" Write waveform data to the file denoted by the given path
using FFMPEG process.
:param path: Path of the audio file to save data in.
:param data: Waveform data to write.
:param sample_rate: Sample rate to write file in.
:param codec: (Optional) Writing codec to use.
:param bitrate: (Optional) Bitrate of the written audio file.
:raise IOError: If any error occurs while using FFMPEG to write data.
"""
directory = os.path.split(path)[0]
if not os.path.exists(directory):
os.makedirs(directory)
get_logger().debug('Writing file %s', path)
# NOTE: Tweak.
if codec == 'wav':
codec = None
command = (
self._get_command_builder()
.flag('-y')
.opt('-loglevel', 'error')
.opt('-f', 'f32le')
.opt('-ar', sample_rate)
.opt('-ac', data.shape[1])
.opt('-i', '-')
.flag('-vn')
.opt('-acodec', codec)
.opt('-ar', sample_rate) # Note: why twice ?
.opt('-strict', '-2') # Note: For 'aac' codec support.
.opt('-ab', bitrate)
.flag(path)
.command())
process = subprocess.Popen(
command,
stdout=open(os.devnull, 'wb'),
stdin=subprocess.PIPE,
stderr=subprocess.PIPE)
# Write data to STDIN.
try:
process.stdin.write(
data.astype('<f4').tostring())
except IOError:
raise IOError(f'FFMPEG error: {process.stderr.read()}')
# Clean process.
process.stdin.close()
if process.stderr is not None:
process.stderr.close()
process.wait()
del process
get_logger().info('File %s written', path)

View File

@@ -0,0 +1,128 @@
#!/usr/bin/env python
# coding: utf8
""" Spectrogram specific data augmentation """
# pylint: disable=import-error
import numpy as np
import tensorflow as tf
from tensorflow.contrib.signal import stft, hann_window
# pylint: enable=import-error
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def compute_spectrogram_tf(
waveform,
frame_length=2048, frame_step=512,
spec_exponent=1., window_exponent=1.):
""" Compute magnitude / power spectrogram from waveform as
a n_samples x n_channels tensor.
:param waveform: Input waveform as (times x number of channels)
tensor.
:param frame_length: Length of a STFT frame to use.
:param frame_step: HOP between successive frames.
:param spec_exponent: Exponent of the spectrogram (usually 1 for
magnitude spectrogram, or 2 for power spectrogram).
:param window_exponent: Exponent applied to the Hann windowing function
(may be useful for making perfect STFT/iSTFT
reconstruction).
:returns: Computed magnitude / power spectrogram as a
(T x F x n_channels) tensor.
"""
stft_tensor = tf.transpose(
stft(
tf.transpose(waveform),
frame_length,
frame_step,
window_fn=lambda f, dtype: hann_window(
f,
periodic=True,
dtype=waveform.dtype) ** window_exponent),
perm=[1, 2, 0])
return np.abs(stft_tensor) ** spec_exponent
def time_stretch(
spectrogram,
factor=1.0,
method=tf.image.ResizeMethod.BILINEAR):
""" Time stretch a spectrogram preserving shape in tensorflow. Note that
this is an approximation in the frequency domain.
:param spectrogram: Input spectrogram to be time stretched as tensor.
:param factor: (Optional) Time stretch factor, must be >0, default to 1.
:param mehtod: (Optional) Interpolation method, default to BILINEAR.
:returns: Time stretched spectrogram as tensor with same shape.
"""
T = tf.shape(spectrogram)[0]
T_ts = tf.cast(tf.cast(T, tf.float32) * factor, tf.int32)[0]
F = tf.shape(spectrogram)[1]
ts_spec = tf.image.resize_images(
spectrogram,
[T_ts, F],
method=method,
align_corners=True)
return tf.image.resize_image_with_crop_or_pad(ts_spec, T, F)
def random_time_stretch(spectrogram, factor_min=0.9, factor_max=1.1, **kwargs):
""" Time stretch a spectrogram preserving shape with random ratio in
tensorflow. Applies time_stretch to spectrogram with a random ratio drawn
uniformly in [factor_min, factor_max].
:param spectrogram: Input spectrogram to be time stretched as tensor.
:param factor_min: (Optional) Min time stretch factor, default to 0.9.
:param factor_max: (Optional) Max time stretch factor, default to 1.1.
:returns: Randomly time stretched spectrogram as tensor with same shape.
"""
factor = tf.random_uniform(
shape=(1,),
seed=0) * (factor_max - factor_min) + factor_min
return time_stretch(spectrogram, factor=factor, **kwargs)
def pitch_shift(
spectrogram,
semitone_shift=0.0,
method=tf.image.ResizeMethod.BILINEAR):
""" Pitch shift a spectrogram preserving shape in tensorflow. Note that
this is an approximation in the frequency domain.
:param spectrogram: Input spectrogram to be pitch shifted as tensor.
:param semitone_shift: (Optional) Pitch shift in semitone, default to 0.0.
:param mehtod: (Optional) Interpolation method, default to BILINEAR.
:returns: Pitch shifted spectrogram (same shape as spectrogram).
"""
factor = 2 ** (semitone_shift / 12.)
T = tf.shape(spectrogram)[0]
F = tf.shape(spectrogram)[1]
F_ps = tf.cast(tf.cast(F, tf.float32) * factor, tf.int32)[0]
ps_spec = tf.image.resize_images(
spectrogram,
[T, F_ps],
method=method,
align_corners=True)
paddings = [[0, 0], [0, tf.maximum(0, F - F_ps)], [0, 0]]
return tf.pad(ps_spec[:, :F, :], paddings, 'CONSTANT')
def random_pitch_shift(spectrogram, shift_min=-1., shift_max=1., **kwargs):
""" Pitch shift a spectrogram preserving shape with random ratio in
tensorflow. Applies pitch_shift to spectrogram with a random shift
amount (expressed in semitones) drawn uniformly in [shift_min, shift_max].
:param spectrogram: Input spectrogram to be pitch shifted as tensor.
:param shift_min: (Optional) Min pitch shift in semitone, default to -1.
:param shift_max: (Optional) Max pitch shift in semitone, default to 1.
:returns: Randomly pitch shifted spectrogram (same shape as spectrogram).
"""
semitone_shift = tf.random_uniform(
shape=(1,),
seed=0) * (shift_max - shift_min) + shift_min
return pitch_shift(spectrogram, semitone_shift=semitone_shift, **kwargs)

View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python
# coding: utf8
""" Module that provides configuration loading function. """
import json
try:
import importlib.resources as loader
except ImportError:
# Try backported to PY<37 `importlib_resources`.
import importlib_resources as loader
from os.path import exists
from .. import resources
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
_EMBEDDED_CONFIGURATION_PREFIX = 'spleeter:'
def load_configuration(descriptor):
""" Load configuration from the given descriptor. Could be
either a `spleeter:` prefixed embedded configuration name
or a file system path to read configuration from.
:param descriptor: Configuration descriptor to use for lookup.
:returns: Loaded description as dict.
:raise ValueError: If required embedded configuration does not exists.
:raise IOError: If required configuration file does not exists.
"""
# Embedded configuration reading.
if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX):
name = descriptor[len(_EMBEDDED_CONFIGURATION_PREFIX):]
if not loader.is_resource(resources, f'{name}.json'):
raise ValueError(f'No embedded configuration {name} found')
with loader.open_text(resources, f'{name}.json') as stream:
return json.load(stream)
# Standard file reading.
if not exists(descriptor):
raise IOError(f'Configuration file {descriptor} not found')
with open(descriptor, 'r') as stream:
return json.load(stream)

View File

@@ -0,0 +1,69 @@
#!/usr/bin/env python
# coding: utf8
""" Utility functions for creating estimator. """
from pathlib import Path
# pylint: disable=import-error
import tensorflow as tf
from tensorflow.contrib import predictor
# pylint: enable=import-error
from ..model import model_fn
from ..model.provider import get_default_model_provider
# Default exporting directory for predictor.
DEFAULT_EXPORT_DIRECTORY = '/tmp/serving'
def create_estimator(params, MWF):
"""
Initialize tensorflow estimator that will perform separation
Params:
- params: a dictionnary of parameters for building the model
Returns:
a tensorflow estimator
"""
# Load model.
model_directory = params['model_dir']
model_provider = get_default_model_provider()
params['model_dir'] = model_provider.get(model_directory)
params['MWF'] = MWF
# Setup config
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
config = tf.estimator.RunConfig(session_config=session_config)
# Setup estimator
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params['model_dir'],
params=params,
config=config
)
return estimator
def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
""" Exports given estimator as predictor into the given directory
and returns associated tf.predictor instance.
:param estimator: Estimator to export.
:param directory: (Optional) path to write exported model into.
"""
def receiver():
shape = (None, estimator.params['n_channels'])
features = {
'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape),
'audio_id': tf.compat.v1.placeholder(tf.string)}
return tf.estimator.export.ServingInputReceiver(features, features)
estimator.export_saved_model(directory, receiver)
versions = [
model for model in Path(directory).iterdir()
if model.is_dir() and 'temp' not in str(model)]
latest = str(sorted(versions)[-1])
return predictor.from_saved_model(latest)

45
spleeter/utils/logging.py Normal file
View File

@@ -0,0 +1,45 @@
#!/usr/bin/env python
# coding: utf8
""" Centralized logging facilities for Spleeter. """
from os import environ
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
class _LoggerHolder(object):
""" Logger singleton instance holder. """
INSTANCE = None
def get_logger():
""" Returns library scoped logger.
:returns: Library logger.
"""
if _LoggerHolder.INSTANCE is None:
# pylint: disable=import-error
from tensorflow.compat.v1 import logging
# pylint: enable=import-error
_LoggerHolder.INSTANCE = logging
_LoggerHolder.INSTANCE.set_verbosity(_LoggerHolder.INSTANCE.ERROR)
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
return _LoggerHolder.INSTANCE
def enable_logging():
""" Enable INFO level logging. """
environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
logger = get_logger()
logger.set_verbosity(logger.INFO)
def enable_verbose_logging():
""" Enable DEBUG level logging. """
environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
logger = get_logger()
logger.set_verbosity(logger.DEBUG)

191
spleeter/utils/tensor.py Normal file
View File

@@ -0,0 +1,191 @@
#!/usr/bin/env python
# coding: utf8
""" Utility function for tensorflow. """
# pylint: disable=import-error
import tensorflow as tf
import pandas as pd
# pylint: enable=import-error
__email__ = 'research@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def sync_apply(tensor_dict, func, concat_axis=1):
""" Return a function that applies synchronously the provided func on the
provided dictionnary of tensor. This means that func is applied to the
concatenation of the tensors in tensor_dict. This is useful for performing
random operation that needs the same drawn value on multiple tensor, such
as a random time-crop on both input data and label (the same crop should be
applied to both input data and label, so random crop cannot be applied
separately on each of them).
IMPORTANT NOTE: all tensor are assumed to be the same shape.
Params:
- tensor_dict: dictionary (key: strings, values: tf.tensor)
a dictionary of tensor.
- func: function
function to be applied to the concatenation of the tensors in
tensor_dict
- concat_axis: int
The axis on which to perform the concatenation.
Returns:
processed tensors dictionary with the same name (keys) as input
tensor_dict.
"""
if concat_axis not in {0, 1}:
raise NotImplementedError(
'Function only implemented for concat_axis equal to 0 or 1')
tensor_list = list(tensor_dict.values())
concat_tensor = tf.concat(tensor_list, concat_axis)
processed_concat_tensor = func(concat_tensor)
tensor_shape = tf.shape(list(tensor_dict.values())[0])
D = tensor_shape[concat_axis]
if concat_axis == 0:
return {
name: processed_concat_tensor[index * D:(index + 1) * D, :, :]
for index, name in enumerate(tensor_dict)
}
return {
name: processed_concat_tensor[:, index * D:(index + 1) * D, :]
for index, name in enumerate(tensor_dict)
}
def from_float32_to_uint8(
tensor,
tensor_key='tensor',
min_key='min',
max_key='max'):
"""
:param tensor:
:param tensor_key:
:param min_key:
:param max_key:
:returns:
"""
tensor_min = tf.reduce_min(tensor)
tensor_max = tf.reduce_max(tensor)
return {
tensor_key: tf.cast(
(tensor - tensor_min) / (tensor_max - tensor_min + 1e-16)
* 255.9999, dtype=tf.uint8),
min_key: tensor_min,
max_key: tensor_max
}
def from_uint8_to_float32(tensor, tensor_min, tensor_max):
"""
:param tensor:
:param tensor_min:
:param tensor_max:
:returns:
"""
return (
tf.cast(tensor, tf.float32)
* (tensor_max - tensor_min)
/ 255.9999 + tensor_min)
def pad_and_partition(tensor, segment_len):
""" Pad and partition a tensor into segment of len segment_len
along the first dimension. The tensor is padded with 0 in order
to ensure that the first dimension is a multiple of segment_len.
Tensor must be of known fixed rank
:Example:
>>> tensor = [[1, 2, 3], [4, 5, 6]]
>>> segment_len = 2
>>> pad_and_partition(tensor, segment_len)
[[[1, 2], [4, 5]], [[3, 0], [6, 0]]]
:param tensor:
:param segment_len:
:returns:
"""
tensor_size = tf.math.floormod(tf.shape(tensor)[0], segment_len)
pad_size = tf.math.floormod(segment_len - tensor_size, segment_len)
padded = tf.pad(
tensor,
[[0, pad_size]] + [[0, 0]] * (len(tensor.shape)-1))
split = (tf.shape(padded)[0] + segment_len - 1) // segment_len
return tf.reshape(
padded,
tf.concat(
[[split, segment_len], tf.shape(padded)[1:]],
axis=0))
def pad_and_reshape(instr_spec, frame_length, F):
"""
:param instr_spec:
:param frame_length:
:param F:
:returns:
"""
spec_shape = tf.shape(instr_spec)
extension_row = tf.zeros((spec_shape[0], spec_shape[1], 1, spec_shape[-1]))
n_extra_row = (frame_length) // 2 + 1 - F
extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])
extended_spec = tf.concat([instr_spec, extension], axis=2)
old_shape = tf.shape(extended_spec)
new_shape = tf.concat([
[old_shape[0] * old_shape[1]],
old_shape[2:]],
axis=0)
processed_instr_spec = tf.reshape(extended_spec, new_shape)
return processed_instr_spec
def dataset_from_csv(csv_path, **kwargs):
""" Load dataset from a CSV file using Pandas. kwargs if any are
forwarded to the `pandas.read_csv` function.
:param csv_path: Path of the CSV file to load dataset from.
:returns: Loaded dataset.
"""
df = pd.read_csv(csv_path, **kwargs)
dataset = (
tf.data.Dataset.from_tensor_slices(
{key: df[key].values for key in df})
)
return dataset
def check_tensor_shape(tensor_tf, target_shape):
""" Return a Tensorflow boolean graph that indicates whether
sample[features_key] has the specified target shape. Only check
not None entries of target_shape.
:param tensor_tf: Tensor to check shape for.
:param target_shape: Target shape to compare tensor to.
:returns: True if shape is valid, False otherwise (as TF boolean).
"""
result = tf.constant(True)
for i, target_length in enumerate(target_shape):
if target_length:
result = tf.logical_and(
result,
tf.equal(tf.constant(target_length), tf.shape(tensor_tf)[i]))
return result
def set_tensor_shape(tensor, tensor_shape):
""" Set shape for a tensor (not in place, as opposed to tf.set_shape)
:param tensor: Tensor to reshape.
:param tensor_shape: Shape to apply to the tensor.
:returns: A reshaped tensor.
"""
# NOTE: That SOUND LIKE IN PLACE HERE ?
tensor.set_shape(tensor_shape)
return tensor