️ fix heavy import on cli parsing

This commit is contained in:
Faylixe
2020-12-18 17:31:16 +01:00
parent 0428a0bde8
commit a41b8f35a6
3 changed files with 32 additions and 24 deletions

View File

@@ -5,10 +5,12 @@
Python oneliner script usage. Python oneliner script usage.
USAGE: python -m spleeter {train,evaluate,separate} ... USAGE: python -m spleeter {train,evaluate,separate} ...
"""
# NOTE: disable TF logging before import. Notes:
from .utils.logging import configure_logger, logger All critical import involving TF, numpy or Pandas are deported to
command function scope to avoid heavy import on CLI evaluation,
leading to large bootstraping time.
"""
import json import json
@@ -17,24 +19,14 @@ from itertools import product
from glob import glob from glob import glob
from os.path import join from os.path import join
from pathlib import Path from pathlib import Path
from typing import Any, Container, Dict, List from typing import Container, Dict, List
from . import SpleeterError from . import SpleeterError
from .audio import Codec
from .audio.adapter import AudioAdapter
from .options import * from .options import *
from .dataset import get_training_dataset, get_validation_dataset from .utils.logging import configure_logger, logger
from .model import model_fn
from .model.provider import ModelProvider
from .separator import Separator
from .utils.configuration import load_configuration
# pyright: reportMissingImports=false # pyright: reportMissingImports=false
# pylint: disable=import-error # pylint: disable=import-error
import numpy as np
import pandas as pd
import tensorflow as tf
from typer import Exit, Typer from typer import Exit, Typer
# pylint: enable=import-error # pylint: enable=import-error
@@ -51,6 +43,14 @@ def train(
""" """
Train a source separation model Train a source separation model
""" """
from .audio.adapter import AudioAdapter
from .dataset import get_training_dataset, get_validation_dataset
from .model import model_fn
from .model.provider import ModelProvider
from .utils.configuration import load_configuration
import tensorflow as tf
configure_logger(verbose) configure_logger(verbose)
audio_adapter = AudioAdapter.get(adapter) audio_adapter = AudioAdapter.get(adapter)
audio_path = str(data) audio_path = str(data)
@@ -104,6 +104,9 @@ def separate(
""" """
Separate audio file(s) Separate audio file(s)
""" """
from .audio.adapter import AudioAdapter
from .separator import Separator
configure_logger(verbose) configure_logger(verbose)
audio_adapter: AudioAdapter = AudioAdapter.get(adapter) audio_adapter: AudioAdapter = AudioAdapter.get(adapter)
separator: Separator = Separator( separator: Separator = Separator(
@@ -144,6 +147,9 @@ def _compile_metrics(metrics_output_directory) -> Dict:
Dict: Dict:
Compiled metrics as dict. Compiled metrics as dict.
""" """
import pandas as pd
import numpy as np
songs = glob(join(metrics_output_directory, 'test/*.json')) songs = glob(join(metrics_output_directory, 'test/*.json'))
index = pd.MultiIndex.from_tuples( index = pd.MultiIndex.from_tuples(
product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS), product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS),
@@ -178,6 +184,8 @@ def evaluate(
""" """
Evaluate a model on the musDB test dataset Evaluate a model on the musDB test dataset
""" """
import numpy as np
configure_logger(verbose) configure_logger(verbose)
try: try:
import musdb import musdb

View File

@@ -12,11 +12,6 @@
from enum import Enum from enum import Enum
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com' __email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research' __author__ = 'Deezer Research'
__license__ = 'MIT License' __license__ = 'MIT License'
@@ -42,6 +37,12 @@ class STFTBackend(str, Enum):
@classmethod @classmethod
def resolve(cls: type, backend: str) -> str: def resolve(cls: type, backend: str) -> str:
# NOTE: import is resolved here to avoid performance issues on command
# evaluation.
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
if backend not in cls.__members__.values(): if backend not in cls.__members__.values():
raise ValueError(f'Unsupported backend {backend}') raise ValueError(f'Unsupported backend {backend}')
if backend == cls.AUTO: if backend == cls.AUTO:

View File

@@ -6,11 +6,10 @@
from tempfile import gettempdir from tempfile import gettempdir
from os.path import join from os.path import join
from .separator import STFTBackend from .audio import Codec, STFTBackend
from .audio import Codec
from typer import Argument, Option from typer import Option
from typer.models import ArgumentInfo, OptionInfo from typer.models import OptionInfo
__email__ = 'spleeter@deezer.com' __email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research' __author__ = 'Deezer Research'