Updating tests to test for librosa backend

This commit is contained in:
akhlif
2020-02-27 15:38:46 +01:00
parent 922fcd85bb
commit 3cba6985f4
3 changed files with 36 additions and 26 deletions

View File

@@ -85,7 +85,7 @@ class Separator(object):
task.get()
task.wait(timeout=timeout)
def separate(self, waveform):
def separate_tensorflow(self, waveform, audio_descriptor):
""" Performs source separation over the given waveform.
The separation is performed synchronously but the result
@@ -103,11 +103,11 @@ class Separator(object):
predictor = self._get_predictor()
prediction = predictor({
'waveform': waveform,
'audio_id': ''})
'audio_id': audio_descriptor})
prediction.pop('audio_id')
return prediction
def stft(self, data, inverse=False):
def stft(self, data, inverse=False, length=None):
"""
Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two
channels are processed separately and are concatenated together in the result. The expected input formats are:
@@ -116,11 +116,13 @@ class Separator(object):
:param inverse: should a stft or an istft be computed.
:return: Stereo data as numpy array for the transform. The channels are stored in the last dimension
"""
assert not (inverse and length is None)
data = np.asfortranarray(data)
N = self._params["frame_length"]
H = self._params["frame_step"]
win = hann(N, sym=False)
fstft = istft if inverse else stft
win_len_arg = {"win_length": None} if inverse else {"n_fft": N}
win_len_arg = {"win_length": None, "length": length} if inverse else {"n_fft": N}
dl, dr = (data[:, :, 0].T, data[:, :, 1].T) if inverse else (data[:, 0], data[:, 1])
s1 = fstft(dl, hop_length=H, window=win, center=False, **win_len_arg)
s2 = fstft(dr, hop_length=H, window=win, center=False, **win_len_arg)
@@ -145,9 +147,15 @@ class Separator(object):
saver.restore(sess, latest_checkpoint)
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))
for inst in builder.instruments:
out[inst] = self.stft(outputs[inst], inverse=True)[:waveform.shape[0], :]
out[inst] = self.stft(outputs[inst], inverse=True, length=waveform.shape[0])
return out
def separate(self, waveform, audio_descriptor):
if self._params["stft_backend"] == "tensorflow":
return self.separate_tensorflow(waveform, audio_descriptor)
else:
return self.separate_librosa(waveform, audio_descriptor)
def separate_to_file(
self, audio_descriptor, destination,
audio_adapter=get_default_audio_adapter(),
@@ -180,10 +188,7 @@ class Separator(object):
offset=offset,
duration=duration,
sample_rate=self._sample_rate)
if self._params["stft_backend"] == "tensorflow":
sources = self.separate(waveform)
else:
sources = self.separate_librosa(waveform, audio_descriptor)
sources = self.separate(waveform, audio_descriptor)
self.save_to_file(sources, audio_descriptor, destination, filename_format, codec,
audio_adapter, bitrate, synchronous)