mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
🎨 finalize refactoring
This commit is contained in:
@@ -11,7 +11,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||||||
|
|
||||||
from .. import SpleeterError
|
from .. import SpleeterError
|
||||||
from ..types import AudioDescriptor, Signal
|
from ..types import AudioDescriptor, Signal
|
||||||
from ..utils.logging import get_logger
|
from ..utils.logging import logger
|
||||||
|
|
||||||
# pyright: reportMissingImports=false
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
@@ -101,7 +101,6 @@ class AudioAdapter(ABC):
|
|||||||
|
|
||||||
# Defined safe loading function.
|
# Defined safe loading function.
|
||||||
def safe_load(path, offset, duration, sample_rate, dtype):
|
def safe_load(path, offset, duration, sample_rate, dtype):
|
||||||
logger = get_logger()
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f'Loading audio {path} from {offset} to {offset + duration}')
|
f'Loading audio {path} from {offset} to {offset + duration}')
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from . import Codec
|
|||||||
from .adapter import AudioAdapter
|
from .adapter import AudioAdapter
|
||||||
from .. import SpleeterError
|
from .. import SpleeterError
|
||||||
from ..types import Signal
|
from ..types import Signal
|
||||||
from ..utils.logging import get_logger
|
from ..utils.logging import logger
|
||||||
|
|
||||||
# pyright: reportMissingImports=false
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
@@ -161,7 +161,7 @@ class FFMPEGProcessAudioAdapter(AudioAdapter):
|
|||||||
if not os.path.exists(directory):
|
if not os.path.exists(directory):
|
||||||
raise SpleeterError(
|
raise SpleeterError(
|
||||||
f'output directory does not exists: {directory}')
|
f'output directory does not exists: {directory}')
|
||||||
get_logger().debug(f'Writing file {path}')
|
logger.debug(f'Writing file {path}')
|
||||||
input_kwargs = {'ar': sample_rate, 'ac': data.shape[1]}
|
input_kwargs = {'ar': sample_rate, 'ac': data.shape[1]}
|
||||||
output_kwargs = {'ar': sample_rate, 'strict': '-2'}
|
output_kwargs = {'ar': sample_rate, 'strict': '-2'}
|
||||||
if bitrate:
|
if bitrate:
|
||||||
@@ -180,4 +180,4 @@ class FFMPEGProcessAudioAdapter(AudioAdapter):
|
|||||||
process.wait()
|
process.wait()
|
||||||
except IOError:
|
except IOError:
|
||||||
raise SpleeterError(f'FFMPEG error: {process.stderr.read()}')
|
raise SpleeterError(f'FFMPEG error: {process.stderr.read()}')
|
||||||
get_logger().info(f'File {path} written succesfully')
|
logger.info(f'File {path} written succesfully')
|
||||||
|
|||||||
@@ -18,12 +18,14 @@ import time
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from os.path import exists, sep as SEPARATOR
|
from os.path import exists, sep as SEPARATOR
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from .audio.adapter import AudioAdapter
|
||||||
from .audio.convertor import db_uint_spectrogram_to_gain
|
from .audio.convertor import db_uint_spectrogram_to_gain
|
||||||
from .audio.convertor import spectrogram_to_db_uint
|
from .audio.convertor import spectrogram_to_db_uint
|
||||||
from .audio.spectrogram import compute_spectrogram_tf
|
from .audio.spectrogram import compute_spectrogram_tf
|
||||||
from .audio.spectrogram import random_pitch_shift, random_time_stretch
|
from .audio.spectrogram import random_pitch_shift, random_time_stretch
|
||||||
from .utils.logging import get_logger
|
from .utils.logging import logger
|
||||||
from .utils.tensor import check_tensor_shape, dataset_from_csv
|
from .utils.tensor import check_tensor_shape, dataset_from_csv
|
||||||
from .utils.tensor import set_tensor_shape, sync_apply
|
from .utils.tensor import set_tensor_shape, sync_apply
|
||||||
|
|
||||||
@@ -37,24 +39,34 @@ __author__ = 'Deezer Research'
|
|||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
# Default audio parameters to use.
|
# Default audio parameters to use.
|
||||||
DEFAULT_AUDIO_PARAMS = {
|
DEFAULT_AUDIO_PARAMS: Dict = {
|
||||||
'instrument_list': ('vocals', 'accompaniment'),
|
'instrument_list': ('vocals', 'accompaniment'),
|
||||||
'mix_name': 'mix',
|
'mix_name': 'mix',
|
||||||
'sample_rate': 44100,
|
'sample_rate': 44100,
|
||||||
'frame_length': 4096,
|
'frame_length': 4096,
|
||||||
'frame_step': 1024,
|
'frame_step': 1024,
|
||||||
'T': 512,
|
'T': 512,
|
||||||
'F': 1024
|
'F': 1024}
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_training_dataset(audio_params, audio_adapter, audio_path):
|
def get_training_dataset(
|
||||||
""" Builds training dataset.
|
audio_params: Dict,
|
||||||
|
audio_adapter: AudioAdapter,
|
||||||
|
audio_path: str) -> Any:
|
||||||
|
"""
|
||||||
|
Builds training dataset.
|
||||||
|
|
||||||
:param audio_params: Audio parameters.
|
Parameters:
|
||||||
:param audio_adapter: Adapter to load audio from.
|
audio_params (Dict):
|
||||||
:param audio_path: Path of directory containing audio.
|
Audio parameters.
|
||||||
:returns: Built dataset.
|
audio_adapter (AudioAdapter):
|
||||||
|
Adapter to load audio from.
|
||||||
|
audio_path (str):
|
||||||
|
Path of directory containing audio.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any:
|
||||||
|
Built dataset.
|
||||||
"""
|
"""
|
||||||
builder = DatasetBuilder(
|
builder = DatasetBuilder(
|
||||||
audio_params,
|
audio_params,
|
||||||
@@ -72,13 +84,24 @@ def get_training_dataset(audio_params, audio_adapter, audio_path):
|
|||||||
wait_for_cache=False)
|
wait_for_cache=False)
|
||||||
|
|
||||||
|
|
||||||
def get_validation_dataset(audio_params, audio_adapter, audio_path):
|
def get_validation_dataset(
|
||||||
""" Builds validation dataset.
|
audio_params: Dict,
|
||||||
|
audio_adapter: AudioAdapter,
|
||||||
|
audio_path: str) -> Any:
|
||||||
|
"""
|
||||||
|
Builds validation dataset.
|
||||||
|
|
||||||
:param audio_params: Audio parameters.
|
Parameters:
|
||||||
:param audio_adapter: Adapter to load audio from.
|
audio_params (Dict):
|
||||||
:param audio_path: Path of directory containing audio.
|
Audio parameters.
|
||||||
:returns: Built dataset.
|
audio_adapter (AudioAdapter):
|
||||||
|
Adapter to load audio from.
|
||||||
|
audio_path (str):
|
||||||
|
Path of directory containing audio.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any:
|
||||||
|
Built dataset.
|
||||||
"""
|
"""
|
||||||
builder = DatasetBuilder(
|
builder = DatasetBuilder(
|
||||||
audio_params,
|
audio_params,
|
||||||
@@ -102,11 +125,15 @@ def get_validation_dataset(audio_params, audio_adapter, audio_path):
|
|||||||
class InstrumentDatasetBuilder(object):
|
class InstrumentDatasetBuilder(object):
|
||||||
""" Instrument based filter and mapper provider. """
|
""" Instrument based filter and mapper provider. """
|
||||||
|
|
||||||
def __init__(self, parent, instrument):
|
def __init__(self, parent, instrument) -> None:
|
||||||
""" Default constructor.
|
"""
|
||||||
|
Default constructor.
|
||||||
|
|
||||||
:param parent: Parent dataset builder.
|
Parameters:
|
||||||
:param instrument: Target instrument.
|
parent:
|
||||||
|
Parent dataset builder.
|
||||||
|
instrument:
|
||||||
|
Target instrument.
|
||||||
"""
|
"""
|
||||||
self._parent = parent
|
self._parent = parent
|
||||||
self._instrument = instrument
|
self._instrument = instrument
|
||||||
@@ -181,7 +208,7 @@ class InstrumentDatasetBuilder(object):
|
|||||||
self._parent._T, self._parent._F, 2))
|
self._parent._T, self._parent._F, 2))
|
||||||
|
|
||||||
def reshape_spectrogram(self, sample):
|
def reshape_spectrogram(self, sample):
|
||||||
""" """
|
""" Reshape given sample. """
|
||||||
return dict(sample, **{
|
return dict(sample, **{
|
||||||
self._spectrogram_key: set_tensor_shape(
|
self._spectrogram_key: set_tensor_shape(
|
||||||
sample[self._spectrogram_key],
|
sample[self._spectrogram_key],
|
||||||
@@ -190,27 +217,35 @@ class InstrumentDatasetBuilder(object):
|
|||||||
|
|
||||||
class DatasetBuilder(object):
|
class DatasetBuilder(object):
|
||||||
"""
|
"""
|
||||||
|
TO BE DOCUMENTED.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Margin at beginning and end of songs in seconds.
|
MARGIN: float = 0.5
|
||||||
MARGIN = 0.5
|
""" Margin at beginning and end of songs in seconds. """
|
||||||
|
|
||||||
# Wait period for cache (in seconds).
|
WAIT_PERIOD: int = 60
|
||||||
WAIT_PERIOD = 60
|
""" Wait period for cache (in seconds). """
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
audio_params, audio_adapter, audio_path,
|
audio_params: Dict,
|
||||||
random_seed=0, chunk_duration=20.0):
|
audio_adapter: AudioAdapter,
|
||||||
""" Default constructor.
|
audio_path: str,
|
||||||
|
random_seed: int = 0,
|
||||||
|
chunk_duration: float = 20.0) -> None:
|
||||||
|
"""
|
||||||
|
Default constructor.
|
||||||
|
|
||||||
NOTE: Probably need for AudioAdapter.
|
NOTE: Probably need for AudioAdapter.
|
||||||
|
|
||||||
:param audio_params: Audio parameters to use.
|
Parameters:
|
||||||
:param audio_adapter: Audio adapter to use.
|
audio_params (Dict):
|
||||||
:param audio_path:
|
Audio parameters to use.
|
||||||
:param random_seed:
|
audio_adapter (AudioAdapter):
|
||||||
:param chunk_duration:
|
Audio adapter to use.
|
||||||
|
audio_path (str):
|
||||||
|
random_seed (int):
|
||||||
|
chunk_duration (float):
|
||||||
"""
|
"""
|
||||||
# Length of segment in frames (if fs=22050 and
|
# Length of segment in frames (if fs=22050 and
|
||||||
# frame_step=512, then T=512 corresponds to 11.89s)
|
# frame_step=512, then T=512 corresponds to 11.89s)
|
||||||
@@ -298,12 +333,22 @@ class DatasetBuilder(object):
|
|||||||
for instrument in self._audio_params['instrument_list']}
|
for instrument in self._audio_params['instrument_list']}
|
||||||
return (input_, output)
|
return (input_, output)
|
||||||
|
|
||||||
def compute_segments(self, dataset, n_chunks_per_song):
|
def compute_segments(
|
||||||
""" Computes segments for each song of the dataset.
|
self,
|
||||||
|
dataset: Any,
|
||||||
|
n_chunks_per_song: int) -> Any:
|
||||||
|
"""
|
||||||
|
Computes segments for each song of the dataset.
|
||||||
|
|
||||||
:param dataset: Dataset to compute segments for.
|
Parameters:
|
||||||
:param n_chunks_per_song: Number of segment per song to compute.
|
dataset (Any):
|
||||||
:returns: Segmented dataset.
|
Dataset to compute segments for.
|
||||||
|
n_chunks_per_song (int):
|
||||||
|
Number of segment per song to compute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any:
|
||||||
|
Segmented dataset.
|
||||||
"""
|
"""
|
||||||
if n_chunks_per_song <= 0:
|
if n_chunks_per_song <= 0:
|
||||||
raise ValueError('n_chunks_per_song must be positif')
|
raise ValueError('n_chunks_per_song must be positif')
|
||||||
@@ -327,10 +372,13 @@ class DatasetBuilder(object):
|
|||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def instruments(self):
|
def instruments(self) -> Any:
|
||||||
""" Instrument dataset builder generator.
|
"""
|
||||||
|
Instrument dataset builder generator.
|
||||||
|
|
||||||
:yield InstrumentBuilder instance.
|
Yields:
|
||||||
|
Any:
|
||||||
|
InstrumentBuilder instance.
|
||||||
"""
|
"""
|
||||||
if self._instrument_builders is None:
|
if self._instrument_builders is None:
|
||||||
self._instrument_builders = []
|
self._instrument_builders = []
|
||||||
@@ -340,22 +388,33 @@ class DatasetBuilder(object):
|
|||||||
for builder in self._instrument_builders:
|
for builder in self._instrument_builders:
|
||||||
yield builder
|
yield builder
|
||||||
|
|
||||||
def cache(self, dataset, cache, wait):
|
def cache(
|
||||||
""" Cache the given dataset if cache is enabled. Eventually waits for
|
self,
|
||||||
cache to be available (useful if another process is already computing
|
dataset: Any,
|
||||||
cache) if provided wait flag is True.
|
cache: str,
|
||||||
|
wait: bool) -> Any:
|
||||||
|
"""
|
||||||
|
Cache the given dataset if cache is enabled. Eventually waits for
|
||||||
|
cache to be available (useful if another process is already
|
||||||
|
computing cache) if provided wait flag is `True`.
|
||||||
|
|
||||||
:param dataset: Dataset to be cached if cache is required.
|
Parameters:
|
||||||
:param cache: Path of cache directory to be used, None if no cache.
|
dataset (Any):
|
||||||
:param wait: If caching is enabled, True is cache should be waited.
|
Dataset to be cached if cache is required.
|
||||||
:returns: Cached dataset if needed, original dataset otherwise.
|
cache (str):
|
||||||
|
Path of cache directory to be used, None if no cache.
|
||||||
|
wait (bool):
|
||||||
|
If caching is enabled, True is cache should be waited.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any:
|
||||||
|
Cached dataset if needed, original dataset otherwise.
|
||||||
"""
|
"""
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
if wait:
|
if wait:
|
||||||
while not exists(f'{cache}.index'):
|
while not exists(f'{cache}.index'):
|
||||||
get_logger().info(
|
logger.info(
|
||||||
'Cache not available, wait %s',
|
f'Cache not available, wait {self.WAIT_PERIOD}')
|
||||||
self.WAIT_PERIOD)
|
|
||||||
time.sleep(self.WAIT_PERIOD)
|
time.sleep(self.WAIT_PERIOD)
|
||||||
cache_path = os.path.split(cache)[0]
|
cache_path = os.path.split(cache)[0]
|
||||||
os.makedirs(cache_path, exist_ok=True)
|
os.makedirs(cache_path, exist_ok=True)
|
||||||
@@ -363,13 +422,20 @@ class DatasetBuilder(object):
|
|||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
def build(
|
def build(
|
||||||
self, csv_path,
|
self,
|
||||||
batch_size=8, shuffle=True, convert_to_uint=True,
|
csv_path: str,
|
||||||
random_data_augmentation=False, random_time_crop=True,
|
batch_size: int = 8,
|
||||||
infinite_generator=True, cache_directory=None,
|
shuffle: bool = True,
|
||||||
wait_for_cache=False, num_parallel_calls=4, n_chunks_per_song=2,):
|
convert_to_uint: bool = True,
|
||||||
|
random_data_augmentation: bool = False,
|
||||||
|
random_time_crop: bool = True,
|
||||||
|
infinite_generator: bool = True,
|
||||||
|
cache_directory: Optional[str] = None,
|
||||||
|
wait_for_cache: bool = False,
|
||||||
|
num_parallel_calls: int = 4,
|
||||||
|
n_chunks_per_song: float = 2,) -> Any:
|
||||||
"""
|
"""
|
||||||
TO BE DOCUMENTED.
|
TO BE DOCUMENTED.
|
||||||
"""
|
"""
|
||||||
dataset = dataset_from_csv(csv_path)
|
dataset = dataset_from_csv(csv_path)
|
||||||
dataset = self.compute_segments(dataset, n_chunks_per_song)
|
dataset = self.compute_segments(dataset, n_chunks_per_song)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from tempfile import NamedTemporaryFile
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from . import ModelProvider
|
from . import ModelProvider
|
||||||
from ...utils.logging import get_logger
|
from ...utils.logging import logger
|
||||||
|
|
||||||
# pyright: reportMissingImports=false
|
# pyright: reportMissingImports=false
|
||||||
# pylint: disable=import-error
|
# pylint: disable=import-error
|
||||||
@@ -138,7 +138,7 @@ class GithubModelProvider(ModelProvider):
|
|||||||
self._release,
|
self._release,
|
||||||
name))
|
name))
|
||||||
url = f'{url}.tar.gz'
|
url = f'{url}.tar.gz'
|
||||||
get_logger().info(f'Downloading model archive {url}')
|
logger.info(f'Downloading model archive {url}')
|
||||||
with httpx.Client(http2=True) as client:
|
with httpx.Client(http2=True) as client:
|
||||||
with client.strema('GET', url) as response:
|
with client.strema('GET', url) as response:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@@ -147,14 +147,14 @@ class GithubModelProvider(ModelProvider):
|
|||||||
with archive as stream:
|
with archive as stream:
|
||||||
for chunk in response.iter_raw():
|
for chunk in response.iter_raw():
|
||||||
stream.write(chunk)
|
stream.write(chunk)
|
||||||
get_logger().info('Validating archive checksum')
|
logger.info('Validating archive checksum')
|
||||||
checksum: str = compute_file_checksum(archive.name)
|
checksum: str = compute_file_checksum(archive.name)
|
||||||
if checksum != self.checksum(name):
|
if checksum != self.checksum(name):
|
||||||
raise IOError(
|
raise IOError(
|
||||||
'Downloaded file is corrupted, please retry')
|
'Downloaded file is corrupted, please retry')
|
||||||
get_logger().info(f'Extracting downloaded {name} archive')
|
logger.info(f'Extracting downloaded {name} archive')
|
||||||
with tarfile.open(name=archive.name) as tar:
|
with tarfile.open(name=archive.name) as tar:
|
||||||
tar.extractall(path=path)
|
tar.extractall(path=path)
|
||||||
finally:
|
finally:
|
||||||
os.unlink(archive.name)
|
os.unlink(archive.name)
|
||||||
get_logger().info(f'{name} model file(s) extracted')
|
logger.info(f'{name} model file(s) extracted')
|
||||||
|
|||||||
@@ -19,13 +19,15 @@ import os
|
|||||||
|
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from os.path import basename, join, splitext, dirname
|
from os.path import basename, join, splitext, dirname
|
||||||
from typing import Generator, Optional
|
from spleeter.model.provider import ModelProvider
|
||||||
|
from typing import Dict, Generator, Optional
|
||||||
|
|
||||||
from . import SpleeterError
|
from . import SpleeterError
|
||||||
from .audio import STFTBackend
|
from .audio import Codec, STFTBackend
|
||||||
from .audio.adapter import get_default_audio_adapter
|
from .audio.adapter import AudioAdapter
|
||||||
from .audio.convertor import to_stereo
|
from .audio.convertor import to_stereo
|
||||||
from .model import EstimatorSpecBuilder, InputProviderFactory
|
from .model import EstimatorSpecBuilder, InputProviderFactory
|
||||||
|
from .model import model_fn
|
||||||
from .utils.configuration import load_configuration
|
from .utils.configuration import load_configuration
|
||||||
|
|
||||||
# pyright: reportMissingImports=false
|
# pyright: reportMissingImports=false
|
||||||
@@ -65,18 +67,6 @@ class DataGenerator(object):
|
|||||||
buffer = self._current_data
|
buffer = self._current_data
|
||||||
|
|
||||||
|
|
||||||
def get_backend(backend: str) -> str:
|
|
||||||
"""
|
|
||||||
"""
|
|
||||||
if backend not in SUPPORTED_BACKEND:
|
|
||||||
raise ValueError(f'Unsupported backend {backend}')
|
|
||||||
if backend == 'auto':
|
|
||||||
if len(tf.config.list_physical_devices('GPU')):
|
|
||||||
return 'tensorflow'
|
|
||||||
return 'librosa'
|
|
||||||
return backend
|
|
||||||
|
|
||||||
|
|
||||||
def create_estimator(params, MWF):
|
def create_estimator(params, MWF):
|
||||||
"""
|
"""
|
||||||
Initialize tensorflow estimator that will perform separation
|
Initialize tensorflow estimator that will perform separation
|
||||||
@@ -137,18 +127,21 @@ class Separator(object):
|
|||||||
else:
|
else:
|
||||||
self._pool = None
|
self._pool = None
|
||||||
self._tasks = []
|
self._tasks = []
|
||||||
self._params['stft_backend'] = get_backend(stft_backend)
|
self._params['stft_backend'] = stft_backend
|
||||||
self._data_generator = DataGenerator()
|
self._data_generator = DataGenerator()
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
if self._session:
|
if self._session:
|
||||||
self._session.close()
|
self._session.close()
|
||||||
|
|
||||||
def _get_prediction_generator(self):
|
def _get_prediction_generator(self) -> Generator:
|
||||||
""" Lazy loading access method for internal prediction generator
|
"""
|
||||||
returned by the predict method of a tensorflow estimator.
|
Lazy loading access method for internal prediction generator
|
||||||
|
returned by the predict method of a tensorflow estimator.
|
||||||
|
|
||||||
:returns: generator of prediction.
|
Returns:
|
||||||
|
Generator:
|
||||||
|
Generator of prediction.
|
||||||
"""
|
"""
|
||||||
if self._prediction_generator is None:
|
if self._prediction_generator is None:
|
||||||
estimator = create_estimator(self._params, self._MWF)
|
estimator = create_estimator(self._params, self._MWF)
|
||||||
@@ -181,17 +174,30 @@ class Separator(object):
|
|||||||
task.get()
|
task.get()
|
||||||
task.wait(timeout=timeout)
|
task.wait(timeout=timeout)
|
||||||
|
|
||||||
def _stft(self, data, inverse: bool = False, length=None):
|
def _stft(
|
||||||
""" Single entrypoint for both stft and istft. This computes stft and
|
self,
|
||||||
istft with librosa on stereo data. The two channels are processed
|
data: np.ndarray,
|
||||||
separately and are concatenated together in the result. The expected
|
inverse: bool = False,
|
||||||
input formats are: (n_samples, 2) for stft and (T, F, 2) for istft.
|
length: Optional[int] = None) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Single entrypoint for both stft and istft. This computes stft and
|
||||||
|
istft with librosa on stereo data. The two channels are processed
|
||||||
|
separately and are concatenated together in the result. The
|
||||||
|
expected input formats are: (n_samples, 2) for stft and (T, F, 2)
|
||||||
|
for istft.
|
||||||
|
|
||||||
:param data: np.array with either the waveform or the complex
|
Parameters:
|
||||||
spectrogram depending on the parameter inverse
|
data (numpy.array):
|
||||||
:param inverse: should a stft or an istft be computed.
|
Array with either the waveform or the complex spectrogram
|
||||||
:returns: Stereo data as numpy array for the transform.
|
depending on the parameter inverse
|
||||||
The channels are stored in the last dimension.
|
inverse (bool):
|
||||||
|
(Optional) Should a stft or an istft be computed.
|
||||||
|
length (Optional[int]):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray:
|
||||||
|
Stereo data as numpy array for the transform. The channels
|
||||||
|
are stored in the last dimension.
|
||||||
"""
|
"""
|
||||||
assert not (inverse and length is None)
|
assert not (inverse and length is None)
|
||||||
data = np.asfortranarray(data)
|
data = np.asfortranarray(data)
|
||||||
@@ -238,19 +244,24 @@ class Separator(object):
|
|||||||
def _get_session(self):
|
def _get_session(self):
|
||||||
if self._session is None:
|
if self._session is None:
|
||||||
saver = tf.compat.v1.train.Saver()
|
saver = tf.compat.v1.train.Saver()
|
||||||
latest_checkpoint = tf.train.latest_checkpoint(
|
provider = ModelProvider.default()
|
||||||
get_default_model_dir(self._params['model_dir']))
|
model_directory: str = provider.get(self._params['model_dir'])
|
||||||
|
latest_checkpoint = tf.train.latest_checkpoint(model_directory)
|
||||||
self._session = tf.compat.v1.Session()
|
self._session = tf.compat.v1.Session()
|
||||||
saver.restore(self._session, latest_checkpoint)
|
saver.restore(self._session, latest_checkpoint)
|
||||||
return self._session
|
return self._session
|
||||||
|
|
||||||
def _separate_librosa(self, waveform: np.ndarray, audio_id):
|
def _separate_librosa(
|
||||||
|
self,
|
||||||
|
waveform: np.ndarray,
|
||||||
|
audio_descriptor: str) -> Dict:
|
||||||
"""
|
"""
|
||||||
Performs separation with librosa backend for STFT.
|
Performs separation with librosa backend for STFT.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
waveform (numpy.ndarray):
|
waveform (numpy.ndarray):
|
||||||
Waveform to be separated (as a numpy array)
|
Waveform to be separated (as a numpy array)
|
||||||
|
audio_descriptor (str):
|
||||||
"""
|
"""
|
||||||
with self._tf_graph.as_default():
|
with self._tf_graph.as_default():
|
||||||
out = {}
|
out = {}
|
||||||
@@ -269,7 +280,7 @@ class Separator(object):
|
|||||||
feed_dict=self._get_input_provider().get_feed_dict(
|
feed_dict=self._get_input_provider().get_feed_dict(
|
||||||
features,
|
features,
|
||||||
stft,
|
stft,
|
||||||
audio_id))
|
audio_descriptor))
|
||||||
for inst in self._get_builder().instruments:
|
for inst in self._get_builder().instruments:
|
||||||
out[inst] = self._stft(
|
out[inst] = self._stft(
|
||||||
outputs[inst],
|
outputs[inst],
|
||||||
@@ -277,7 +288,10 @@ class Separator(object):
|
|||||||
length=waveform.shape[0])
|
length=waveform.shape[0])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor):
|
def _separate_tensorflow(
|
||||||
|
self,
|
||||||
|
waveform: np.ndarray,
|
||||||
|
audio_descriptor: str) -> Dict:
|
||||||
"""
|
"""
|
||||||
Performs source separation over the given waveform with tensorflow
|
Performs source separation over the given waveform with tensorflow
|
||||||
backend.
|
backend.
|
||||||
@@ -285,6 +299,7 @@ class Separator(object):
|
|||||||
Parameters:
|
Parameters:
|
||||||
waveform (numpy.ndarray):
|
waveform (numpy.ndarray):
|
||||||
Waveform to be separated (as a numpy array)
|
Waveform to be separated (as a numpy array)
|
||||||
|
audio_descriptor (str):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Separated waveforms.
|
Separated waveforms.
|
||||||
@@ -314,44 +329,61 @@ class Separator(object):
|
|||||||
audio_descriptor (str):
|
audio_descriptor (str):
|
||||||
(Optional) string describing the waveform (e.g. filename).
|
(Optional) string describing the waveform (e.g. filename).
|
||||||
"""
|
"""
|
||||||
if self._params['stft_backend'] == 'tensorflow':
|
backend: str = self._params['stft_backend']
|
||||||
|
if backend == STFTBackend.TENSORFLOW:
|
||||||
return self._separate_tensorflow(waveform, audio_descriptor)
|
return self._separate_tensorflow(waveform, audio_descriptor)
|
||||||
else:
|
elif backend == STFTBackend.LIBROSA:
|
||||||
return self._separate_librosa(waveform, audio_descriptor)
|
return self._separate_librosa(waveform, audio_descriptor)
|
||||||
|
raise ValueError(f'Unsupported STFT backend {backend}')
|
||||||
|
|
||||||
def separate_to_file(
|
def separate_to_file(
|
||||||
self,
|
self,
|
||||||
audio_descriptor,
|
audio_descriptor: str,
|
||||||
destination,
|
destination: str,
|
||||||
audio_adapter=get_default_audio_adapter(),
|
audio_adapter: Optional[AudioAdapter] = None,
|
||||||
offset=0,
|
offset: int = 0,
|
||||||
duration=600.,
|
duration: float = 600.,
|
||||||
codec='wav',
|
codec: Codec = Codec.WAV,
|
||||||
bitrate='128k',
|
bitrate: str = '128k',
|
||||||
filename_format='{filename}/{instrument}.{codec}',
|
filename_format: str = '{filename}/{instrument}.{codec}',
|
||||||
synchronous=True):
|
synchronous: bool = True) -> None:
|
||||||
""" Performs source separation and export result to file using
|
|
||||||
given audio adapter.
|
|
||||||
|
|
||||||
Filename format should be a Python formattable string that could use
|
|
||||||
following parameters : {instrument}, {filename}, {foldername} and
|
|
||||||
{codec}.
|
|
||||||
|
|
||||||
:param audio_descriptor: Describe song to separate, used by audio
|
|
||||||
adapter to retrieve and load audio data,
|
|
||||||
in case of file based audio adapter, such
|
|
||||||
descriptor would be a file path.
|
|
||||||
:param destination: Target directory to write output to.
|
|
||||||
:param audio_adapter: (Optional) Audio adapter to use for I/O.
|
|
||||||
:param offset: (Optional) Offset of loaded song.
|
|
||||||
:param duration: (Optional) Duration of loaded song
|
|
||||||
(default: 600s).
|
|
||||||
:param codec: (Optional) Export codec.
|
|
||||||
:param bitrate: (Optional) Export bitrate.
|
|
||||||
:param filename_format: (Optional) Filename format.
|
|
||||||
:param synchronous: (Optional) True is should by synchronous.
|
|
||||||
"""
|
"""
|
||||||
waveform, sample_rate = audio_adapter.load(
|
Performs source separation and export result to file using
|
||||||
|
given audio adapter.
|
||||||
|
|
||||||
|
Filename format should be a Python formattable string that could
|
||||||
|
use following parameters :
|
||||||
|
|
||||||
|
- {instrument}
|
||||||
|
- {filename}
|
||||||
|
- {foldername}
|
||||||
|
- {codec}.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
audio_descriptor (str):
|
||||||
|
Describe song to separate, used by audio adapter to
|
||||||
|
retrieve and load audio data, in case of file based
|
||||||
|
audio adapter, such descriptor would be a file path.
|
||||||
|
destination (str):
|
||||||
|
Target directory to write output to.
|
||||||
|
audio_adapter (Optional[AudioAdapter]):
|
||||||
|
(Optional) Audio adapter to use for I/O.
|
||||||
|
offset (int):
|
||||||
|
(Optional) Offset of loaded song.
|
||||||
|
duration (float):
|
||||||
|
(Optional) Duration of loaded song (default: 600s).
|
||||||
|
codec (Codec):
|
||||||
|
(Optional) Export codec.
|
||||||
|
bitrate (str):
|
||||||
|
(Optional) Export bitrate.
|
||||||
|
filename_format (str):
|
||||||
|
(Optional) Filename format.
|
||||||
|
synchronous (bool):
|
||||||
|
(Optional) True is should by synchronous.
|
||||||
|
"""
|
||||||
|
if audio_adapter is None:
|
||||||
|
audio_adapter = AudioAdapter.default()
|
||||||
|
waveform, _ = audio_adapter.load(
|
||||||
audio_descriptor,
|
audio_descriptor,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
@@ -369,33 +401,42 @@ class Separator(object):
|
|||||||
|
|
||||||
def save_to_file(
|
def save_to_file(
|
||||||
self,
|
self,
|
||||||
sources,
|
sources: Dict,
|
||||||
audio_descriptor,
|
audio_descriptor: str,
|
||||||
destination,
|
destination: str,
|
||||||
filename_format='{filename}/{instrument}.{codec}',
|
filename_format: str = '{filename}/{instrument}.{codec}',
|
||||||
codec='wav',
|
codec: Codec = Codec.WAV,
|
||||||
audio_adapter=get_default_audio_adapter(),
|
audio_adapter: Optional[AudioAdapter] = None,
|
||||||
bitrate='128k',
|
bitrate: str = '128k',
|
||||||
synchronous=True):
|
synchronous: bool = True) -> None:
|
||||||
""" Export dictionary of sources to files.
|
|
||||||
|
|
||||||
:param sources: Dictionary of sources to be exported. The
|
|
||||||
keys are the name of the instruments, and
|
|
||||||
the values are Nx2 numpy arrays containing
|
|
||||||
the corresponding intrument waveform, as
|
|
||||||
returned by the separate method
|
|
||||||
:param audio_descriptor: Describe song to separate, used by audio
|
|
||||||
adapter to retrieve and load audio data,
|
|
||||||
in case of file based audio adapter, such
|
|
||||||
descriptor would be a file path.
|
|
||||||
:param destination: Target directory to write output to.
|
|
||||||
:param filename_format: (Optional) Filename format.
|
|
||||||
:param codec: (Optional) Export codec.
|
|
||||||
:param audio_adapter: (Optional) Audio adapter to use for I/O.
|
|
||||||
:param bitrate: (Optional) Export bitrate.
|
|
||||||
:param synchronous: (Optional) True is should by synchronous.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
Export dictionary of sources to files.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
sources (Dict):
|
||||||
|
Dictionary of sources to be exported. The keys are the name
|
||||||
|
of the instruments, and the values are `N x 2` numpy arrays
|
||||||
|
containing the corresponding intrument waveform, as
|
||||||
|
returned by the separate method
|
||||||
|
audio_descriptor (str):
|
||||||
|
Describe song to separate, used by audio adapter to
|
||||||
|
retrieve and load audio data, in case of file based audio
|
||||||
|
adapter, such descriptor would be a file path.
|
||||||
|
destination (str):
|
||||||
|
Target directory to write output to.
|
||||||
|
filename_format (str):
|
||||||
|
(Optional) Filename format.
|
||||||
|
codec (Codec):
|
||||||
|
(Optional) Export codec.
|
||||||
|
audio_adapter (Optional[AudioAdapter]):
|
||||||
|
(Optional) Audio adapter to use for I/O.
|
||||||
|
bitrate (str):
|
||||||
|
(Optional) Export bitrate.
|
||||||
|
synchronous (bool):
|
||||||
|
(Optional) True is should by synchronous.
|
||||||
|
"""
|
||||||
|
if audio_adapter is None:
|
||||||
|
audio_adapter = AudioAdapter.default()
|
||||||
foldername = basename(dirname(audio_descriptor))
|
foldername = basename(dirname(audio_descriptor))
|
||||||
filename = splitext(basename(audio_descriptor))[0]
|
filename = splitext(basename(audio_descriptor))[0]
|
||||||
generated = []
|
generated = []
|
||||||
|
|||||||
Reference in New Issue
Block a user