️ fix heavy import on cli parsing

This commit is contained in:
Faylixe
2020-12-18 17:31:16 +01:00
parent 0428a0bde8
commit a41b8f35a6
3 changed files with 32 additions and 24 deletions

View File

@@ -5,10 +5,12 @@
Python oneliner script usage.
USAGE: python -m spleeter {train,evaluate,separate} ...
"""
# NOTE: disable TF logging before import.
from .utils.logging import configure_logger, logger
Notes:
All critical import involving TF, numpy or Pandas are deported to
command function scope to avoid heavy import on CLI evaluation,
leading to large bootstraping time.
"""
import json
@@ -17,24 +19,14 @@ from itertools import product
from glob import glob
from os.path import join
from pathlib import Path
from typing import Any, Container, Dict, List
from typing import Container, Dict, List
from . import SpleeterError
from .audio import Codec
from .audio.adapter import AudioAdapter
from .options import *
from .dataset import get_training_dataset, get_validation_dataset
from .model import model_fn
from .model.provider import ModelProvider
from .separator import Separator
from .utils.configuration import load_configuration
from .utils.logging import configure_logger, logger
# pyright: reportMissingImports=false
# pylint: disable=import-error
import numpy as np
import pandas as pd
import tensorflow as tf
from typer import Exit, Typer
# pylint: enable=import-error
@@ -51,6 +43,14 @@ def train(
"""
Train a source separation model
"""
from .audio.adapter import AudioAdapter
from .dataset import get_training_dataset, get_validation_dataset
from .model import model_fn
from .model.provider import ModelProvider
from .utils.configuration import load_configuration
import tensorflow as tf
configure_logger(verbose)
audio_adapter = AudioAdapter.get(adapter)
audio_path = str(data)
@@ -104,6 +104,9 @@ def separate(
"""
Separate audio file(s)
"""
from .audio.adapter import AudioAdapter
from .separator import Separator
configure_logger(verbose)
audio_adapter: AudioAdapter = AudioAdapter.get(adapter)
separator: Separator = Separator(
@@ -144,6 +147,9 @@ def _compile_metrics(metrics_output_directory) -> Dict:
Dict:
Compiled metrics as dict.
"""
import pandas as pd
import numpy as np
songs = glob(join(metrics_output_directory, 'test/*.json'))
index = pd.MultiIndex.from_tuples(
product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS),
@@ -178,6 +184,8 @@ def evaluate(
"""
Evaluate a model on the musDB test dataset
"""
import numpy as np
configure_logger(verbose)
try:
import musdb

View File

@@ -12,11 +12,6 @@
from enum import Enum
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
# pylint: enable=import-error
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
@@ -42,6 +37,12 @@ class STFTBackend(str, Enum):
@classmethod
def resolve(cls: type, backend: str) -> str:
# NOTE: import is resolved here to avoid performance issues on command
# evaluation.
# pyright: reportMissingImports=false
# pylint: disable=import-error
import tensorflow as tf
if backend not in cls.__members__.values():
raise ValueError(f'Unsupported backend {backend}')
if backend == cls.AUTO:

View File

@@ -6,11 +6,10 @@
from tempfile import gettempdir
from os.path import join
from .separator import STFTBackend
from .audio import Codec
from .audio import Codec, STFTBackend
from typer import Argument, Option
from typer.models import ArgumentInfo, OptionInfo
from typer import Option
from typer.models import OptionInfo
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'