Get rid of predictor

This commit is contained in:
romi1502
2020-07-01 15:49:06 +02:00
parent d2919f9062
commit dd7ce237ed

View File

@@ -5,20 +5,14 @@
from pathlib import Path
from os.path import join
from tempfile import gettempdir
# pylint: disable=import-error
import tensorflow as tf
from tensorflow.contrib import predictor
# pylint: enable=import-error
from ..model import model_fn, InputProviderFactory
from ..model import model_fn
from ..model.provider import get_default_model_provider
# Default exporting directory for predictor.
DEFAULT_EXPORT_DIRECTORY = join(gettempdir(), 'serving')
def get_default_model_dir(model_dir):
@@ -57,24 +51,3 @@ def create_estimator(params, MWF):
config=config
)
return estimator
def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
""" Exports given estimator as predictor into the given directory
and returns associated tf.predictor instance.
:param estimator: Estimator to export.
:param directory: (Optional) path to write exported model into.
"""
input_provider = InputProviderFactory.get(estimator.params)
def receiver():
features = input_provider.get_input_dict_placeholders()
return tf.estimator.export.ServingInputReceiver(features, features)
estimator.export_saved_model(directory, receiver)
versions = [
model for model in Path(directory).iterdir()
if model.is_dir() and 'temp' not in str(model)]
latest = str(sorted(versions)[-1])
return predictor.from_saved_model(latest)