mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
Fixing the stft/istft computations
This commit is contained in:
@@ -98,14 +98,23 @@ class Separator(object):
|
|||||||
prediction.pop('audio_id')
|
prediction.pop('audio_id')
|
||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
def stft(self, waveform, inverse=False):
|
def stft(self, data, inverse=False):
|
||||||
|
"""
|
||||||
|
Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two
|
||||||
|
channels are processed separately and are concatenated together in the result. The expected input formats are:
|
||||||
|
(n_samples, 2) for stft and (T, F, 2) for istft.
|
||||||
|
:param data: np.array with either the waveform or the complex spectrogram depending on the parameter inverse
|
||||||
|
:param inverse: should a stft or an istft be computed.
|
||||||
|
:return: Stereo data as numpy array for the transform. The channels are stored in the last dimension
|
||||||
|
"""
|
||||||
N = self._params["frame_length"]
|
N = self._params["frame_length"]
|
||||||
H = self._params["frame_step"]
|
H = self._params["frame_step"]
|
||||||
win = hann(N, sym=False)
|
win = hann(N, sym=False)
|
||||||
fstft = istft if inverse else stft
|
fstft = istft if inverse else stft
|
||||||
win_len_arg = "win_length" if inverse else "n_fft"
|
win_len_arg = {"win_length": None} if inverse else {"n_fft": N}
|
||||||
s1 = fstft(waveform[:, 0], hop_length=H, window=win, center=False, **{win_len_arg: N})
|
dl, dr = (data[:, :, 0].T, data[:, :, 1].T) if inverse else (data[:, 0], data[:, 1])
|
||||||
s2 = fstft(waveform[:, 1], hop_length=H, window=win, center=False, **{win_len_arg: N})
|
s1 = fstft(dl, hop_length=H, window=win, center=False, **win_len_arg)
|
||||||
|
s2 = fstft(dr, hop_length=H, window=win, center=False, **win_len_arg)
|
||||||
s1 = np.expand_dims(s1.T, 2-inverse)
|
s1 = np.expand_dims(s1.T, 2-inverse)
|
||||||
s2 = np.expand_dims(s2.T, 2-inverse)
|
s2 = np.expand_dims(s2.T, 2-inverse)
|
||||||
return np.concatenate([s1, s2], axis=2-inverse)
|
return np.concatenate([s1, s2], axis=2-inverse)
|
||||||
@@ -127,7 +136,7 @@ class Separator(object):
|
|||||||
saver.restore(sess, latest_checkpoint)
|
saver.restore(sess, latest_checkpoint)
|
||||||
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))
|
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))
|
||||||
for inst in builder.instruments:
|
for inst in builder.instruments:
|
||||||
out[inst] = self.stft(outputs[inst], inverse=True)
|
out[inst] = self.stft(outputs[inst], inverse=True)[:waveform.shape[0], :]
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def separate_to_file(
|
def separate_to_file(
|
||||||
|
|||||||
Reference in New Issue
Block a user