Adding a new argument to support chunked inference

This commit is contained in:
akhlif
2020-02-19 10:55:48 +01:00
parent d5d372dd0f
commit f93dfbc235
2 changed files with 36 additions and 7 deletions

View File

@@ -68,6 +68,18 @@ OPT_DURATION = {
'the input file)') 'the input file)')
} }
# -w opt specification (separate)
OPT_CHUNKED = {
'dest': 'chunk_duration',
'type': float,
'default': -1,
'help': 'Maximum duration of the segments that are fed to'
' the network. Use this parameter to limit '
'memory usage. Use -1 to process the whole signal'
' in one pass.'
}
# -c opt specification (separate). # -c opt specification (separate).
OPT_CODEC = { OPT_CODEC = {
'dest': 'codec', 'dest': 'codec',
@@ -176,6 +188,7 @@ def _create_separate_parser(parser_factory):
parser.add_argument('-c', '--codec', **OPT_CODEC) parser.add_argument('-c', '--codec', **OPT_CODEC)
parser.add_argument('-b', '--birate', **OPT_BITRATE) parser.add_argument('-b', '--birate', **OPT_BITRATE)
parser.add_argument('-m', '--mwf', **OPT_MWF) parser.add_argument('-m', '--mwf', **OPT_MWF)
parser.add_argument('-w', '--chunk', **OPT_CHUNKED)
return parser return parser

View File

@@ -13,17 +13,14 @@
""" """
import os import os
import json
from functools import partial
from multiprocessing import Pool from multiprocessing import Pool
from pathlib import Path
from os.path import basename, join, splitext from os.path import basename, join, splitext
import numpy as np
from . import SpleeterError from . import SpleeterError
from .audio.adapter import get_default_audio_adapter from .audio.adapter import get_default_audio_adapter
from .audio.convertor import to_stereo from .audio.convertor import to_stereo
from .model import model_fn
from .utils.configuration import load_configuration from .utils.configuration import load_configuration
from .utils.estimator import create_estimator, to_predictor from .utils.estimator import create_estimator, to_predictor
@@ -90,9 +87,21 @@ class Separator(object):
prediction.pop('audio_id') prediction.pop('audio_id')
return prediction return prediction
def separate_chunked(self, waveform, sample_rate, chunk_duration=-1):
chunk_size = waveform.shape[0] if chunk_duration == -1 else chunk_duration*sample_rate
n_chunks = int(waveform.shape[0]/chunk_size)
out = {}
for i in range(n_chunks):
sources = self.separate(waveform)
for inst, data in sources.items():
out.setdefault(inst, []).append(data)
for inst, data in out.items():
out[inst] = np.concatenate(data, axis=0)
return out
def separate_to_file( def separate_to_file(
self, audio_descriptor, destination, self, audio_descriptor, destination,
audio_adapter=get_default_audio_adapter(), audio_adapter=get_default_audio_adapter(), chunk_duration=-1,
offset=0, duration=600., codec='wav', bitrate='128k', offset=0, duration=600., codec='wav', bitrate='128k',
filename_format='{filename}/{instrument}.{codec}', filename_format='{filename}/{instrument}.{codec}',
synchronous=True): synchronous=True):
@@ -108,6 +117,8 @@ class Separator(object):
descriptor would be a file path. descriptor would be a file path.
:param destination: Target directory to write output to. :param destination: Target directory to write output to.
:param audio_adapter: (Optional) Audio adapter to use for I/O. :param audio_adapter: (Optional) Audio adapter to use for I/O.
:param chunk_duration: (Optional) Maximum signal duration that is processed
in one pass. Default: all signal.
:param offset: (Optional) Offset of loaded song. :param offset: (Optional) Offset of loaded song.
:param duration: (Optional) Duration of loaded song. :param duration: (Optional) Duration of loaded song.
:param codec: (Optional) Export codec. :param codec: (Optional) Export codec.
@@ -115,12 +126,17 @@ class Separator(object):
:param filename_format: (Optional) Filename format. :param filename_format: (Optional) Filename format.
:param synchronous: (Optional) True is should by synchronous. :param synchronous: (Optional) True is should by synchronous.
""" """
waveform, _ = audio_adapter.load( waveform, sample_rate = audio_adapter.load(
audio_descriptor, audio_descriptor,
offset=offset, offset=offset,
duration=duration, duration=duration,
sample_rate=self._sample_rate) sample_rate=self._sample_rate)
sources = self.separate(waveform) sources = self.separate_chunked(waveform, sample_rate, chunk_duration=chunk_duration)
self.save_to_file(sources, audio_descriptor, destination, filename_format, codec,
audio_adapter, bitrate, synchronous)
def save_to_file(self, sources, audio_descriptor, destination, filename_format, codec,
audio_adapter, bitrate, synchronous):
filename = splitext(basename(audio_descriptor))[0] filename = splitext(basename(audio_descriptor))[0]
generated = [] generated = []
for instrument, data in sources.items(): for instrument, data in sources.items():