mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-30 12:22:58 +00:00
Changes in the EstimatorBuilder to set attributes instead of returning tensors for the _build methods. InputProvider classes to handle the different backend cases. New method in Separator.
81 lines
2.3 KiB
Python
81 lines
2.3 KiB
Python
#!/usr/bin/env python
|
|
# coding: utf8
|
|
|
|
""" Utility functions for creating estimator. """
|
|
|
|
from pathlib import Path
|
|
from os.path import join
|
|
from tempfile import gettempdir
|
|
|
|
# pylint: disable=import-error
|
|
import tensorflow as tf
|
|
|
|
from tensorflow.contrib import predictor
|
|
# pylint: enable=import-error
|
|
|
|
from ..model import model_fn, InputProviderFactory
|
|
from ..model.provider import get_default_model_provider
|
|
|
|
# Default exporting directory for predictor.
|
|
DEFAULT_EXPORT_DIRECTORY = join(gettempdir(), 'serving')
|
|
|
|
|
|
|
|
def get_default_model_dir(model_dir):
|
|
"""
|
|
Transforms a string like 'spleeter:2stems' into an actual path.
|
|
:param model_dir:
|
|
:return:
|
|
"""
|
|
model_provider = get_default_model_provider()
|
|
return model_provider.get(model_dir)
|
|
|
|
def create_estimator(params, MWF):
|
|
"""
|
|
Initialize tensorflow estimator that will perform separation
|
|
|
|
Params:
|
|
- params: a dictionary of parameters for building the model
|
|
|
|
Returns:
|
|
a tensorflow estimator
|
|
"""
|
|
# Load model.
|
|
|
|
|
|
params['model_dir'] = get_default_model_dir(params['model_dir'])
|
|
params['MWF'] = MWF
|
|
# Setup config
|
|
session_config = tf.compat.v1.ConfigProto()
|
|
session_config.gpu_options.per_process_gpu_memory_fraction = 0.7
|
|
config = tf.estimator.RunConfig(session_config=session_config)
|
|
# Setup estimator
|
|
estimator = tf.estimator.Estimator(
|
|
model_fn=model_fn,
|
|
model_dir=params['model_dir'],
|
|
params=params,
|
|
config=config
|
|
)
|
|
return estimator
|
|
|
|
|
|
def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
|
|
""" Exports given estimator as predictor into the given directory
|
|
and returns associated tf.predictor instance.
|
|
|
|
:param estimator: Estimator to export.
|
|
:param directory: (Optional) path to write exported model into.
|
|
"""
|
|
|
|
input_provider = InputProviderFactory.get(estimator.params)
|
|
def receiver():
|
|
features = input_provider.get_input_dict_placeholders()
|
|
return tf.estimator.export.ServingInputReceiver(features, features)
|
|
|
|
estimator.export_saved_model(directory, receiver)
|
|
versions = [
|
|
model for model in Path(directory).iterdir()
|
|
if model.is_dir() and 'temp' not in str(model)]
|
|
latest = str(sorted(versions)[-1])
|
|
return predictor.from_saved_model(latest)
|