mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
Adding option to use librosa backend.
Changes in the EstimatorBuilder to set attributes instead of returning tensors for the _build methods. InputProvider classes to handle the different backend cases. New method in Separator.
This commit is contained in:
@@ -13,7 +13,7 @@ import tensorflow as tf
|
||||
from tensorflow.contrib import predictor
|
||||
# pylint: enable=import-error
|
||||
|
||||
from ..model import model_fn
|
||||
from ..model import model_fn, InputProviderFactory
|
||||
from ..model.provider import get_default_model_provider
|
||||
|
||||
# Default exporting directory for predictor.
|
||||
@@ -59,14 +59,6 @@ def create_estimator(params, MWF):
|
||||
return estimator
|
||||
|
||||
|
||||
def get_input_dict_placeholders(params):
|
||||
shape = (None, params['n_channels'])
|
||||
features = {
|
||||
'waveform': tf.compat.v1.placeholder(tf.float32, shape=shape, name="waveform"),
|
||||
'audio_id': tf.compat.v1.placeholder(tf.string, name="audio_id")}
|
||||
return features
|
||||
|
||||
|
||||
def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
|
||||
""" Exports given estimator as predictor into the given directory
|
||||
and returns associated tf.predictor instance.
|
||||
@@ -74,8 +66,10 @@ def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY):
|
||||
:param estimator: Estimator to export.
|
||||
:param directory: (Optional) path to write exported model into.
|
||||
"""
|
||||
|
||||
input_provider = InputProviderFactory.get(estimator.params)
|
||||
def receiver():
|
||||
features = get_input_dict_placeholders(estimator.params)
|
||||
features = input_provider.get_input_dict_placeholders()
|
||||
return tf.estimator.export.ServingInputReceiver(features, features)
|
||||
|
||||
estimator.export_saved_model(directory, receiver)
|
||||
|
||||
Reference in New Issue
Block a user