mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
Initial commit from private spleeter
This commit is contained in:
154
spleeter/commands/evaluate.py
Normal file
154
spleeter/commands/evaluate.py
Normal file
@@ -0,0 +1,154 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf8
|
||||
|
||||
"""
|
||||
Entrypoint provider for performing model evaluation.
|
||||
|
||||
Evaluation is performed against musDB dataset.
|
||||
|
||||
USAGE: python -m spleeter evaluate \
|
||||
-p /path/to/params \
|
||||
-o /path/to/output/dir \
|
||||
[-m] \
|
||||
--mus_dir /path/to/musdb dataset
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from argparse import Namespace
|
||||
from itertools import product
|
||||
from glob import glob
|
||||
from os.path import join, exists
|
||||
|
||||
# pylint: disable=import-error
|
||||
import musdb
|
||||
import museval
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
# pylint: enable=import-error
|
||||
|
||||
from .separate import entrypoint as separate_entrypoint
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
__email__ = 'research@deezer.com'
|
||||
__author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
_SPLIT = 'test'
|
||||
_MIXTURE = 'mixture.wav'
|
||||
_NAMING = 'directory'
|
||||
_AUDIO_DIRECTORY = 'audio'
|
||||
_METRICS_DIRECTORY = 'metrics'
|
||||
_INSTRUMENTS = ('vocals', 'drums', 'bass', 'other')
|
||||
_METRICS = ('SDR', 'SAR', 'SIR', 'ISR')
|
||||
|
||||
|
||||
def _separate_evaluation_dataset(arguments, musdb_root_directory, params):
|
||||
""" Performs audio separation on the musdb dataset from
|
||||
the given directory and params.
|
||||
|
||||
:param arguments: Entrypoint arguments.
|
||||
:param musdb_root_directory: Directory to retrieve dataset from.
|
||||
:param params: Spleeter configuration to apply to separation.
|
||||
:returns: Separation output directory path.
|
||||
"""
|
||||
songs = glob(join(musdb_root_directory, _SPLIT, '*/'))
|
||||
mixtures = [join(song, _MIXTURE) for song in songs]
|
||||
audio_output_directory = join(
|
||||
arguments.output_path,
|
||||
_AUDIO_DIRECTORY)
|
||||
separate_entrypoint(
|
||||
Namespace(
|
||||
audio_adapter=arguments.audio_adapter,
|
||||
audio_filenames=mixtures,
|
||||
audio_codec='wav',
|
||||
output_path=join(audio_output_directory, _SPLIT),
|
||||
output_naming=_NAMING,
|
||||
max_duration=600.,
|
||||
MWF=arguments.MWF,
|
||||
verbose=arguments.verbose),
|
||||
params)
|
||||
return audio_output_directory
|
||||
|
||||
|
||||
def _compute_musdb_metrics(
|
||||
arguments,
|
||||
musdb_root_directory,
|
||||
audio_output_directory):
|
||||
""" Generates musdb metrics fro previsouly computed audio estimation.
|
||||
|
||||
:param arguments: Entrypoint arguments.
|
||||
:param audio_output_directory: Directory to get audio estimation from.
|
||||
:returns: Path of generated metrics directory.
|
||||
"""
|
||||
metrics_output_directory = join(
|
||||
arguments.output_path,
|
||||
_METRICS_DIRECTORY)
|
||||
get_logger().info('Starting musdb evaluation (this could be long) ...')
|
||||
dataset = musdb.DB(
|
||||
root=musdb_root_directory,
|
||||
is_wav=True,
|
||||
subsets=[_SPLIT])
|
||||
museval.eval_mus_dir(
|
||||
dataset=dataset,
|
||||
estimates_dir=audio_output_directory,
|
||||
output_dir=metrics_output_directory)
|
||||
get_logger().info('musdb evaluation done')
|
||||
return metrics_output_directory
|
||||
|
||||
|
||||
def _compile_metrics(metrics_output_directory):
|
||||
""" Compiles metrics from given directory and returns
|
||||
results as dict.
|
||||
|
||||
:param metrics_output_directory: Directory to get metrics from.
|
||||
:returns: Compiled metrics as dict.
|
||||
"""
|
||||
songs = glob(join(metrics_output_directory, 'test/*.json'))
|
||||
index = pd.MultiIndex.from_tuples(
|
||||
product(_INSTRUMENTS, _METRICS),
|
||||
names=['instrument', 'metric'])
|
||||
pd.DataFrame([], index=['config1', 'config2'], columns=index)
|
||||
metrics = {
|
||||
instrument: {k: [] for k in _METRICS}
|
||||
for instrument in _INSTRUMENTS}
|
||||
for song in songs:
|
||||
with open(song, 'r') as stream:
|
||||
data = json.load(stream)
|
||||
for target in data['targets']:
|
||||
instrument = target['name']
|
||||
for metric in _METRICS:
|
||||
sdr_med = np.median([
|
||||
frame['metrics'][metric]
|
||||
for frame in target['frames']
|
||||
if not np.isnan(frame['metrics'][metric])])
|
||||
metrics[instrument][metric].append(sdr_med)
|
||||
return metrics
|
||||
|
||||
|
||||
def entrypoint(arguments, params):
|
||||
""" Command entrypoint.
|
||||
|
||||
:param arguments: Command line parsed argument as argparse.Namespace.
|
||||
:param params: Deserialized JSON configuration file provided in CLI args.
|
||||
"""
|
||||
# Parse and check musdb directory.
|
||||
musdb_root_directory = arguments.mus_dir
|
||||
if not exists(musdb_root_directory):
|
||||
raise IOError(f'musdb directory {musdb_root_directory} not found')
|
||||
# Separate musdb sources.
|
||||
audio_output_directory = _separate_evaluation_dataset(
|
||||
arguments,
|
||||
musdb_root_directory,
|
||||
params)
|
||||
# Compute metrics with musdb.
|
||||
metrics_output_directory = _compute_musdb_metrics(
|
||||
arguments,
|
||||
musdb_root_directory,
|
||||
audio_output_directory)
|
||||
# Compute and pretty print median metrics.
|
||||
metrics = _compile_metrics(metrics_output_directory)
|
||||
for instrument, metric in metrics.items():
|
||||
get_logger().info('%s:', instrument)
|
||||
for metric, value in metric.items():
|
||||
get_logger().info('%s: %s', metric, f'{np.median(value):.3f}')
|
||||
Reference in New Issue
Block a user