Files
spleeter/spleeter/commands/train.py
mmoussallam 227020c256 replace mail
2020-07-17 13:30:42 +02:00

101 lines
3.0 KiB
Python

#!/usr/bin/env python
# coding: utf8
"""
Entrypoint provider for performing model training.
USAGE: python -m spleeter train -p /path/to/params
"""
from functools import partial
# pylint: disable=import-error
import tensorflow as tf
# pylint: enable=import-error
from ..audio.adapter import get_audio_adapter
from ..dataset import get_training_dataset, get_validation_dataset
from ..model import model_fn
from ..model.provider import ModelProvider
from ..utils.logging import get_logger
__email__ = 'spleeter@deezer.com'
__author__ = 'Deezer Research'
__license__ = 'MIT License'
def _create_estimator(params):
""" Creates estimator.
:param params: TF params to build estimator from.
:returns: Built estimator.
"""
session_config = tf.compat.v1.ConfigProto()
session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=params['model_dir'],
params=params,
config=tf.estimator.RunConfig(
save_checkpoints_steps=params['save_checkpoints_steps'],
tf_random_seed=params['random_seed'],
save_summary_steps=params['save_summary_steps'],
session_config=session_config,
log_step_count_steps=10,
keep_checkpoint_max=2))
return estimator
def _create_train_spec(params, audio_adapter, audio_path):
""" Creates train spec.
:param params: TF params to build spec from.
:returns: Built train spec.
"""
input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
train_spec = tf.estimator.TrainSpec(
input_fn=input_fn,
max_steps=params['train_max_steps'])
return train_spec
def _create_evaluation_spec(params, audio_adapter, audio_path):
""" Setup eval spec evaluating ever n seconds
:param params: TF params to build spec from.
:returns: Built evaluation spec.
"""
input_fn = partial(
get_validation_dataset,
params,
audio_adapter,
audio_path)
evaluation_spec = tf.estimator.EvalSpec(
input_fn=input_fn,
steps=None,
throttle_secs=params['throttle_secs'])
return evaluation_spec
def entrypoint(arguments, params):
""" Command entrypoint.
:param arguments: Command line parsed argument as argparse.Namespace.
:param params: Deserialized JSON configuration file provided in CLI args.
"""
audio_adapter = get_audio_adapter(arguments.audio_adapter)
audio_path = arguments.audio_path
estimator = _create_estimator(params)
train_spec = _create_train_spec(params, audio_adapter, audio_path)
evaluation_spec = _create_evaluation_spec(
params,
audio_adapter,
audio_path)
get_logger().info('Start model training')
tf.estimator.train_and_evaluate(
estimator,
train_spec,
evaluation_spec)
ModelProvider.writeProbe(params['model_dir'])
get_logger().info('Model training done')