mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
refactor: add probe in train
This commit is contained in:
@@ -16,6 +16,7 @@ import tensorflow as tf
|
|||||||
from ..audio.adapter import get_audio_adapter
|
from ..audio.adapter import get_audio_adapter
|
||||||
from ..dataset import get_training_dataset, get_validation_dataset
|
from ..dataset import get_training_dataset, get_validation_dataset
|
||||||
from ..model import model_fn
|
from ..model import model_fn
|
||||||
|
from ..model.provider import ModelProvider
|
||||||
from ..utils.logging import get_logger
|
from ..utils.logging import get_logger
|
||||||
|
|
||||||
__email__ = 'research@deezer.com'
|
__email__ = 'research@deezer.com'
|
||||||
@@ -95,4 +96,5 @@ def entrypoint(arguments, params):
|
|||||||
estimator,
|
estimator,
|
||||||
train_spec,
|
train_spec,
|
||||||
evaluation_spec)
|
evaluation_spec)
|
||||||
|
ModelProvider.writeProbe(params['model_dir'])
|
||||||
get_logger().info('Model training done')
|
get_logger().info('Model training done')
|
||||||
|
|||||||
@@ -38,12 +38,14 @@ class ModelProvider(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def writeProbe(self, directory):
|
@staticmethod
|
||||||
|
def writeProbe(directory):
|
||||||
""" Write a model probe file into the given directory.
|
""" Write a model probe file into the given directory.
|
||||||
|
|
||||||
:param directory: Directory to write probe into.
|
:param directory: Directory to write probe into.
|
||||||
"""
|
"""
|
||||||
with open(join(directory, self.MODEL_PROBE_PATH), 'w') as stream:
|
probe = join(directory, ModelProvider.MODEL_PROBE_PATH)
|
||||||
|
with open(probe, 'w') as stream:
|
||||||
stream.write('OK')
|
stream.write('OK')
|
||||||
|
|
||||||
def get(self, model_directory):
|
def get(self, model_directory):
|
||||||
|
|||||||
Reference in New Issue
Block a user