mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-30 20:24:31 +00:00
🎨 finalize refactoring
This commit is contained in:
@@ -19,13 +19,15 @@ import os
|
||||
|
||||
from multiprocessing import Pool
|
||||
from os.path import basename, join, splitext, dirname
|
||||
from typing import Generator, Optional
|
||||
from spleeter.model.provider import ModelProvider
|
||||
from typing import Dict, Generator, Optional
|
||||
|
||||
from . import SpleeterError
|
||||
from .audio import STFTBackend
|
||||
from .audio.adapter import get_default_audio_adapter
|
||||
from .audio import Codec, STFTBackend
|
||||
from .audio.adapter import AudioAdapter
|
||||
from .audio.convertor import to_stereo
|
||||
from .model import EstimatorSpecBuilder, InputProviderFactory
|
||||
from .model import model_fn
|
||||
from .utils.configuration import load_configuration
|
||||
|
||||
# pyright: reportMissingImports=false
|
||||
@@ -65,18 +67,6 @@ class DataGenerator(object):
|
||||
buffer = self._current_data
|
||||
|
||||
|
||||
def get_backend(backend: str) -> str:
|
||||
"""
|
||||
"""
|
||||
if backend not in SUPPORTED_BACKEND:
|
||||
raise ValueError(f'Unsupported backend {backend}')
|
||||
if backend == 'auto':
|
||||
if len(tf.config.list_physical_devices('GPU')):
|
||||
return 'tensorflow'
|
||||
return 'librosa'
|
||||
return backend
|
||||
|
||||
|
||||
def create_estimator(params, MWF):
|
||||
"""
|
||||
Initialize tensorflow estimator that will perform separation
|
||||
@@ -137,18 +127,21 @@ class Separator(object):
|
||||
else:
|
||||
self._pool = None
|
||||
self._tasks = []
|
||||
self._params['stft_backend'] = get_backend(stft_backend)
|
||||
self._params['stft_backend'] = stft_backend
|
||||
self._data_generator = DataGenerator()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self._session:
|
||||
self._session.close()
|
||||
|
||||
def _get_prediction_generator(self):
|
||||
""" Lazy loading access method for internal prediction generator
|
||||
returned by the predict method of a tensorflow estimator.
|
||||
def _get_prediction_generator(self) -> Generator:
|
||||
"""
|
||||
Lazy loading access method for internal prediction generator
|
||||
returned by the predict method of a tensorflow estimator.
|
||||
|
||||
:returns: generator of prediction.
|
||||
Returns:
|
||||
Generator:
|
||||
Generator of prediction.
|
||||
"""
|
||||
if self._prediction_generator is None:
|
||||
estimator = create_estimator(self._params, self._MWF)
|
||||
@@ -181,17 +174,30 @@ class Separator(object):
|
||||
task.get()
|
||||
task.wait(timeout=timeout)
|
||||
|
||||
def _stft(self, data, inverse: bool = False, length=None):
|
||||
""" Single entrypoint for both stft and istft. This computes stft and
|
||||
istft with librosa on stereo data. The two channels are processed
|
||||
separately and are concatenated together in the result. The expected
|
||||
input formats are: (n_samples, 2) for stft and (T, F, 2) for istft.
|
||||
def _stft(
|
||||
self,
|
||||
data: np.ndarray,
|
||||
inverse: bool = False,
|
||||
length: Optional[int] = None) -> np.ndarray:
|
||||
"""
|
||||
Single entrypoint for both stft and istft. This computes stft and
|
||||
istft with librosa on stereo data. The two channels are processed
|
||||
separately and are concatenated together in the result. The
|
||||
expected input formats are: (n_samples, 2) for stft and (T, F, 2)
|
||||
for istft.
|
||||
|
||||
:param data: np.array with either the waveform or the complex
|
||||
spectrogram depending on the parameter inverse
|
||||
:param inverse: should a stft or an istft be computed.
|
||||
:returns: Stereo data as numpy array for the transform.
|
||||
The channels are stored in the last dimension.
|
||||
Parameters:
|
||||
data (numpy.array):
|
||||
Array with either the waveform or the complex spectrogram
|
||||
depending on the parameter inverse
|
||||
inverse (bool):
|
||||
(Optional) Should a stft or an istft be computed.
|
||||
length (Optional[int]):
|
||||
|
||||
Returns:
|
||||
numpy.ndarray:
|
||||
Stereo data as numpy array for the transform. The channels
|
||||
are stored in the last dimension.
|
||||
"""
|
||||
assert not (inverse and length is None)
|
||||
data = np.asfortranarray(data)
|
||||
@@ -238,19 +244,24 @@ class Separator(object):
|
||||
def _get_session(self):
|
||||
if self._session is None:
|
||||
saver = tf.compat.v1.train.Saver()
|
||||
latest_checkpoint = tf.train.latest_checkpoint(
|
||||
get_default_model_dir(self._params['model_dir']))
|
||||
provider = ModelProvider.default()
|
||||
model_directory: str = provider.get(self._params['model_dir'])
|
||||
latest_checkpoint = tf.train.latest_checkpoint(model_directory)
|
||||
self._session = tf.compat.v1.Session()
|
||||
saver.restore(self._session, latest_checkpoint)
|
||||
return self._session
|
||||
|
||||
def _separate_librosa(self, waveform: np.ndarray, audio_id):
|
||||
def _separate_librosa(
|
||||
self,
|
||||
waveform: np.ndarray,
|
||||
audio_descriptor: str) -> Dict:
|
||||
"""
|
||||
Performs separation with librosa backend for STFT.
|
||||
|
||||
Parameters:
|
||||
waveform (numpy.ndarray):
|
||||
Waveform to be separated (as a numpy array)
|
||||
audio_descriptor (str):
|
||||
"""
|
||||
with self._tf_graph.as_default():
|
||||
out = {}
|
||||
@@ -269,7 +280,7 @@ class Separator(object):
|
||||
feed_dict=self._get_input_provider().get_feed_dict(
|
||||
features,
|
||||
stft,
|
||||
audio_id))
|
||||
audio_descriptor))
|
||||
for inst in self._get_builder().instruments:
|
||||
out[inst] = self._stft(
|
||||
outputs[inst],
|
||||
@@ -277,7 +288,10 @@ class Separator(object):
|
||||
length=waveform.shape[0])
|
||||
return out
|
||||
|
||||
def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
|
||||
def _separate_tensorflow(
|
||||
self,
|
||||
waveform: np.ndarray,
|
||||
audio_descriptor: str) -> Dict:
|
||||
"""
|
||||
Performs source separation over the given waveform with tensorflow
|
||||
backend.
|
||||
@@ -285,6 +299,7 @@ class Separator(object):
|
||||
Parameters:
|
||||
waveform (numpy.ndarray):
|
||||
Waveform to be separated (as a numpy array)
|
||||
audio_descriptor (str):
|
||||
|
||||
Returns:
|
||||
Separated waveforms.
|
||||
@@ -314,44 +329,61 @@ class Separator(object):
|
||||
audio_descriptor (str):
|
||||
(Optional) string describing the waveform (e.g. filename).
|
||||
"""
|
||||
if self._params['stft_backend'] == 'tensorflow':
|
||||
backend: str = self._params['stft_backend']
|
||||
if backend == STFTBackend.TENSORFLOW:
|
||||
return self._separate_tensorflow(waveform, audio_descriptor)
|
||||
else:
|
||||
elif backend == STFTBackend.LIBROSA:
|
||||
return self._separate_librosa(waveform, audio_descriptor)
|
||||
raise ValueError(f'Unsupported STFT backend {backend}')
|
||||
|
||||
def separate_to_file(
|
||||
self,
|
||||
audio_descriptor,
|
||||
destination,
|
||||
audio_adapter=get_default_audio_adapter(),
|
||||
offset=0,
|
||||
duration=600.,
|
||||
codec='wav',
|
||||
bitrate='128k',
|
||||
filename_format='{filename}/{instrument}.{codec}',
|
||||
synchronous=True):
|
||||
""" Performs source separation and export result to file using
|
||||
given audio adapter.
|
||||
|
||||
Filename format should be a Python formattable string that could use
|
||||
following parameters : {instrument}, {filename}, {foldername} and
|
||||
{codec}.
|
||||
|
||||
:param audio_descriptor: Describe song to separate, used by audio
|
||||
adapter to retrieve and load audio data,
|
||||
in case of file based audio adapter, such
|
||||
descriptor would be a file path.
|
||||
:param destination: Target directory to write output to.
|
||||
:param audio_adapter: (Optional) Audio adapter to use for I/O.
|
||||
:param offset: (Optional) Offset of loaded song.
|
||||
:param duration: (Optional) Duration of loaded song
|
||||
(default: 600s).
|
||||
:param codec: (Optional) Export codec.
|
||||
:param bitrate: (Optional) Export bitrate.
|
||||
:param filename_format: (Optional) Filename format.
|
||||
:param synchronous: (Optional) True is should by synchronous.
|
||||
audio_descriptor: str,
|
||||
destination: str,
|
||||
audio_adapter: Optional[AudioAdapter] = None,
|
||||
offset: int = 0,
|
||||
duration: float = 600.,
|
||||
codec: Codec = Codec.WAV,
|
||||
bitrate: str = '128k',
|
||||
filename_format: str = '{filename}/{instrument}.{codec}',
|
||||
synchronous: bool = True) -> None:
|
||||
"""
|
||||
waveform, sample_rate = audio_adapter.load(
|
||||
Performs source separation and export result to file using
|
||||
given audio adapter.
|
||||
|
||||
Filename format should be a Python formattable string that could
|
||||
use following parameters :
|
||||
|
||||
- {instrument}
|
||||
- {filename}
|
||||
- {foldername}
|
||||
- {codec}.
|
||||
|
||||
Parameters:
|
||||
audio_descriptor (str):
|
||||
Describe song to separate, used by audio adapter to
|
||||
retrieve and load audio data, in case of file based
|
||||
audio adapter, such descriptor would be a file path.
|
||||
destination (str):
|
||||
Target directory to write output to.
|
||||
audio_adapter (Optional[AudioAdapter]):
|
||||
(Optional) Audio adapter to use for I/O.
|
||||
offset (int):
|
||||
(Optional) Offset of loaded song.
|
||||
duration (float):
|
||||
(Optional) Duration of loaded song (default: 600s).
|
||||
codec (Codec):
|
||||
(Optional) Export codec.
|
||||
bitrate (str):
|
||||
(Optional) Export bitrate.
|
||||
filename_format (str):
|
||||
(Optional) Filename format.
|
||||
synchronous (bool):
|
||||
(Optional) True is should by synchronous.
|
||||
"""
|
||||
if audio_adapter is None:
|
||||
audio_adapter = AudioAdapter.default()
|
||||
waveform, _ = audio_adapter.load(
|
||||
audio_descriptor,
|
||||
offset=offset,
|
||||
duration=duration,
|
||||
@@ -369,33 +401,42 @@ class Separator(object):
|
||||
|
||||
def save_to_file(
|
||||
self,
|
||||
sources,
|
||||
audio_descriptor,
|
||||
destination,
|
||||
filename_format='{filename}/{instrument}.{codec}',
|
||||
codec='wav',
|
||||
audio_adapter=get_default_audio_adapter(),
|
||||
bitrate='128k',
|
||||
synchronous=True):
|
||||
""" Export dictionary of sources to files.
|
||||
|
||||
:param sources: Dictionary of sources to be exported. The
|
||||
keys are the name of the instruments, and
|
||||
the values are Nx2 numpy arrays containing
|
||||
the corresponding intrument waveform, as
|
||||
returned by the separate method
|
||||
:param audio_descriptor: Describe song to separate, used by audio
|
||||
adapter to retrieve and load audio data,
|
||||
in case of file based audio adapter, such
|
||||
descriptor would be a file path.
|
||||
:param destination: Target directory to write output to.
|
||||
:param filename_format: (Optional) Filename format.
|
||||
:param codec: (Optional) Export codec.
|
||||
:param audio_adapter: (Optional) Audio adapter to use for I/O.
|
||||
:param bitrate: (Optional) Export bitrate.
|
||||
:param synchronous: (Optional) True is should by synchronous.
|
||||
|
||||
sources: Dict,
|
||||
audio_descriptor: str,
|
||||
destination: str,
|
||||
filename_format: str = '{filename}/{instrument}.{codec}',
|
||||
codec: Codec = Codec.WAV,
|
||||
audio_adapter: Optional[AudioAdapter] = None,
|
||||
bitrate: str = '128k',
|
||||
synchronous: bool = True) -> None:
|
||||
"""
|
||||
Export dictionary of sources to files.
|
||||
|
||||
Parameters:
|
||||
sources (Dict):
|
||||
Dictionary of sources to be exported. The keys are the name
|
||||
of the instruments, and the values are `N x 2` numpy arrays
|
||||
containing the corresponding intrument waveform, as
|
||||
returned by the separate method
|
||||
audio_descriptor (str):
|
||||
Describe song to separate, used by audio adapter to
|
||||
retrieve and load audio data, in case of file based audio
|
||||
adapter, such descriptor would be a file path.
|
||||
destination (str):
|
||||
Target directory to write output to.
|
||||
filename_format (str):
|
||||
(Optional) Filename format.
|
||||
codec (Codec):
|
||||
(Optional) Export codec.
|
||||
audio_adapter (Optional[AudioAdapter]):
|
||||
(Optional) Audio adapter to use for I/O.
|
||||
bitrate (str):
|
||||
(Optional) Export bitrate.
|
||||
synchronous (bool):
|
||||
(Optional) True is should by synchronous.
|
||||
"""
|
||||
if audio_adapter is None:
|
||||
audio_adapter = AudioAdapter.default()
|
||||
foldername = basename(dirname(audio_descriptor))
|
||||
filename = splitext(basename(audio_descriptor))[0]
|
||||
generated = []
|
||||
|
||||
Reference in New Issue
Block a user