diff --git a/spleeter/__main__.py b/spleeter/__main__.py index c977cfe..d2ca490 100644 --- a/spleeter/__main__.py +++ b/spleeter/__main__.py @@ -7,6 +7,9 @@ USAGE: python -m spleeter {train,evaluate,separate} ... """ +# NOTE: disable TF logging before import. +from .utils.logging import configure_logger, logger + import json from functools import partial @@ -25,7 +28,6 @@ from .model import model_fn from .model.provider import ModelProvider from .separator import Separator from .utils.configuration import load_configuration -from .utils.logging import configure_logger, logger # pyright: reportMissingImports=false # pylint: disable=import-error @@ -200,7 +202,7 @@ def evaluate( filename_format='{foldername}/{instrument}.{codec}', params_filename=params_filename, mwf=mwf, - verbose=verbose,) + verbose=verbose) # Compute metrics with musdb. metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY) logger.info('Starting musdb evaluation (this could be long) ...') diff --git a/spleeter/utils/logging.py b/spleeter/utils/logging.py index e7503f6..1db3bd6 100644 --- a/spleeter/utils/logging.py +++ b/spleeter/utils/logging.py @@ -10,6 +10,8 @@ from os import environ # pyright: reportMissingImports=false # pylint: disable=import-error +import tensorflow as tf + from tensorflow.compat.v1 import logging as tf_logging from typer import echo # pylint: enable=import-error @@ -18,6 +20,8 @@ __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' +environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + class TyperLoggerHandler(logging.Handler): """ A custom logger handler that use Typer echo. """ @@ -42,13 +46,11 @@ def configure_logger(verbose: bool) -> None: verbose (bool): `True` to use verbose logger, `False` otherwise. """ - tf_logger = tf_logging.get_logger() + tf_logger = tf.get_logger() tf_logger.handlers = [handler] if verbose: - environ['TF_CPP_MIN_LOG_LEVEL'] = '1' tf_logging.set_verbosity(tf_logging.INFO) logger.setLevel(logging.DEBUG) else: warnings.filterwarnings('ignore') - environ['TF_CPP_MIN_LOG_LEVEL'] = '3' tf_logging.set_verbosity(tf_logging.ERROR)