From a41b8f35a693d610c169c64551686cecc6a2b509 Mon Sep 17 00:00:00 2001 From: Faylixe Date: Fri, 18 Dec 2020 17:31:16 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20=20fix=20heavy=20import=20?= =?UTF-8?q?on=20cli=20parsing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- spleeter/__main__.py | 38 +++++++++++++++++++++++--------------- spleeter/audio/__init__.py | 11 ++++++----- spleeter/options.py | 7 +++---- 3 files changed, 32 insertions(+), 24 deletions(-) diff --git a/spleeter/__main__.py b/spleeter/__main__.py index 1b34680..0c71e09 100644 --- a/spleeter/__main__.py +++ b/spleeter/__main__.py @@ -5,10 +5,12 @@ Python oneliner script usage. USAGE: python -m spleeter {train,evaluate,separate} ... -""" -# NOTE: disable TF logging before import. -from .utils.logging import configure_logger, logger + Notes: + All critical import involving TF, numpy or Pandas are deported to + command function scope to avoid heavy import on CLI evaluation, + leading to large bootstraping time. +""" import json @@ -17,24 +19,14 @@ from itertools import product from glob import glob from os.path import join from pathlib import Path -from typing import Any, Container, Dict, List +from typing import Container, Dict, List from . import SpleeterError -from .audio import Codec -from .audio.adapter import AudioAdapter from .options import * -from .dataset import get_training_dataset, get_validation_dataset -from .model import model_fn -from .model.provider import ModelProvider -from .separator import Separator -from .utils.configuration import load_configuration +from .utils.logging import configure_logger, logger # pyright: reportMissingImports=false # pylint: disable=import-error -import numpy as np -import pandas as pd -import tensorflow as tf - from typer import Exit, Typer # pylint: enable=import-error @@ -51,6 +43,14 @@ def train( """ Train a source separation model """ + from .audio.adapter import AudioAdapter + from .dataset import get_training_dataset, get_validation_dataset + from .model import model_fn + from .model.provider import ModelProvider + from .utils.configuration import load_configuration + + import tensorflow as tf + configure_logger(verbose) audio_adapter = AudioAdapter.get(adapter) audio_path = str(data) @@ -104,6 +104,9 @@ def separate( """ Separate audio file(s) """ + from .audio.adapter import AudioAdapter + from .separator import Separator + configure_logger(verbose) audio_adapter: AudioAdapter = AudioAdapter.get(adapter) separator: Separator = Separator( @@ -144,6 +147,9 @@ def _compile_metrics(metrics_output_directory) -> Dict: Dict: Compiled metrics as dict. """ + import pandas as pd + import numpy as np + songs = glob(join(metrics_output_directory, 'test/*.json')) index = pd.MultiIndex.from_tuples( product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS), @@ -178,6 +184,8 @@ def evaluate( """ Evaluate a model on the musDB test dataset """ + import numpy as np + configure_logger(verbose) try: import musdb diff --git a/spleeter/audio/__init__.py b/spleeter/audio/__init__.py index c93b663..64b5b53 100644 --- a/spleeter/audio/__init__.py +++ b/spleeter/audio/__init__.py @@ -12,11 +12,6 @@ from enum import Enum -# pyright: reportMissingImports=false -# pylint: disable=import-error -import tensorflow as tf -# pylint: enable=import-error - __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' @@ -42,6 +37,12 @@ class STFTBackend(str, Enum): @classmethod def resolve(cls: type, backend: str) -> str: + # NOTE: import is resolved here to avoid performance issues on command + # evaluation. + # pyright: reportMissingImports=false + # pylint: disable=import-error + import tensorflow as tf + if backend not in cls.__members__.values(): raise ValueError(f'Unsupported backend {backend}') if backend == cls.AUTO: diff --git a/spleeter/options.py b/spleeter/options.py index 5b9547d..4660e19 100644 --- a/spleeter/options.py +++ b/spleeter/options.py @@ -6,11 +6,10 @@ from tempfile import gettempdir from os.path import join -from .separator import STFTBackend -from .audio import Codec +from .audio import Codec, STFTBackend -from typer import Argument, Option -from typer.models import ArgumentInfo, OptionInfo +from typer import Option +from typer.models import OptionInfo __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research'