mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
Cleaning: updated docstring and unexposed some internal methods
This commit is contained in:
@@ -85,15 +85,9 @@ class Separator(object):
|
|||||||
task.get()
|
task.get()
|
||||||
task.wait(timeout=timeout)
|
task.wait(timeout=timeout)
|
||||||
|
|
||||||
def separate_tensorflow(self, waveform, audio_descriptor):
|
def _separate_tensorflow(self, waveform, audio_descriptor):
|
||||||
""" Performs source separation over the given waveform.
|
"""
|
||||||
|
Performs source separation over the given waveform with tensorflow backend.
|
||||||
The separation is performed synchronously but the result
|
|
||||||
processing is done asynchronously, allowing for instance
|
|
||||||
to export audio in parallel (through multiprocessing).
|
|
||||||
|
|
||||||
Given result is passed by to the given consumer, which will
|
|
||||||
be waited for task finishing if synchronous flag is True.
|
|
||||||
|
|
||||||
:param waveform: Waveform to apply separation on.
|
:param waveform: Waveform to apply separation on.
|
||||||
:returns: Separated waveforms.
|
:returns: Separated waveforms.
|
||||||
@@ -107,7 +101,7 @@ class Separator(object):
|
|||||||
prediction.pop('audio_id')
|
prediction.pop('audio_id')
|
||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
def stft(self, data, inverse=False, length=None):
|
def _stft(self, data, inverse=False, length=None):
|
||||||
"""
|
"""
|
||||||
Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two
|
Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two
|
||||||
channels are processed separately and are concatenated together in the result. The expected input formats are:
|
channels are processed separately and are concatenated together in the result. The expected input formats are:
|
||||||
@@ -134,7 +128,10 @@ class Separator(object):
|
|||||||
return out[0]
|
return out[0]
|
||||||
return np.concatenate(out, axis=2-inverse)
|
return np.concatenate(out, axis=2-inverse)
|
||||||
|
|
||||||
def separate_librosa(self, waveform, audio_id):
|
def _separate_librosa(self, waveform, audio_id):
|
||||||
|
"""
|
||||||
|
Performs separation with librosa backend for STFT.
|
||||||
|
"""
|
||||||
out = {}
|
out = {}
|
||||||
input_provider = InputProviderFactory.get(self._params)
|
input_provider = InputProviderFactory.get(self._params)
|
||||||
features = input_provider.get_input_dict_placeholders()
|
features = input_provider.get_input_dict_placeholders()
|
||||||
@@ -144,7 +141,7 @@ class Separator(object):
|
|||||||
|
|
||||||
# TODO: fix the logic, build sometimes return, sometimes set attribute
|
# TODO: fix the logic, build sometimes return, sometimes set attribute
|
||||||
outputs = builder.outputs
|
outputs = builder.outputs
|
||||||
stft = self.stft(waveform)
|
stft = self._stft(waveform)
|
||||||
if stft.shape[-1] == 1:
|
if stft.shape[-1] == 1:
|
||||||
stft = np.concatenate([stft, stft], axis=-1)
|
stft = np.concatenate([stft, stft], axis=-1)
|
||||||
elif stft.shape[-1] > 2:
|
elif stft.shape[-1] > 2:
|
||||||
@@ -155,14 +152,19 @@ class Separator(object):
|
|||||||
saver.restore(sess, latest_checkpoint)
|
saver.restore(sess, latest_checkpoint)
|
||||||
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))
|
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))
|
||||||
for inst in builder.instruments:
|
for inst in builder.instruments:
|
||||||
out[inst] = self.stft(outputs[inst], inverse=True, length=waveform.shape[0])
|
out[inst] = self._stft(outputs[inst], inverse=True, length=waveform.shape[0])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def separate(self, waveform, audio_descriptor):
|
def separate(self, waveform, audio_descriptor=""):
|
||||||
|
""" Performs separation on a waveform.
|
||||||
|
|
||||||
|
:param waveform: Waveform to be separated (as a numpy array)
|
||||||
|
:param audio_descriptor: (Optional) string describing the waveform (e.g. filename).
|
||||||
|
"""
|
||||||
if self._params["stft_backend"] == "tensorflow":
|
if self._params["stft_backend"] == "tensorflow":
|
||||||
return self.separate_tensorflow(waveform, audio_descriptor)
|
return self._separate_tensorflow(waveform, audio_descriptor)
|
||||||
else:
|
else:
|
||||||
return self.separate_librosa(waveform, audio_descriptor)
|
return self._separate_librosa(waveform, audio_descriptor)
|
||||||
|
|
||||||
def separate_to_file(
|
def separate_to_file(
|
||||||
self, audio_descriptor, destination,
|
self, audio_descriptor, destination,
|
||||||
@@ -197,10 +199,10 @@ class Separator(object):
|
|||||||
duration=duration,
|
duration=duration,
|
||||||
sample_rate=self._sample_rate)
|
sample_rate=self._sample_rate)
|
||||||
sources = self.separate(waveform, audio_descriptor)
|
sources = self.separate(waveform, audio_descriptor)
|
||||||
self.save_to_file(sources, audio_descriptor, destination, filename_format, codec,
|
self._save_to_file(sources, audio_descriptor, destination, filename_format, codec,
|
||||||
audio_adapter, bitrate, synchronous)
|
audio_adapter, bitrate, synchronous)
|
||||||
|
|
||||||
def save_to_file(self, sources, audio_descriptor, destination, filename_format, codec,
|
def _save_to_file(self, sources, audio_descriptor, destination, filename_format, codec,
|
||||||
audio_adapter, bitrate, synchronous):
|
audio_adapter, bitrate, synchronous):
|
||||||
filename = splitext(basename(audio_descriptor))[0]
|
filename = splitext(basename(audio_descriptor))[0]
|
||||||
generated = []
|
generated = []
|
||||||
|
|||||||
Reference in New Issue
Block a user