mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
Moving get_backend to separator
This commit is contained in:
@@ -11,8 +11,6 @@
|
|||||||
-i /path/to/audio1.wav /path/to/audio2.mp3
|
-i /path/to/audio1.wav /path/to/audio2.mp3
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
|
|
||||||
from ..audio.adapter import get_audio_adapter
|
from ..audio.adapter import get_audio_adapter
|
||||||
from ..separator import Separator
|
from ..separator import Separator
|
||||||
|
|
||||||
@@ -21,11 +19,6 @@ __author__ = 'Deezer Research'
|
|||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
def get_backend(backend):
|
|
||||||
if backend == "auto":
|
|
||||||
return "tensorflow" if tf.test.is_gpu_available() else "librosa"
|
|
||||||
return backend
|
|
||||||
|
|
||||||
|
|
||||||
def entrypoint(arguments, params):
|
def entrypoint(arguments, params):
|
||||||
""" Command entrypoint.
|
""" Command entrypoint.
|
||||||
@@ -38,7 +31,7 @@ def entrypoint(arguments, params):
|
|||||||
separator = Separator(
|
separator = Separator(
|
||||||
arguments.configuration,
|
arguments.configuration,
|
||||||
MWF=arguments.MWF,
|
MWF=arguments.MWF,
|
||||||
stft_backend=get_backend(arguments.stft_backend))
|
stft_backend=arguments.stft_backend)
|
||||||
for filename in arguments.inputs:
|
for filename in arguments.inputs:
|
||||||
separator.separate_to_file(
|
separator.separate_to_file(
|
||||||
filename,
|
filename,
|
||||||
|
|||||||
@@ -320,6 +320,12 @@ class EstimatorSpecBuilder(object):
|
|||||||
self._build_masks()
|
self._build_masks()
|
||||||
return self._masks
|
return self._masks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def masked_stfts(self):
|
||||||
|
if not hasattr(self, "_masked_stfts"):
|
||||||
|
self._build_masked_stfts()
|
||||||
|
return self._masked_stfts
|
||||||
|
|
||||||
def _inverse_stft(self, stft, time_crop=None):
|
def _inverse_stft(self, stft, time_crop=None):
|
||||||
""" Inverse and reshape the given STFT
|
""" Inverse and reshape the given STFT
|
||||||
|
|
||||||
@@ -397,6 +403,10 @@ class EstimatorSpecBuilder(object):
|
|||||||
return tf.concat([mask, extension], axis=2)
|
return tf.concat([mask, extension], axis=2)
|
||||||
|
|
||||||
def _build_masks(self):
|
def _build_masks(self):
|
||||||
|
"""
|
||||||
|
Compute masks from the output spectrograms of the model.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
output_dict = self.model_outputs
|
output_dict = self.model_outputs
|
||||||
stft_feature = self.stft_feature
|
stft_feature = self.stft_feature
|
||||||
separation_exponent = self._params['separation_exponent']
|
separation_exponent = self._params['separation_exponent']
|
||||||
@@ -425,12 +435,12 @@ class EstimatorSpecBuilder(object):
|
|||||||
out[instrument] = instrument_mask
|
out[instrument] = instrument_mask
|
||||||
self._masks = out
|
self._masks = out
|
||||||
|
|
||||||
def _build_masked_stft(self):
|
def _build_masked_stfts(self):
|
||||||
input_stft = self.stft_feature
|
input_stft = self.stft_feature
|
||||||
out = {}
|
out = {}
|
||||||
for instrument, mask in self.masks.items():
|
for instrument, mask in self.masks.items():
|
||||||
out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft
|
out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft
|
||||||
return out
|
self._masked_stfts = out
|
||||||
|
|
||||||
def _build_manual_output_waveform(self, masked_stft):
|
def _build_manual_output_waveform(self, masked_stft):
|
||||||
""" Perform ratio mask separation
|
""" Perform ratio mask separation
|
||||||
@@ -461,11 +471,10 @@ class EstimatorSpecBuilder(object):
|
|||||||
return output_waveform
|
return output_waveform
|
||||||
|
|
||||||
def _build_outputs(self):
|
def _build_outputs(self):
|
||||||
masked_stft = self._build_masked_stft()
|
|
||||||
if self.include_stft_computations():
|
if self.include_stft_computations():
|
||||||
self._outputs = self._build_output_waveform(masked_stft)
|
self._outputs = self._build_output_waveform(self.masked_stfts)
|
||||||
else:
|
else:
|
||||||
self._outputs = masked_stft
|
self._outputs = self.masked_stfts
|
||||||
|
|
||||||
if 'audio_id' in self._features:
|
if 'audio_id' in self._features:
|
||||||
self._outputs['audio_id'] = self._features['audio_id']
|
self._outputs['audio_id'] = self._features['audio_id']
|
||||||
|
|||||||
@@ -39,6 +39,14 @@ __license__ = 'MIT License'
|
|||||||
logger = logging.getLogger("spleeter")
|
logger = logging.getLogger("spleeter")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_backend(backend):
|
||||||
|
assert backend in ["auto", "tensorflow", "librosa"]
|
||||||
|
if backend == "auto":
|
||||||
|
return "tensorflow" if tf.test.is_gpu_available() else "librosa"
|
||||||
|
return backend
|
||||||
|
|
||||||
|
|
||||||
class Separator(object):
|
class Separator(object):
|
||||||
""" A wrapper class for performing separation. """
|
""" A wrapper class for performing separation. """
|
||||||
|
|
||||||
@@ -48,13 +56,14 @@ class Separator(object):
|
|||||||
:param params_descriptor: Descriptor for TF params to be used.
|
:param params_descriptor: Descriptor for TF params to be used.
|
||||||
:param MWF: (Optional) True if MWF should be used, False otherwise.
|
:param MWF: (Optional) True if MWF should be used, False otherwise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._params = load_configuration(params_descriptor)
|
self._params = load_configuration(params_descriptor)
|
||||||
self._sample_rate = self._params['sample_rate']
|
self._sample_rate = self._params['sample_rate']
|
||||||
self._MWF = MWF
|
self._MWF = MWF
|
||||||
self._predictor = None
|
self._predictor = None
|
||||||
self._pool = Pool() if multiprocess else None
|
self._pool = Pool() if multiprocess else None
|
||||||
self._tasks = []
|
self._tasks = []
|
||||||
self._params["stft_backend"] = stft_backend
|
self._params["stft_backend"] = get_backend(stft_backend)
|
||||||
|
|
||||||
def _get_predictor(self):
|
def _get_predictor(self):
|
||||||
""" Lazy loading access method for internal predictor instance.
|
""" Lazy loading access method for internal predictor instance.
|
||||||
|
|||||||
Reference in New Issue
Block a user