#!/usr/bin/env python # coding: utf8 """ This module provides audio data convertion functions. """ from ..utils.tensor import from_float32_to_uint8, from_uint8_to_float32 # pyright: reportMissingImports=false # pylint: disable=import-error import numpy as np import tensorflow as tf # pylint: enable=import-error __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' def to_n_channels( waveform: tf.Tensor, n_channels: int) -> tf.Tensor: """ Convert a waveform to n_channels by removing or duplicating channels if needed (in tensorflow). Parameters: waveform (tensorflow.Tensor): Waveform to transform. n_channels (int): Number of channel to reshape waveform in. Returns: tensorflow.Tensor: Reshaped waveform. """ return tf.cond( tf.shape(waveform)[1] >= n_channels, true_fn=lambda: waveform[:, :n_channels], false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels]) def to_stereo(waveform: np.ndarray) -> np.ndarray: """ Convert a waveform to stereo by duplicating if mono, or truncating if too many channels. Parameters: waveform (numpy.ndarray): a `(N, d)` numpy array. Returns: numpy.ndarray: A stereo waveform as a `(N, 1)` numpy array. """ if waveform.shape[1] == 1: return np.repeat(waveform, 2, axis=-1) if waveform.shape[1] > 2: return waveform[:, :2] return waveform def gain_to_db(tensor: tf.Tensor, espilon: float = 10e-10) -> tf.Tensor: """ Convert from gain to decibel in tensorflow. Parameters: tensor (tensorflow.Tensor): Tensor to convert epsilon (float): Operation constant. Returns: tensorflow.Tensor: Converted tensor. """ return 20. / np.log(10) * tf.math.log(tf.maximum(tensor, espilon)) def db_to_gain(tensor: tf.Tensor) -> tf.Tensor: """ Convert from decibel to gain in tensorflow. Parameters: tensor (tensorflow.Tensor): Tensor to convert Returns: tensorflow.Tensor: Converted tensor. """ return tf.pow(10., (tensor / 20.)) def spectrogram_to_db_uint( spectrogram: tf.Tensor, db_range: float = 100., **kwargs) -> tf.Tensor: """ Encodes given spectrogram into uint8 using decibel scale. Parameters: spectrogram (tensorflow.Tensor): Spectrogram to be encoded as TF float tensor. db_range (float): Range in decibel for encoding. Returns: tensorflow.Tensor: Encoded decibel spectrogram as `uint8` tensor. """ db_spectrogram: tf.Tensor = gain_to_db(spectrogram) max_db_spectrogram: tf.Tensor = tf.reduce_max(db_spectrogram) db_spectrogram: tf.Tensor = tf.maximum( db_spectrogram, max_db_spectrogram - db_range) return from_float32_to_uint8(db_spectrogram, **kwargs) def db_uint_spectrogram_to_gain( db_uint_spectrogram: tf.Tensor, min_db: tf.Tensor, max_db: tf.Tensor) -> tf.Tensor: """ Decode spectrogram from uint8 decibel scale. Paramters: db_uint_spectrogram (tensorflow.Tensor): Decibel spectrogram to decode. min_db (tensorflow.Tensor): Lower bound limit for decoding. max_db (tensorflow.Tensor): Upper bound limit for decoding. Returns: tensorflow.Tensor: Decoded spectrogram as `float32` tensor. """ db_spectrogram: tf.Tensor = from_uint8_to_float32( db_uint_spectrogram, min_db, max_db) return db_to_gain(db_spectrogram)