diff --git a/spleeter/__main__.py b/spleeter/__main__.py index 0c49505..52b90f1 100644 --- a/spleeter/__main__.py +++ b/spleeter/__main__.py @@ -1,13 +1,22 @@ #!/usr/bin/env python # coding: utf8 -""" TO DOCUMENT """ +""" + Python oneliner script usage. + USAGE: python -m spleeter {train,evaluate,separate} ... +""" + +import json from functools import partial +from itertools import product +from glob import glob +from os.path import join from pathlib import Path -from typing import List +from typing import Any, Container, Dict, List +from . import SpleeterError from .audio import Codec from .audio.adapter import AudioAdapter from .options import * @@ -16,11 +25,12 @@ from .model import model_fn from .model.provider import ModelProvider from .separator import Separator from .utils.configuration import load_configuration -from .utils.logging import get_logger - +from .utils.logging import configure_logger, logger # pyright: reportMissingImports=false # pylint: disable=import-error +import numpy as np +import pandas as pd import tensorflow as tf from typer import Exit, Typer @@ -39,8 +49,7 @@ def train( """ Train a source separation model """ - # TODO: try / catch or custom decorator for function handling. - # TODO: handle verbose flag ? + configure_logger(verbose) audio_adapter = AudioAdapter.get(adapter) audio_path = str(data) params = load_configuration(params_filename) @@ -70,120 +79,19 @@ def train( input_fn=input_fn, steps=None, throttle_secs=params['throttle_secs']) - get_logger().info('Start model training') + logger.info('Start model training') tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec) ModelProvider.writeProbe(params['model_dir']) - get_logger().info('Model training done') - -_SPLIT = 'test' -_MIXTURE = 'mixture.wav' -_AUDIO_DIRECTORY = 'audio' -_METRICS_DIRECTORY = 'metrics' -_INSTRUMENTS = ('vocals', 'drums', 'bass', 'other') -_METRICS = ('SDR', 'SAR', 'SIR', 'ISR') - - -def _compute_musdb_metrics( - arguments, - musdb_root_directory, - audio_output_directory): - """ Generates musdb metrics fro previsouly computed audio estimation. - - :param arguments: Entrypoint arguments. - :param audio_output_directory: Directory to get audio estimation from. - :returns: Path of generated metrics directory. - """ - metrics_output_directory = join( - arguments.output_path, - _METRICS_DIRECTORY) - get_logger().info('Starting musdb evaluation (this could be long) ...') - try: - import musdb - import museval - except ImportError: - logger = get_logger() - logger.error('Extra dependencies musdb and museval not found') - logger.error('Please install musdb and museval first, abort') - raise Exit(10) - dataset = musdb.DB( - root=musdb_root_directory, - is_wav=True, - subsets=[_SPLIT]) - museval.eval_mus_dir( - dataset=dataset, - estimates_dir=audio_output_directory, - output_dir=metrics_output_directory) - get_logger().info('musdb evaluation done') - return metrics_output_directory - - -def _compile_metrics(metrics_output_directory): - """ Compiles metrics from given directory and returns - results as dict. - - :param metrics_output_directory: Directory to get metrics from. - :returns: Compiled metrics as dict. - """ - songs = glob(join(metrics_output_directory, 'test/*.json')) - index = pd.MultiIndex.from_tuples( - product(_INSTRUMENTS, _METRICS), - names=['instrument', 'metric']) - pd.DataFrame([], index=['config1', 'config2'], columns=index) - metrics = { - instrument: {k: [] for k in _METRICS} - for instrument in _INSTRUMENTS} - for song in songs: - with open(song, 'r') as stream: - data = json.load(stream) - for target in data['targets']: - instrument = target['name'] - for metric in _METRICS: - sdr_med = np.median([ - frame['metrics'][metric] - for frame in target['frames'] - if not np.isnan(frame['metrics'][metric])]) - metrics[instrument][metric].append(sdr_med) - return metrics - - -@spleeter.command() -def evaluate( - adapter: str = AudioAdapterOption, - output_path: Path = AudioAdapterOption, - stft_backend: STFTBackend = AudioSTFTBackendOption, - params_filename: str = ModelParametersOption, - mus_dir: Path = MUSDBDirectoryOption, - mwf: bool = MWFOption, - verbose: bool = VerboseOption) -> None: - """ - Evaluate a model on the musDB test dataset - """ - # Separate musdb sources. - audio_output_directory = _separate_evaluation_dataset( - arguments, - mus_dir, - params) - # Compute metrics with musdb. - metrics_output_directory = _compute_musdb_metrics( - arguments, - mus_dir, - audio_output_directory) - # Compute and pretty print median metrics. - metrics = _compile_metrics(metrics_output_directory) - for instrument, metric in metrics.items(): - get_logger().info('%s:', instrument) - for metric, value in metric.items(): - get_logger().info('%s: %s', metric, f'{np.median(value):.3f}') - return metrics + logger.info('Model training done') @spleeter.commmand() def separate( + files: List[Path] = AudioInputArgument, adapter: str = AudioAdapterOption, bitrate: str = AudioBitrateOption, codec: Codec = AudioCodecOption, duration: float = AudioDurationOption, - files: List[Path] = AudioInputArgument, offset: float = AudioOffsetOption, output_path: Path = AudioAdapterOption, stft_backend: STFTBackend = AudioSTFTBackendOption, @@ -194,13 +102,7 @@ def separate( """ Separate audio file(s) """ - # TODO: try / catch or custom decorator for function handling. - # TODO: enable_logging() - # TODO: handle MWF - if verbose: - # TODO: enable_tensorflow_logging() - pass - # PREV: params = load_configuration(arguments.configuration) + configure_logger(verbose) audio_adapter: AudioAdapter = AudioAdapter.get(adapter) separator: Separator = Separator( params_filename, @@ -220,6 +122,102 @@ def separate( separator.join() +EVALUATION_SPLIT: str = 'test' +EVALUATION_METRICS_DIRECTORY: str = 'metrics' +EVALUATION_INSTRUMENTS: Container[str] = ('vocals', 'drums', 'bass', 'other') +EVALUATION_METRICS: Container[str] = ('SDR', 'SAR', 'SIR', 'ISR') +EVALUATION_MIXTURE: str = 'mixture.wav' +EVALUATION_AUDIO_DIRECTORY: str = 'audio' + + +def _compile_metrics(metrics_output_directory) -> Dict: + """ + Compiles metrics from given directory and returns results as dict. + + Parameters: + metrics_output_directory (str): + Directory to get metrics from. + + Returns: + Dict: + Compiled metrics as dict. + """ + songs = glob(join(metrics_output_directory, 'test/*.json')) + index = pd.MultiIndex.from_tuples( + product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS), + names=['instrument', 'metric']) + pd.DataFrame([], index=['config1', 'config2'], columns=index) + metrics = { + instrument: {k: [] for k in EVALUATION_METRICS} + for instrument in EVALUATION_INSTRUMENTS} + for song in songs: + with open(song, 'r') as stream: + data = json.load(stream) + for target in data['targets']: + instrument = target['name'] + for metric in EVALUATION_METRICS: + sdr_med = np.median([ + frame['metrics'][metric] + for frame in target['frames'] + if not np.isnan(frame['metrics'][metric])]) + metrics[instrument][metric].append(sdr_med) + return metrics + + +@spleeter.command() +def evaluate( + adapter: str = AudioAdapterOption, + output_path: Path = AudioAdapterOption, + stft_backend: STFTBackend = AudioSTFTBackendOption, + params_filename: str = ModelParametersOption, + mus_dir: Path = MUSDBDirectoryOption, + mwf: bool = MWFOption, + verbose: bool = VerboseOption) -> Dict: + """ + Evaluate a model on the musDB test dataset + """ + configure_logger(verbose) + try: + import musdb + import museval + except ImportError: + logger.error('Extra dependencies musdb and museval not found') + logger.error('Please install musdb and museval first, abort') + raise Exit(10) + # Separate musdb sources. + songs = glob(join(mus_dir, EVALUATION_SPLIT, '*/')) + mixtures = [join(song, EVALUATION_MIXTURE) for song in songs] + audio_output_directory = join(output_path, EVALUATION_AUDIO_DIRECTORY) + separate( + adapter=adapter, + params_filename=params_filename, + files=mixtures, + output_path=output_path, + filename_format='{foldername}/{instrument}.{codec}', + codec=Codec.WAV, + mwf=mwf, + verbose=verbose, + stft_backend=stft_backend) + # Compute metrics with musdb. + metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY) + logger.info('Starting musdb evaluation (this could be long) ...') + dataset = musdb.DB(root=mus_dir, is_wav=True, subsets=[EVALUATION_SPLIT]) + museval.eval_mus_dir( + dataset=dataset, + estimates_dir=audio_output_directory, + output_dir=metrics_output_directory) + logger.info('musdb evaluation done') + # Compute and pretty print median metrics. + metrics = _compile_metrics(metrics_output_directory) + for instrument, metric in metrics.items(): + logger.info(f'{instrument}:') + for metric, value in metric.items(): + logger.info(f'{metric}: {np.median(value):.3f}') + return metrics + + if __name__ == '__main__': - # TODO: warnings.filterwarnings('ignore') - spleeter() + try: + spleeter() + except SpleeterError as e: + logger.error(e) diff --git a/spleeter/commands/evaluate.py b/spleeter/commands/evaluate.py deleted file mode 100644 index d6cf7d6..0000000 --- a/spleeter/commands/evaluate.py +++ /dev/null @@ -1,166 +0,0 @@ -#!/usr/bin/env python -# coding: utf8 - -""" - Entrypoint provider for performing model evaluation. - - Evaluation is performed against musDB dataset. - - USAGE: python -m spleeter evaluate \ - -p /path/to/params \ - -o /path/to/output/dir \ - [-m] \ - --mus_dir /path/to/musdb dataset -""" - -import sys -import json - -from argparse import Namespace -from itertools import product -from glob import glob -from os.path import join, exists - -# pylint: disable=import-error -import numpy as np -import pandas as pd -# pylint: enable=import-error - -from .separate import entrypoint as separate_entrypoint -from ..utils.logging import get_logger - -try: - import musdb - import museval -except ImportError: - logger = get_logger() - logger.error('Extra dependencies musdb and museval not found') - logger.error('Please install musdb and museval first, abort') - sys.exit(1) - -__email__ = 'spleeter@deezer.com' -__author__ = 'Deezer Research' -__license__ = 'MIT License' - -_SPLIT = 'test' -_MIXTURE = 'mixture.wav' -_AUDIO_DIRECTORY = 'audio' -_METRICS_DIRECTORY = 'metrics' -_INSTRUMENTS = ('vocals', 'drums', 'bass', 'other') -_METRICS = ('SDR', 'SAR', 'SIR', 'ISR') - - -def _separate_evaluation_dataset(arguments, musdb_root_directory, params): - """ Performs audio separation on the musdb dataset from - the given directory and params. - - :param arguments: Entrypoint arguments. - :param musdb_root_directory: Directory to retrieve dataset from. - :param params: Spleeter configuration to apply to separation. - :returns: Separation output directory path. - """ - songs = glob(join(musdb_root_directory, _SPLIT, '*/')) - mixtures = [join(song, _MIXTURE) for song in songs] - audio_output_directory = join( - arguments.output_path, - _AUDIO_DIRECTORY) - separate_entrypoint( - Namespace( - audio_adapter=arguments.audio_adapter, - configuration=arguments.configuration, - inputs=mixtures, - output_path=join(audio_output_directory, _SPLIT), - filename_format='{foldername}/{instrument}.{codec}', - codec='wav', - duration=600., - offset=0., - bitrate='128k', - MWF=arguments.MWF, - verbose=arguments.verbose, - stft_backend=arguments.stft_backend), - params) - return audio_output_directory - - -def _compute_musdb_metrics( - arguments, - musdb_root_directory, - audio_output_directory): - """ Generates musdb metrics fro previsouly computed audio estimation. - - :param arguments: Entrypoint arguments. - :param audio_output_directory: Directory to get audio estimation from. - :returns: Path of generated metrics directory. - """ - metrics_output_directory = join( - arguments.output_path, - _METRICS_DIRECTORY) - get_logger().info('Starting musdb evaluation (this could be long) ...') - dataset = musdb.DB( - root=musdb_root_directory, - is_wav=True, - subsets=[_SPLIT]) - museval.eval_mus_dir( - dataset=dataset, - estimates_dir=audio_output_directory, - output_dir=metrics_output_directory) - get_logger().info('musdb evaluation done') - return metrics_output_directory - - -def _compile_metrics(metrics_output_directory): - """ Compiles metrics from given directory and returns - results as dict. - - :param metrics_output_directory: Directory to get metrics from. - :returns: Compiled metrics as dict. - """ - songs = glob(join(metrics_output_directory, 'test/*.json')) - index = pd.MultiIndex.from_tuples( - product(_INSTRUMENTS, _METRICS), - names=['instrument', 'metric']) - pd.DataFrame([], index=['config1', 'config2'], columns=index) - metrics = { - instrument: {k: [] for k in _METRICS} - for instrument in _INSTRUMENTS} - for song in songs: - with open(song, 'r') as stream: - data = json.load(stream) - for target in data['targets']: - instrument = target['name'] - for metric in _METRICS: - sdr_med = np.median([ - frame['metrics'][metric] - for frame in target['frames'] - if not np.isnan(frame['metrics'][metric])]) - metrics[instrument][metric].append(sdr_med) - return metrics - - -def entrypoint(arguments, params): - """ Command entrypoint. - - :param arguments: Command line parsed argument as argparse.Namespace. - :param params: Deserialized JSON configuration file provided in CLI args. - """ - # Parse and check musdb directory. - musdb_root_directory = arguments.mus_dir - if not exists(musdb_root_directory): - raise IOError(f'musdb directory {musdb_root_directory} not found') - # Separate musdb sources. - audio_output_directory = _separate_evaluation_dataset( - arguments, - musdb_root_directory, - params) - # Compute metrics with musdb. - metrics_output_directory = _compute_musdb_metrics( - arguments, - musdb_root_directory, - audio_output_directory) - # Compute and pretty print median metrics. - metrics = _compile_metrics(metrics_output_directory) - for instrument, metric in metrics.items(): - get_logger().info('%s:', instrument) - for metric, value in metric.items(): - get_logger().info('%s: %s', metric, f'{np.median(value):.3f}') - return metrics diff --git a/spleeter/commands/separate.py b/spleeter/commands/separate.py deleted file mode 100644 index 740deb4..0000000 --- a/spleeter/commands/separate.py +++ /dev/null @@ -1,51 +0,0 @@ -#!/usr/bin/env python -# coding: utf8 - -""" - Entrypoint provider for performing source separation. - - USAGE: python -m spleeter separate \ - -p /path/to/params \ - -i inputfile1 inputfile2 ... inputfilen - -o /path/to/output/dir \ - -i /path/to/audio1.wav /path/to/audio2.mp3 -""" - -from ..audio.adapter import get_audio_adapter -from ..separator import Separator - -__email__ = 'spleeter@deezer.com' -__author__ = 'Deezer Research' -__license__ = 'MIT License' - - -from typer import Option - -AudioAdapter = Option() - - -def entrypoint(arguments, params): - """ Command entrypoint. - - :param arguments: Command line parsed argument as argparse.Namespace. - :param params: Deserialized JSON configuration file provided in CLI args. - """ - # TODO: check with output naming. - audio_adapter = get_audio_adapter(arguments.audio_adapter) - separator = Separator( - arguments.configuration, - MWF=arguments.MWF, - stft_backend=arguments.stft_backend) - for filename in arguments.inputs: - separator.separate_to_file( - filename, - arguments.output_path, - audio_adapter=audio_adapter, - offset=arguments.offset, - duration=arguments.duration, - codec=arguments.codec, - bitrate=arguments.bitrate, - filename_format=arguments.filename_format, - synchronous=False - ) - separator.join() diff --git a/spleeter/commands/train.py b/spleeter/commands/train.py deleted file mode 100644 index 3bffaef..0000000 --- a/spleeter/commands/train.py +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env python -# coding: utf8 - -""" - Entrypoint provider for performing model training. - - USAGE: python -m spleeter train -p /path/to/params -""" - -from functools import partial - -# pylint: disable=import-error -import tensorflow as tf -# pylint: enable=import-error - -from ..audio.adapter import get_audio_adapter -from ..dataset import get_training_dataset, get_validation_dataset -from ..model import model_fn -from ..model.provider import ModelProvider -from ..utils.logging import get_logger - -__email__ = 'spleeter@deezer.com' -__author__ = 'Deezer Research' -__license__ = 'MIT License' - - -def _create_estimator(params): - """ Creates estimator. - - :param params: TF params to build estimator from. - :returns: Built estimator. - """ - session_config = tf.compat.v1.ConfigProto() - session_config.gpu_options.per_process_gpu_memory_fraction = 0.45 - estimator = tf.estimator.Estimator( - model_fn=model_fn, - model_dir=params['model_dir'], - params=params, - config=tf.estimator.RunConfig( - save_checkpoints_steps=params['save_checkpoints_steps'], - tf_random_seed=params['random_seed'], - save_summary_steps=params['save_summary_steps'], - session_config=session_config, - log_step_count_steps=10, - keep_checkpoint_max=2)) - return estimator - - -def _create_train_spec(params, audio_adapter, audio_path): - """ Creates train spec. - - :param params: TF params to build spec from. - :returns: Built train spec. - """ - input_fn = partial(get_training_dataset, params, audio_adapter, audio_path) - train_spec = tf.estimator.TrainSpec( - input_fn=input_fn, - max_steps=params['train_max_steps']) - return train_spec - - -def _create_evaluation_spec(params, audio_adapter, audio_path): - """ Setup eval spec evaluating ever n seconds - - :param params: TF params to build spec from. - :returns: Built evaluation spec. - """ - input_fn = partial( - get_validation_dataset, - params, - audio_adapter, - audio_path) - evaluation_spec = tf.estimator.EvalSpec( - input_fn=input_fn, - steps=None, - throttle_secs=params['throttle_secs']) - return evaluation_spec - - -def entrypoint(arguments, params): - """ Command entrypoint. - - :param arguments: Command line parsed argument as argparse.Namespace. - :param params: Deserialized JSON configuration file provided in CLI args. - """ - audio_adapter = get_audio_adapter(arguments.audio_adapter) - audio_path = arguments.audio_path - estimator = _create_estimator(params) - train_spec = _create_train_spec(params, audio_adapter, audio_path) - evaluation_spec = _create_evaluation_spec( - params, - audio_adapter, - audio_path) - get_logger().info('Start model training') - tf.estimator.train_and_evaluate( - estimator, - train_spec, - evaluation_spec) - ModelProvider.writeProbe(params['model_dir']) - get_logger().info('Model training done') diff --git a/spleeter/utils/logging.py b/spleeter/utils/logging.py index 6fee540..633549e 100644 --- a/spleeter/utils/logging.py +++ b/spleeter/utils/logging.py @@ -4,58 +4,41 @@ """ Centralized logging facilities for Spleeter. """ import logging +import warnings from os import environ +# pyright: reportMissingImports=false +# pylint: disable=import-error +from tensorflow.compat.v1 import logging as tflogging +# pylint: enable=import-error + __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' -_FORMAT = '%(levelname)s:%(name)s:%(message)s' + +formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s') +handler = logging.StreamHandler() +handler.setFormatter(formatter) +logger: logging.Logger = logging.getLogger('spleeter') +logger.addHandler(handler) +logger.setLevel(logging.INFO) -class _LoggerHolder(object): - """ Logger singleton instance holder. """ - - INSTANCE = None - - -def get_tensorflow_logger(): +def configure_logger(verbose: bool) -> None: """ + Configure application logger. + + Parameters: + verbose (bool): + `True` to use verbose logger, `False` otherwise. """ - # pylint: disable=import-error - from tensorflow.compat.v1 import logging - # pylint: enable=import-error - return logging - - -def get_logger(): - """ Returns library scoped logger. - - :returns: Library logger. - """ - if _LoggerHolder.INSTANCE is None: - formatter = logging.Formatter(_FORMAT) - handler = logging.StreamHandler() - handler.setFormatter(formatter) - logger = logging.getLogger('spleeter') - logger.addHandler(handler) - logger.setLevel(logging.INFO) - _LoggerHolder.INSTANCE = logger - return _LoggerHolder.INSTANCE - - -def enable_tensorflow_logging(): - """ Enable tensorflow logging. """ - environ['TF_CPP_MIN_LOG_LEVEL'] = '1' - tf_logger = get_tensorflow_logger() - tf_logger.set_verbosity(tf_logger.INFO) - logger = get_logger() - logger.setLevel(logging.DEBUG) - - -def enable_logging(): - """ Configure default logging. """ - environ['TF_CPP_MIN_LOG_LEVEL'] = '3' - tf_logger = get_tensorflow_logger() - tf_logger.set_verbosity(tf_logger.ERROR) + if verbose: + environ['TF_CPP_MIN_LOG_LEVEL'] = '1' + tflogging.set_verbosity(tflogging.INFO) + logger.setLevel(logging.DEBUG) + else: + warnings.filterwarnings('ignore') + environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + tflogging.set_verbosity(tflogging.ERROR)