mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-06 04:32:43 +00:00
166 lines
5.2 KiB
Python
166 lines
5.2 KiB
Python
#!/usr/bin/env python
|
|
# coding: utf8
|
|
|
|
"""
|
|
Entrypoint provider for performing model evaluation.
|
|
|
|
Evaluation is performed against musDB dataset.
|
|
|
|
USAGE: python -m spleeter evaluate \
|
|
-p /path/to/params \
|
|
-o /path/to/output/dir \
|
|
[-m] \
|
|
--mus_dir /path/to/musdb dataset
|
|
"""
|
|
|
|
import sys
|
|
import json
|
|
|
|
from argparse import Namespace
|
|
from itertools import product
|
|
from glob import glob
|
|
from os.path import join, exists
|
|
|
|
# pylint: disable=import-error
|
|
import numpy as np
|
|
import pandas as pd
|
|
# pylint: enable=import-error
|
|
|
|
from .separate import entrypoint as separate_entrypoint
|
|
from ..utils.logging import get_logger
|
|
|
|
try:
|
|
import musdb
|
|
import museval
|
|
except ImportError:
|
|
logger = get_logger()
|
|
logger.error('Extra dependencies musdb and museval not found')
|
|
logger.error('Please install musdb and museval first, abort')
|
|
sys.exit(1)
|
|
|
|
__email__ = 'research@deezer.com'
|
|
__author__ = 'Deezer Research'
|
|
__license__ = 'MIT License'
|
|
|
|
_SPLIT = 'test'
|
|
_MIXTURE = 'mixture.wav'
|
|
_AUDIO_DIRECTORY = 'audio'
|
|
_METRICS_DIRECTORY = 'metrics'
|
|
_INSTRUMENTS = ('vocals', 'drums', 'bass', 'other')
|
|
_METRICS = ('SDR', 'SAR', 'SIR', 'ISR')
|
|
|
|
|
|
def _separate_evaluation_dataset(arguments, musdb_root_directory, params):
|
|
""" Performs audio separation on the musdb dataset from
|
|
the given directory and params.
|
|
|
|
:param arguments: Entrypoint arguments.
|
|
:param musdb_root_directory: Directory to retrieve dataset from.
|
|
:param params: Spleeter configuration to apply to separation.
|
|
:returns: Separation output directory path.
|
|
"""
|
|
songs = glob(join(musdb_root_directory, _SPLIT, '*/'))
|
|
mixtures = [join(song, _MIXTURE) for song in songs]
|
|
audio_output_directory = join(
|
|
arguments.output_path,
|
|
_AUDIO_DIRECTORY)
|
|
separate_entrypoint(
|
|
Namespace(
|
|
audio_adapter=arguments.audio_adapter,
|
|
configuration=arguments.configuration,
|
|
inputs=mixtures,
|
|
output_path=join(audio_output_directory, _SPLIT),
|
|
filename_format='{foldername}/{instrument}.{codec}',
|
|
codec='wav',
|
|
duration=600.,
|
|
offset=0.,
|
|
bitrate='128k',
|
|
MWF=arguments.MWF,
|
|
verbose=arguments.verbose,
|
|
stft_backend="auto"),
|
|
params)
|
|
return audio_output_directory
|
|
|
|
|
|
def _compute_musdb_metrics(
|
|
arguments,
|
|
musdb_root_directory,
|
|
audio_output_directory):
|
|
""" Generates musdb metrics fro previsouly computed audio estimation.
|
|
|
|
:param arguments: Entrypoint arguments.
|
|
:param audio_output_directory: Directory to get audio estimation from.
|
|
:returns: Path of generated metrics directory.
|
|
"""
|
|
metrics_output_directory = join(
|
|
arguments.output_path,
|
|
_METRICS_DIRECTORY)
|
|
get_logger().info('Starting musdb evaluation (this could be long) ...')
|
|
dataset = musdb.DB(
|
|
root=musdb_root_directory,
|
|
is_wav=True,
|
|
subsets=[_SPLIT])
|
|
museval.eval_mus_dir(
|
|
dataset=dataset,
|
|
estimates_dir=audio_output_directory,
|
|
output_dir=metrics_output_directory)
|
|
get_logger().info('musdb evaluation done')
|
|
return metrics_output_directory
|
|
|
|
|
|
def _compile_metrics(metrics_output_directory):
|
|
""" Compiles metrics from given directory and returns
|
|
results as dict.
|
|
|
|
:param metrics_output_directory: Directory to get metrics from.
|
|
:returns: Compiled metrics as dict.
|
|
"""
|
|
songs = glob(join(metrics_output_directory, 'test/*.json'))
|
|
index = pd.MultiIndex.from_tuples(
|
|
product(_INSTRUMENTS, _METRICS),
|
|
names=['instrument', 'metric'])
|
|
pd.DataFrame([], index=['config1', 'config2'], columns=index)
|
|
metrics = {
|
|
instrument: {k: [] for k in _METRICS}
|
|
for instrument in _INSTRUMENTS}
|
|
for song in songs:
|
|
with open(song, 'r') as stream:
|
|
data = json.load(stream)
|
|
for target in data['targets']:
|
|
instrument = target['name']
|
|
for metric in _METRICS:
|
|
sdr_med = np.median([
|
|
frame['metrics'][metric]
|
|
for frame in target['frames']
|
|
if not np.isnan(frame['metrics'][metric])])
|
|
metrics[instrument][metric].append(sdr_med)
|
|
return metrics
|
|
|
|
|
|
def entrypoint(arguments, params):
|
|
""" Command entrypoint.
|
|
|
|
:param arguments: Command line parsed argument as argparse.Namespace.
|
|
:param params: Deserialized JSON configuration file provided in CLI args.
|
|
"""
|
|
# Parse and check musdb directory.
|
|
musdb_root_directory = arguments.mus_dir
|
|
if not exists(musdb_root_directory):
|
|
raise IOError(f'musdb directory {musdb_root_directory} not found')
|
|
# Separate musdb sources.
|
|
audio_output_directory = _separate_evaluation_dataset(
|
|
arguments,
|
|
musdb_root_directory,
|
|
params)
|
|
# Compute metrics with musdb.
|
|
metrics_output_directory = _compute_musdb_metrics(
|
|
arguments,
|
|
musdb_root_directory,
|
|
audio_output_directory)
|
|
# Compute and pretty print median metrics.
|
|
metrics = _compile_metrics(metrics_output_directory)
|
|
for instrument, metric in metrics.items():
|
|
get_logger().info('%s:', instrument)
|
|
for metric, value in metric.items():
|
|
get_logger().info('%s: %s', metric, f'{np.median(value):.3f}')
|