Fixed channels setting in dataset + added dimension check

This commit is contained in:
romi1502
2021-02-26 10:26:26 +01:00
parent 4ef593dc94
commit af4ea322fe
2 changed files with 63 additions and 29 deletions

View File

@@ -241,7 +241,8 @@ class InstrumentDatasetBuilder(object):
def filter_shape(self, sample): def filter_shape(self, sample):
""" Filter badly shaped sample. """ """ Filter badly shaped sample. """
return check_tensor_shape( return check_tensor_shape(
sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2) sample[self._spectrogram_key],
(self._parent._T, self._parent._F, self._parent._n_channels),
) )
def reshape_spectrogram(self, sample): def reshape_spectrogram(self, sample):
@@ -250,7 +251,8 @@ class InstrumentDatasetBuilder(object):
sample, sample,
**{ **{
self._spectrogram_key: set_tensor_shape( self._spectrogram_key: set_tensor_shape(
sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2) sample[self._spectrogram_key],
(self._parent._T, self._parent._F, self._parent._n_channels),
) )
}, },
) )
@@ -299,6 +301,7 @@ class DatasetBuilder(object):
self._frame_length = audio_params["frame_length"] self._frame_length = audio_params["frame_length"]
self._frame_step = audio_params["frame_step"] self._frame_step = audio_params["frame_step"]
self._mix_name = audio_params["mix_name"] self._mix_name = audio_params["mix_name"]
self._n_channels = audio_params["n_channels"]
self._instruments = [self._mix_name] + audio_params["instrument_list"] self._instruments = [self._mix_name] + audio_params["instrument_list"]
self._instrument_builders = None self._instrument_builders = None
self._chunk_duration = chunk_duration self._chunk_duration = chunk_duration
@@ -307,6 +310,21 @@ class DatasetBuilder(object):
self._audio_path = audio_path self._audio_path = audio_path
self._random_seed = random_seed self._random_seed = random_seed
self.check_parameters_compatibility()
def check_parameters_compatibility(self):
if self._frame_length / 2 + 1 < self._F:
raise ValueError(
"F is too large and must be set to at most frame_length/2+1. Decrease F or increase frame_length to fix."
)
if (
self._chunk_duration * self._sample_rate - self._frame_length
) / self._frame_step < self._T:
raise ValueError(
"T is too large considering STFT parameters and chunk duratoin. Make sure spectrogram time dimension of chunks is larger than T (for instance reducing T or frame_step or increasing chunk duration)."
)
def expand_path(self, sample): def expand_path(self, sample):
""" Expands audio paths for the given sample. """ """ Expands audio paths for the given sample. """
return dict( return dict(
@@ -368,7 +386,7 @@ class DatasetBuilder(object):
}, },
lambda x: tf.image.random_crop( lambda x: tf.image.random_crop(
x, x,
(self._T, len(self._instruments) * self._F, 2), (self._T, len(self._instruments) * self._F, self._n_channels),
seed=self._random_seed, seed=self._random_seed,
), ),
), ),

View File

@@ -52,17 +52,19 @@ TRAIN_CONFIG = {
} }
def generate_fake_training_dataset(path, instrument_list=['vocals', 'other']): def generate_fake_training_dataset(path,
instrument_list=['vocals', 'other'],
n_channels=2,
n_songs = 2,
fs = 44100,
duration = 6,
):
""" """
generates a fake training dataset in path: generates a fake training dataset in path:
- generates audio files - generates audio files
- generates a csv file describing the dataset - generates a csv file describing the dataset
""" """
aa = AudioAdapter.default() aa = AudioAdapter.default()
n_songs = 2
fs = 44100
duration = 6
n_channels = 2
rng = np.random.RandomState(seed=0) rng = np.random.RandomState(seed=0)
dataset_df = pd.DataFrame( dataset_df = pd.DataFrame(
columns=['mix_path'] + [ columns=['mix_path'] + [
@@ -83,26 +85,40 @@ def generate_fake_training_dataset(path, instrument_list=['vocals', 'other']):
def test_train(): def test_train():
with TemporaryDirectory() as path: with TemporaryDirectory() as path:
# generate training dataset # generate training dataset
generate_fake_training_dataset(path) for n_channels in [1,2]:
# set training command aruments TRAIN_CONFIG["n_channels"] = n_channels
runner = CliRunner() generate_fake_training_dataset(path,
TRAIN_CONFIG['train_csv'] = join(path, 'train', 'train.csv') n_channels=n_channels,
TRAIN_CONFIG['validation_csv'] = join(path, 'train', 'train.csv') fs=TRAIN_CONFIG["sample_rate"]
TRAIN_CONFIG['model_dir'] = join(path, 'model') )
TRAIN_CONFIG['training_cache'] = join(path, 'cache', 'training') # set training command arguments
TRAIN_CONFIG['validation_cache'] = join(path, 'cache', 'validation') runner = CliRunner()
with open('useless_config.json', 'w') as stream:
json.dump(TRAIN_CONFIG, stream) model_dir = join(path, f'model_{n_channels}')
# execute training train_dir = join(path, f'train')
result = runner.invoke(spleeter, [ cache_dir = join(path, f'cache_{n_channels}')
'train',
'-p', 'useless_config.json', TRAIN_CONFIG['train_csv'] = join(train_dir, 'train.csv')
'-d', path TRAIN_CONFIG['validation_csv'] = join(train_dir, 'train.csv')
]) TRAIN_CONFIG['model_dir'] = model_dir
# assert that model checkpoint was created. TRAIN_CONFIG['training_cache'] = join(cache_dir, 'training')
assert os.path.exists(join(path, 'model', 'model.ckpt-10.index')) TRAIN_CONFIG['validation_cache'] = join(cache_dir, 'validation')
assert os.path.exists(join(path, 'model', 'checkpoint')) with open('useless_config.json', 'w') as stream:
assert os.path.exists(join(path, 'model', 'model.ckpt-0.meta')) json.dump(TRAIN_CONFIG, stream)
assert result.exit_code == 0
# execute training
result = runner.invoke(spleeter, [
'train',
'-p', 'useless_config.json',
'-d', path,
"--verbose"
])
# assert that model checkpoint was created.
assert os.path.exists(join(model_dir, 'model.ckpt-10.index'))
assert os.path.exists(join(model_dir, 'checkpoint'))
assert os.path.exists(join(model_dir, 'model.ckpt-0.meta'))
assert result.exit_code == 0