mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-06 04:32:43 +00:00
🎨 finalizes model provider and functions
This commit is contained in:
@@ -16,28 +16,22 @@
|
||||
|
||||
import time
|
||||
import os
|
||||
from os.path import exists, join, sep as SEPARATOR
|
||||
|
||||
from os.path import exists, sep as SEPARATOR
|
||||
|
||||
from .audio.convertor import db_uint_spectrogram_to_gain
|
||||
from .audio.convertor import spectrogram_to_db_uint
|
||||
from .audio.spectrogram import compute_spectrogram_tf
|
||||
from .audio.spectrogram import random_pitch_shift, random_time_stretch
|
||||
from .utils.logging import get_logger
|
||||
from .utils.tensor import check_tensor_shape, dataset_from_csv
|
||||
from .utils.tensor import set_tensor_shape, sync_apply
|
||||
|
||||
# pyright: reportMissingImports=false
|
||||
# pylint: disable=import-error
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
# pylint: enable=import-error
|
||||
|
||||
from .audio.convertor import (
|
||||
db_uint_spectrogram_to_gain,
|
||||
spectrogram_to_db_uint)
|
||||
from .audio.spectrogram import (
|
||||
compute_spectrogram_tf,
|
||||
random_pitch_shift,
|
||||
random_time_stretch)
|
||||
from .utils.logging import get_logger
|
||||
from .utils.tensor import (
|
||||
check_tensor_shape,
|
||||
dataset_from_csv,
|
||||
set_tensor_shape,
|
||||
sync_apply)
|
||||
|
||||
__email__ = 'spleeter@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
@@ -3,25 +3,44 @@
|
||||
|
||||
""" This package provide model functions. """
|
||||
|
||||
from typing import Callable, Dict, Iterable, Optional
|
||||
|
||||
# pyright: reportMissingImports=false
|
||||
# pylint: disable=import-error
|
||||
import tensorflow as tf
|
||||
# pylint: enable=import-error
|
||||
|
||||
__email__ = 'spleeter@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
|
||||
def apply(function, input_tensor, instruments, params={}):
|
||||
""" Apply given function to the input tensor.
|
||||
|
||||
:param function: Function to be applied to tensor.
|
||||
:param input_tensor: Tensor to apply blstm to.
|
||||
:param instruments: Iterable that provides a collection of instruments.
|
||||
:param params: (Optional) dict of BLSTM parameters.
|
||||
:returns: Created output tensor dict.
|
||||
def apply(
|
||||
function: Callable,
|
||||
input_tensor: tf.Tensor,
|
||||
instruments: Iterable[str],
|
||||
params: Optional[Dict] = None) -> Dict:
|
||||
"""
|
||||
output_dict = {}
|
||||
Apply given function to the input tensor.
|
||||
|
||||
Parameters:
|
||||
function:
|
||||
Function to be applied to tensor.
|
||||
input_tensor (tensorflow.Tensor):
|
||||
Tensor to apply blstm to.
|
||||
instruments (Iterable[str]):
|
||||
Iterable that provides a collection of instruments.
|
||||
params:
|
||||
(Optional) dict of BLSTM parameters.
|
||||
|
||||
Returns:
|
||||
Created output tensor dict.
|
||||
"""
|
||||
output_dict: Dict = {}
|
||||
for instrument in instruments:
|
||||
out_name = f'{instrument}_spectrogram'
|
||||
output_dict[out_name] = function(
|
||||
input_tensor,
|
||||
output_name=out_name,
|
||||
params=params)
|
||||
params=params or {})
|
||||
return output_dict
|
||||
|
||||
@@ -20,7 +20,14 @@
|
||||
selection (LSTM layer dropout rate, regularization strength).
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
from . import apply
|
||||
|
||||
# pyright: reportMissingImports=false
|
||||
# pylint: disable=import-error
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.compat.v1.keras.initializers import he_uniform
|
||||
from tensorflow.compat.v1.keras.layers import CuDNNLSTM
|
||||
from tensorflow.keras.layers import (
|
||||
@@ -31,22 +38,33 @@ from tensorflow.keras.layers import (
|
||||
TimeDistributed)
|
||||
# pylint: enable=import-error
|
||||
|
||||
from . import apply
|
||||
|
||||
__email__ = 'spleeter@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
|
||||
def apply_blstm(input_tensor, output_name='output', params={}):
|
||||
""" Apply BLSTM to the given input_tensor.
|
||||
|
||||
:param input_tensor: Input of the model.
|
||||
:param output_name: (Optional) name of the output, default to 'output'.
|
||||
:param params: (Optional) dict of BLSTM parameters.
|
||||
:returns: Output tensor.
|
||||
def apply_blstm(
|
||||
input_tensor: tf.Tensor,
|
||||
output_name: str = 'output',
|
||||
params: Optional[Dict] = None) -> tf.Tensor:
|
||||
"""
|
||||
units = params.get('lstm_units', 250)
|
||||
Apply BLSTM to the given input_tensor.
|
||||
|
||||
Parameters:
|
||||
input_tensor (tensorflow.Tensor):
|
||||
Input of the model.
|
||||
output_name (str):
|
||||
(Optional) name of the output, default to 'output'.
|
||||
params (Optional[Dict]):
|
||||
(Optional) dict of BLSTM parameters.
|
||||
|
||||
Returns:
|
||||
tensorflow.Tensor:
|
||||
Output tensor.
|
||||
"""
|
||||
if params is None:
|
||||
params = {}
|
||||
units: int = params.get('lstm_units', 250)
|
||||
kernel_initializer = he_uniform(seed=50)
|
||||
flatten_input = TimeDistributed(Flatten())((input_tensor))
|
||||
|
||||
@@ -65,12 +83,15 @@ def apply_blstm(input_tensor, output_name='output', params={}):
|
||||
int(flatten_input.shape[2]),
|
||||
activation='relu',
|
||||
kernel_initializer=kernel_initializer))((l3))
|
||||
output = TimeDistributed(
|
||||
output: tf.Tensor = TimeDistributed(
|
||||
Reshape(input_tensor.shape[2:]),
|
||||
name=output_name)(dense)
|
||||
return output
|
||||
|
||||
|
||||
def blstm(input_tensor, output_name='output', params={}):
|
||||
def blstm(
|
||||
input_tensor: tf.Tensor,
|
||||
output_name: str = 'output',
|
||||
params: Optional[Dict] = None) -> tf.Tensor:
|
||||
""" Model function applier. """
|
||||
return apply(apply_blstm, input_tensor, output_name, params)
|
||||
|
||||
@@ -2,16 +2,23 @@
|
||||
# coding: utf8
|
||||
|
||||
"""
|
||||
This module contains building functions for U-net source
|
||||
separation models in a similar way as in A. Jansson et al. "Singing
|
||||
voice separation with deep u-net convolutional networks", ISMIR 2017.
|
||||
Each instrument is modeled by a single U-net convolutional
|
||||
/ deconvolutional network that take a mix spectrogram as input and the
|
||||
estimated sound spectrogram as output.
|
||||
This module contains building functions for U-net source
|
||||
separation models in a similar way as in A. Jansson et al. :
|
||||
|
||||
"Singing voice separation with deep u-net convolutional networks",
|
||||
ISMIR 2017
|
||||
|
||||
Each instrument is modeled by a single U-net
|
||||
convolutional / deconvolutional network that take a mix spectrogram
|
||||
as input and the estimated sound spectrogram as output.
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, Optional
|
||||
|
||||
from . import apply
|
||||
|
||||
# pyright: reportMissingImports=false
|
||||
# pylint: disable=import-error
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -30,20 +37,23 @@ from tensorflow.compat.v1 import logging
|
||||
from tensorflow.compat.v1.keras.initializers import he_uniform
|
||||
# pylint: enable=import-error
|
||||
|
||||
from . import apply
|
||||
|
||||
__email__ = 'spleeter@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
|
||||
def _get_conv_activation_layer(params):
|
||||
def _get_conv_activation_layer(params: Dict) -> Any:
|
||||
"""
|
||||
> To be documented.
|
||||
|
||||
:param params:
|
||||
:returns: Required Activation function.
|
||||
Parameters:
|
||||
params (Dict):
|
||||
|
||||
Returns:
|
||||
Any:
|
||||
Required Activation function.
|
||||
"""
|
||||
conv_activation = params.get('conv_activation')
|
||||
conv_activation: str = params.get('conv_activation')
|
||||
if conv_activation == 'ReLU':
|
||||
return ReLU()
|
||||
elif conv_activation == 'ELU':
|
||||
@@ -51,13 +61,18 @@ def _get_conv_activation_layer(params):
|
||||
return LeakyReLU(0.2)
|
||||
|
||||
|
||||
def _get_deconv_activation_layer(params):
|
||||
def _get_deconv_activation_layer(params: Dict) -> Any:
|
||||
"""
|
||||
> To be documented.
|
||||
|
||||
:param params:
|
||||
:returns: Required Activation function.
|
||||
Parameters:
|
||||
params (Dict):
|
||||
|
||||
Returns:
|
||||
Any:
|
||||
Required Activation function.
|
||||
"""
|
||||
deconv_activation = params.get('deconv_activation')
|
||||
deconv_activation: str = params.get('deconv_activation')
|
||||
if deconv_activation == 'LeakyReLU':
|
||||
return LeakyReLU(0.2)
|
||||
elif deconv_activation == 'ELU':
|
||||
@@ -66,17 +81,19 @@ def _get_deconv_activation_layer(params):
|
||||
|
||||
|
||||
def apply_unet(
|
||||
input_tensor,
|
||||
output_name='output',
|
||||
params={},
|
||||
output_mask_logit=False):
|
||||
""" Apply a convolutionnal U-net to model a single instrument (one U-net
|
||||
is used for each instrument).
|
||||
input_tensor: tf.Tensor,
|
||||
output_name: str = 'output',
|
||||
params: Optional[Dict] = None,
|
||||
output_mask_logit: bool = False) -> Any:
|
||||
"""
|
||||
Apply a convolutionnal U-net to model a single instrument (one U-net
|
||||
is used for each instrument).
|
||||
|
||||
:param input_tensor:
|
||||
:param output_name: (Optional) , default to 'output'
|
||||
:param params: (Optional) , default to empty dict.
|
||||
:param output_mask_logit: (Optional) , default to False.
|
||||
Parameters:
|
||||
input_tensor (tensorflow.Tensor):
|
||||
output_name (str):
|
||||
params (Optional[Dict]):
|
||||
output_mask_logit (bool):
|
||||
"""
|
||||
logging.info(f'Apply unet for {output_name}')
|
||||
conv_n_filters = params.get('conv_n_filters', [16, 32, 64, 128, 256, 512])
|
||||
@@ -170,18 +187,32 @@ def apply_unet(
|
||||
kernel_initializer=kernel_initializer)((batch12))
|
||||
|
||||
|
||||
def unet(input_tensor, instruments, params={}):
|
||||
def unet(
|
||||
input_tensor: tf.Tensor,
|
||||
instruments: Iterable[str],
|
||||
params: Optional[Dict] = None) -> Dict:
|
||||
""" Model function applier. """
|
||||
return apply(apply_unet, input_tensor, instruments, params)
|
||||
|
||||
|
||||
def softmax_unet(input_tensor, instruments, params={}):
|
||||
""" Apply softmax to multitrack unet in order to have mask suming to one.
|
||||
def softmax_unet(
|
||||
input_tensor: tf.Tensor,
|
||||
instruments: Iterable[str],
|
||||
params: Optional[Dict] = None) -> Dict:
|
||||
"""
|
||||
Apply softmax to multitrack unet in order to have mask suming to one.
|
||||
|
||||
:param input_tensor: Tensor to apply blstm to.
|
||||
:param instruments: Iterable that provides a collection of instruments.
|
||||
:param params: (Optional) dict of BLSTM parameters.
|
||||
:returns: Created output tensor dict.
|
||||
Parameters:
|
||||
input_tensor (tensorflow.Tensor):
|
||||
Tensor to apply blstm to.
|
||||
instruments (Iterable[str]):
|
||||
Iterable that provides a collection of instruments.
|
||||
params (Optional[Dict]):
|
||||
(Optional) dict of BLSTM parameters.
|
||||
|
||||
Returns:
|
||||
Dict:
|
||||
Created output tensor dict.
|
||||
"""
|
||||
logit_mask_list = []
|
||||
for instrument in instruments:
|
||||
|
||||
@@ -5,10 +5,12 @@
|
||||
This package provides tools for downloading model from network
|
||||
using remote storage abstraction.
|
||||
|
||||
:Example:
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> provider = MyProviderImplementation()
|
||||
>>> provider.get('/path/to/local/storage', params)
|
||||
```
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -26,39 +28,52 @@ class ModelProvider(ABC):
|
||||
file download is not available.
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL_PATH = environ.get('MODEL_PATH', 'pretrained_models')
|
||||
MODEL_PROBE_PATH = '.probe'
|
||||
DEFAULT_MODEL_PATH: str = environ.get('MODEL_PATH', 'pretrained_models')
|
||||
MODEL_PROBE_PATH: str = '.probe'
|
||||
|
||||
@abstractmethod
|
||||
def download(self, name, path):
|
||||
""" Download model denoted by the given name to disk.
|
||||
def download(_, name: str, path: str) -> None:
|
||||
"""
|
||||
Download model denoted by the given name to disk.
|
||||
|
||||
:param name: Name of the model to download.
|
||||
:param path: Path of the directory to save model into.
|
||||
Parameters:
|
||||
name (str):
|
||||
Name of the model to download.
|
||||
path (str):
|
||||
Path of the directory to save model into.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def writeProbe(directory):
|
||||
""" Write a model probe file into the given directory.
|
||||
|
||||
:param directory: Directory to write probe into.
|
||||
def writeProbe(directory: str) -> None:
|
||||
"""
|
||||
probe = join(directory, ModelProvider.MODEL_PROBE_PATH)
|
||||
Write a model probe file into the given directory.
|
||||
|
||||
Parameters:
|
||||
directory (str):
|
||||
Directory to write probe into.
|
||||
"""
|
||||
probe: str = join(directory, ModelProvider.MODEL_PROBE_PATH)
|
||||
with open(probe, 'w') as stream:
|
||||
stream.write('OK')
|
||||
|
||||
def get(self, model_directory):
|
||||
""" Ensures required model is available at given location.
|
||||
def get(self, model_directory: str) -> str:
|
||||
"""
|
||||
Ensures required model is available at given location.
|
||||
|
||||
:param model_directory: Expected model_directory to be available.
|
||||
:raise IOError: If model can not be retrieved.
|
||||
Parameters:
|
||||
model_directory (str):
|
||||
Expected model_directory to be available.
|
||||
|
||||
Raises:
|
||||
IOError:
|
||||
If model can not be retrieved.
|
||||
"""
|
||||
# Expend model directory if needed.
|
||||
if not isabs(model_directory):
|
||||
model_directory = join(self.DEFAULT_MODEL_PATH, model_directory)
|
||||
# Download it if not exists.
|
||||
model_probe = join(model_directory, self.MODEL_PROBE_PATH)
|
||||
model_probe: str = join(model_directory, self.MODEL_PROBE_PATH)
|
||||
if not exists(model_probe):
|
||||
if not exists(model_directory):
|
||||
makedirs(model_directory)
|
||||
@@ -68,14 +83,14 @@ class ModelProvider(ABC):
|
||||
self.writeProbe(model_directory)
|
||||
return model_directory
|
||||
|
||||
@classmethod
|
||||
def default(_: type) -> 'ModelProvider':
|
||||
"""
|
||||
Builds and returns a default model provider.
|
||||
|
||||
def get_default_model_provider():
|
||||
""" Builds and returns a default model provider.
|
||||
|
||||
:returns: A default model provider instance to use.
|
||||
"""
|
||||
from .github import GithubModelProvider
|
||||
host = environ.get('GITHUB_HOST', 'https://github.com')
|
||||
repository = environ.get('GITHUB_REPOSITORY', 'deezer/spleeter')
|
||||
release = environ.get('GITHUB_RELEASE', GithubModelProvider.LATEST_RELEASE)
|
||||
return GithubModelProvider(host, repository, release)
|
||||
Returns:
|
||||
ModelProvider:
|
||||
A default model provider instance to use.
|
||||
"""
|
||||
from .github import GithubModelProvider
|
||||
return GithubModelProvider.from_environ()
|
||||
|
||||
@@ -4,27 +4,34 @@
|
||||
"""
|
||||
A ModelProvider backed by Github Release feature.
|
||||
|
||||
:Example:
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from spleeter.model.provider import github
|
||||
>>> provider = github.GithubModelProvider(
|
||||
'github.com',
|
||||
'Deezer/spleeter',
|
||||
'latest')
|
||||
>>> provider.download('2stems', '/path/to/local/storage')
|
||||
```
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import tarfile
|
||||
import os
|
||||
|
||||
from os import environ
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import requests
|
||||
from typing import Dict
|
||||
|
||||
from . import ModelProvider
|
||||
from ...utils.logging import get_logger
|
||||
|
||||
# pyright: reportMissingImports=false
|
||||
# pylint: disable=import-error
|
||||
import httpx
|
||||
# pylint: enable=import-error
|
||||
|
||||
__email__ = 'spleeter@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
@@ -46,69 +53,108 @@ def compute_file_checksum(path):
|
||||
class GithubModelProvider(ModelProvider):
|
||||
""" A ModelProvider implementation backed on Github for remote storage. """
|
||||
|
||||
LATEST_RELEASE = 'v1.4.0'
|
||||
RELEASE_PATH = 'releases/download'
|
||||
CHECKSUM_INDEX = 'checksum.json'
|
||||
DEFAULT_HOST: str = 'https://github.com'
|
||||
DEFAULT_REPOSITORY: str = 'deezer/spleeter'
|
||||
|
||||
def __init__(self, host, repository, release):
|
||||
CHECKSUM_INDEX: str = 'checksum.json'
|
||||
LATEST_RELEASE: str = 'v1.4.0'
|
||||
RELEASE_PATH: str = 'releases/download'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
repository: str,
|
||||
release: str) -> None:
|
||||
""" Default constructor.
|
||||
|
||||
:param host: Host to the Github instance to reach.
|
||||
:param repository: Repository path within target Github.
|
||||
:param release: Release name to get models from.
|
||||
Parameters:
|
||||
host (str):
|
||||
Host to the Github instance to reach.
|
||||
repository (str):
|
||||
Repository path within target Github.
|
||||
release (str):
|
||||
Release name to get models from.
|
||||
"""
|
||||
self._host = host
|
||||
self._repository = repository
|
||||
self._release = release
|
||||
self._host: str = host
|
||||
self._repository: str = repository
|
||||
self._release: str = release
|
||||
|
||||
def checksum(self, name):
|
||||
""" Downloads and returns reference checksum for the given model name.
|
||||
|
||||
:param name: Name of the model to get checksum for.
|
||||
:returns: Checksum of the required model.
|
||||
:raise ValueError: If the given model name is not indexed.
|
||||
@classmethod
|
||||
def from_environ(cls: type) -> 'GithubModelProvider':
|
||||
"""
|
||||
url = '{}/{}/{}/{}/{}'.format(
|
||||
Factory method that creates provider from envvars.
|
||||
|
||||
Returns:
|
||||
GithubModelProvider:
|
||||
Created instance.
|
||||
"""
|
||||
return cls(
|
||||
environ.get('GITHUB_HOST', cls.DEFAULT_HOST),
|
||||
environ.get('GITHUB_REPOSITORY', cls.DEFAULT_REPOSITORY),
|
||||
environ.get('GITHUB_RELEASE', cls.LATEST_RELEASE))
|
||||
|
||||
def checksum(self, name: str) -> str:
|
||||
"""
|
||||
Downloads and returns reference checksum for the given model name.
|
||||
|
||||
Parameters:
|
||||
name (str):
|
||||
Name of the model to get checksum for.
|
||||
Returns:
|
||||
str:
|
||||
Checksum of the required model.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
If the given model name is not indexed.
|
||||
"""
|
||||
url: str = '/'.join((
|
||||
self._host,
|
||||
self._repository,
|
||||
self.RELEASE_PATH,
|
||||
self._release,
|
||||
self.CHECKSUM_INDEX)
|
||||
response = requests.get(url)
|
||||
self.CHECKSUM_INDEX))
|
||||
response: httpx.Response = httpx.get(url)
|
||||
response.raise_for_status()
|
||||
index = response.json()
|
||||
index: Dict = response.json()
|
||||
if name not in index:
|
||||
raise ValueError('No checksum for model {}'.format(name))
|
||||
raise ValueError(f'No checksum for model {name}')
|
||||
return index[name]
|
||||
|
||||
def download(self, name, path):
|
||||
""" Download model denoted by the given name to disk.
|
||||
|
||||
:param name: Name of the model to download.
|
||||
:param path: Path of the directory to save model into.
|
||||
def download(self, name: str, path: str) -> None:
|
||||
"""
|
||||
url = '{}/{}/{}/{}/{}.tar.gz'.format(
|
||||
Download model denoted by the given name to disk.
|
||||
|
||||
Parameters:
|
||||
name (str):
|
||||
Name of the model to download.
|
||||
path (str):
|
||||
Path of the directory to save model into.
|
||||
"""
|
||||
url: str = '/'.join((
|
||||
self._host,
|
||||
self._repository,
|
||||
self.RELEASE_PATH,
|
||||
self._release,
|
||||
name)
|
||||
get_logger().info('Downloading model archive %s', url)
|
||||
with requests.get(url, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
archive = NamedTemporaryFile(delete=False)
|
||||
try:
|
||||
with archive as stream:
|
||||
# Note: check for chunk size parameters ?
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
name))
|
||||
url = f'{url}.tar.gz'
|
||||
get_logger().info(f'Downloading model archive {url}')
|
||||
with httpx.Client(http2=True) as client:
|
||||
with client.strema('GET', url) as response:
|
||||
response.raise_for_status()
|
||||
archive = NamedTemporaryFile(delete=False)
|
||||
try:
|
||||
with archive as stream:
|
||||
for chunk in response.iter_raw():
|
||||
stream.write(chunk)
|
||||
get_logger().info('Validating archive checksum')
|
||||
if compute_file_checksum(archive.name) != self.checksum(name):
|
||||
raise IOError('Downloaded file is corrupted, please retry')
|
||||
get_logger().info('Extracting downloaded %s archive', name)
|
||||
with tarfile.open(name=archive.name) as tar:
|
||||
tar.extractall(path=path)
|
||||
finally:
|
||||
os.unlink(archive.name)
|
||||
get_logger().info('%s model file(s) extracted', name)
|
||||
get_logger().info('Validating archive checksum')
|
||||
checksum: str = compute_file_checksum(archive.name)
|
||||
if checksum != self.checksum(name):
|
||||
raise IOError(
|
||||
'Downloaded file is corrupted, please retry')
|
||||
get_logger().info(f'Extracting downloaded {name} archive')
|
||||
with tarfile.open(name=archive.name) as tar:
|
||||
tar.extractall(path=path)
|
||||
finally:
|
||||
os.unlink(archive.name)
|
||||
get_logger().info(f'{name} model file(s) extracted')
|
||||
|
||||
@@ -4,62 +4,60 @@
|
||||
"""
|
||||
Module that provides a class wrapper for source separation.
|
||||
|
||||
:Example:
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from spleeter.separator import Separator
|
||||
>>> separator = Separator('spleeter:2stems')
|
||||
>>> separator.separate(waveform, lambda instrument, data: ...)
|
||||
>>> separator.separate_to_file(...)
|
||||
```
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import os
|
||||
import logging
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from multiprocessing import Pool
|
||||
from os.path import basename, join, splitext, dirname
|
||||
from time import time
|
||||
from typing import Container, NoReturn
|
||||
from typing import Generator, Optional
|
||||
|
||||
from . import SpleeterError
|
||||
from .audio import STFTBackend
|
||||
from .audio.adapter import get_default_audio_adapter
|
||||
from .audio.convertor import to_stereo
|
||||
from .model import EstimatorSpecBuilder, InputProviderFactory
|
||||
from .utils.configuration import load_configuration
|
||||
|
||||
# pyright: reportMissingImports=false
|
||||
# pylint: disable=import-error
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from librosa.core import stft, istft
|
||||
from scipy.signal.windows import hann
|
||||
|
||||
from . import SpleeterError
|
||||
from .audio.adapter import get_default_audio_adapter
|
||||
from .audio.convertor import to_stereo
|
||||
from .utils.configuration import load_configuration
|
||||
from .utils.estimator import create_estimator, get_default_model_dir
|
||||
from .model import EstimatorSpecBuilder, InputProviderFactory
|
||||
# pylint: enable=import-error
|
||||
|
||||
__email__ = 'spleeter@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
SUPPORTED_BACKEND: Container[str] = ('auto', 'tensorflow', 'librosa')
|
||||
""" """
|
||||
|
||||
|
||||
class DataGenerator():
|
||||
class DataGenerator(object):
|
||||
"""
|
||||
Generator object that store a sample and generate it once while called.
|
||||
Used to feed a tensorflow estimator without knowing the whole data at
|
||||
build time.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
""" Default constructor. """
|
||||
self._current_data = None
|
||||
|
||||
def update_data(self, data):
|
||||
def update_data(self, data) -> None:
|
||||
""" Replace internal data. """
|
||||
self._current_data = data
|
||||
|
||||
def __call__(self):
|
||||
def __call__(self) -> Generator:
|
||||
""" Generation process. """
|
||||
buffer = self._current_data
|
||||
while buffer:
|
||||
@@ -79,19 +77,50 @@ def get_backend(backend: str) -> str:
|
||||
return backend
|
||||
|
||||
|
||||
def create_estimator(params, MWF):
|
||||
"""
|
||||
Initialize tensorflow estimator that will perform separation
|
||||
|
||||
Params:
|
||||
- params: a dictionary of parameters for building the model
|
||||
|
||||
Returns:
|
||||
a tensorflow estimator
|
||||
"""
|
||||
# Load model.
|
||||
provider: ModelProvider = ModelProvider.default()
|
||||
params['model_dir'] = provider.get(params['model_dir'])
|
||||
params['MWF'] = MWF
|
||||
# Setup config
|
||||
session_config = tf.compat.v1.ConfigProto()
|
||||
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
|
||||
config = tf.estimator.RunConfig(session_config=session_config)
|
||||
# Setup estimator
|
||||
estimator = tf.estimator.Estimator(
|
||||
model_fn=model_fn,
|
||||
model_dir=params['model_dir'],
|
||||
params=params,
|
||||
config=config)
|
||||
return estimator
|
||||
|
||||
|
||||
class Separator(object):
|
||||
""" A wrapper class for performing separation. """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params_descriptor,
|
||||
params_descriptor: str,
|
||||
MWF: bool = False,
|
||||
stft_backend: str = 'auto',
|
||||
multiprocess: bool = True):
|
||||
""" Default constructor.
|
||||
stft_backend: STFTBackend = STFTBackend.AUTO,
|
||||
multiprocess: bool = True) -> None:
|
||||
"""
|
||||
Default constructor.
|
||||
|
||||
:param params_descriptor: Descriptor for TF params to be used.
|
||||
:param MWF: (Optional) True if MWF should be used, False otherwise.
|
||||
Parameters:
|
||||
params_descriptor (str):
|
||||
Descriptor for TF params to be used.
|
||||
MWF (bool):
|
||||
(Optional) `True` if MWF should be used, `False` otherwise.
|
||||
"""
|
||||
self._params = load_configuration(params_descriptor)
|
||||
self._sample_rate = self._params['sample_rate']
|
||||
@@ -111,8 +140,7 @@ class Separator(object):
|
||||
self._params['stft_backend'] = get_backend(stft_backend)
|
||||
self._data_generator = DataGenerator()
|
||||
|
||||
def __del__(self):
|
||||
""" """
|
||||
def __del__(self) -> None:
|
||||
if self._session:
|
||||
self._session.close()
|
||||
|
||||
@@ -140,35 +168,19 @@ class Separator(object):
|
||||
yield_single_examples=False)
|
||||
return self._prediction_generator
|
||||
|
||||
def join(self, timeout: int = 200) -> NoReturn:
|
||||
""" Wait for all pending tasks to be finished.
|
||||
def join(self, timeout: int = 200) -> None:
|
||||
"""
|
||||
Wait for all pending tasks to be finished.
|
||||
|
||||
:param timeout: (Optional) task waiting timeout.
|
||||
Parameters:
|
||||
timeout (int):
|
||||
(Optional) task waiting timeout.
|
||||
"""
|
||||
while len(self._tasks) > 0:
|
||||
task = self._tasks.pop()
|
||||
task.get()
|
||||
task.wait(timeout=timeout)
|
||||
|
||||
def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
|
||||
""" Performs source separation over the given waveform with tensorflow
|
||||
backend.
|
||||
|
||||
:param waveform: Waveform to apply separation on.
|
||||
:returns: Separated waveforms.
|
||||
"""
|
||||
if not waveform.shape[-1] == 2:
|
||||
waveform = to_stereo(waveform)
|
||||
prediction_generator = self._get_prediction_generator()
|
||||
# NOTE: update data in generator before performing separation.
|
||||
self._data_generator.update_data({
|
||||
'waveform': waveform,
|
||||
'audio_id': np.array(audio_descriptor)})
|
||||
# NOTE: perform separation.
|
||||
prediction = next(prediction_generator)
|
||||
prediction.pop('audio_id')
|
||||
return prediction
|
||||
|
||||
def _stft(self, data, inverse: bool = False, length=None):
|
||||
""" Single entrypoint for both stft and istft. This computes stft and
|
||||
istft with librosa on stereo data. The two channels are processed
|
||||
@@ -233,7 +245,12 @@ class Separator(object):
|
||||
return self._session
|
||||
|
||||
def _separate_librosa(self, waveform: np.ndarray, audio_id):
|
||||
""" Performs separation with librosa backend for STFT.
|
||||
"""
|
||||
Performs separation with librosa backend for STFT.
|
||||
|
||||
Parameters:
|
||||
waveform (numpy.ndarray):
|
||||
Waveform to be separated (as a numpy array)
|
||||
"""
|
||||
with self._tf_graph.as_default():
|
||||
out = {}
|
||||
@@ -260,12 +277,42 @@ class Separator(object):
|
||||
length=waveform.shape[0])
|
||||
return out
|
||||
|
||||
def separate(self, waveform: np.ndarray, audio_descriptor=''):
|
||||
""" Performs separation on a waveform.
|
||||
def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
|
||||
"""
|
||||
Performs source separation over the given waveform with tensorflow
|
||||
backend.
|
||||
|
||||
:param waveform: Waveform to be separated (as a numpy array)
|
||||
:param audio_descriptor: (Optional) string describing the waveform
|
||||
(e.g. filename).
|
||||
Parameters:
|
||||
waveform (numpy.ndarray):
|
||||
Waveform to be separated (as a numpy array)
|
||||
|
||||
Returns:
|
||||
Separated waveforms.
|
||||
"""
|
||||
if not waveform.shape[-1] == 2:
|
||||
waveform = to_stereo(waveform)
|
||||
prediction_generator = self._get_prediction_generator()
|
||||
# NOTE: update data in generator before performing separation.
|
||||
self._data_generator.update_data({
|
||||
'waveform': waveform,
|
||||
'audio_id': np.array(audio_descriptor)})
|
||||
# NOTE: perform separation.
|
||||
prediction = next(prediction_generator)
|
||||
prediction.pop('audio_id')
|
||||
return prediction
|
||||
|
||||
def separate(
|
||||
self,
|
||||
waveform: np.ndarray,
|
||||
audio_descriptor: Optional[str] = None) -> None:
|
||||
"""
|
||||
Performs separation on a waveform.
|
||||
|
||||
Parameters:
|
||||
waveform (numpy.ndarray):
|
||||
Waveform to be separated (as a numpy array)
|
||||
audio_descriptor (str):
|
||||
(Optional) string describing the waveform (e.g. filename).
|
||||
"""
|
||||
if self._params['stft_backend'] == 'tensorflow':
|
||||
return self._separate_tensorflow(waveform, audio_descriptor)
|
||||
|
||||
@@ -4,14 +4,10 @@
|
||||
""" Module that provides configuration loading function. """
|
||||
|
||||
import json
|
||||
|
||||
try:
|
||||
import importlib.resources as loader
|
||||
except ImportError:
|
||||
# Try backported to PY<37 `importlib_resources`.
|
||||
import importlib_resources as loader
|
||||
import importlib.resources as loader
|
||||
|
||||
from os.path import exists
|
||||
from typing import Dict
|
||||
|
||||
from .. import resources, SpleeterError
|
||||
|
||||
@@ -20,18 +16,28 @@ __email__ = 'spleeter@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
_EMBEDDED_CONFIGURATION_PREFIX = 'spleeter:'
|
||||
_EMBEDDED_CONFIGURATION_PREFIX: str = 'spleeter:'
|
||||
|
||||
|
||||
def load_configuration(descriptor):
|
||||
""" Load configuration from the given descriptor. Could be
|
||||
either a `spleeter:` prefixed embedded configuration name
|
||||
or a file system path to read configuration from.
|
||||
def load_configuration(descriptor: str) -> Dict:
|
||||
"""
|
||||
Load configuration from the given descriptor. Could be either a
|
||||
`spleeter:` prefixed embedded configuration name or a file system path
|
||||
to read configuration from.
|
||||
|
||||
:param descriptor: Configuration descriptor to use for lookup.
|
||||
:returns: Loaded description as dict.
|
||||
:raise ValueError: If required embedded configuration does not exists.
|
||||
:raise SpleeterError: If required configuration file does not exists.
|
||||
Parameters:
|
||||
descriptor (str):
|
||||
Configuration descriptor to use for lookup.
|
||||
|
||||
Returns:
|
||||
Dict:
|
||||
Loaded description as dict.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
If required embedded configuration does not exists.
|
||||
SpleeterError:
|
||||
If required configuration file does not exists.
|
||||
"""
|
||||
# Embedded configuration reading.
|
||||
if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX):
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf8
|
||||
|
||||
""" Utility functions for creating estimator. """
|
||||
|
||||
import tensorflow as tf # pylint: disable=import-error
|
||||
|
||||
from ..model import model_fn
|
||||
from ..model.provider import get_default_model_provider
|
||||
|
||||
|
||||
def get_default_model_dir(model_dir):
|
||||
"""
|
||||
Transforms a string like 'spleeter:2stems' into an actual path.
|
||||
:param model_dir:
|
||||
:return:
|
||||
"""
|
||||
model_provider = get_default_model_provider()
|
||||
return model_provider.get(model_dir)
|
||||
|
||||
|
||||
def create_estimator(params, MWF):
|
||||
"""
|
||||
Initialize tensorflow estimator that will perform separation
|
||||
|
||||
Params:
|
||||
- params: a dictionary of parameters for building the model
|
||||
|
||||
Returns:
|
||||
a tensorflow estimator
|
||||
"""
|
||||
# Load model.
|
||||
params['model_dir'] = get_default_model_dir(params['model_dir'])
|
||||
params['MWF'] = MWF
|
||||
# Setup config
|
||||
session_config = tf.compat.v1.ConfigProto()
|
||||
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
|
||||
config = tf.estimator.RunConfig(session_config=session_config)
|
||||
# Setup estimator
|
||||
estimator = tf.estimator.Estimator(
|
||||
model_fn=model_fn,
|
||||
model_dir=params['model_dir'],
|
||||
params=params,
|
||||
config=config
|
||||
)
|
||||
return estimator
|
||||
Reference in New Issue
Block a user