mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-06 04:32:43 +00:00
refactor: add probe in train
This commit is contained in:
@@ -16,6 +16,7 @@ import tensorflow as tf
|
||||
from ..audio.adapter import get_audio_adapter
|
||||
from ..dataset import get_training_dataset, get_validation_dataset
|
||||
from ..model import model_fn
|
||||
from ..model.provider import ModelProvider
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
__email__ = 'research@deezer.com'
|
||||
@@ -95,4 +96,5 @@ def entrypoint(arguments, params):
|
||||
estimator,
|
||||
train_spec,
|
||||
evaluation_spec)
|
||||
ModelProvider.writeProbe(params['model_dir'])
|
||||
get_logger().info('Model training done')
|
||||
|
||||
@@ -38,12 +38,14 @@ class ModelProvider(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def writeProbe(self, directory):
|
||||
@staticmethod
|
||||
def writeProbe(directory):
|
||||
""" Write a model probe file into the given directory.
|
||||
|
||||
:param directory: Directory to write probe into.
|
||||
"""
|
||||
with open(join(directory, self.MODEL_PROBE_PATH), 'w') as stream:
|
||||
probe = join(directory, ModelProvider.MODEL_PROBE_PATH)
|
||||
with open(probe, 'w') as stream:
|
||||
stream.write('OK')
|
||||
|
||||
def get(self, model_directory):
|
||||
|
||||
Reference in New Issue
Block a user