diff --git a/spleeter/separator.py b/spleeter/separator.py index 3bc6289..3488fb5 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -145,8 +145,10 @@ class Separator(object): # TODO: fix the logic, build sometimes return, sometimes set attribute outputs = builder.outputs stft = self.stft(waveform) - if stft.shape[-1] != 2: + if stft.shape[-1] == 1: stft = np.concatenate([stft, stft], axis=-1) + elif stft.shape[-1] > 2: + stft = stft[:, :2] saver = tf.train.Saver() with tf.Session() as sess: