Fixing tests when running in single process

This commit is contained in:
akhlif
2020-03-27 11:12:05 +01:00
parent 4e16e0fe23
commit 38bfff833e
2 changed files with 23 additions and 10 deletions

View File

@@ -127,7 +127,7 @@ class Separator(object):
out = []
for c in range(n_channels):
d = data[:, :, c].T if inverse else data[:, c]
s = fstft(dl, hop_length=H, window=win, center=False, **win_len_arg)
s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg)
s = np.expand_dims(s.T, 2-inverse)
out.append(s)
if len(out) == 1:
@@ -144,9 +144,11 @@ class Separator(object):
# TODO: fix the logic, build sometimes return, sometimes set attribute
outputs = builder.outputs
stft = self.stft(waveform)
if stft.shape[-1] != 2:
stft = np.concatenate([stft, stft], axis=-1)
saver = tf.train.Saver()
stft = self.stft(waveform)
with tf.Session() as sess:
saver.restore(sess, latest_checkpoint)
outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id))