mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-30 12:22:58 +00:00
Adding a new argument to support chunked inference
This commit is contained in:
@@ -68,6 +68,18 @@ OPT_DURATION = {
|
||||
'the input file)')
|
||||
}
|
||||
|
||||
# -w opt specification (separate)
|
||||
OPT_CHUNKED = {
|
||||
'dest': 'chunk_duration',
|
||||
'type': float,
|
||||
'default': -1,
|
||||
'help': 'Maximum duration of the segments that are fed to'
|
||||
' the network. Use this parameter to limit '
|
||||
'memory usage. Use -1 to process the whole signal'
|
||||
' in one pass.'
|
||||
}
|
||||
|
||||
|
||||
# -c opt specification (separate).
|
||||
OPT_CODEC = {
|
||||
'dest': 'codec',
|
||||
@@ -176,6 +188,7 @@ def _create_separate_parser(parser_factory):
|
||||
parser.add_argument('-c', '--codec', **OPT_CODEC)
|
||||
parser.add_argument('-b', '--birate', **OPT_BITRATE)
|
||||
parser.add_argument('-m', '--mwf', **OPT_MWF)
|
||||
parser.add_argument('-w', '--chunk', **OPT_CHUNKED)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
@@ -13,17 +13,14 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
|
||||
from functools import partial
|
||||
from multiprocessing import Pool
|
||||
from pathlib import Path
|
||||
from os.path import basename, join, splitext
|
||||
import numpy as np
|
||||
|
||||
from . import SpleeterError
|
||||
from .audio.adapter import get_default_audio_adapter
|
||||
from .audio.convertor import to_stereo
|
||||
from .model import model_fn
|
||||
from .utils.configuration import load_configuration
|
||||
from .utils.estimator import create_estimator, to_predictor
|
||||
|
||||
@@ -90,9 +87,21 @@ class Separator(object):
|
||||
prediction.pop('audio_id')
|
||||
return prediction
|
||||
|
||||
def separate_chunked(self, waveform, sample_rate, chunk_duration=-1):
|
||||
chunk_size = waveform.shape[0] if chunk_duration == -1 else chunk_duration*sample_rate
|
||||
n_chunks = int(waveform.shape[0]/chunk_size)
|
||||
out = {}
|
||||
for i in range(n_chunks):
|
||||
sources = self.separate(waveform)
|
||||
for inst, data in sources.items():
|
||||
out.setdefault(inst, []).append(data)
|
||||
for inst, data in out.items():
|
||||
out[inst] = np.concatenate(data, axis=0)
|
||||
return out
|
||||
|
||||
def separate_to_file(
|
||||
self, audio_descriptor, destination,
|
||||
audio_adapter=get_default_audio_adapter(),
|
||||
audio_adapter=get_default_audio_adapter(), chunk_duration=-1,
|
||||
offset=0, duration=600., codec='wav', bitrate='128k',
|
||||
filename_format='{filename}/{instrument}.{codec}',
|
||||
synchronous=True):
|
||||
@@ -108,6 +117,8 @@ class Separator(object):
|
||||
descriptor would be a file path.
|
||||
:param destination: Target directory to write output to.
|
||||
:param audio_adapter: (Optional) Audio adapter to use for I/O.
|
||||
:param chunk_duration: (Optional) Maximum signal duration that is processed
|
||||
in one pass. Default: all signal.
|
||||
:param offset: (Optional) Offset of loaded song.
|
||||
:param duration: (Optional) Duration of loaded song.
|
||||
:param codec: (Optional) Export codec.
|
||||
@@ -115,12 +126,17 @@ class Separator(object):
|
||||
:param filename_format: (Optional) Filename format.
|
||||
:param synchronous: (Optional) True is should by synchronous.
|
||||
"""
|
||||
waveform, _ = audio_adapter.load(
|
||||
waveform, sample_rate = audio_adapter.load(
|
||||
audio_descriptor,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
sample_rate=self._sample_rate)
|
||||
sources = self.separate(waveform)
|
||||
sources = self.separate_chunked(waveform, sample_rate, chunk_duration=chunk_duration)
|
||||
self.save_to_file(sources, audio_descriptor, destination, filename_format, codec,
|
||||
audio_adapter, bitrate, synchronous)
|
||||
|
||||
def save_to_file(self, sources, audio_descriptor, destination, filename_format, codec,
|
||||
audio_adapter, bitrate, synchronous):
|
||||
filename = splitext(basename(audio_descriptor))[0]
|
||||
generated = []
|
||||
for instrument, data in sources.items():
|
||||
|
||||
Reference in New Issue
Block a user