From af4ea322fe423e2d9e7e6539e2ab63bc43993430 Mon Sep 17 00:00:00 2001 From: romi1502 Date: Fri, 26 Feb 2021 10:26:26 +0100 Subject: [PATCH] Fixed channels setting in dataset + added dimension check --- spleeter/dataset.py | 24 ++++++++++++++-- tests/test_train.py | 68 ++++++++++++++++++++++++++++----------------- 2 files changed, 63 insertions(+), 29 deletions(-) diff --git a/spleeter/dataset.py b/spleeter/dataset.py index 4ee096d..b73e414 100644 --- a/spleeter/dataset.py +++ b/spleeter/dataset.py @@ -241,7 +241,8 @@ class InstrumentDatasetBuilder(object): def filter_shape(self, sample): """ Filter badly shaped sample. """ return check_tensor_shape( - sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2) + sample[self._spectrogram_key], + (self._parent._T, self._parent._F, self._parent._n_channels), ) def reshape_spectrogram(self, sample): @@ -250,7 +251,8 @@ class InstrumentDatasetBuilder(object): sample, **{ self._spectrogram_key: set_tensor_shape( - sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2) + sample[self._spectrogram_key], + (self._parent._T, self._parent._F, self._parent._n_channels), ) }, ) @@ -299,6 +301,7 @@ class DatasetBuilder(object): self._frame_length = audio_params["frame_length"] self._frame_step = audio_params["frame_step"] self._mix_name = audio_params["mix_name"] + self._n_channels = audio_params["n_channels"] self._instruments = [self._mix_name] + audio_params["instrument_list"] self._instrument_builders = None self._chunk_duration = chunk_duration @@ -307,6 +310,21 @@ class DatasetBuilder(object): self._audio_path = audio_path self._random_seed = random_seed + self.check_parameters_compatibility() + + def check_parameters_compatibility(self): + if self._frame_length / 2 + 1 < self._F: + raise ValueError( + "F is too large and must be set to at most frame_length/2+1. Decrease F or increase frame_length to fix." + ) + + if ( + self._chunk_duration * self._sample_rate - self._frame_length + ) / self._frame_step < self._T: + raise ValueError( + "T is too large considering STFT parameters and chunk duratoin. Make sure spectrogram time dimension of chunks is larger than T (for instance reducing T or frame_step or increasing chunk duration)." + ) + def expand_path(self, sample): """ Expands audio paths for the given sample. """ return dict( @@ -368,7 +386,7 @@ class DatasetBuilder(object): }, lambda x: tf.image.random_crop( x, - (self._T, len(self._instruments) * self._F, 2), + (self._T, len(self._instruments) * self._F, self._n_channels), seed=self._random_seed, ), ), diff --git a/tests/test_train.py b/tests/test_train.py index 47ce747..4a5280d 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -52,17 +52,19 @@ TRAIN_CONFIG = { } -def generate_fake_training_dataset(path, instrument_list=['vocals', 'other']): +def generate_fake_training_dataset(path, + instrument_list=['vocals', 'other'], + n_channels=2, + n_songs = 2, + fs = 44100, + duration = 6, + ): """ generates a fake training dataset in path: - generates audio files - generates a csv file describing the dataset """ aa = AudioAdapter.default() - n_songs = 2 - fs = 44100 - duration = 6 - n_channels = 2 rng = np.random.RandomState(seed=0) dataset_df = pd.DataFrame( columns=['mix_path'] + [ @@ -83,26 +85,40 @@ def generate_fake_training_dataset(path, instrument_list=['vocals', 'other']): def test_train(): + with TemporaryDirectory() as path: # generate training dataset - generate_fake_training_dataset(path) - # set training command aruments - runner = CliRunner() - TRAIN_CONFIG['train_csv'] = join(path, 'train', 'train.csv') - TRAIN_CONFIG['validation_csv'] = join(path, 'train', 'train.csv') - TRAIN_CONFIG['model_dir'] = join(path, 'model') - TRAIN_CONFIG['training_cache'] = join(path, 'cache', 'training') - TRAIN_CONFIG['validation_cache'] = join(path, 'cache', 'validation') - with open('useless_config.json', 'w') as stream: - json.dump(TRAIN_CONFIG, stream) - # execute training - result = runner.invoke(spleeter, [ - 'train', - '-p', 'useless_config.json', - '-d', path - ]) - # assert that model checkpoint was created. - assert os.path.exists(join(path, 'model', 'model.ckpt-10.index')) - assert os.path.exists(join(path, 'model', 'checkpoint')) - assert os.path.exists(join(path, 'model', 'model.ckpt-0.meta')) - assert result.exit_code == 0 + for n_channels in [1,2]: + TRAIN_CONFIG["n_channels"] = n_channels + generate_fake_training_dataset(path, + n_channels=n_channels, + fs=TRAIN_CONFIG["sample_rate"] + ) + # set training command arguments + runner = CliRunner() + + model_dir = join(path, f'model_{n_channels}') + train_dir = join(path, f'train') + cache_dir = join(path, f'cache_{n_channels}') + + TRAIN_CONFIG['train_csv'] = join(train_dir, 'train.csv') + TRAIN_CONFIG['validation_csv'] = join(train_dir, 'train.csv') + TRAIN_CONFIG['model_dir'] = model_dir + TRAIN_CONFIG['training_cache'] = join(cache_dir, 'training') + TRAIN_CONFIG['validation_cache'] = join(cache_dir, 'validation') + with open('useless_config.json', 'w') as stream: + json.dump(TRAIN_CONFIG, stream) + + # execute training + result = runner.invoke(spleeter, [ + 'train', + '-p', 'useless_config.json', + '-d', path, + "--verbose" + ]) + + # assert that model checkpoint was created. + assert os.path.exists(join(model_dir, 'model.ckpt-10.index')) + assert os.path.exists(join(model_dir, 'checkpoint')) + assert os.path.exists(join(model_dir, 'model.ckpt-0.meta')) + assert result.exit_code == 0