Cleaning: updated docstring and unexposed some internal methods

This commit is contained in:
romi1502
2020-04-10 17:17:35 +02:00
parent 1287ef85d1
commit a23a66e683

View File

@@ -85,15 +85,9 @@ class Separator(object):
task.get()
task.wait(timeout=timeout)
def separate_tensorflow(self, waveform, audio_descriptor):
""" Performs source separation over the given waveform.
The separation is performed synchronously but the result
processing is done asynchronously, allowing for instance
to export audio in parallel (through multiprocessing).
Given result is passed by to the given consumer, which will
be waited for task finishing if synchronous flag is True.
def _separate_tensorflow(self, waveform, audio_descriptor):
"""
Performs source separation over the given waveform with tensorflow backend.
:param waveform: Waveform to apply separation on.
:returns: Separated waveforms.
@@ -107,7 +101,7 @@ class Separator(object):
prediction.pop('audio_id')
return prediction
def stft(self, data, inverse=False, length=None):
def _stft(self, data, inverse=False, length=None):
"""
Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two
channels are processed separately and are concatenated together in the result. The expected input formats are:
@@ -134,7 +128,10 @@ class Separator(object):
return out[0]
return np.concatenate(out, axis=2-inverse)
def separate_librosa(self, waveform, audio_id):
def _separate_librosa(self, waveform, audio_id):
"""
Performs separation with librosa backend for STFT.
"""
out = {}
input_provider = InputProviderFactory.get(self._params)
features = input_provider.get_input_dict_placeholders()
@@ -144,7 +141,7 @@ class Separator(object):
# TODO: fix the logic, build sometimes return, sometimes set attribute
outputs = builder.outputs
stft = self.stft(waveform)
stft = self._stft(waveform)
if stft.shape[-1] == 1:
stft = np.concatenate([stft, stft], axis=-1)
elif stft.shape[-1] > 2:
@@ -155,14 +152,19 @@ class Separator(object):
saver.restore(sess, latest_checkpoint)
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))
for inst in builder.instruments:
out[inst] = self.stft(outputs[inst], inverse=True, length=waveform.shape[0])
out[inst] = self._stft(outputs[inst], inverse=True, length=waveform.shape[0])
return out
def separate(self, waveform, audio_descriptor):
def separate(self, waveform, audio_descriptor=""):
""" Performs separation on a waveform.
:param waveform: Waveform to be separated (as a numpy array)
:param audio_descriptor: (Optional) string describing the waveform (e.g. filename).
"""
if self._params["stft_backend"] == "tensorflow":
return self.separate_tensorflow(waveform, audio_descriptor)
return self._separate_tensorflow(waveform, audio_descriptor)
else:
return self.separate_librosa(waveform, audio_descriptor)
return self._separate_librosa(waveform, audio_descriptor)
def separate_to_file(
self, audio_descriptor, destination,
@@ -197,10 +199,10 @@ class Separator(object):
duration=duration,
sample_rate=self._sample_rate)
sources = self.separate(waveform, audio_descriptor)
self.save_to_file(sources, audio_descriptor, destination, filename_format, codec,
self._save_to_file(sources, audio_descriptor, destination, filename_format, codec,
audio_adapter, bitrate, synchronous)
def save_to_file(self, sources, audio_descriptor, destination, filename_format, codec,
def _save_to_file(self, sources, audio_descriptor, destination, filename_format, codec,
audio_adapter, bitrate, synchronous):
filename = splitext(basename(audio_descriptor))[0]
generated = []