Files
spleeter/spleeter/model/functions/unet.py
2021-01-08 17:32:39 +01:00

235 lines
7.0 KiB
Python

#!/usr/bin/env python
# coding: utf8
"""
This module contains building functions for U-net source
separation models in a similar way as in A. Jansson et al. :
"Singing voice separation with deep u-net convolutional networks",
ISMIR 2017
Each instrument is modeled by a single U-net
convolutional / deconvolutional network that take a mix spectrogram
as input and the estimated sound spectrogram as output.
"""
from functools import partial
from typing import Any, Dict, Iterable, Optional
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
from tensorflow.compat.v1 import logging
from tensorflow.compat.v1.keras.initializers import he_uniform
from tensorflow.keras.layers import (
ELU,
BatchNormalization,
Concatenate,
Conv2D,
Conv2DTranspose,
Dropout,
LeakyReLU,
Multiply,
ReLU,
Softmax,
)
from . import apply
# pylint: enable=import-error
__email__ = "spleeter@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
def _get_conv_activation_layer(params: Dict) -> Any:
"""
> To be documented.
Parameters:
params (Dict):
Returns:
Any:
Required Activation function.
"""
conv_activation: str = params.get("conv_activation")
if conv_activation == "ReLU":
return ReLU()
elif conv_activation == "ELU":
return ELU()
return LeakyReLU(0.2)
def _get_deconv_activation_layer(params: Dict) -> Any:
"""
> To be documented.
Parameters:
params (Dict):
Returns:
Any:
Required Activation function.
"""
deconv_activation: str = params.get("deconv_activation")
if deconv_activation == "LeakyReLU":
return LeakyReLU(0.2)
elif deconv_activation == "ELU":
return ELU()
return ReLU()
def apply_unet(
input_tensor: tf.Tensor,
output_name: str = "output",
params: Optional[Dict] = None,
output_mask_logit: bool = False,
) -> Any:
"""
Apply a convolutionnal U-net to model a single instrument (one U-net
is used for each instrument).
Parameters:
input_tensor (tensorflow.Tensor):
output_name (str):
params (Optional[Dict]):
output_mask_logit (bool):
"""
logging.info(f"Apply unet for {output_name}")
conv_n_filters = params.get("conv_n_filters", [16, 32, 64, 128, 256, 512])
conv_activation_layer = _get_conv_activation_layer(params)
deconv_activation_layer = _get_deconv_activation_layer(params)
kernel_initializer = he_uniform(seed=50)
conv2d_factory = partial(
Conv2D, strides=(2, 2), padding="same", kernel_initializer=kernel_initializer
)
# First layer.
conv1 = conv2d_factory(conv_n_filters[0], (5, 5))(input_tensor)
batch1 = BatchNormalization(axis=-1)(conv1)
rel1 = conv_activation_layer(batch1)
# Second layer.
conv2 = conv2d_factory(conv_n_filters[1], (5, 5))(rel1)
batch2 = BatchNormalization(axis=-1)(conv2)
rel2 = conv_activation_layer(batch2)
# Third layer.
conv3 = conv2d_factory(conv_n_filters[2], (5, 5))(rel2)
batch3 = BatchNormalization(axis=-1)(conv3)
rel3 = conv_activation_layer(batch3)
# Fourth layer.
conv4 = conv2d_factory(conv_n_filters[3], (5, 5))(rel3)
batch4 = BatchNormalization(axis=-1)(conv4)
rel4 = conv_activation_layer(batch4)
# Fifth layer.
conv5 = conv2d_factory(conv_n_filters[4], (5, 5))(rel4)
batch5 = BatchNormalization(axis=-1)(conv5)
rel5 = conv_activation_layer(batch5)
# Sixth layer
conv6 = conv2d_factory(conv_n_filters[5], (5, 5))(rel5)
batch6 = BatchNormalization(axis=-1)(conv6)
_ = conv_activation_layer(batch6)
#
#
conv2d_transpose_factory = partial(
Conv2DTranspose,
strides=(2, 2),
padding="same",
kernel_initializer=kernel_initializer,
)
#
up1 = conv2d_transpose_factory(conv_n_filters[4], (5, 5))((conv6))
up1 = deconv_activation_layer(up1)
batch7 = BatchNormalization(axis=-1)(up1)
drop1 = Dropout(0.5)(batch7)
merge1 = Concatenate(axis=-1)([conv5, drop1])
#
up2 = conv2d_transpose_factory(conv_n_filters[3], (5, 5))((merge1))
up2 = deconv_activation_layer(up2)
batch8 = BatchNormalization(axis=-1)(up2)
drop2 = Dropout(0.5)(batch8)
merge2 = Concatenate(axis=-1)([conv4, drop2])
#
up3 = conv2d_transpose_factory(conv_n_filters[2], (5, 5))((merge2))
up3 = deconv_activation_layer(up3)
batch9 = BatchNormalization(axis=-1)(up3)
drop3 = Dropout(0.5)(batch9)
merge3 = Concatenate(axis=-1)([conv3, drop3])
#
up4 = conv2d_transpose_factory(conv_n_filters[1], (5, 5))((merge3))
up4 = deconv_activation_layer(up4)
batch10 = BatchNormalization(axis=-1)(up4)
merge4 = Concatenate(axis=-1)([conv2, batch10])
#
up5 = conv2d_transpose_factory(conv_n_filters[0], (5, 5))((merge4))
up5 = deconv_activation_layer(up5)
batch11 = BatchNormalization(axis=-1)(up5)
merge5 = Concatenate(axis=-1)([conv1, batch11])
#
up6 = conv2d_transpose_factory(1, (5, 5), strides=(2, 2))((merge5))
up6 = deconv_activation_layer(up6)
batch12 = BatchNormalization(axis=-1)(up6)
# Last layer to ensure initial shape reconstruction.
if not output_mask_logit:
up7 = Conv2D(
2,
(4, 4),
dilation_rate=(2, 2),
activation="sigmoid",
padding="same",
kernel_initializer=kernel_initializer,
)((batch12))
output = Multiply(name=output_name)([up7, input_tensor])
return output
return Conv2D(
2,
(4, 4),
dilation_rate=(2, 2),
padding="same",
kernel_initializer=kernel_initializer,
)((batch12))
def unet(
input_tensor: tf.Tensor, instruments: Iterable[str], params: Optional[Dict] = None
) -> Dict:
""" Model function applier. """
return apply(apply_unet, input_tensor, instruments, params)
def softmax_unet(
input_tensor: tf.Tensor, instruments: Iterable[str], params: Optional[Dict] = None
) -> Dict:
"""
Apply softmax to multitrack unet in order to have mask suming to one.
Parameters:
input_tensor (tensorflow.Tensor):
Tensor to apply blstm to.
instruments (Iterable[str]):
Iterable that provides a collection of instruments.
params (Optional[Dict]):
(Optional) dict of BLSTM parameters.
Returns:
Dict:
Created output tensor dict.
"""
logit_mask_list = []
for instrument in instruments:
out_name = f"{instrument}_spectrogram"
logit_mask_list.append(
apply_unet(
input_tensor,
output_name=out_name,
params=params,
output_mask_logit=True,
)
)
masks = Softmax(axis=4)(tf.stack(logit_mask_list, axis=4))
output_dict = {}
for i, instrument in enumerate(instruments):
out_name = f"{instrument}_spectrogram"
output_dict[out_name] = Multiply(name=out_name)([masks[..., i], input_tensor])
return output_dict