Fixed channels setting in dataset + added dimension check

This commit is contained in:
romi1502
2021-02-26 10:26:26 +01:00
parent 4ef593dc94
commit af4ea322fe
2 changed files with 63 additions and 29 deletions

View File

@@ -241,7 +241,8 @@ class InstrumentDatasetBuilder(object):
def filter_shape(self, sample):
""" Filter badly shaped sample. """
return check_tensor_shape(
sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2)
sample[self._spectrogram_key],
(self._parent._T, self._parent._F, self._parent._n_channels),
)
def reshape_spectrogram(self, sample):
@@ -250,7 +251,8 @@ class InstrumentDatasetBuilder(object):
sample,
**{
self._spectrogram_key: set_tensor_shape(
sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2)
sample[self._spectrogram_key],
(self._parent._T, self._parent._F, self._parent._n_channels),
)
},
)
@@ -299,6 +301,7 @@ class DatasetBuilder(object):
self._frame_length = audio_params["frame_length"]
self._frame_step = audio_params["frame_step"]
self._mix_name = audio_params["mix_name"]
self._n_channels = audio_params["n_channels"]
self._instruments = [self._mix_name] + audio_params["instrument_list"]
self._instrument_builders = None
self._chunk_duration = chunk_duration
@@ -307,6 +310,21 @@ class DatasetBuilder(object):
self._audio_path = audio_path
self._random_seed = random_seed
self.check_parameters_compatibility()
def check_parameters_compatibility(self):
if self._frame_length / 2 + 1 < self._F:
raise ValueError(
"F is too large and must be set to at most frame_length/2+1. Decrease F or increase frame_length to fix."
)
if (
self._chunk_duration * self._sample_rate - self._frame_length
) / self._frame_step < self._T:
raise ValueError(
"T is too large considering STFT parameters and chunk duratoin. Make sure spectrogram time dimension of chunks is larger than T (for instance reducing T or frame_step or increasing chunk duration)."
)
def expand_path(self, sample):
""" Expands audio paths for the given sample. """
return dict(
@@ -368,7 +386,7 @@ class DatasetBuilder(object):
},
lambda x: tf.image.random_crop(
x,
(self._T, len(self._instruments) * self._F, 2),
(self._T, len(self._instruments) * self._F, self._n_channels),
seed=self._random_seed,
),
),

View File

@@ -52,17 +52,19 @@ TRAIN_CONFIG = {
}
def generate_fake_training_dataset(path, instrument_list=['vocals', 'other']):
def generate_fake_training_dataset(path,
instrument_list=['vocals', 'other'],
n_channels=2,
n_songs = 2,
fs = 44100,
duration = 6,
):
"""
generates a fake training dataset in path:
- generates audio files
- generates a csv file describing the dataset
"""
aa = AudioAdapter.default()
n_songs = 2
fs = 44100
duration = 6
n_channels = 2
rng = np.random.RandomState(seed=0)
dataset_df = pd.DataFrame(
columns=['mix_path'] + [
@@ -83,26 +85,40 @@ def generate_fake_training_dataset(path, instrument_list=['vocals', 'other']):
def test_train():
with TemporaryDirectory() as path:
# generate training dataset
generate_fake_training_dataset(path)
# set training command aruments
runner = CliRunner()
TRAIN_CONFIG['train_csv'] = join(path, 'train', 'train.csv')
TRAIN_CONFIG['validation_csv'] = join(path, 'train', 'train.csv')
TRAIN_CONFIG['model_dir'] = join(path, 'model')
TRAIN_CONFIG['training_cache'] = join(path, 'cache', 'training')
TRAIN_CONFIG['validation_cache'] = join(path, 'cache', 'validation')
with open('useless_config.json', 'w') as stream:
json.dump(TRAIN_CONFIG, stream)
# execute training
result = runner.invoke(spleeter, [
'train',
'-p', 'useless_config.json',
'-d', path
])
# assert that model checkpoint was created.
assert os.path.exists(join(path, 'model', 'model.ckpt-10.index'))
assert os.path.exists(join(path, 'model', 'checkpoint'))
assert os.path.exists(join(path, 'model', 'model.ckpt-0.meta'))
assert result.exit_code == 0
for n_channels in [1,2]:
TRAIN_CONFIG["n_channels"] = n_channels
generate_fake_training_dataset(path,
n_channels=n_channels,
fs=TRAIN_CONFIG["sample_rate"]
)
# set training command arguments
runner = CliRunner()
model_dir = join(path, f'model_{n_channels}')
train_dir = join(path, f'train')
cache_dir = join(path, f'cache_{n_channels}')
TRAIN_CONFIG['train_csv'] = join(train_dir, 'train.csv')
TRAIN_CONFIG['validation_csv'] = join(train_dir, 'train.csv')
TRAIN_CONFIG['model_dir'] = model_dir
TRAIN_CONFIG['training_cache'] = join(cache_dir, 'training')
TRAIN_CONFIG['validation_cache'] = join(cache_dir, 'validation')
with open('useless_config.json', 'w') as stream:
json.dump(TRAIN_CONFIG, stream)
# execute training
result = runner.invoke(spleeter, [
'train',
'-p', 'useless_config.json',
'-d', path,
"--verbose"
])
# assert that model checkpoint was created.
assert os.path.exists(join(model_dir, 'model.ckpt-10.index'))
assert os.path.exists(join(model_dir, 'checkpoint'))
assert os.path.exists(join(model_dir, 'model.ckpt-0.meta'))
assert result.exit_code == 0