mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-06 04:32:43 +00:00
101 lines
3.0 KiB
Python
101 lines
3.0 KiB
Python
#!/usr/bin/env python
|
|
# coding: utf8
|
|
|
|
"""
|
|
Entrypoint provider for performing model training.
|
|
|
|
USAGE: python -m spleeter train -p /path/to/params
|
|
"""
|
|
|
|
from functools import partial
|
|
|
|
# pylint: disable=import-error
|
|
import tensorflow as tf
|
|
# pylint: enable=import-error
|
|
|
|
from ..audio.adapter import get_audio_adapter
|
|
from ..dataset import get_training_dataset, get_validation_dataset
|
|
from ..model import model_fn
|
|
from ..model.provider import ModelProvider
|
|
from ..utils.logging import get_logger
|
|
|
|
__email__ = 'spleeter@deezer.com'
|
|
__author__ = 'Deezer Research'
|
|
__license__ = 'MIT License'
|
|
|
|
|
|
def _create_estimator(params):
|
|
""" Creates estimator.
|
|
|
|
:param params: TF params to build estimator from.
|
|
:returns: Built estimator.
|
|
"""
|
|
session_config = tf.compat.v1.ConfigProto()
|
|
session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
|
|
estimator = tf.estimator.Estimator(
|
|
model_fn=model_fn,
|
|
model_dir=params['model_dir'],
|
|
params=params,
|
|
config=tf.estimator.RunConfig(
|
|
save_checkpoints_steps=params['save_checkpoints_steps'],
|
|
tf_random_seed=params['random_seed'],
|
|
save_summary_steps=params['save_summary_steps'],
|
|
session_config=session_config,
|
|
log_step_count_steps=10,
|
|
keep_checkpoint_max=2))
|
|
return estimator
|
|
|
|
|
|
def _create_train_spec(params, audio_adapter, audio_path):
|
|
""" Creates train spec.
|
|
|
|
:param params: TF params to build spec from.
|
|
:returns: Built train spec.
|
|
"""
|
|
input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
|
|
train_spec = tf.estimator.TrainSpec(
|
|
input_fn=input_fn,
|
|
max_steps=params['train_max_steps'])
|
|
return train_spec
|
|
|
|
|
|
def _create_evaluation_spec(params, audio_adapter, audio_path):
|
|
""" Setup eval spec evaluating ever n seconds
|
|
|
|
:param params: TF params to build spec from.
|
|
:returns: Built evaluation spec.
|
|
"""
|
|
input_fn = partial(
|
|
get_validation_dataset,
|
|
params,
|
|
audio_adapter,
|
|
audio_path)
|
|
evaluation_spec = tf.estimator.EvalSpec(
|
|
input_fn=input_fn,
|
|
steps=None,
|
|
throttle_secs=params['throttle_secs'])
|
|
return evaluation_spec
|
|
|
|
|
|
def entrypoint(arguments, params):
|
|
""" Command entrypoint.
|
|
|
|
:param arguments: Command line parsed argument as argparse.Namespace.
|
|
:param params: Deserialized JSON configuration file provided in CLI args.
|
|
"""
|
|
audio_adapter = get_audio_adapter(arguments.audio_adapter)
|
|
audio_path = arguments.audio_path
|
|
estimator = _create_estimator(params)
|
|
train_spec = _create_train_spec(params, audio_adapter, audio_path)
|
|
evaluation_spec = _create_evaluation_spec(
|
|
params,
|
|
audio_adapter,
|
|
audio_path)
|
|
get_logger().info('Start model training')
|
|
tf.estimator.train_and_evaluate(
|
|
estimator,
|
|
train_spec,
|
|
evaluation_spec)
|
|
ModelProvider.writeProbe(params['model_dir'])
|
|
get_logger().info('Model training done')
|