🎨 finalize refactoring

This commit is contained in:
Faylixe
2020-12-08 12:10:45 +01:00
parent 075bb97f82
commit ed7bd4b945
6 changed files with 269 additions and 162 deletions

View File

@@ -19,13 +19,15 @@ import os
from multiprocessing import Pool
from os.path import basename, join, splitext, dirname
from typing import Generator, Optional
from spleeter.model.provider import ModelProvider
from typing import Dict, Generator, Optional
from . import SpleeterError
from .audio import STFTBackend
from .audio.adapter import get_default_audio_adapter
from .audio import Codec, STFTBackend
from .audio.adapter import AudioAdapter
from .audio.convertor import to_stereo
from .model import EstimatorSpecBuilder, InputProviderFactory
from .model import model_fn
from .utils.configuration import load_configuration
# pyright: reportMissingImports=false
@@ -65,18 +67,6 @@ class DataGenerator(object):
buffer = self._current_data
def get_backend(backend: str) -> str:
"""
"""
if backend not in SUPPORTED_BACKEND:
raise ValueError(f'Unsupported backend {backend}')
if backend == 'auto':
if len(tf.config.list_physical_devices('GPU')):
return 'tensorflow'
return 'librosa'
return backend
def create_estimator(params, MWF):
"""
Initialize tensorflow estimator that will perform separation
@@ -137,18 +127,21 @@ class Separator(object):
else:
self._pool = None
self._tasks = []
self._params['stft_backend'] = get_backend(stft_backend)
self._params['stft_backend'] = stft_backend
self._data_generator = DataGenerator()
def __del__(self) -> None:
if self._session:
self._session.close()
def _get_prediction_generator(self):
""" Lazy loading access method for internal prediction generator
returned by the predict method of a tensorflow estimator.
def _get_prediction_generator(self) -> Generator:
"""
Lazy loading access method for internal prediction generator
returned by the predict method of a tensorflow estimator.
:returns: generator of prediction.
Returns:
Generator:
Generator of prediction.
"""
if self._prediction_generator is None:
estimator = create_estimator(self._params, self._MWF)
@@ -181,17 +174,30 @@ class Separator(object):
task.get()
task.wait(timeout=timeout)
def _stft(self, data, inverse: bool = False, length=None):
""" Single entrypoint for both stft and istft. This computes stft and
istft with librosa on stereo data. The two channels are processed
separately and are concatenated together in the result. The expected
input formats are: (n_samples, 2) for stft and (T, F, 2) for istft.
def _stft(
self,
data: np.ndarray,
inverse: bool = False,
length: Optional[int] = None) -> np.ndarray:
"""
Single entrypoint for both stft and istft. This computes stft and
istft with librosa on stereo data. The two channels are processed
separately and are concatenated together in the result. The
expected input formats are: (n_samples, 2) for stft and (T, F, 2)
for istft.
:param data: np.array with either the waveform or the complex
spectrogram depending on the parameter inverse
:param inverse: should a stft or an istft be computed.
:returns: Stereo data as numpy array for the transform.
The channels are stored in the last dimension.
Parameters:
data (numpy.array):
Array with either the waveform or the complex spectrogram
depending on the parameter inverse
inverse (bool):
(Optional) Should a stft or an istft be computed.
length (Optional[int]):
Returns:
numpy.ndarray:
Stereo data as numpy array for the transform. The channels
are stored in the last dimension.
"""
assert not (inverse and length is None)
data = np.asfortranarray(data)
@@ -238,19 +244,24 @@ class Separator(object):
def _get_session(self):
if self._session is None:
saver = tf.compat.v1.train.Saver()
latest_checkpoint = tf.train.latest_checkpoint(
get_default_model_dir(self._params['model_dir']))
provider = ModelProvider.default()
model_directory: str = provider.get(self._params['model_dir'])
latest_checkpoint = tf.train.latest_checkpoint(model_directory)
self._session = tf.compat.v1.Session()
saver.restore(self._session, latest_checkpoint)
return self._session
def _separate_librosa(self, waveform: np.ndarray, audio_id):
def _separate_librosa(
self,
waveform: np.ndarray,
audio_descriptor: str) -> Dict:
"""
Performs separation with librosa backend for STFT.
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
audio_descriptor (str):
"""
with self._tf_graph.as_default():
out = {}
@@ -269,7 +280,7 @@ class Separator(object):
feed_dict=self._get_input_provider().get_feed_dict(
features,
stft,
audio_id))
audio_descriptor))
for inst in self._get_builder().instruments:
out[inst] = self._stft(
outputs[inst],
@@ -277,7 +288,10 @@ class Separator(object):
length=waveform.shape[0])
return out
def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
def _separate_tensorflow(
self,
waveform: np.ndarray,
audio_descriptor: str) -> Dict:
"""
Performs source separation over the given waveform with tensorflow
backend.
@@ -285,6 +299,7 @@ class Separator(object):
Parameters:
waveform (numpy.ndarray):
Waveform to be separated (as a numpy array)
audio_descriptor (str):
Returns:
Separated waveforms.
@@ -314,44 +329,61 @@ class Separator(object):
audio_descriptor (str):
(Optional) string describing the waveform (e.g. filename).
"""
if self._params['stft_backend'] == 'tensorflow':
backend: str = self._params['stft_backend']
if backend == STFTBackend.TENSORFLOW:
return self._separate_tensorflow(waveform, audio_descriptor)
else:
elif backend == STFTBackend.LIBROSA:
return self._separate_librosa(waveform, audio_descriptor)
raise ValueError(f'Unsupported STFT backend {backend}')
def separate_to_file(
self,
audio_descriptor,
destination,
audio_adapter=get_default_audio_adapter(),
offset=0,
duration=600.,
codec='wav',
bitrate='128k',
filename_format='{filename}/{instrument}.{codec}',
synchronous=True):
""" Performs source separation and export result to file using
given audio adapter.
Filename format should be a Python formattable string that could use
following parameters : {instrument}, {filename}, {foldername} and
{codec}.
:param audio_descriptor: Describe song to separate, used by audio
adapter to retrieve and load audio data,
in case of file based audio adapter, such
descriptor would be a file path.
:param destination: Target directory to write output to.
:param audio_adapter: (Optional) Audio adapter to use for I/O.
:param offset: (Optional) Offset of loaded song.
:param duration: (Optional) Duration of loaded song
(default: 600s).
:param codec: (Optional) Export codec.
:param bitrate: (Optional) Export bitrate.
:param filename_format: (Optional) Filename format.
:param synchronous: (Optional) True is should by synchronous.
audio_descriptor: str,
destination: str,
audio_adapter: Optional[AudioAdapter] = None,
offset: int = 0,
duration: float = 600.,
codec: Codec = Codec.WAV,
bitrate: str = '128k',
filename_format: str = '{filename}/{instrument}.{codec}',
synchronous: bool = True) -> None:
"""
waveform, sample_rate = audio_adapter.load(
Performs source separation and export result to file using
given audio adapter.
Filename format should be a Python formattable string that could
use following parameters :
- {instrument}
- {filename}
- {foldername}
- {codec}.
Parameters:
audio_descriptor (str):
Describe song to separate, used by audio adapter to
retrieve and load audio data, in case of file based
audio adapter, such descriptor would be a file path.
destination (str):
Target directory to write output to.
audio_adapter (Optional[AudioAdapter]):
(Optional) Audio adapter to use for I/O.
offset (int):
(Optional) Offset of loaded song.
duration (float):
(Optional) Duration of loaded song (default: 600s).
codec (Codec):
(Optional) Export codec.
bitrate (str):
(Optional) Export bitrate.
filename_format (str):
(Optional) Filename format.
synchronous (bool):
(Optional) True is should by synchronous.
"""
if audio_adapter is None:
audio_adapter = AudioAdapter.default()
waveform, _ = audio_adapter.load(
audio_descriptor,
offset=offset,
duration=duration,
@@ -369,33 +401,42 @@ class Separator(object):
def save_to_file(
self,
sources,
audio_descriptor,
destination,
filename_format='{filename}/{instrument}.{codec}',
codec='wav',
audio_adapter=get_default_audio_adapter(),
bitrate='128k',
synchronous=True):
""" Export dictionary of sources to files.
:param sources: Dictionary of sources to be exported. The
keys are the name of the instruments, and
the values are Nx2 numpy arrays containing
the corresponding intrument waveform, as
returned by the separate method
:param audio_descriptor: Describe song to separate, used by audio
adapter to retrieve and load audio data,
in case of file based audio adapter, such
descriptor would be a file path.
:param destination: Target directory to write output to.
:param filename_format: (Optional) Filename format.
:param codec: (Optional) Export codec.
:param audio_adapter: (Optional) Audio adapter to use for I/O.
:param bitrate: (Optional) Export bitrate.
:param synchronous: (Optional) True is should by synchronous.
sources: Dict,
audio_descriptor: str,
destination: str,
filename_format: str = '{filename}/{instrument}.{codec}',
codec: Codec = Codec.WAV,
audio_adapter: Optional[AudioAdapter] = None,
bitrate: str = '128k',
synchronous: bool = True) -> None:
"""
Export dictionary of sources to files.
Parameters:
sources (Dict):
Dictionary of sources to be exported. The keys are the name
of the instruments, and the values are `N x 2` numpy arrays
containing the corresponding intrument waveform, as
returned by the separate method
audio_descriptor (str):
Describe song to separate, used by audio adapter to
retrieve and load audio data, in case of file based audio
adapter, such descriptor would be a file path.
destination (str):
Target directory to write output to.
filename_format (str):
(Optional) Filename format.
codec (Codec):
(Optional) Export codec.
audio_adapter (Optional[AudioAdapter]):
(Optional) Audio adapter to use for I/O.
bitrate (str):
(Optional) Export bitrate.
synchronous (bool):
(Optional) True is should by synchronous.
"""
if audio_adapter is None:
audio_adapter = AudioAdapter.default()
foldername = basename(dirname(audio_descriptor))
filename = splitext(basename(audio_descriptor))[0]
generated = []