mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
Get rid of predictor
This commit is contained in:
@@ -5,20 +5,14 @@
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from os.path import join
|
from os.path import join
|
||||||
from tempfile import gettempdir
|
|
||||||
|
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.contrib import predictor
|
|
||||||
# pylint: enable=import-error
|
|
||||||
|
|
||||||
from ..model import model_fn, InputProviderFactory
|
from ..model import model_fn
|
||||||
from ..model.provider import get_default_model_provider
|
from ..model.provider import get_default_model_provider
|
||||||
|
|
||||||
# Default exporting directory for predictor.
|
|
||||||
DEFAULT_EXPORT_DIRECTORY = join(gettempdir(), 'serving')
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_model_dir(model_dir):
|
def get_default_model_dir(model_dir):
|
||||||
@@ -57,24 +51,3 @@ def create_estimator(params, MWF):
|
|||||||
config=config
|
config=config
|
||||||
)
|
)
|
||||||
return estimator
|
return estimator
|
||||||
|
|
||||||
|
|
||||||
def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
|
|
||||||
""" Exports given estimator as predictor into the given directory
|
|
||||||
and returns associated tf.predictor instance.
|
|
||||||
|
|
||||||
:param estimator: Estimator to export.
|
|
||||||
:param directory: (Optional) path to write exported model into.
|
|
||||||
"""
|
|
||||||
|
|
||||||
input_provider = InputProviderFactory.get(estimator.params)
|
|
||||||
def receiver():
|
|
||||||
features = input_provider.get_input_dict_placeholders()
|
|
||||||
return tf.estimator.export.ServingInputReceiver(features, features)
|
|
||||||
|
|
||||||
estimator.export_saved_model(directory, receiver)
|
|
||||||
versions = [
|
|
||||||
model for model in Path(directory).iterdir()
|
|
||||||
if model.is_dir() and 'temp' not in str(model)]
|
|
||||||
latest = str(sorted(versions)[-1])
|
|
||||||
return predictor.from_saved_model(latest)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user