mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-06 04:32:43 +00:00
🚧 typer integration
This commit is contained in:
@@ -1,13 +1,22 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf8
|
||||
|
||||
""" TO DOCUMENT """
|
||||
"""
|
||||
Python oneliner script usage.
|
||||
|
||||
USAGE: python -m spleeter {train,evaluate,separate} ...
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from functools import partial
|
||||
from itertools import product
|
||||
from glob import glob
|
||||
from os.path import join
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import Any, Container, Dict, List
|
||||
|
||||
from . import SpleeterError
|
||||
from .audio import Codec
|
||||
from .audio.adapter import AudioAdapter
|
||||
from .options import *
|
||||
@@ -16,11 +25,12 @@ from .model import model_fn
|
||||
from .model.provider import ModelProvider
|
||||
from .separator import Separator
|
||||
from .utils.configuration import load_configuration
|
||||
from .utils.logging import get_logger
|
||||
|
||||
from .utils.logging import configure_logger, logger
|
||||
|
||||
# pyright: reportMissingImports=false
|
||||
# pylint: disable=import-error
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow as tf
|
||||
|
||||
from typer import Exit, Typer
|
||||
@@ -39,8 +49,7 @@ def train(
|
||||
"""
|
||||
Train a source separation model
|
||||
"""
|
||||
# TODO: try / catch or custom decorator for function handling.
|
||||
# TODO: handle verbose flag ?
|
||||
configure_logger(verbose)
|
||||
audio_adapter = AudioAdapter.get(adapter)
|
||||
audio_path = str(data)
|
||||
params = load_configuration(params_filename)
|
||||
@@ -70,120 +79,19 @@ def train(
|
||||
input_fn=input_fn,
|
||||
steps=None,
|
||||
throttle_secs=params['throttle_secs'])
|
||||
get_logger().info('Start model training')
|
||||
logger.info('Start model training')
|
||||
tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec)
|
||||
ModelProvider.writeProbe(params['model_dir'])
|
||||
get_logger().info('Model training done')
|
||||
|
||||
_SPLIT = 'test'
|
||||
_MIXTURE = 'mixture.wav'
|
||||
_AUDIO_DIRECTORY = 'audio'
|
||||
_METRICS_DIRECTORY = 'metrics'
|
||||
_INSTRUMENTS = ('vocals', 'drums', 'bass', 'other')
|
||||
_METRICS = ('SDR', 'SAR', 'SIR', 'ISR')
|
||||
|
||||
|
||||
def _compute_musdb_metrics(
|
||||
arguments,
|
||||
musdb_root_directory,
|
||||
audio_output_directory):
|
||||
""" Generates musdb metrics fro previsouly computed audio estimation.
|
||||
|
||||
:param arguments: Entrypoint arguments.
|
||||
:param audio_output_directory: Directory to get audio estimation from.
|
||||
:returns: Path of generated metrics directory.
|
||||
"""
|
||||
metrics_output_directory = join(
|
||||
arguments.output_path,
|
||||
_METRICS_DIRECTORY)
|
||||
get_logger().info('Starting musdb evaluation (this could be long) ...')
|
||||
try:
|
||||
import musdb
|
||||
import museval
|
||||
except ImportError:
|
||||
logger = get_logger()
|
||||
logger.error('Extra dependencies musdb and museval not found')
|
||||
logger.error('Please install musdb and museval first, abort')
|
||||
raise Exit(10)
|
||||
dataset = musdb.DB(
|
||||
root=musdb_root_directory,
|
||||
is_wav=True,
|
||||
subsets=[_SPLIT])
|
||||
museval.eval_mus_dir(
|
||||
dataset=dataset,
|
||||
estimates_dir=audio_output_directory,
|
||||
output_dir=metrics_output_directory)
|
||||
get_logger().info('musdb evaluation done')
|
||||
return metrics_output_directory
|
||||
|
||||
|
||||
def _compile_metrics(metrics_output_directory):
|
||||
""" Compiles metrics from given directory and returns
|
||||
results as dict.
|
||||
|
||||
:param metrics_output_directory: Directory to get metrics from.
|
||||
:returns: Compiled metrics as dict.
|
||||
"""
|
||||
songs = glob(join(metrics_output_directory, 'test/*.json'))
|
||||
index = pd.MultiIndex.from_tuples(
|
||||
product(_INSTRUMENTS, _METRICS),
|
||||
names=['instrument', 'metric'])
|
||||
pd.DataFrame([], index=['config1', 'config2'], columns=index)
|
||||
metrics = {
|
||||
instrument: {k: [] for k in _METRICS}
|
||||
for instrument in _INSTRUMENTS}
|
||||
for song in songs:
|
||||
with open(song, 'r') as stream:
|
||||
data = json.load(stream)
|
||||
for target in data['targets']:
|
||||
instrument = target['name']
|
||||
for metric in _METRICS:
|
||||
sdr_med = np.median([
|
||||
frame['metrics'][metric]
|
||||
for frame in target['frames']
|
||||
if not np.isnan(frame['metrics'][metric])])
|
||||
metrics[instrument][metric].append(sdr_med)
|
||||
return metrics
|
||||
|
||||
|
||||
@spleeter.command()
|
||||
def evaluate(
|
||||
adapter: str = AudioAdapterOption,
|
||||
output_path: Path = AudioAdapterOption,
|
||||
stft_backend: STFTBackend = AudioSTFTBackendOption,
|
||||
params_filename: str = ModelParametersOption,
|
||||
mus_dir: Path = MUSDBDirectoryOption,
|
||||
mwf: bool = MWFOption,
|
||||
verbose: bool = VerboseOption) -> None:
|
||||
"""
|
||||
Evaluate a model on the musDB test dataset
|
||||
"""
|
||||
# Separate musdb sources.
|
||||
audio_output_directory = _separate_evaluation_dataset(
|
||||
arguments,
|
||||
mus_dir,
|
||||
params)
|
||||
# Compute metrics with musdb.
|
||||
metrics_output_directory = _compute_musdb_metrics(
|
||||
arguments,
|
||||
mus_dir,
|
||||
audio_output_directory)
|
||||
# Compute and pretty print median metrics.
|
||||
metrics = _compile_metrics(metrics_output_directory)
|
||||
for instrument, metric in metrics.items():
|
||||
get_logger().info('%s:', instrument)
|
||||
for metric, value in metric.items():
|
||||
get_logger().info('%s: %s', metric, f'{np.median(value):.3f}')
|
||||
return metrics
|
||||
logger.info('Model training done')
|
||||
|
||||
|
||||
@spleeter.commmand()
|
||||
def separate(
|
||||
files: List[Path] = AudioInputArgument,
|
||||
adapter: str = AudioAdapterOption,
|
||||
bitrate: str = AudioBitrateOption,
|
||||
codec: Codec = AudioCodecOption,
|
||||
duration: float = AudioDurationOption,
|
||||
files: List[Path] = AudioInputArgument,
|
||||
offset: float = AudioOffsetOption,
|
||||
output_path: Path = AudioAdapterOption,
|
||||
stft_backend: STFTBackend = AudioSTFTBackendOption,
|
||||
@@ -194,13 +102,7 @@ def separate(
|
||||
"""
|
||||
Separate audio file(s)
|
||||
"""
|
||||
# TODO: try / catch or custom decorator for function handling.
|
||||
# TODO: enable_logging()
|
||||
# TODO: handle MWF
|
||||
if verbose:
|
||||
# TODO: enable_tensorflow_logging()
|
||||
pass
|
||||
# PREV: params = load_configuration(arguments.configuration)
|
||||
configure_logger(verbose)
|
||||
audio_adapter: AudioAdapter = AudioAdapter.get(adapter)
|
||||
separator: Separator = Separator(
|
||||
params_filename,
|
||||
@@ -220,6 +122,102 @@ def separate(
|
||||
separator.join()
|
||||
|
||||
|
||||
EVALUATION_SPLIT: str = 'test'
|
||||
EVALUATION_METRICS_DIRECTORY: str = 'metrics'
|
||||
EVALUATION_INSTRUMENTS: Container[str] = ('vocals', 'drums', 'bass', 'other')
|
||||
EVALUATION_METRICS: Container[str] = ('SDR', 'SAR', 'SIR', 'ISR')
|
||||
EVALUATION_MIXTURE: str = 'mixture.wav'
|
||||
EVALUATION_AUDIO_DIRECTORY: str = 'audio'
|
||||
|
||||
|
||||
def _compile_metrics(metrics_output_directory) -> Dict:
|
||||
"""
|
||||
Compiles metrics from given directory and returns results as dict.
|
||||
|
||||
Parameters:
|
||||
metrics_output_directory (str):
|
||||
Directory to get metrics from.
|
||||
|
||||
Returns:
|
||||
Dict:
|
||||
Compiled metrics as dict.
|
||||
"""
|
||||
songs = glob(join(metrics_output_directory, 'test/*.json'))
|
||||
index = pd.MultiIndex.from_tuples(
|
||||
product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS),
|
||||
names=['instrument', 'metric'])
|
||||
pd.DataFrame([], index=['config1', 'config2'], columns=index)
|
||||
metrics = {
|
||||
instrument: {k: [] for k in EVALUATION_METRICS}
|
||||
for instrument in EVALUATION_INSTRUMENTS}
|
||||
for song in songs:
|
||||
with open(song, 'r') as stream:
|
||||
data = json.load(stream)
|
||||
for target in data['targets']:
|
||||
instrument = target['name']
|
||||
for metric in EVALUATION_METRICS:
|
||||
sdr_med = np.median([
|
||||
frame['metrics'][metric]
|
||||
for frame in target['frames']
|
||||
if not np.isnan(frame['metrics'][metric])])
|
||||
metrics[instrument][metric].append(sdr_med)
|
||||
return metrics
|
||||
|
||||
|
||||
@spleeter.command()
|
||||
def evaluate(
|
||||
adapter: str = AudioAdapterOption,
|
||||
output_path: Path = AudioAdapterOption,
|
||||
stft_backend: STFTBackend = AudioSTFTBackendOption,
|
||||
params_filename: str = ModelParametersOption,
|
||||
mus_dir: Path = MUSDBDirectoryOption,
|
||||
mwf: bool = MWFOption,
|
||||
verbose: bool = VerboseOption) -> Dict:
|
||||
"""
|
||||
Evaluate a model on the musDB test dataset
|
||||
"""
|
||||
configure_logger(verbose)
|
||||
try:
|
||||
import musdb
|
||||
import museval
|
||||
except ImportError:
|
||||
logger.error('Extra dependencies musdb and museval not found')
|
||||
logger.error('Please install musdb and museval first, abort')
|
||||
raise Exit(10)
|
||||
# Separate musdb sources.
|
||||
songs = glob(join(mus_dir, EVALUATION_SPLIT, '*/'))
|
||||
mixtures = [join(song, EVALUATION_MIXTURE) for song in songs]
|
||||
audio_output_directory = join(output_path, EVALUATION_AUDIO_DIRECTORY)
|
||||
separate(
|
||||
adapter=adapter,
|
||||
params_filename=params_filename,
|
||||
files=mixtures,
|
||||
output_path=output_path,
|
||||
filename_format='{foldername}/{instrument}.{codec}',
|
||||
codec=Codec.WAV,
|
||||
mwf=mwf,
|
||||
verbose=verbose,
|
||||
stft_backend=stft_backend)
|
||||
# Compute metrics with musdb.
|
||||
metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY)
|
||||
logger.info('Starting musdb evaluation (this could be long) ...')
|
||||
dataset = musdb.DB(root=mus_dir, is_wav=True, subsets=[EVALUATION_SPLIT])
|
||||
museval.eval_mus_dir(
|
||||
dataset=dataset,
|
||||
estimates_dir=audio_output_directory,
|
||||
output_dir=metrics_output_directory)
|
||||
logger.info('musdb evaluation done')
|
||||
# Compute and pretty print median metrics.
|
||||
metrics = _compile_metrics(metrics_output_directory)
|
||||
for instrument, metric in metrics.items():
|
||||
logger.info(f'{instrument}:')
|
||||
for metric, value in metric.items():
|
||||
logger.info(f'{metric}: {np.median(value):.3f}')
|
||||
return metrics
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# TODO: warnings.filterwarnings('ignore')
|
||||
spleeter()
|
||||
try:
|
||||
spleeter()
|
||||
except SpleeterError as e:
|
||||
logger.error(e)
|
||||
|
||||
@@ -1,166 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf8
|
||||
|
||||
"""
|
||||
Entrypoint provider for performing model evaluation.
|
||||
|
||||
Evaluation is performed against musDB dataset.
|
||||
|
||||
USAGE: python -m spleeter evaluate \
|
||||
-p /path/to/params \
|
||||
-o /path/to/output/dir \
|
||||
[-m] \
|
||||
--mus_dir /path/to/musdb dataset
|
||||
"""
|
||||
|
||||
import sys
|
||||
import json
|
||||
|
||||
from argparse import Namespace
|
||||
from itertools import product
|
||||
from glob import glob
|
||||
from os.path import join, exists
|
||||
|
||||
# pylint: disable=import-error
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
# pylint: enable=import-error
|
||||
|
||||
from .separate import entrypoint as separate_entrypoint
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
try:
|
||||
import musdb
|
||||
import museval
|
||||
except ImportError:
|
||||
logger = get_logger()
|
||||
logger.error('Extra dependencies musdb and museval not found')
|
||||
logger.error('Please install musdb and museval first, abort')
|
||||
sys.exit(1)
|
||||
|
||||
__email__ = 'spleeter@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
_SPLIT = 'test'
|
||||
_MIXTURE = 'mixture.wav'
|
||||
_AUDIO_DIRECTORY = 'audio'
|
||||
_METRICS_DIRECTORY = 'metrics'
|
||||
_INSTRUMENTS = ('vocals', 'drums', 'bass', 'other')
|
||||
_METRICS = ('SDR', 'SAR', 'SIR', 'ISR')
|
||||
|
||||
|
||||
def _separate_evaluation_dataset(arguments, musdb_root_directory, params):
|
||||
""" Performs audio separation on the musdb dataset from
|
||||
the given directory and params.
|
||||
|
||||
:param arguments: Entrypoint arguments.
|
||||
:param musdb_root_directory: Directory to retrieve dataset from.
|
||||
:param params: Spleeter configuration to apply to separation.
|
||||
:returns: Separation output directory path.
|
||||
"""
|
||||
songs = glob(join(musdb_root_directory, _SPLIT, '*/'))
|
||||
mixtures = [join(song, _MIXTURE) for song in songs]
|
||||
audio_output_directory = join(
|
||||
arguments.output_path,
|
||||
_AUDIO_DIRECTORY)
|
||||
separate_entrypoint(
|
||||
Namespace(
|
||||
audio_adapter=arguments.audio_adapter,
|
||||
configuration=arguments.configuration,
|
||||
inputs=mixtures,
|
||||
output_path=join(audio_output_directory, _SPLIT),
|
||||
filename_format='{foldername}/{instrument}.{codec}',
|
||||
codec='wav',
|
||||
duration=600.,
|
||||
offset=0.,
|
||||
bitrate='128k',
|
||||
MWF=arguments.MWF,
|
||||
verbose=arguments.verbose,
|
||||
stft_backend=arguments.stft_backend),
|
||||
params)
|
||||
return audio_output_directory
|
||||
|
||||
|
||||
def _compute_musdb_metrics(
|
||||
arguments,
|
||||
musdb_root_directory,
|
||||
audio_output_directory):
|
||||
""" Generates musdb metrics fro previsouly computed audio estimation.
|
||||
|
||||
:param arguments: Entrypoint arguments.
|
||||
:param audio_output_directory: Directory to get audio estimation from.
|
||||
:returns: Path of generated metrics directory.
|
||||
"""
|
||||
metrics_output_directory = join(
|
||||
arguments.output_path,
|
||||
_METRICS_DIRECTORY)
|
||||
get_logger().info('Starting musdb evaluation (this could be long) ...')
|
||||
dataset = musdb.DB(
|
||||
root=musdb_root_directory,
|
||||
is_wav=True,
|
||||
subsets=[_SPLIT])
|
||||
museval.eval_mus_dir(
|
||||
dataset=dataset,
|
||||
estimates_dir=audio_output_directory,
|
||||
output_dir=metrics_output_directory)
|
||||
get_logger().info('musdb evaluation done')
|
||||
return metrics_output_directory
|
||||
|
||||
|
||||
def _compile_metrics(metrics_output_directory):
|
||||
""" Compiles metrics from given directory and returns
|
||||
results as dict.
|
||||
|
||||
:param metrics_output_directory: Directory to get metrics from.
|
||||
:returns: Compiled metrics as dict.
|
||||
"""
|
||||
songs = glob(join(metrics_output_directory, 'test/*.json'))
|
||||
index = pd.MultiIndex.from_tuples(
|
||||
product(_INSTRUMENTS, _METRICS),
|
||||
names=['instrument', 'metric'])
|
||||
pd.DataFrame([], index=['config1', 'config2'], columns=index)
|
||||
metrics = {
|
||||
instrument: {k: [] for k in _METRICS}
|
||||
for instrument in _INSTRUMENTS}
|
||||
for song in songs:
|
||||
with open(song, 'r') as stream:
|
||||
data = json.load(stream)
|
||||
for target in data['targets']:
|
||||
instrument = target['name']
|
||||
for metric in _METRICS:
|
||||
sdr_med = np.median([
|
||||
frame['metrics'][metric]
|
||||
for frame in target['frames']
|
||||
if not np.isnan(frame['metrics'][metric])])
|
||||
metrics[instrument][metric].append(sdr_med)
|
||||
return metrics
|
||||
|
||||
|
||||
def entrypoint(arguments, params):
|
||||
""" Command entrypoint.
|
||||
|
||||
:param arguments: Command line parsed argument as argparse.Namespace.
|
||||
:param params: Deserialized JSON configuration file provided in CLI args.
|
||||
"""
|
||||
# Parse and check musdb directory.
|
||||
musdb_root_directory = arguments.mus_dir
|
||||
if not exists(musdb_root_directory):
|
||||
raise IOError(f'musdb directory {musdb_root_directory} not found')
|
||||
# Separate musdb sources.
|
||||
audio_output_directory = _separate_evaluation_dataset(
|
||||
arguments,
|
||||
musdb_root_directory,
|
||||
params)
|
||||
# Compute metrics with musdb.
|
||||
metrics_output_directory = _compute_musdb_metrics(
|
||||
arguments,
|
||||
musdb_root_directory,
|
||||
audio_output_directory)
|
||||
# Compute and pretty print median metrics.
|
||||
metrics = _compile_metrics(metrics_output_directory)
|
||||
for instrument, metric in metrics.items():
|
||||
get_logger().info('%s:', instrument)
|
||||
for metric, value in metric.items():
|
||||
get_logger().info('%s: %s', metric, f'{np.median(value):.3f}')
|
||||
return metrics
|
||||
@@ -1,51 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf8
|
||||
|
||||
"""
|
||||
Entrypoint provider for performing source separation.
|
||||
|
||||
USAGE: python -m spleeter separate \
|
||||
-p /path/to/params \
|
||||
-i inputfile1 inputfile2 ... inputfilen
|
||||
-o /path/to/output/dir \
|
||||
-i /path/to/audio1.wav /path/to/audio2.mp3
|
||||
"""
|
||||
|
||||
from ..audio.adapter import get_audio_adapter
|
||||
from ..separator import Separator
|
||||
|
||||
__email__ = 'spleeter@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
|
||||
from typer import Option
|
||||
|
||||
AudioAdapter = Option()
|
||||
|
||||
|
||||
def entrypoint(arguments, params):
|
||||
""" Command entrypoint.
|
||||
|
||||
:param arguments: Command line parsed argument as argparse.Namespace.
|
||||
:param params: Deserialized JSON configuration file provided in CLI args.
|
||||
"""
|
||||
# TODO: check with output naming.
|
||||
audio_adapter = get_audio_adapter(arguments.audio_adapter)
|
||||
separator = Separator(
|
||||
arguments.configuration,
|
||||
MWF=arguments.MWF,
|
||||
stft_backend=arguments.stft_backend)
|
||||
for filename in arguments.inputs:
|
||||
separator.separate_to_file(
|
||||
filename,
|
||||
arguments.output_path,
|
||||
audio_adapter=audio_adapter,
|
||||
offset=arguments.offset,
|
||||
duration=arguments.duration,
|
||||
codec=arguments.codec,
|
||||
bitrate=arguments.bitrate,
|
||||
filename_format=arguments.filename_format,
|
||||
synchronous=False
|
||||
)
|
||||
separator.join()
|
||||
@@ -1,100 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf8
|
||||
|
||||
"""
|
||||
Entrypoint provider for performing model training.
|
||||
|
||||
USAGE: python -m spleeter train -p /path/to/params
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
|
||||
# pylint: disable=import-error
|
||||
import tensorflow as tf
|
||||
# pylint: enable=import-error
|
||||
|
||||
from ..audio.adapter import get_audio_adapter
|
||||
from ..dataset import get_training_dataset, get_validation_dataset
|
||||
from ..model import model_fn
|
||||
from ..model.provider import ModelProvider
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
__email__ = 'spleeter@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
|
||||
def _create_estimator(params):
|
||||
""" Creates estimator.
|
||||
|
||||
:param params: TF params to build estimator from.
|
||||
:returns: Built estimator.
|
||||
"""
|
||||
session_config = tf.compat.v1.ConfigProto()
|
||||
session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
|
||||
estimator = tf.estimator.Estimator(
|
||||
model_fn=model_fn,
|
||||
model_dir=params['model_dir'],
|
||||
params=params,
|
||||
config=tf.estimator.RunConfig(
|
||||
save_checkpoints_steps=params['save_checkpoints_steps'],
|
||||
tf_random_seed=params['random_seed'],
|
||||
save_summary_steps=params['save_summary_steps'],
|
||||
session_config=session_config,
|
||||
log_step_count_steps=10,
|
||||
keep_checkpoint_max=2))
|
||||
return estimator
|
||||
|
||||
|
||||
def _create_train_spec(params, audio_adapter, audio_path):
|
||||
""" Creates train spec.
|
||||
|
||||
:param params: TF params to build spec from.
|
||||
:returns: Built train spec.
|
||||
"""
|
||||
input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
|
||||
train_spec = tf.estimator.TrainSpec(
|
||||
input_fn=input_fn,
|
||||
max_steps=params['train_max_steps'])
|
||||
return train_spec
|
||||
|
||||
|
||||
def _create_evaluation_spec(params, audio_adapter, audio_path):
|
||||
""" Setup eval spec evaluating ever n seconds
|
||||
|
||||
:param params: TF params to build spec from.
|
||||
:returns: Built evaluation spec.
|
||||
"""
|
||||
input_fn = partial(
|
||||
get_validation_dataset,
|
||||
params,
|
||||
audio_adapter,
|
||||
audio_path)
|
||||
evaluation_spec = tf.estimator.EvalSpec(
|
||||
input_fn=input_fn,
|
||||
steps=None,
|
||||
throttle_secs=params['throttle_secs'])
|
||||
return evaluation_spec
|
||||
|
||||
|
||||
def entrypoint(arguments, params):
|
||||
""" Command entrypoint.
|
||||
|
||||
:param arguments: Command line parsed argument as argparse.Namespace.
|
||||
:param params: Deserialized JSON configuration file provided in CLI args.
|
||||
"""
|
||||
audio_adapter = get_audio_adapter(arguments.audio_adapter)
|
||||
audio_path = arguments.audio_path
|
||||
estimator = _create_estimator(params)
|
||||
train_spec = _create_train_spec(params, audio_adapter, audio_path)
|
||||
evaluation_spec = _create_evaluation_spec(
|
||||
params,
|
||||
audio_adapter,
|
||||
audio_path)
|
||||
get_logger().info('Start model training')
|
||||
tf.estimator.train_and_evaluate(
|
||||
estimator,
|
||||
train_spec,
|
||||
evaluation_spec)
|
||||
ModelProvider.writeProbe(params['model_dir'])
|
||||
get_logger().info('Model training done')
|
||||
@@ -4,58 +4,41 @@
|
||||
""" Centralized logging facilities for Spleeter. """
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
from os import environ
|
||||
|
||||
# pyright: reportMissingImports=false
|
||||
# pylint: disable=import-error
|
||||
from tensorflow.compat.v1 import logging as tflogging
|
||||
# pylint: enable=import-error
|
||||
|
||||
__email__ = 'spleeter@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
_FORMAT = '%(levelname)s:%(name)s:%(message)s'
|
||||
|
||||
formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
logger: logging.Logger = logging.getLogger('spleeter')
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class _LoggerHolder(object):
|
||||
""" Logger singleton instance holder. """
|
||||
|
||||
INSTANCE = None
|
||||
|
||||
|
||||
def get_tensorflow_logger():
|
||||
def configure_logger(verbose: bool) -> None:
|
||||
"""
|
||||
Configure application logger.
|
||||
|
||||
Parameters:
|
||||
verbose (bool):
|
||||
`True` to use verbose logger, `False` otherwise.
|
||||
"""
|
||||
# pylint: disable=import-error
|
||||
from tensorflow.compat.v1 import logging
|
||||
# pylint: enable=import-error
|
||||
return logging
|
||||
|
||||
|
||||
def get_logger():
|
||||
""" Returns library scoped logger.
|
||||
|
||||
:returns: Library logger.
|
||||
"""
|
||||
if _LoggerHolder.INSTANCE is None:
|
||||
formatter = logging.Formatter(_FORMAT)
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
logger = logging.getLogger('spleeter')
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
_LoggerHolder.INSTANCE = logger
|
||||
return _LoggerHolder.INSTANCE
|
||||
|
||||
|
||||
def enable_tensorflow_logging():
|
||||
""" Enable tensorflow logging. """
|
||||
environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||
tf_logger = get_tensorflow_logger()
|
||||
tf_logger.set_verbosity(tf_logger.INFO)
|
||||
logger = get_logger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def enable_logging():
|
||||
""" Configure default logging. """
|
||||
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
tf_logger = get_tensorflow_logger()
|
||||
tf_logger.set_verbosity(tf_logger.ERROR)
|
||||
if verbose:
|
||||
environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||
tflogging.set_verbosity(tflogging.INFO)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
else:
|
||||
warnings.filterwarnings('ignore')
|
||||
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
tflogging.set_verbosity(tflogging.ERROR)
|
||||
|
||||
Reference in New Issue
Block a user