mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-06 04:32:43 +00:00
Get rid of predictor
This commit is contained in:
@@ -5,20 +5,14 @@
|
||||
|
||||
from pathlib import Path
|
||||
from os.path import join
|
||||
from tempfile import gettempdir
|
||||
|
||||
# pylint: disable=import-error
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib import predictor
|
||||
# pylint: enable=import-error
|
||||
|
||||
from ..model import model_fn, InputProviderFactory
|
||||
from ..model import model_fn
|
||||
from ..model.provider import get_default_model_provider
|
||||
|
||||
# Default exporting directory for predictor.
|
||||
DEFAULT_EXPORT_DIRECTORY = join(gettempdir(), 'serving')
|
||||
|
||||
|
||||
|
||||
def get_default_model_dir(model_dir):
|
||||
@@ -57,24 +51,3 @@ def create_estimator(params, MWF):
|
||||
config=config
|
||||
)
|
||||
return estimator
|
||||
|
||||
|
||||
def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
|
||||
""" Exports given estimator as predictor into the given directory
|
||||
and returns associated tf.predictor instance.
|
||||
|
||||
:param estimator: Estimator to export.
|
||||
:param directory: (Optional) path to write exported model into.
|
||||
"""
|
||||
|
||||
input_provider = InputProviderFactory.get(estimator.params)
|
||||
def receiver():
|
||||
features = input_provider.get_input_dict_placeholders()
|
||||
return tf.estimator.export.ServingInputReceiver(features, features)
|
||||
|
||||
estimator.export_saved_model(directory, receiver)
|
||||
versions = [
|
||||
model for model in Path(directory).iterdir()
|
||||
if model.is_dir() and 'temp' not in str(model)]
|
||||
latest = str(sorted(versions)[-1])
|
||||
return predictor.from_saved_model(latest)
|
||||
|
||||
Reference in New Issue
Block a user