mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-06 04:32:43 +00:00
Moving get_backend to separator
This commit is contained in:
@@ -11,8 +11,6 @@
|
||||
-i /path/to/audio1.wav /path/to/audio2.mp3
|
||||
"""
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from ..audio.adapter import get_audio_adapter
|
||||
from ..separator import Separator
|
||||
|
||||
@@ -21,11 +19,6 @@ __author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
|
||||
def get_backend(backend):
|
||||
if backend == "auto":
|
||||
return "tensorflow" if tf.test.is_gpu_available() else "librosa"
|
||||
return backend
|
||||
|
||||
|
||||
def entrypoint(arguments, params):
|
||||
""" Command entrypoint.
|
||||
@@ -38,7 +31,7 @@ def entrypoint(arguments, params):
|
||||
separator = Separator(
|
||||
arguments.configuration,
|
||||
MWF=arguments.MWF,
|
||||
stft_backend=get_backend(arguments.stft_backend))
|
||||
stft_backend=arguments.stft_backend)
|
||||
for filename in arguments.inputs:
|
||||
separator.separate_to_file(
|
||||
filename,
|
||||
|
||||
@@ -320,6 +320,12 @@ class EstimatorSpecBuilder(object):
|
||||
self._build_masks()
|
||||
return self._masks
|
||||
|
||||
@property
|
||||
def masked_stfts(self):
|
||||
if not hasattr(self, "_masked_stfts"):
|
||||
self._build_masked_stfts()
|
||||
return self._masked_stfts
|
||||
|
||||
def _inverse_stft(self, stft, time_crop=None):
|
||||
""" Inverse and reshape the given STFT
|
||||
|
||||
@@ -397,6 +403,10 @@ class EstimatorSpecBuilder(object):
|
||||
return tf.concat([mask, extension], axis=2)
|
||||
|
||||
def _build_masks(self):
|
||||
"""
|
||||
Compute masks from the output spectrograms of the model.
|
||||
:return:
|
||||
"""
|
||||
output_dict = self.model_outputs
|
||||
stft_feature = self.stft_feature
|
||||
separation_exponent = self._params['separation_exponent']
|
||||
@@ -425,12 +435,12 @@ class EstimatorSpecBuilder(object):
|
||||
out[instrument] = instrument_mask
|
||||
self._masks = out
|
||||
|
||||
def _build_masked_stft(self):
|
||||
def _build_masked_stfts(self):
|
||||
input_stft = self.stft_feature
|
||||
out = {}
|
||||
for instrument, mask in self.masks.items():
|
||||
out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft
|
||||
return out
|
||||
self._masked_stfts = out
|
||||
|
||||
def _build_manual_output_waveform(self, masked_stft):
|
||||
""" Perform ratio mask separation
|
||||
@@ -461,11 +471,10 @@ class EstimatorSpecBuilder(object):
|
||||
return output_waveform
|
||||
|
||||
def _build_outputs(self):
|
||||
masked_stft = self._build_masked_stft()
|
||||
if self.include_stft_computations():
|
||||
self._outputs = self._build_output_waveform(masked_stft)
|
||||
self._outputs = self._build_output_waveform(self.masked_stfts)
|
||||
else:
|
||||
self._outputs = masked_stft
|
||||
self._outputs = self.masked_stfts
|
||||
|
||||
if 'audio_id' in self._features:
|
||||
self._outputs['audio_id'] = self._features['audio_id']
|
||||
|
||||
@@ -39,6 +39,14 @@ __license__ = 'MIT License'
|
||||
logger = logging.getLogger("spleeter")
|
||||
|
||||
|
||||
|
||||
def get_backend(backend):
|
||||
assert backend in ["auto", "tensorflow", "librosa"]
|
||||
if backend == "auto":
|
||||
return "tensorflow" if tf.test.is_gpu_available() else "librosa"
|
||||
return backend
|
||||
|
||||
|
||||
class Separator(object):
|
||||
""" A wrapper class for performing separation. """
|
||||
|
||||
@@ -48,13 +56,14 @@ class Separator(object):
|
||||
:param params_descriptor: Descriptor for TF params to be used.
|
||||
:param MWF: (Optional) True if MWF should be used, False otherwise.
|
||||
"""
|
||||
|
||||
self._params = load_configuration(params_descriptor)
|
||||
self._sample_rate = self._params['sample_rate']
|
||||
self._MWF = MWF
|
||||
self._predictor = None
|
||||
self._pool = Pool() if multiprocess else None
|
||||
self._tasks = []
|
||||
self._params["stft_backend"] = stft_backend
|
||||
self._params["stft_backend"] = get_backend(stft_backend)
|
||||
|
||||
def _get_predictor(self):
|
||||
""" Lazy loading access method for internal predictor instance.
|
||||
|
||||
Reference in New Issue
Block a user