Merge pull request #444 from deezer/pad_waveform

Added padding at the begining to avoid tf STFT reconstruction error
This commit is contained in:
Moussallam
2020-07-24 15:28:30 +02:00
committed by GitHub
4 changed files with 80 additions and 35 deletions

View File

@@ -170,6 +170,7 @@ def _create_evaluate_parser(parser_factory):
parser.add_argument('-o', '--output_path', **OPT_OUTPUT)
parser.add_argument('--mus_dir', **OPT_MUSDB)
parser.add_argument('-m', '--mwf', **OPT_MWF)
parser.add_argument('-B', '--stft-backend', **OPT_STFT_BACKEND)
return parser

View File

@@ -77,7 +77,7 @@ def _separate_evaluation_dataset(arguments, musdb_root_directory, params):
bitrate='128k',
MWF=arguments.MWF,
verbose=arguments.verbose,
stft_backend="auto"),
stft_backend=arguments.stft_backend),
params)
return audio_output_directory

View File

@@ -275,9 +275,16 @@ class EstimatorSpecBuilder(object):
spec_name = self.spectrogram_name
if stft_name not in self._features:
# pad input with a frame of zeros
waveform = tf.concat([
tf.zeros((self._frame_length, self._n_channels)),
self._features['waveform']
],
0
)
stft_feature = tf.transpose(
stft(
tf.transpose(self._features['waveform']),
tf.transpose(waveform),
self._frame_length,
self._frame_step,
window_fn=lambda frame_length, dtype: (
@@ -341,7 +348,7 @@ class EstimatorSpecBuilder(object):
reshaped = tf.transpose(inversed)
if time_crop is None:
time_crop = tf.shape(self._features['waveform'])[0]
return reshaped[:time_crop, :]
return reshaped[self._frame_length:self._frame_length+time_crop, :]
def _build_mwf_output_waveform(self):
""" Perform separation with multichannel Wiener Filtering using Norbert.