🎨 ⬆️ add fix dep and clean code

This commit is contained in:
Faylixe
2020-12-04 12:43:13 +01:00
parent 1181f4b54d
commit d68113ec80
2 changed files with 6 additions and 11 deletions

View File

@@ -3,18 +3,12 @@
""" Utility functions for creating estimator. """
from pathlib import Path
from os.path import join
# pylint: disable=import-error
import tensorflow as tf
import tensorflow as tf # pylint: disable=import-error
from ..model import model_fn
from ..model.provider import get_default_model_provider
def get_default_model_dir(model_dir):
"""
Transforms a string like 'spleeter:2stems' into an actual path.
@@ -24,6 +18,7 @@ def get_default_model_dir(model_dir):
model_provider = get_default_model_provider()
return model_provider.get(model_dir)
def create_estimator(params, MWF):
"""
Initialize tensorflow estimator that will perform separation
@@ -35,8 +30,6 @@ def create_estimator(params, MWF):
a tensorflow estimator
"""
# Load model.
params['model_dir'] = get_default_model_dir(params['model_dir'])
params['MWF'] = MWF
# Setup config