diff --git a/spleeter/utils/logging.py b/spleeter/utils/logging.py index 633549e..bb1df18 100644 --- a/spleeter/utils/logging.py +++ b/spleeter/utils/logging.py @@ -11,6 +11,7 @@ from os import environ # pyright: reportMissingImports=false # pylint: disable=import-error from tensorflow.compat.v1 import logging as tflogging +from typer import echo # pylint: enable=import-error __email__ = 'spleeter@deezer.com' @@ -18,8 +19,15 @@ __author__ = 'Deezer Research' __license__ = 'MIT License' +class TyperLoggerHandler(logging.Handler): + """ A custom logger handler that use Typer echo. """ + + def emit(self, record: logging.LogRecord) -> None: + echo(self.format(record)) + + formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s') -handler = logging.StreamHandler() +handler = TyperLoggerHandler() handler.setFormatter(formatter) logger: logging.Logger = logging.getLogger('spleeter') logger.addHandler(handler) diff --git a/tests/test_train.py b/tests/test_train.py index 7c4bc48..dd6f306 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -101,6 +101,7 @@ def test_train(): '-p', 'useless_config.json', '-d', path ]) + raise IOError(f'STDOUT: {result.stdout}, STDERR: {result.stderr}') assert result.exit_code == 0 # assert that model checkpoint was created. assert os.path.exists(join(path, 'model', 'model.ckpt-10.index'))