mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-06 04:32:43 +00:00
Cleaning: updated docstring and unexposed some internal methods
This commit is contained in:
@@ -85,15 +85,9 @@ class Separator(object):
|
||||
task.get()
|
||||
task.wait(timeout=timeout)
|
||||
|
||||
def separate_tensorflow(self, waveform, audio_descriptor):
|
||||
""" Performs source separation over the given waveform.
|
||||
|
||||
The separation is performed synchronously but the result
|
||||
processing is done asynchronously, allowing for instance
|
||||
to export audio in parallel (through multiprocessing).
|
||||
|
||||
Given result is passed by to the given consumer, which will
|
||||
be waited for task finishing if synchronous flag is True.
|
||||
def _separate_tensorflow(self, waveform, audio_descriptor):
|
||||
"""
|
||||
Performs source separation over the given waveform with tensorflow backend.
|
||||
|
||||
:param waveform: Waveform to apply separation on.
|
||||
:returns: Separated waveforms.
|
||||
@@ -107,7 +101,7 @@ class Separator(object):
|
||||
prediction.pop('audio_id')
|
||||
return prediction
|
||||
|
||||
def stft(self, data, inverse=False, length=None):
|
||||
def _stft(self, data, inverse=False, length=None):
|
||||
"""
|
||||
Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two
|
||||
channels are processed separately and are concatenated together in the result. The expected input formats are:
|
||||
@@ -134,7 +128,10 @@ class Separator(object):
|
||||
return out[0]
|
||||
return np.concatenate(out, axis=2-inverse)
|
||||
|
||||
def separate_librosa(self, waveform, audio_id):
|
||||
def _separate_librosa(self, waveform, audio_id):
|
||||
"""
|
||||
Performs separation with librosa backend for STFT.
|
||||
"""
|
||||
out = {}
|
||||
input_provider = InputProviderFactory.get(self._params)
|
||||
features = input_provider.get_input_dict_placeholders()
|
||||
@@ -144,7 +141,7 @@ class Separator(object):
|
||||
|
||||
# TODO: fix the logic, build sometimes return, sometimes set attribute
|
||||
outputs = builder.outputs
|
||||
stft = self.stft(waveform)
|
||||
stft = self._stft(waveform)
|
||||
if stft.shape[-1] == 1:
|
||||
stft = np.concatenate([stft, stft], axis=-1)
|
||||
elif stft.shape[-1] > 2:
|
||||
@@ -155,14 +152,19 @@ class Separator(object):
|
||||
saver.restore(sess, latest_checkpoint)
|
||||
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))
|
||||
for inst in builder.instruments:
|
||||
out[inst] = self.stft(outputs[inst], inverse=True, length=waveform.shape[0])
|
||||
out[inst] = self._stft(outputs[inst], inverse=True, length=waveform.shape[0])
|
||||
return out
|
||||
|
||||
def separate(self, waveform, audio_descriptor):
|
||||
def separate(self, waveform, audio_descriptor=""):
|
||||
""" Performs separation on a waveform.
|
||||
|
||||
:param waveform: Waveform to be separated (as a numpy array)
|
||||
:param audio_descriptor: (Optional) string describing the waveform (e.g. filename).
|
||||
"""
|
||||
if self._params["stft_backend"] == "tensorflow":
|
||||
return self.separate_tensorflow(waveform, audio_descriptor)
|
||||
return self._separate_tensorflow(waveform, audio_descriptor)
|
||||
else:
|
||||
return self.separate_librosa(waveform, audio_descriptor)
|
||||
return self._separate_librosa(waveform, audio_descriptor)
|
||||
|
||||
def separate_to_file(
|
||||
self, audio_descriptor, destination,
|
||||
@@ -197,10 +199,10 @@ class Separator(object):
|
||||
duration=duration,
|
||||
sample_rate=self._sample_rate)
|
||||
sources = self.separate(waveform, audio_descriptor)
|
||||
self.save_to_file(sources, audio_descriptor, destination, filename_format, codec,
|
||||
self._save_to_file(sources, audio_descriptor, destination, filename_format, codec,
|
||||
audio_adapter, bitrate, synchronous)
|
||||
|
||||
def save_to_file(self, sources, audio_descriptor, destination, filename_format, codec,
|
||||
def _save_to_file(self, sources, audio_descriptor, destination, filename_format, codec,
|
||||
audio_adapter, bitrate, synchronous):
|
||||
filename = splitext(basename(audio_descriptor))[0]
|
||||
generated = []
|
||||
|
||||
Reference in New Issue
Block a user