🚧 typer integration

This commit is contained in:
Faylixe
2020-12-08 11:26:08 +01:00
parent 1991b222e2
commit 054fcdde46
5 changed files with 144 additions and 480 deletions

View File

@@ -1,13 +1,22 @@
#!/usr/bin/env python
# coding: utf8
""" TO DOCUMENT """
"""
Python oneliner script usage.
USAGE: python -m spleeter {train,evaluate,separate} ...
"""
import json
from functools import partial
from itertools import product
from glob import glob
from os.path import join
from pathlib import Path
from typing import List
from typing import Any, Container, Dict, List
from . import SpleeterError
from .audio import Codec
from .audio.adapter import AudioAdapter
from .options import *
@@ -16,11 +25,12 @@ from .model import model_fn
from .model.provider import ModelProvider
from .separator import Separator
from .utils.configuration import load_configuration
from .utils.logging import get_logger
from .utils.logging import configure_logger, logger
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
import pandas as pd
import tensorflow as tf
from typer import Exit, Typer
@@ -39,8 +49,7 @@ def train(
"""
Train a source separation model
"""
# TODO: try / catch or custom decorator for function handling.
# TODO: handle verbose flag ?
configure_logger(verbose)
audio_adapter = AudioAdapter.get(adapter)
audio_path = str(data)
params = load_configuration(params_filename)
@@ -70,120 +79,19 @@ def train(
input_fn=input_fn,
steps=None,
throttle_secs=params['throttle_secs'])
get_logger().info('Start model training')
logger.info('Start model training')
tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec)
ModelProvider.writeProbe(params['model_dir'])
get_logger().info('Model training done')
_SPLIT = 'test'
_MIXTURE = 'mixture.wav'
_AUDIO_DIRECTORY = 'audio'
_METRICS_DIRECTORY = 'metrics'
_INSTRUMENTS = ('vocals', 'drums', 'bass', 'other')
_METRICS = ('SDR', 'SAR', 'SIR', 'ISR')
def _compute_musdb_metrics(
arguments,
musdb_root_directory,
audio_output_directory):
""" Generates musdb metrics fro previsouly computed audio estimation.
:param arguments: Entrypoint arguments.
:param audio_output_directory: Directory to get audio estimation from.
:returns: Path of generated metrics directory.
"""
metrics_output_directory = join(
arguments.output_path,
_METRICS_DIRECTORY)
get_logger().info('Starting musdb evaluation (this could be long) ...')
try:
import musdb
import museval
except ImportError:
logger = get_logger()
logger.error('Extra dependencies musdb and museval not found')
logger.error('Please install musdb and museval first, abort')
raise Exit(10)
dataset = musdb.DB(
root=musdb_root_directory,
is_wav=True,
subsets=[_SPLIT])
museval.eval_mus_dir(
dataset=dataset,
estimates_dir=audio_output_directory,
output_dir=metrics_output_directory)
get_logger().info('musdb evaluation done')
return metrics_output_directory
def _compile_metrics(metrics_output_directory):
""" Compiles metrics from given directory and returns
results as dict.
:param metrics_output_directory: Directory to get metrics from.
:returns: Compiled metrics as dict.
"""
songs = glob(join(metrics_output_directory, 'test/*.json'))
index = pd.MultiIndex.from_tuples(
product(_INSTRUMENTS, _METRICS),
names=['instrument', 'metric'])
pd.DataFrame([], index=['config1', 'config2'], columns=index)
metrics = {
instrument: {k: [] for k in _METRICS}
for instrument in _INSTRUMENTS}
for song in songs:
with open(song, 'r') as stream:
data = json.load(stream)
for target in data['targets']:
instrument = target['name']
for metric in _METRICS:
sdr_med = np.median([
frame['metrics'][metric]
for frame in target['frames']
if not np.isnan(frame['metrics'][metric])])
metrics[instrument][metric].append(sdr_med)
return metrics
@spleeter.command()
def evaluate(
adapter: str = AudioAdapterOption,
output_path: Path = AudioAdapterOption,
stft_backend: STFTBackend = AudioSTFTBackendOption,
params_filename: str = ModelParametersOption,
mus_dir: Path = MUSDBDirectoryOption,
mwf: bool = MWFOption,
verbose: bool = VerboseOption) -> None:
"""
Evaluate a model on the musDB test dataset
"""
# Separate musdb sources.
audio_output_directory = _separate_evaluation_dataset(
arguments,
mus_dir,
params)
# Compute metrics with musdb.
metrics_output_directory = _compute_musdb_metrics(
arguments,
mus_dir,
audio_output_directory)
# Compute and pretty print median metrics.
metrics = _compile_metrics(metrics_output_directory)
for instrument, metric in metrics.items():
get_logger().info('%s:', instrument)
for metric, value in metric.items():
get_logger().info('%s: %s', metric, f'{np.median(value):.3f}')
return metrics
logger.info('Model training done')
@spleeter.commmand()
def separate(
files: List[Path] = AudioInputArgument,
adapter: str = AudioAdapterOption,
bitrate: str = AudioBitrateOption,
codec: Codec = AudioCodecOption,
duration: float = AudioDurationOption,
files: List[Path] = AudioInputArgument,
offset: float = AudioOffsetOption,
output_path: Path = AudioAdapterOption,
stft_backend: STFTBackend = AudioSTFTBackendOption,
@@ -194,13 +102,7 @@ def separate(
"""
Separate audio file(s)
"""
# TODO: try / catch or custom decorator for function handling.
# TODO: enable_logging()
# TODO: handle MWF
if verbose:
# TODO: enable_tensorflow_logging()
pass
# PREV: params = load_configuration(arguments.configuration)
configure_logger(verbose)
audio_adapter: AudioAdapter = AudioAdapter.get(adapter)
separator: Separator = Separator(
params_filename,
@@ -220,6 +122,102 @@ def separate(
separator.join()
EVALUATION_SPLIT: str = 'test'
EVALUATION_METRICS_DIRECTORY: str = 'metrics'
EVALUATION_INSTRUMENTS: Container[str] = ('vocals', 'drums', 'bass', 'other')
EVALUATION_METRICS: Container[str] = ('SDR', 'SAR', 'SIR', 'ISR')
EVALUATION_MIXTURE: str = 'mixture.wav'
EVALUATION_AUDIO_DIRECTORY: str = 'audio'
def _compile_metrics(metrics_output_directory) -> Dict:
"""
Compiles metrics from given directory and returns results as dict.
Parameters:
metrics_output_directory (str):
Directory to get metrics from.
Returns:
Dict:
Compiled metrics as dict.
"""
songs = glob(join(metrics_output_directory, 'test/*.json'))
index = pd.MultiIndex.from_tuples(
product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS),
names=['instrument', 'metric'])
pd.DataFrame([], index=['config1', 'config2'], columns=index)
metrics = {
instrument: {k: [] for k in EVALUATION_METRICS}
for instrument in EVALUATION_INSTRUMENTS}
for song in songs:
with open(song, 'r') as stream:
data = json.load(stream)
for target in data['targets']:
instrument = target['name']
for metric in EVALUATION_METRICS:
sdr_med = np.median([
frame['metrics'][metric]
for frame in target['frames']
if not np.isnan(frame['metrics'][metric])])
metrics[instrument][metric].append(sdr_med)
return metrics
@spleeter.command()
def evaluate(
adapter: str = AudioAdapterOption,
output_path: Path = AudioAdapterOption,
stft_backend: STFTBackend = AudioSTFTBackendOption,
params_filename: str = ModelParametersOption,
mus_dir: Path = MUSDBDirectoryOption,
mwf: bool = MWFOption,
verbose: bool = VerboseOption) -> Dict:
"""
Evaluate a model on the musDB test dataset
"""
configure_logger(verbose)
try:
import musdb
import museval
except ImportError:
logger.error('Extra dependencies musdb and museval not found')
logger.error('Please install musdb and museval first, abort')
raise Exit(10)
# Separate musdb sources.
songs = glob(join(mus_dir, EVALUATION_SPLIT, '*/'))
mixtures = [join(song, EVALUATION_MIXTURE) for song in songs]
audio_output_directory = join(output_path, EVALUATION_AUDIO_DIRECTORY)
separate(
adapter=adapter,
params_filename=params_filename,
files=mixtures,
output_path=output_path,
filename_format='{foldername}/{instrument}.{codec}',
codec=Codec.WAV,
mwf=mwf,
verbose=verbose,
stft_backend=stft_backend)
# Compute metrics with musdb.
metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY)
logger.info('Starting musdb evaluation (this could be long) ...')
dataset = musdb.DB(root=mus_dir, is_wav=True, subsets=[EVALUATION_SPLIT])
museval.eval_mus_dir(
dataset=dataset,
estimates_dir=audio_output_directory,
output_dir=metrics_output_directory)
logger.info('musdb evaluation done')
# Compute and pretty print median metrics.
metrics = _compile_metrics(metrics_output_directory)
for instrument, metric in metrics.items():
logger.info(f'{instrument}:')
for metric, value in metric.items():
logger.info(f'{metric}: {np.median(value):.3f}')
return metrics
if __name__ == '__main__':
# TODO: warnings.filterwarnings('ignore')
spleeter()
try:
spleeter()
except SpleeterError as e:
logger.error(e)

View File

@@ -1,166 +0,0 @@
#!/usr/bin/env python
# coding: utf8
"""
Entrypoint provider for performing model evaluation.
Evaluation is performed against musDB dataset.
USAGE: python -m spleeter evaluate \
-p /path/to/params \
-o /path/to/output/dir \
[-m] \
--mus_dir /path/to/musdb dataset
"""
import sys
import json
from argparse import Namespace
from itertools import product
from glob import glob
from os.path import join, exists
# pylint: disable=import-error
import numpy as np
import pandas as pd
# pylint: enable=import-error
from .separate import entrypoint as separate_entrypoint
from ..utils.logging import get_logger
try:
import musdb
import museval
except ImportError:
logger = get_logger()
logger.error('Extra dependencies musdb and museval not found')
logger.error('Please install musdb and museval first, abort')
sys.exit(1)
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
_SPLIT = 'test'
_MIXTURE = 'mixture.wav'
_AUDIO_DIRECTORY = 'audio'
_METRICS_DIRECTORY = 'metrics'
_INSTRUMENTS = ('vocals', 'drums', 'bass', 'other')
_METRICS = ('SDR', 'SAR', 'SIR', 'ISR')
def _separate_evaluation_dataset(arguments, musdb_root_directory, params):
""" Performs audio separation on the musdb dataset from
the given directory and params.
:param arguments: Entrypoint arguments.
:param musdb_root_directory: Directory to retrieve dataset from.
:param params: Spleeter configuration to apply to separation.
:returns: Separation output directory path.
"""
songs = glob(join(musdb_root_directory, _SPLIT, '*/'))
mixtures = [join(song, _MIXTURE) for song in songs]
audio_output_directory = join(
arguments.output_path,
_AUDIO_DIRECTORY)
separate_entrypoint(
Namespace(
audio_adapter=arguments.audio_adapter,
configuration=arguments.configuration,
inputs=mixtures,
output_path=join(audio_output_directory, _SPLIT),
filename_format='{foldername}/{instrument}.{codec}',
codec='wav',
duration=600.,
offset=0.,
bitrate='128k',
MWF=arguments.MWF,
verbose=arguments.verbose,
stft_backend=arguments.stft_backend),
params)
return audio_output_directory
def _compute_musdb_metrics(
arguments,
musdb_root_directory,
audio_output_directory):
""" Generates musdb metrics fro previsouly computed audio estimation.
:param arguments: Entrypoint arguments.
:param audio_output_directory: Directory to get audio estimation from.
:returns: Path of generated metrics directory.
"""
metrics_output_directory = join(
arguments.output_path,
_METRICS_DIRECTORY)
get_logger().info('Starting musdb evaluation (this could be long) ...')
dataset = musdb.DB(
root=musdb_root_directory,
is_wav=True,
subsets=[_SPLIT])
museval.eval_mus_dir(
dataset=dataset,
estimates_dir=audio_output_directory,
output_dir=metrics_output_directory)
get_logger().info('musdb evaluation done')
return metrics_output_directory
def _compile_metrics(metrics_output_directory):
""" Compiles metrics from given directory and returns
results as dict.
:param metrics_output_directory: Directory to get metrics from.
:returns: Compiled metrics as dict.
"""
songs = glob(join(metrics_output_directory, 'test/*.json'))
index = pd.MultiIndex.from_tuples(
product(_INSTRUMENTS, _METRICS),
names=['instrument', 'metric'])
pd.DataFrame([], index=['config1', 'config2'], columns=index)
metrics = {
instrument: {k: [] for k in _METRICS}
for instrument in _INSTRUMENTS}
for song in songs:
with open(song, 'r') as stream:
data = json.load(stream)
for target in data['targets']:
instrument = target['name']
for metric in _METRICS:
sdr_med = np.median([
frame['metrics'][metric]
for frame in target['frames']
if not np.isnan(frame['metrics'][metric])])
metrics[instrument][metric].append(sdr_med)
return metrics
def entrypoint(arguments, params):
""" Command entrypoint.
:param arguments: Command line parsed argument as argparse.Namespace.
:param params: Deserialized JSON configuration file provided in CLI args.
"""
# Parse and check musdb directory.
musdb_root_directory = arguments.mus_dir
if not exists(musdb_root_directory):
raise IOError(f'musdb directory {musdb_root_directory} not found')
# Separate musdb sources.
audio_output_directory = _separate_evaluation_dataset(
arguments,
musdb_root_directory,
params)
# Compute metrics with musdb.
metrics_output_directory = _compute_musdb_metrics(
arguments,
musdb_root_directory,
audio_output_directory)
# Compute and pretty print median metrics.
metrics = _compile_metrics(metrics_output_directory)
for instrument, metric in metrics.items():
get_logger().info('%s:', instrument)
for metric, value in metric.items():
get_logger().info('%s: %s', metric, f'{np.median(value):.3f}')
return metrics

View File

@@ -1,51 +0,0 @@
#!/usr/bin/env python
# coding: utf8
"""
Entrypoint provider for performing source separation.
USAGE: python -m spleeter separate \
-p /path/to/params \
-i inputfile1 inputfile2 ... inputfilen
-o /path/to/output/dir \
-i /path/to/audio1.wav /path/to/audio2.mp3
"""
from ..audio.adapter import get_audio_adapter
from ..separator import Separator
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
from typer import Option
AudioAdapter = Option()
def entrypoint(arguments, params):
""" Command entrypoint.
:param arguments: Command line parsed argument as argparse.Namespace.
:param params: Deserialized JSON configuration file provided in CLI args.
"""
# TODO: check with output naming.
audio_adapter = get_audio_adapter(arguments.audio_adapter)
separator = Separator(
arguments.configuration,
MWF=arguments.MWF,
stft_backend=arguments.stft_backend)
for filename in arguments.inputs:
separator.separate_to_file(
filename,
arguments.output_path,
audio_adapter=audio_adapter,
offset=arguments.offset,
duration=arguments.duration,
codec=arguments.codec,
bitrate=arguments.bitrate,
filename_format=arguments.filename_format,
synchronous=False
)
separator.join()

View File

@@ -1,100 +0,0 @@
#!/usr/bin/env python
# coding: utf8
"""
Entrypoint provider for performing model training.
USAGE: python -m spleeter train -p /path/to/params
"""
from functools import partial
# pylint: disable=import-error
import tensorflow as tf
# pylint: enable=import-error
from ..audio.adapter import get_audio_adapter
from ..dataset import get_training_dataset, get_validation_dataset
from ..model import model_fn
from ..model.provider import ModelProvider
from ..utils.logging import get_logger
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def _create_estimator(params):
""" Creates estimator.
:param params: TF params to build estimator from.
:returns: Built estimator.
"""
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params['model_dir'],
params=params,
config=tf.estimator.RunConfig(
save_checkpoints_steps=params['save_checkpoints_steps'],
tf_random_seed=params['random_seed'],
save_summary_steps=params['save_summary_steps'],
session_config=session_config,
log_step_count_steps=10,
keep_checkpoint_max=2))
return estimator
def _create_train_spec(params, audio_adapter, audio_path):
""" Creates train spec.
:param params: TF params to build spec from.
:returns: Built train spec.
"""
input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
train_spec = tf.estimator.TrainSpec(
input_fn=input_fn,
max_steps=params['train_max_steps'])
return train_spec
def _create_evaluation_spec(params, audio_adapter, audio_path):
""" Setup eval spec evaluating ever n seconds
:param params: TF params to build spec from.
:returns: Built evaluation spec.
"""
input_fn = partial(
get_validation_dataset,
params,
audio_adapter,
audio_path)
evaluation_spec = tf.estimator.EvalSpec(
input_fn=input_fn,
steps=None,
throttle_secs=params['throttle_secs'])
return evaluation_spec
def entrypoint(arguments, params):
""" Command entrypoint.
:param arguments: Command line parsed argument as argparse.Namespace.
:param params: Deserialized JSON configuration file provided in CLI args.
"""
audio_adapter = get_audio_adapter(arguments.audio_adapter)
audio_path = arguments.audio_path
estimator = _create_estimator(params)
train_spec = _create_train_spec(params, audio_adapter, audio_path)
evaluation_spec = _create_evaluation_spec(
params,
audio_adapter,
audio_path)
get_logger().info('Start model training')
tf.estimator.train_and_evaluate(
estimator,
train_spec,
evaluation_spec)
ModelProvider.writeProbe(params['model_dir'])
get_logger().info('Model training done')

View File

@@ -4,58 +4,41 @@
""" Centralized logging facilities for Spleeter. """
import logging
import warnings
from os import environ
# pyright: reportMissingImports=false
# pylint: disable=import-error
from tensorflow.compat.v1 import logging as tflogging
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
_FORMAT = '%(levelname)s:%(name)s:%(message)s'
formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger: logging.Logger = logging.getLogger('spleeter')
logger.addHandler(handler)
logger.setLevel(logging.INFO)
class _LoggerHolder(object):
""" Logger singleton instance holder. """
INSTANCE = None
def get_tensorflow_logger():
def configure_logger(verbose: bool) -> None:
"""
Configure application logger.
Parameters:
verbose (bool):
`True` to use verbose logger, `False` otherwise.
"""
# pylint: disable=import-error
from tensorflow.compat.v1 import logging
# pylint: enable=import-error
return logging
def get_logger():
""" Returns library scoped logger.
:returns: Library logger.
"""
if _LoggerHolder.INSTANCE is None:
formatter = logging.Formatter(_FORMAT)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger = logging.getLogger('spleeter')
logger.addHandler(handler)
logger.setLevel(logging.INFO)
_LoggerHolder.INSTANCE = logger
return _LoggerHolder.INSTANCE
def enable_tensorflow_logging():
""" Enable tensorflow logging. """
environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
tf_logger = get_tensorflow_logger()
tf_logger.set_verbosity(tf_logger.INFO)
logger = get_logger()
logger.setLevel(logging.DEBUG)
def enable_logging():
""" Configure default logging. """
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf_logger = get_tensorflow_logger()
tf_logger.set_verbosity(tf_logger.ERROR)
if verbose:
environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
tflogging.set_verbosity(tflogging.INFO)
logger.setLevel(logging.DEBUG)
else:
warnings.filterwarnings('ignore')
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tflogging.set_verbosity(tflogging.ERROR)