mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
✨ add dedicated Typer logger
This commit is contained in:
@@ -11,6 +11,7 @@ from os import environ
|
|||||||
# pyright: reportMissingImports=false
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
from tensorflow.compat.v1 import logging as tflogging
|
from tensorflow.compat.v1 import logging as tflogging
|
||||||
|
from typer import echo
|
||||||
# pylint: enable=import-error
|
# pylint: enable=import-error
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
__email__ = 'spleeter@deezer.com'
|
||||||
@@ -18,8 +19,15 @@ __author__ = 'Deezer Research'
|
|||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
class TyperLoggerHandler(logging.Handler):
|
||||||
|
""" A custom logger handler that use Typer echo. """
|
||||||
|
|
||||||
|
def emit(self, record: logging.LogRecord) -> None:
|
||||||
|
echo(self.format(record))
|
||||||
|
|
||||||
|
|
||||||
formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')
|
formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')
|
||||||
handler = logging.StreamHandler()
|
handler = TyperLoggerHandler()
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
logger: logging.Logger = logging.getLogger('spleeter')
|
logger: logging.Logger = logging.getLogger('spleeter')
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
|
|||||||
@@ -101,6 +101,7 @@ def test_train():
|
|||||||
'-p', 'useless_config.json',
|
'-p', 'useless_config.json',
|
||||||
'-d', path
|
'-d', path
|
||||||
])
|
])
|
||||||
|
raise IOError(f'STDOUT: {result.stdout}, STDERR: {result.stderr}')
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
# assert that model checkpoint was created.
|
# assert that model checkpoint was created.
|
||||||
assert os.path.exists(join(path, 'model', 'model.ckpt-10.index'))
|
assert os.path.exists(join(path, 'model', 'model.ckpt-10.index'))
|
||||||
|
|||||||
Reference in New Issue
Block a user