diff --git a/spleeter/separator.py b/spleeter/separator.py index 3488fb5..554e0e0 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -85,15 +85,9 @@ class Separator(object): task.get() task.wait(timeout=timeout) - def separate_tensorflow(self, waveform, audio_descriptor): - """ Performs source separation over the given waveform. - - The separation is performed synchronously but the result - processing is done asynchronously, allowing for instance - to export audio in parallel (through multiprocessing). - - Given result is passed by to the given consumer, which will - be waited for task finishing if synchronous flag is True. + def _separate_tensorflow(self, waveform, audio_descriptor): + """ + Performs source separation over the given waveform with tensorflow backend. :param waveform: Waveform to apply separation on. :returns: Separated waveforms. @@ -107,7 +101,7 @@ class Separator(object): prediction.pop('audio_id') return prediction - def stft(self, data, inverse=False, length=None): + def _stft(self, data, inverse=False, length=None): """ Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two channels are processed separately and are concatenated together in the result. The expected input formats are: @@ -134,7 +128,10 @@ class Separator(object): return out[0] return np.concatenate(out, axis=2-inverse) - def separate_librosa(self, waveform, audio_id): + def _separate_librosa(self, waveform, audio_id): + """ + Performs separation with librosa backend for STFT. + """ out = {} input_provider = InputProviderFactory.get(self._params) features = input_provider.get_input_dict_placeholders() @@ -144,7 +141,7 @@ class Separator(object): # TODO: fix the logic, build sometimes return, sometimes set attribute outputs = builder.outputs - stft = self.stft(waveform) + stft = self._stft(waveform) if stft.shape[-1] == 1: stft = np.concatenate([stft, stft], axis=-1) elif stft.shape[-1] > 2: @@ -155,14 +152,19 @@ class Separator(object): saver.restore(sess, latest_checkpoint) outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id)) for inst in builder.instruments: - out[inst] = self.stft(outputs[inst], inverse=True, length=waveform.shape[0]) + out[inst] = self._stft(outputs[inst], inverse=True, length=waveform.shape[0]) return out - def separate(self, waveform, audio_descriptor): + def separate(self, waveform, audio_descriptor=""): + """ Performs separation on a waveform. + + :param waveform: Waveform to be separated (as a numpy array) + :param audio_descriptor: (Optional) string describing the waveform (e.g. filename). + """ if self._params["stft_backend"] == "tensorflow": - return self.separate_tensorflow(waveform, audio_descriptor) + return self._separate_tensorflow(waveform, audio_descriptor) else: - return self.separate_librosa(waveform, audio_descriptor) + return self._separate_librosa(waveform, audio_descriptor) def separate_to_file( self, audio_descriptor, destination, @@ -197,10 +199,10 @@ class Separator(object): duration=duration, sample_rate=self._sample_rate) sources = self.separate(waveform, audio_descriptor) - self.save_to_file(sources, audio_descriptor, destination, filename_format, codec, + self._save_to_file(sources, audio_descriptor, destination, filename_format, codec, audio_adapter, bitrate, synchronous) - def save_to_file(self, sources, audio_descriptor, destination, filename_format, codec, + def _save_to_file(self, sources, audio_descriptor, destination, filename_format, codec, audio_adapter, bitrate, synchronous): filename = splitext(basename(audio_descriptor))[0] generated = []