mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
🐛 fix tensorflow logging
This commit is contained in:
@@ -7,6 +7,9 @@
|
|||||||
USAGE: python -m spleeter {train,evaluate,separate} ...
|
USAGE: python -m spleeter {train,evaluate,separate} ...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# NOTE: disable TF logging before import.
|
||||||
|
from .utils.logging import configure_logger, logger
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@@ -25,7 +28,6 @@ from .model import model_fn
|
|||||||
from .model.provider import ModelProvider
|
from .model.provider import ModelProvider
|
||||||
from .separator import Separator
|
from .separator import Separator
|
||||||
from .utils.configuration import load_configuration
|
from .utils.configuration import load_configuration
|
||||||
from .utils.logging import configure_logger, logger
|
|
||||||
|
|
||||||
# pyright: reportMissingImports=false
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
@@ -200,7 +202,7 @@ def evaluate(
|
|||||||
filename_format='{foldername}/{instrument}.{codec}',
|
filename_format='{foldername}/{instrument}.{codec}',
|
||||||
params_filename=params_filename,
|
params_filename=params_filename,
|
||||||
mwf=mwf,
|
mwf=mwf,
|
||||||
verbose=verbose,)
|
verbose=verbose)
|
||||||
# Compute metrics with musdb.
|
# Compute metrics with musdb.
|
||||||
metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY)
|
metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY)
|
||||||
logger.info('Starting musdb evaluation (this could be long) ...')
|
logger.info('Starting musdb evaluation (this could be long) ...')
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ from os import environ
|
|||||||
|
|
||||||
# pyright: reportMissingImports=false
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.compat.v1 import logging as tf_logging
|
from tensorflow.compat.v1 import logging as tf_logging
|
||||||
from typer import echo
|
from typer import echo
|
||||||
# pylint: enable=import-error
|
# pylint: enable=import-error
|
||||||
@@ -18,6 +20,8 @@ __email__ = 'spleeter@deezer.com'
|
|||||||
__author__ = 'Deezer Research'
|
__author__ = 'Deezer Research'
|
||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
|
|
||||||
|
|
||||||
class TyperLoggerHandler(logging.Handler):
|
class TyperLoggerHandler(logging.Handler):
|
||||||
""" A custom logger handler that use Typer echo. """
|
""" A custom logger handler that use Typer echo. """
|
||||||
@@ -42,13 +46,11 @@ def configure_logger(verbose: bool) -> None:
|
|||||||
verbose (bool):
|
verbose (bool):
|
||||||
`True` to use verbose logger, `False` otherwise.
|
`True` to use verbose logger, `False` otherwise.
|
||||||
"""
|
"""
|
||||||
tf_logger = tf_logging.get_logger()
|
tf_logger = tf.get_logger()
|
||||||
tf_logger.handlers = [handler]
|
tf_logger.handlers = [handler]
|
||||||
if verbose:
|
if verbose:
|
||||||
environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
|
||||||
tf_logging.set_verbosity(tf_logging.INFO)
|
tf_logging.set_verbosity(tf_logging.INFO)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
else:
|
else:
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|
||||||
tf_logging.set_verbosity(tf_logging.ERROR)
|
tf_logging.set_verbosity(tf_logging.ERROR)
|
||||||
|
|||||||
Reference in New Issue
Block a user