refactor: add probe in train

This commit is contained in:
Félix Voituret
2019-11-20 15:40:25 +01:00
parent 151c2a7b93
commit fbe5f290ef
2 changed files with 6 additions and 2 deletions

View File

@@ -16,6 +16,7 @@ import tensorflow as tf
from ..audio.adapter import get_audio_adapter
from ..dataset import get_training_dataset, get_validation_dataset
from ..model import model_fn
from ..model.provider import ModelProvider
from ..utils.logging import get_logger
__email__ = 'research@deezer.com'
@@ -95,4 +96,5 @@ def entrypoint(arguments, params):
estimator,
train_spec,
evaluation_spec)
ModelProvider.writeProbe(params['model_dir'])
get_logger().info('Model training done')

View File

@@ -38,12 +38,14 @@ class ModelProvider(ABC):
"""
pass
def writeProbe(self, directory):
@staticmethod
def writeProbe(directory):
""" Write a model probe file into the given directory.
:param directory: Directory to write probe into.
"""
with open(join(directory, self.MODEL_PROBE_PATH), 'w') as stream:
probe = join(directory, ModelProvider.MODEL_PROBE_PATH)
with open(probe, 'w') as stream:
stream.write('OK')
def get(self, model_directory):