Moving get_backend to separator

This commit is contained in:
akhlif
2020-02-27 14:13:59 +01:00
parent 6001ae12a9
commit d177525ea7
3 changed files with 25 additions and 14 deletions

View File

@@ -11,8 +11,6 @@
-i /path/to/audio1.wav /path/to/audio2.mp3
"""
import tensorflow as tf
from ..audio.adapter import get_audio_adapter
from ..separator import Separator
@@ -21,11 +19,6 @@ __author__ = 'Deezer Research'
__license__ = 'MIT License'
def get_backend(backend):
if backend == "auto":
return "tensorflow" if tf.test.is_gpu_available() else "librosa"
return backend
def entrypoint(arguments, params):
""" Command entrypoint.
@@ -38,7 +31,7 @@ def entrypoint(arguments, params):
separator = Separator(
arguments.configuration,
MWF=arguments.MWF,
stft_backend=get_backend(arguments.stft_backend))
stft_backend=arguments.stft_backend)
for filename in arguments.inputs:
separator.separate_to_file(
filename,

View File

@@ -320,6 +320,12 @@ class EstimatorSpecBuilder(object):
self._build_masks()
return self._masks
@property
def masked_stfts(self):
if not hasattr(self, "_masked_stfts"):
self._build_masked_stfts()
return self._masked_stfts
def _inverse_stft(self, stft, time_crop=None):
""" Inverse and reshape the given STFT
@@ -397,6 +403,10 @@ class EstimatorSpecBuilder(object):
return tf.concat([mask, extension], axis=2)
def _build_masks(self):
"""
Compute masks from the output spectrograms of the model.
:return:
"""
output_dict = self.model_outputs
stft_feature = self.stft_feature
separation_exponent = self._params['separation_exponent']
@@ -425,12 +435,12 @@ class EstimatorSpecBuilder(object):
out[instrument] = instrument_mask
self._masks = out
def _build_masked_stft(self):
def _build_masked_stfts(self):
input_stft = self.stft_feature
out = {}
for instrument, mask in self.masks.items():
out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft
return out
self._masked_stfts = out
def _build_manual_output_waveform(self, masked_stft):
""" Perform ratio mask separation
@@ -461,11 +471,10 @@ class EstimatorSpecBuilder(object):
return output_waveform
def _build_outputs(self):
masked_stft = self._build_masked_stft()
if self.include_stft_computations():
self._outputs = self._build_output_waveform(masked_stft)
self._outputs = self._build_output_waveform(self.masked_stfts)
else:
self._outputs = masked_stft
self._outputs = self.masked_stfts
if 'audio_id' in self._features:
self._outputs['audio_id'] = self._features['audio_id']

View File

@@ -39,6 +39,14 @@ __license__ = 'MIT License'
logger = logging.getLogger("spleeter")
def get_backend(backend):
assert backend in ["auto", "tensorflow", "librosa"]
if backend == "auto":
return "tensorflow" if tf.test.is_gpu_available() else "librosa"
return backend
class Separator(object):
""" A wrapper class for performing separation. """
@@ -48,13 +56,14 @@ class Separator(object):
:param params_descriptor: Descriptor for TF params to be used.
:param MWF: (Optional) True if MWF should be used, False otherwise.
"""
self._params = load_configuration(params_descriptor)
self._sample_rate = self._params['sample_rate']
self._MWF = MWF
self._predictor = None
self._pool = Pool() if multiprocess else None
self._tasks = []
self._params["stft_backend"] = stft_backend
self._params["stft_backend"] = get_backend(stft_backend)
def _get_predictor(self):
""" Lazy loading access method for internal predictor instance.