From dd7ce237ed7a64c40a7315894ec3c230a8d6eefe Mon Sep 17 00:00:00 2001 From: romi1502 Date: Wed, 1 Jul 2020 15:49:06 +0200 Subject: [PATCH] Get rid of predictor --- spleeter/utils/estimator.py | 29 +---------------------------- 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/spleeter/utils/estimator.py b/spleeter/utils/estimator.py index a9aa736..f49fb56 100644 --- a/spleeter/utils/estimator.py +++ b/spleeter/utils/estimator.py @@ -5,20 +5,14 @@ from pathlib import Path from os.path import join -from tempfile import gettempdir # pylint: disable=import-error import tensorflow as tf -from tensorflow.contrib import predictor -# pylint: enable=import-error -from ..model import model_fn, InputProviderFactory +from ..model import model_fn from ..model.provider import get_default_model_provider -# Default exporting directory for predictor. -DEFAULT_EXPORT_DIRECTORY = join(gettempdir(), 'serving') - def get_default_model_dir(model_dir): @@ -57,24 +51,3 @@ def create_estimator(params, MWF): config=config ) return estimator - - -def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY): - """ Exports given estimator as predictor into the given directory - and returns associated tf.predictor instance. - - :param estimator: Estimator to export. - :param directory: (Optional) path to write exported model into. - """ - - input_provider = InputProviderFactory.get(estimator.params) - def receiver(): - features = input_provider.get_input_dict_placeholders() - return tf.estimator.export.ServingInputReceiver(features, features) - - estimator.export_saved_model(directory, receiver) - versions = [ - model for model in Path(directory).iterdir() - if model.is_dir() and 'temp' not in str(model)] - latest = str(sorted(versions)[-1]) - return predictor.from_saved_model(latest)