mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
Merge pull request #444 from deezer/pad_waveform
Added padding at the begining to avoid tf STFT reconstruction error
This commit is contained in:
@@ -170,6 +170,7 @@ def _create_evaluate_parser(parser_factory):
|
||||
parser.add_argument('-o', '--output_path', **OPT_OUTPUT)
|
||||
parser.add_argument('--mus_dir', **OPT_MUSDB)
|
||||
parser.add_argument('-m', '--mwf', **OPT_MWF)
|
||||
parser.add_argument('-B', '--stft-backend', **OPT_STFT_BACKEND)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ def _separate_evaluation_dataset(arguments, musdb_root_directory, params):
|
||||
bitrate='128k',
|
||||
MWF=arguments.MWF,
|
||||
verbose=arguments.verbose,
|
||||
stft_backend="auto"),
|
||||
stft_backend=arguments.stft_backend),
|
||||
params)
|
||||
return audio_output_directory
|
||||
|
||||
|
||||
@@ -275,9 +275,16 @@ class EstimatorSpecBuilder(object):
|
||||
spec_name = self.spectrogram_name
|
||||
|
||||
if stft_name not in self._features:
|
||||
# pad input with a frame of zeros
|
||||
waveform = tf.concat([
|
||||
tf.zeros((self._frame_length, self._n_channels)),
|
||||
self._features['waveform']
|
||||
],
|
||||
0
|
||||
)
|
||||
stft_feature = tf.transpose(
|
||||
stft(
|
||||
tf.transpose(self._features['waveform']),
|
||||
tf.transpose(waveform),
|
||||
self._frame_length,
|
||||
self._frame_step,
|
||||
window_fn=lambda frame_length, dtype: (
|
||||
@@ -341,7 +348,7 @@ class EstimatorSpecBuilder(object):
|
||||
reshaped = tf.transpose(inversed)
|
||||
if time_crop is None:
|
||||
time_crop = tf.shape(self._features['waveform'])[0]
|
||||
return reshaped[:time_crop, :]
|
||||
return reshaped[self._frame_length:self._frame_length+time_crop, :]
|
||||
|
||||
def _build_mwf_output_waveform(self):
|
||||
""" Perform separation with multichannel Wiener Filtering using Norbert.
|
||||
|
||||
Reference in New Issue
Block a user