mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
🚧 typer integration
This commit is contained in:
@@ -1,13 +1,22 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# coding: utf8
|
# coding: utf8
|
||||||
|
|
||||||
""" TO DOCUMENT """
|
"""
|
||||||
|
Python oneliner script usage.
|
||||||
|
|
||||||
|
USAGE: python -m spleeter {train,evaluate,separate} ...
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from itertools import product
|
||||||
|
from glob import glob
|
||||||
|
from os.path import join
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import Any, Container, Dict, List
|
||||||
|
|
||||||
|
from . import SpleeterError
|
||||||
from .audio import Codec
|
from .audio import Codec
|
||||||
from .audio.adapter import AudioAdapter
|
from .audio.adapter import AudioAdapter
|
||||||
from .options import *
|
from .options import *
|
||||||
@@ -16,11 +25,12 @@ from .model import model_fn
|
|||||||
from .model.provider import ModelProvider
|
from .model.provider import ModelProvider
|
||||||
from .separator import Separator
|
from .separator import Separator
|
||||||
from .utils.configuration import load_configuration
|
from .utils.configuration import load_configuration
|
||||||
from .utils.logging import get_logger
|
from .utils.logging import configure_logger, logger
|
||||||
|
|
||||||
|
|
||||||
# pyright: reportMissingImports=false
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from typer import Exit, Typer
|
from typer import Exit, Typer
|
||||||
@@ -39,8 +49,7 @@ def train(
|
|||||||
"""
|
"""
|
||||||
Train a source separation model
|
Train a source separation model
|
||||||
"""
|
"""
|
||||||
# TODO: try / catch or custom decorator for function handling.
|
configure_logger(verbose)
|
||||||
# TODO: handle verbose flag ?
|
|
||||||
audio_adapter = AudioAdapter.get(adapter)
|
audio_adapter = AudioAdapter.get(adapter)
|
||||||
audio_path = str(data)
|
audio_path = str(data)
|
||||||
params = load_configuration(params_filename)
|
params = load_configuration(params_filename)
|
||||||
@@ -70,120 +79,19 @@ def train(
|
|||||||
input_fn=input_fn,
|
input_fn=input_fn,
|
||||||
steps=None,
|
steps=None,
|
||||||
throttle_secs=params['throttle_secs'])
|
throttle_secs=params['throttle_secs'])
|
||||||
get_logger().info('Start model training')
|
logger.info('Start model training')
|
||||||
tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec)
|
tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec)
|
||||||
ModelProvider.writeProbe(params['model_dir'])
|
ModelProvider.writeProbe(params['model_dir'])
|
||||||
get_logger().info('Model training done')
|
logger.info('Model training done')
|
||||||
|
|
||||||
_SPLIT = 'test'
|
|
||||||
_MIXTURE = 'mixture.wav'
|
|
||||||
_AUDIO_DIRECTORY = 'audio'
|
|
||||||
_METRICS_DIRECTORY = 'metrics'
|
|
||||||
_INSTRUMENTS = ('vocals', 'drums', 'bass', 'other')
|
|
||||||
_METRICS = ('SDR', 'SAR', 'SIR', 'ISR')
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_musdb_metrics(
|
|
||||||
arguments,
|
|
||||||
musdb_root_directory,
|
|
||||||
audio_output_directory):
|
|
||||||
""" Generates musdb metrics fro previsouly computed audio estimation.
|
|
||||||
|
|
||||||
:param arguments: Entrypoint arguments.
|
|
||||||
:param audio_output_directory: Directory to get audio estimation from.
|
|
||||||
:returns: Path of generated metrics directory.
|
|
||||||
"""
|
|
||||||
metrics_output_directory = join(
|
|
||||||
arguments.output_path,
|
|
||||||
_METRICS_DIRECTORY)
|
|
||||||
get_logger().info('Starting musdb evaluation (this could be long) ...')
|
|
||||||
try:
|
|
||||||
import musdb
|
|
||||||
import museval
|
|
||||||
except ImportError:
|
|
||||||
logger = get_logger()
|
|
||||||
logger.error('Extra dependencies musdb and museval not found')
|
|
||||||
logger.error('Please install musdb and museval first, abort')
|
|
||||||
raise Exit(10)
|
|
||||||
dataset = musdb.DB(
|
|
||||||
root=musdb_root_directory,
|
|
||||||
is_wav=True,
|
|
||||||
subsets=[_SPLIT])
|
|
||||||
museval.eval_mus_dir(
|
|
||||||
dataset=dataset,
|
|
||||||
estimates_dir=audio_output_directory,
|
|
||||||
output_dir=metrics_output_directory)
|
|
||||||
get_logger().info('musdb evaluation done')
|
|
||||||
return metrics_output_directory
|
|
||||||
|
|
||||||
|
|
||||||
def _compile_metrics(metrics_output_directory):
|
|
||||||
""" Compiles metrics from given directory and returns
|
|
||||||
results as dict.
|
|
||||||
|
|
||||||
:param metrics_output_directory: Directory to get metrics from.
|
|
||||||
:returns: Compiled metrics as dict.
|
|
||||||
"""
|
|
||||||
songs = glob(join(metrics_output_directory, 'test/*.json'))
|
|
||||||
index = pd.MultiIndex.from_tuples(
|
|
||||||
product(_INSTRUMENTS, _METRICS),
|
|
||||||
names=['instrument', 'metric'])
|
|
||||||
pd.DataFrame([], index=['config1', 'config2'], columns=index)
|
|
||||||
metrics = {
|
|
||||||
instrument: {k: [] for k in _METRICS}
|
|
||||||
for instrument in _INSTRUMENTS}
|
|
||||||
for song in songs:
|
|
||||||
with open(song, 'r') as stream:
|
|
||||||
data = json.load(stream)
|
|
||||||
for target in data['targets']:
|
|
||||||
instrument = target['name']
|
|
||||||
for metric in _METRICS:
|
|
||||||
sdr_med = np.median([
|
|
||||||
frame['metrics'][metric]
|
|
||||||
for frame in target['frames']
|
|
||||||
if not np.isnan(frame['metrics'][metric])])
|
|
||||||
metrics[instrument][metric].append(sdr_med)
|
|
||||||
return metrics
|
|
||||||
|
|
||||||
|
|
||||||
@spleeter.command()
|
|
||||||
def evaluate(
|
|
||||||
adapter: str = AudioAdapterOption,
|
|
||||||
output_path: Path = AudioAdapterOption,
|
|
||||||
stft_backend: STFTBackend = AudioSTFTBackendOption,
|
|
||||||
params_filename: str = ModelParametersOption,
|
|
||||||
mus_dir: Path = MUSDBDirectoryOption,
|
|
||||||
mwf: bool = MWFOption,
|
|
||||||
verbose: bool = VerboseOption) -> None:
|
|
||||||
"""
|
|
||||||
Evaluate a model on the musDB test dataset
|
|
||||||
"""
|
|
||||||
# Separate musdb sources.
|
|
||||||
audio_output_directory = _separate_evaluation_dataset(
|
|
||||||
arguments,
|
|
||||||
mus_dir,
|
|
||||||
params)
|
|
||||||
# Compute metrics with musdb.
|
|
||||||
metrics_output_directory = _compute_musdb_metrics(
|
|
||||||
arguments,
|
|
||||||
mus_dir,
|
|
||||||
audio_output_directory)
|
|
||||||
# Compute and pretty print median metrics.
|
|
||||||
metrics = _compile_metrics(metrics_output_directory)
|
|
||||||
for instrument, metric in metrics.items():
|
|
||||||
get_logger().info('%s:', instrument)
|
|
||||||
for metric, value in metric.items():
|
|
||||||
get_logger().info('%s: %s', metric, f'{np.median(value):.3f}')
|
|
||||||
return metrics
|
|
||||||
|
|
||||||
|
|
||||||
@spleeter.commmand()
|
@spleeter.commmand()
|
||||||
def separate(
|
def separate(
|
||||||
|
files: List[Path] = AudioInputArgument,
|
||||||
adapter: str = AudioAdapterOption,
|
adapter: str = AudioAdapterOption,
|
||||||
bitrate: str = AudioBitrateOption,
|
bitrate: str = AudioBitrateOption,
|
||||||
codec: Codec = AudioCodecOption,
|
codec: Codec = AudioCodecOption,
|
||||||
duration: float = AudioDurationOption,
|
duration: float = AudioDurationOption,
|
||||||
files: List[Path] = AudioInputArgument,
|
|
||||||
offset: float = AudioOffsetOption,
|
offset: float = AudioOffsetOption,
|
||||||
output_path: Path = AudioAdapterOption,
|
output_path: Path = AudioAdapterOption,
|
||||||
stft_backend: STFTBackend = AudioSTFTBackendOption,
|
stft_backend: STFTBackend = AudioSTFTBackendOption,
|
||||||
@@ -194,13 +102,7 @@ def separate(
|
|||||||
"""
|
"""
|
||||||
Separate audio file(s)
|
Separate audio file(s)
|
||||||
"""
|
"""
|
||||||
# TODO: try / catch or custom decorator for function handling.
|
configure_logger(verbose)
|
||||||
# TODO: enable_logging()
|
|
||||||
# TODO: handle MWF
|
|
||||||
if verbose:
|
|
||||||
# TODO: enable_tensorflow_logging()
|
|
||||||
pass
|
|
||||||
# PREV: params = load_configuration(arguments.configuration)
|
|
||||||
audio_adapter: AudioAdapter = AudioAdapter.get(adapter)
|
audio_adapter: AudioAdapter = AudioAdapter.get(adapter)
|
||||||
separator: Separator = Separator(
|
separator: Separator = Separator(
|
||||||
params_filename,
|
params_filename,
|
||||||
@@ -220,6 +122,102 @@ def separate(
|
|||||||
separator.join()
|
separator.join()
|
||||||
|
|
||||||
|
|
||||||
|
EVALUATION_SPLIT: str = 'test'
|
||||||
|
EVALUATION_METRICS_DIRECTORY: str = 'metrics'
|
||||||
|
EVALUATION_INSTRUMENTS: Container[str] = ('vocals', 'drums', 'bass', 'other')
|
||||||
|
EVALUATION_METRICS: Container[str] = ('SDR', 'SAR', 'SIR', 'ISR')
|
||||||
|
EVALUATION_MIXTURE: str = 'mixture.wav'
|
||||||
|
EVALUATION_AUDIO_DIRECTORY: str = 'audio'
|
||||||
|
|
||||||
|
|
||||||
|
def _compile_metrics(metrics_output_directory) -> Dict:
|
||||||
|
"""
|
||||||
|
Compiles metrics from given directory and returns results as dict.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
metrics_output_directory (str):
|
||||||
|
Directory to get metrics from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict:
|
||||||
|
Compiled metrics as dict.
|
||||||
|
"""
|
||||||
|
songs = glob(join(metrics_output_directory, 'test/*.json'))
|
||||||
|
index = pd.MultiIndex.from_tuples(
|
||||||
|
product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS),
|
||||||
|
names=['instrument', 'metric'])
|
||||||
|
pd.DataFrame([], index=['config1', 'config2'], columns=index)
|
||||||
|
metrics = {
|
||||||
|
instrument: {k: [] for k in EVALUATION_METRICS}
|
||||||
|
for instrument in EVALUATION_INSTRUMENTS}
|
||||||
|
for song in songs:
|
||||||
|
with open(song, 'r') as stream:
|
||||||
|
data = json.load(stream)
|
||||||
|
for target in data['targets']:
|
||||||
|
instrument = target['name']
|
||||||
|
for metric in EVALUATION_METRICS:
|
||||||
|
sdr_med = np.median([
|
||||||
|
frame['metrics'][metric]
|
||||||
|
for frame in target['frames']
|
||||||
|
if not np.isnan(frame['metrics'][metric])])
|
||||||
|
metrics[instrument][metric].append(sdr_med)
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
@spleeter.command()
|
||||||
|
def evaluate(
|
||||||
|
adapter: str = AudioAdapterOption,
|
||||||
|
output_path: Path = AudioAdapterOption,
|
||||||
|
stft_backend: STFTBackend = AudioSTFTBackendOption,
|
||||||
|
params_filename: str = ModelParametersOption,
|
||||||
|
mus_dir: Path = MUSDBDirectoryOption,
|
||||||
|
mwf: bool = MWFOption,
|
||||||
|
verbose: bool = VerboseOption) -> Dict:
|
||||||
|
"""
|
||||||
|
Evaluate a model on the musDB test dataset
|
||||||
|
"""
|
||||||
|
configure_logger(verbose)
|
||||||
|
try:
|
||||||
|
import musdb
|
||||||
|
import museval
|
||||||
|
except ImportError:
|
||||||
|
logger.error('Extra dependencies musdb and museval not found')
|
||||||
|
logger.error('Please install musdb and museval first, abort')
|
||||||
|
raise Exit(10)
|
||||||
|
# Separate musdb sources.
|
||||||
|
songs = glob(join(mus_dir, EVALUATION_SPLIT, '*/'))
|
||||||
|
mixtures = [join(song, EVALUATION_MIXTURE) for song in songs]
|
||||||
|
audio_output_directory = join(output_path, EVALUATION_AUDIO_DIRECTORY)
|
||||||
|
separate(
|
||||||
|
adapter=adapter,
|
||||||
|
params_filename=params_filename,
|
||||||
|
files=mixtures,
|
||||||
|
output_path=output_path,
|
||||||
|
filename_format='{foldername}/{instrument}.{codec}',
|
||||||
|
codec=Codec.WAV,
|
||||||
|
mwf=mwf,
|
||||||
|
verbose=verbose,
|
||||||
|
stft_backend=stft_backend)
|
||||||
|
# Compute metrics with musdb.
|
||||||
|
metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY)
|
||||||
|
logger.info('Starting musdb evaluation (this could be long) ...')
|
||||||
|
dataset = musdb.DB(root=mus_dir, is_wav=True, subsets=[EVALUATION_SPLIT])
|
||||||
|
museval.eval_mus_dir(
|
||||||
|
dataset=dataset,
|
||||||
|
estimates_dir=audio_output_directory,
|
||||||
|
output_dir=metrics_output_directory)
|
||||||
|
logger.info('musdb evaluation done')
|
||||||
|
# Compute and pretty print median metrics.
|
||||||
|
metrics = _compile_metrics(metrics_output_directory)
|
||||||
|
for instrument, metric in metrics.items():
|
||||||
|
logger.info(f'{instrument}:')
|
||||||
|
for metric, value in metric.items():
|
||||||
|
logger.info(f'{metric}: {np.median(value):.3f}')
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# TODO: warnings.filterwarnings('ignore')
|
try:
|
||||||
spleeter()
|
spleeter()
|
||||||
|
except SpleeterError as e:
|
||||||
|
logger.error(e)
|
||||||
|
|||||||
@@ -1,166 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# coding: utf8
|
|
||||||
|
|
||||||
"""
|
|
||||||
Entrypoint provider for performing model evaluation.
|
|
||||||
|
|
||||||
Evaluation is performed against musDB dataset.
|
|
||||||
|
|
||||||
USAGE: python -m spleeter evaluate \
|
|
||||||
-p /path/to/params \
|
|
||||||
-o /path/to/output/dir \
|
|
||||||
[-m] \
|
|
||||||
--mus_dir /path/to/musdb dataset
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
|
|
||||||
from argparse import Namespace
|
|
||||||
from itertools import product
|
|
||||||
from glob import glob
|
|
||||||
from os.path import join, exists
|
|
||||||
|
|
||||||
# pylint: disable=import-error
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
# pylint: enable=import-error
|
|
||||||
|
|
||||||
from .separate import entrypoint as separate_entrypoint
|
|
||||||
from ..utils.logging import get_logger
|
|
||||||
|
|
||||||
try:
|
|
||||||
import musdb
|
|
||||||
import museval
|
|
||||||
except ImportError:
|
|
||||||
logger = get_logger()
|
|
||||||
logger.error('Extra dependencies musdb and museval not found')
|
|
||||||
logger.error('Please install musdb and museval first, abort')
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
|
||||||
__author__ = 'Deezer Research'
|
|
||||||
__license__ = 'MIT License'
|
|
||||||
|
|
||||||
_SPLIT = 'test'
|
|
||||||
_MIXTURE = 'mixture.wav'
|
|
||||||
_AUDIO_DIRECTORY = 'audio'
|
|
||||||
_METRICS_DIRECTORY = 'metrics'
|
|
||||||
_INSTRUMENTS = ('vocals', 'drums', 'bass', 'other')
|
|
||||||
_METRICS = ('SDR', 'SAR', 'SIR', 'ISR')
|
|
||||||
|
|
||||||
|
|
||||||
def _separate_evaluation_dataset(arguments, musdb_root_directory, params):
|
|
||||||
""" Performs audio separation on the musdb dataset from
|
|
||||||
the given directory and params.
|
|
||||||
|
|
||||||
:param arguments: Entrypoint arguments.
|
|
||||||
:param musdb_root_directory: Directory to retrieve dataset from.
|
|
||||||
:param params: Spleeter configuration to apply to separation.
|
|
||||||
:returns: Separation output directory path.
|
|
||||||
"""
|
|
||||||
songs = glob(join(musdb_root_directory, _SPLIT, '*/'))
|
|
||||||
mixtures = [join(song, _MIXTURE) for song in songs]
|
|
||||||
audio_output_directory = join(
|
|
||||||
arguments.output_path,
|
|
||||||
_AUDIO_DIRECTORY)
|
|
||||||
separate_entrypoint(
|
|
||||||
Namespace(
|
|
||||||
audio_adapter=arguments.audio_adapter,
|
|
||||||
configuration=arguments.configuration,
|
|
||||||
inputs=mixtures,
|
|
||||||
output_path=join(audio_output_directory, _SPLIT),
|
|
||||||
filename_format='{foldername}/{instrument}.{codec}',
|
|
||||||
codec='wav',
|
|
||||||
duration=600.,
|
|
||||||
offset=0.,
|
|
||||||
bitrate='128k',
|
|
||||||
MWF=arguments.MWF,
|
|
||||||
verbose=arguments.verbose,
|
|
||||||
stft_backend=arguments.stft_backend),
|
|
||||||
params)
|
|
||||||
return audio_output_directory
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_musdb_metrics(
|
|
||||||
arguments,
|
|
||||||
musdb_root_directory,
|
|
||||||
audio_output_directory):
|
|
||||||
""" Generates musdb metrics fro previsouly computed audio estimation.
|
|
||||||
|
|
||||||
:param arguments: Entrypoint arguments.
|
|
||||||
:param audio_output_directory: Directory to get audio estimation from.
|
|
||||||
:returns: Path of generated metrics directory.
|
|
||||||
"""
|
|
||||||
metrics_output_directory = join(
|
|
||||||
arguments.output_path,
|
|
||||||
_METRICS_DIRECTORY)
|
|
||||||
get_logger().info('Starting musdb evaluation (this could be long) ...')
|
|
||||||
dataset = musdb.DB(
|
|
||||||
root=musdb_root_directory,
|
|
||||||
is_wav=True,
|
|
||||||
subsets=[_SPLIT])
|
|
||||||
museval.eval_mus_dir(
|
|
||||||
dataset=dataset,
|
|
||||||
estimates_dir=audio_output_directory,
|
|
||||||
output_dir=metrics_output_directory)
|
|
||||||
get_logger().info('musdb evaluation done')
|
|
||||||
return metrics_output_directory
|
|
||||||
|
|
||||||
|
|
||||||
def _compile_metrics(metrics_output_directory):
|
|
||||||
""" Compiles metrics from given directory and returns
|
|
||||||
results as dict.
|
|
||||||
|
|
||||||
:param metrics_output_directory: Directory to get metrics from.
|
|
||||||
:returns: Compiled metrics as dict.
|
|
||||||
"""
|
|
||||||
songs = glob(join(metrics_output_directory, 'test/*.json'))
|
|
||||||
index = pd.MultiIndex.from_tuples(
|
|
||||||
product(_INSTRUMENTS, _METRICS),
|
|
||||||
names=['instrument', 'metric'])
|
|
||||||
pd.DataFrame([], index=['config1', 'config2'], columns=index)
|
|
||||||
metrics = {
|
|
||||||
instrument: {k: [] for k in _METRICS}
|
|
||||||
for instrument in _INSTRUMENTS}
|
|
||||||
for song in songs:
|
|
||||||
with open(song, 'r') as stream:
|
|
||||||
data = json.load(stream)
|
|
||||||
for target in data['targets']:
|
|
||||||
instrument = target['name']
|
|
||||||
for metric in _METRICS:
|
|
||||||
sdr_med = np.median([
|
|
||||||
frame['metrics'][metric]
|
|
||||||
for frame in target['frames']
|
|
||||||
if not np.isnan(frame['metrics'][metric])])
|
|
||||||
metrics[instrument][metric].append(sdr_med)
|
|
||||||
return metrics
|
|
||||||
|
|
||||||
|
|
||||||
def entrypoint(arguments, params):
|
|
||||||
""" Command entrypoint.
|
|
||||||
|
|
||||||
:param arguments: Command line parsed argument as argparse.Namespace.
|
|
||||||
:param params: Deserialized JSON configuration file provided in CLI args.
|
|
||||||
"""
|
|
||||||
# Parse and check musdb directory.
|
|
||||||
musdb_root_directory = arguments.mus_dir
|
|
||||||
if not exists(musdb_root_directory):
|
|
||||||
raise IOError(f'musdb directory {musdb_root_directory} not found')
|
|
||||||
# Separate musdb sources.
|
|
||||||
audio_output_directory = _separate_evaluation_dataset(
|
|
||||||
arguments,
|
|
||||||
musdb_root_directory,
|
|
||||||
params)
|
|
||||||
# Compute metrics with musdb.
|
|
||||||
metrics_output_directory = _compute_musdb_metrics(
|
|
||||||
arguments,
|
|
||||||
musdb_root_directory,
|
|
||||||
audio_output_directory)
|
|
||||||
# Compute and pretty print median metrics.
|
|
||||||
metrics = _compile_metrics(metrics_output_directory)
|
|
||||||
for instrument, metric in metrics.items():
|
|
||||||
get_logger().info('%s:', instrument)
|
|
||||||
for metric, value in metric.items():
|
|
||||||
get_logger().info('%s: %s', metric, f'{np.median(value):.3f}')
|
|
||||||
return metrics
|
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# coding: utf8
|
|
||||||
|
|
||||||
"""
|
|
||||||
Entrypoint provider for performing source separation.
|
|
||||||
|
|
||||||
USAGE: python -m spleeter separate \
|
|
||||||
-p /path/to/params \
|
|
||||||
-i inputfile1 inputfile2 ... inputfilen
|
|
||||||
-o /path/to/output/dir \
|
|
||||||
-i /path/to/audio1.wav /path/to/audio2.mp3
|
|
||||||
"""
|
|
||||||
|
|
||||||
from ..audio.adapter import get_audio_adapter
|
|
||||||
from ..separator import Separator
|
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
|
||||||
__author__ = 'Deezer Research'
|
|
||||||
__license__ = 'MIT License'
|
|
||||||
|
|
||||||
|
|
||||||
from typer import Option
|
|
||||||
|
|
||||||
AudioAdapter = Option()
|
|
||||||
|
|
||||||
|
|
||||||
def entrypoint(arguments, params):
|
|
||||||
""" Command entrypoint.
|
|
||||||
|
|
||||||
:param arguments: Command line parsed argument as argparse.Namespace.
|
|
||||||
:param params: Deserialized JSON configuration file provided in CLI args.
|
|
||||||
"""
|
|
||||||
# TODO: check with output naming.
|
|
||||||
audio_adapter = get_audio_adapter(arguments.audio_adapter)
|
|
||||||
separator = Separator(
|
|
||||||
arguments.configuration,
|
|
||||||
MWF=arguments.MWF,
|
|
||||||
stft_backend=arguments.stft_backend)
|
|
||||||
for filename in arguments.inputs:
|
|
||||||
separator.separate_to_file(
|
|
||||||
filename,
|
|
||||||
arguments.output_path,
|
|
||||||
audio_adapter=audio_adapter,
|
|
||||||
offset=arguments.offset,
|
|
||||||
duration=arguments.duration,
|
|
||||||
codec=arguments.codec,
|
|
||||||
bitrate=arguments.bitrate,
|
|
||||||
filename_format=arguments.filename_format,
|
|
||||||
synchronous=False
|
|
||||||
)
|
|
||||||
separator.join()
|
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# coding: utf8
|
|
||||||
|
|
||||||
"""
|
|
||||||
Entrypoint provider for performing model training.
|
|
||||||
|
|
||||||
USAGE: python -m spleeter train -p /path/to/params
|
|
||||||
"""
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
# pylint: disable=import-error
|
|
||||||
import tensorflow as tf
|
|
||||||
# pylint: enable=import-error
|
|
||||||
|
|
||||||
from ..audio.adapter import get_audio_adapter
|
|
||||||
from ..dataset import get_training_dataset, get_validation_dataset
|
|
||||||
from ..model import model_fn
|
|
||||||
from ..model.provider import ModelProvider
|
|
||||||
from ..utils.logging import get_logger
|
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
|
||||||
__author__ = 'Deezer Research'
|
|
||||||
__license__ = 'MIT License'
|
|
||||||
|
|
||||||
|
|
||||||
def _create_estimator(params):
|
|
||||||
""" Creates estimator.
|
|
||||||
|
|
||||||
:param params: TF params to build estimator from.
|
|
||||||
:returns: Built estimator.
|
|
||||||
"""
|
|
||||||
session_config = tf.compat.v1.ConfigProto()
|
|
||||||
session_config.gpu_options.per_process_gpu_memory_fraction = 0.45
|
|
||||||
estimator = tf.estimator.Estimator(
|
|
||||||
model_fn=model_fn,
|
|
||||||
model_dir=params['model_dir'],
|
|
||||||
params=params,
|
|
||||||
config=tf.estimator.RunConfig(
|
|
||||||
save_checkpoints_steps=params['save_checkpoints_steps'],
|
|
||||||
tf_random_seed=params['random_seed'],
|
|
||||||
save_summary_steps=params['save_summary_steps'],
|
|
||||||
session_config=session_config,
|
|
||||||
log_step_count_steps=10,
|
|
||||||
keep_checkpoint_max=2))
|
|
||||||
return estimator
|
|
||||||
|
|
||||||
|
|
||||||
def _create_train_spec(params, audio_adapter, audio_path):
|
|
||||||
""" Creates train spec.
|
|
||||||
|
|
||||||
:param params: TF params to build spec from.
|
|
||||||
:returns: Built train spec.
|
|
||||||
"""
|
|
||||||
input_fn = partial(get_training_dataset, params, audio_adapter, audio_path)
|
|
||||||
train_spec = tf.estimator.TrainSpec(
|
|
||||||
input_fn=input_fn,
|
|
||||||
max_steps=params['train_max_steps'])
|
|
||||||
return train_spec
|
|
||||||
|
|
||||||
|
|
||||||
def _create_evaluation_spec(params, audio_adapter, audio_path):
|
|
||||||
""" Setup eval spec evaluating ever n seconds
|
|
||||||
|
|
||||||
:param params: TF params to build spec from.
|
|
||||||
:returns: Built evaluation spec.
|
|
||||||
"""
|
|
||||||
input_fn = partial(
|
|
||||||
get_validation_dataset,
|
|
||||||
params,
|
|
||||||
audio_adapter,
|
|
||||||
audio_path)
|
|
||||||
evaluation_spec = tf.estimator.EvalSpec(
|
|
||||||
input_fn=input_fn,
|
|
||||||
steps=None,
|
|
||||||
throttle_secs=params['throttle_secs'])
|
|
||||||
return evaluation_spec
|
|
||||||
|
|
||||||
|
|
||||||
def entrypoint(arguments, params):
|
|
||||||
""" Command entrypoint.
|
|
||||||
|
|
||||||
:param arguments: Command line parsed argument as argparse.Namespace.
|
|
||||||
:param params: Deserialized JSON configuration file provided in CLI args.
|
|
||||||
"""
|
|
||||||
audio_adapter = get_audio_adapter(arguments.audio_adapter)
|
|
||||||
audio_path = arguments.audio_path
|
|
||||||
estimator = _create_estimator(params)
|
|
||||||
train_spec = _create_train_spec(params, audio_adapter, audio_path)
|
|
||||||
evaluation_spec = _create_evaluation_spec(
|
|
||||||
params,
|
|
||||||
audio_adapter,
|
|
||||||
audio_path)
|
|
||||||
get_logger().info('Start model training')
|
|
||||||
tf.estimator.train_and_evaluate(
|
|
||||||
estimator,
|
|
||||||
train_spec,
|
|
||||||
evaluation_spec)
|
|
||||||
ModelProvider.writeProbe(params['model_dir'])
|
|
||||||
get_logger().info('Model training done')
|
|
||||||
@@ -4,58 +4,41 @@
|
|||||||
""" Centralized logging facilities for Spleeter. """
|
""" Centralized logging facilities for Spleeter. """
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
|
|
||||||
from os import environ
|
from os import environ
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
|
# pylint: disable=import-error
|
||||||
|
from tensorflow.compat.v1 import logging as tflogging
|
||||||
|
# pylint: enable=import-error
|
||||||
|
|
||||||
__email__ = 'spleeter@deezer.com'
|
__email__ = 'spleeter@deezer.com'
|
||||||
__author__ = 'Deezer Research'
|
__author__ = 'Deezer Research'
|
||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
_FORMAT = '%(levelname)s:%(name)s:%(message)s'
|
|
||||||
|
formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger: logging.Logger = logging.getLogger('spleeter')
|
||||||
|
logger.addHandler(handler)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
class _LoggerHolder(object):
|
def configure_logger(verbose: bool) -> None:
|
||||||
""" Logger singleton instance holder. """
|
|
||||||
|
|
||||||
INSTANCE = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_tensorflow_logger():
|
|
||||||
"""
|
"""
|
||||||
|
Configure application logger.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
verbose (bool):
|
||||||
|
`True` to use verbose logger, `False` otherwise.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=import-error
|
if verbose:
|
||||||
from tensorflow.compat.v1 import logging
|
|
||||||
# pylint: enable=import-error
|
|
||||||
return logging
|
|
||||||
|
|
||||||
|
|
||||||
def get_logger():
|
|
||||||
""" Returns library scoped logger.
|
|
||||||
|
|
||||||
:returns: Library logger.
|
|
||||||
"""
|
|
||||||
if _LoggerHolder.INSTANCE is None:
|
|
||||||
formatter = logging.Formatter(_FORMAT)
|
|
||||||
handler = logging.StreamHandler()
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
logger = logging.getLogger('spleeter')
|
|
||||||
logger.addHandler(handler)
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
_LoggerHolder.INSTANCE = logger
|
|
||||||
return _LoggerHolder.INSTANCE
|
|
||||||
|
|
||||||
|
|
||||||
def enable_tensorflow_logging():
|
|
||||||
""" Enable tensorflow logging. """
|
|
||||||
environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
|
||||||
tf_logger = get_tensorflow_logger()
|
tflogging.set_verbosity(tflogging.INFO)
|
||||||
tf_logger.set_verbosity(tf_logger.INFO)
|
|
||||||
logger = get_logger()
|
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
else:
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
def enable_logging():
|
|
||||||
""" Configure default logging. """
|
|
||||||
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
tf_logger = get_tensorflow_logger()
|
tflogging.set_verbosity(tflogging.ERROR)
|
||||||
tf_logger.set_verbosity(tf_logger.ERROR)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user