From fbe5f290ef6af8d58a7c69f99841c8f925e93a18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Voituret?= Date: Wed, 20 Nov 2019 15:40:25 +0100 Subject: [PATCH] refactor: add probe in train --- spleeter/commands/train.py | 2 ++ spleeter/model/provider/__init__.py | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/spleeter/commands/train.py b/spleeter/commands/train.py index bb48c41..2a40c84 100644 --- a/spleeter/commands/train.py +++ b/spleeter/commands/train.py @@ -16,6 +16,7 @@ import tensorflow as tf from ..audio.adapter import get_audio_adapter from ..dataset import get_training_dataset, get_validation_dataset from ..model import model_fn +from ..model.provider import ModelProvider from ..utils.logging import get_logger __email__ = 'research@deezer.com' @@ -95,4 +96,5 @@ def entrypoint(arguments, params): estimator, train_spec, evaluation_spec) + ModelProvider.writeProbe(params['model_dir']) get_logger().info('Model training done') diff --git a/spleeter/model/provider/__init__.py b/spleeter/model/provider/__init__.py index 854b065..3aa3d8d 100644 --- a/spleeter/model/provider/__init__.py +++ b/spleeter/model/provider/__init__.py @@ -38,12 +38,14 @@ class ModelProvider(ABC): """ pass - def writeProbe(self, directory): + @staticmethod + def writeProbe(directory): """ Write a model probe file into the given directory. :param directory: Directory to write probe into. """ - with open(join(directory, self.MODEL_PROBE_PATH), 'w') as stream: + probe = join(directory, ModelProvider.MODEL_PROBE_PATH) + with open(probe, 'w') as stream: stream.write('OK') def get(self, model_directory):