diff --git a/spleeter/commands/separate.py b/spleeter/commands/separate.py index 2fac8aa..158190c 100644 --- a/spleeter/commands/separate.py +++ b/spleeter/commands/separate.py @@ -11,8 +11,6 @@ -i /path/to/audio1.wav /path/to/audio2.mp3 """ -import tensorflow as tf - from ..audio.adapter import get_audio_adapter from ..separator import Separator @@ -21,11 +19,6 @@ __author__ = 'Deezer Research' __license__ = 'MIT License' -def get_backend(backend): - if backend == "auto": - return "tensorflow" if tf.test.is_gpu_available() else "librosa" - return backend - def entrypoint(arguments, params): """ Command entrypoint. @@ -38,7 +31,7 @@ def entrypoint(arguments, params): separator = Separator( arguments.configuration, MWF=arguments.MWF, - stft_backend=get_backend(arguments.stft_backend)) + stft_backend=arguments.stft_backend) for filename in arguments.inputs: separator.separate_to_file( filename, diff --git a/spleeter/model/__init__.py b/spleeter/model/__init__.py index 4c3e200..b99a9dd 100644 --- a/spleeter/model/__init__.py +++ b/spleeter/model/__init__.py @@ -320,6 +320,12 @@ class EstimatorSpecBuilder(object): self._build_masks() return self._masks + @property + def masked_stfts(self): + if not hasattr(self, "_masked_stfts"): + self._build_masked_stfts() + return self._masked_stfts + def _inverse_stft(self, stft, time_crop=None): """ Inverse and reshape the given STFT @@ -397,6 +403,10 @@ class EstimatorSpecBuilder(object): return tf.concat([mask, extension], axis=2) def _build_masks(self): + """ + Compute masks from the output spectrograms of the model. + :return: + """ output_dict = self.model_outputs stft_feature = self.stft_feature separation_exponent = self._params['separation_exponent'] @@ -425,12 +435,12 @@ class EstimatorSpecBuilder(object): out[instrument] = instrument_mask self._masks = out - def _build_masked_stft(self): + def _build_masked_stfts(self): input_stft = self.stft_feature out = {} for instrument, mask in self.masks.items(): out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft - return out + self._masked_stfts = out def _build_manual_output_waveform(self, masked_stft): """ Perform ratio mask separation @@ -461,11 +471,10 @@ class EstimatorSpecBuilder(object): return output_waveform def _build_outputs(self): - masked_stft = self._build_masked_stft() if self.include_stft_computations(): - self._outputs = self._build_output_waveform(masked_stft) + self._outputs = self._build_output_waveform(self.masked_stfts) else: - self._outputs = masked_stft + self._outputs = self.masked_stfts if 'audio_id' in self._features: self._outputs['audio_id'] = self._features['audio_id'] diff --git a/spleeter/separator.py b/spleeter/separator.py index 73a1a89..668b59d 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -39,6 +39,14 @@ __license__ = 'MIT License' logger = logging.getLogger("spleeter") + +def get_backend(backend): + assert backend in ["auto", "tensorflow", "librosa"] + if backend == "auto": + return "tensorflow" if tf.test.is_gpu_available() else "librosa" + return backend + + class Separator(object): """ A wrapper class for performing separation. """ @@ -48,13 +56,14 @@ class Separator(object): :param params_descriptor: Descriptor for TF params to be used. :param MWF: (Optional) True if MWF should be used, False otherwise. """ + self._params = load_configuration(params_descriptor) self._sample_rate = self._params['sample_rate'] self._MWF = MWF self._predictor = None self._pool = Pool() if multiprocess else None self._tasks = [] - self._params["stft_backend"] = stft_backend + self._params["stft_backend"] = get_backend(stft_backend) def _get_predictor(self): """ Lazy loading access method for internal predictor instance.