mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-06 04:32:43 +00:00
Fixed channels setting in dataset + added dimension check
This commit is contained in:
@@ -52,17 +52,19 @@ TRAIN_CONFIG = {
|
||||
}
|
||||
|
||||
|
||||
def generate_fake_training_dataset(path, instrument_list=['vocals', 'other']):
|
||||
def generate_fake_training_dataset(path,
|
||||
instrument_list=['vocals', 'other'],
|
||||
n_channels=2,
|
||||
n_songs = 2,
|
||||
fs = 44100,
|
||||
duration = 6,
|
||||
):
|
||||
"""
|
||||
generates a fake training dataset in path:
|
||||
- generates audio files
|
||||
- generates a csv file describing the dataset
|
||||
"""
|
||||
aa = AudioAdapter.default()
|
||||
n_songs = 2
|
||||
fs = 44100
|
||||
duration = 6
|
||||
n_channels = 2
|
||||
rng = np.random.RandomState(seed=0)
|
||||
dataset_df = pd.DataFrame(
|
||||
columns=['mix_path'] + [
|
||||
@@ -83,26 +85,40 @@ def generate_fake_training_dataset(path, instrument_list=['vocals', 'other']):
|
||||
|
||||
|
||||
def test_train():
|
||||
|
||||
with TemporaryDirectory() as path:
|
||||
# generate training dataset
|
||||
generate_fake_training_dataset(path)
|
||||
# set training command aruments
|
||||
runner = CliRunner()
|
||||
TRAIN_CONFIG['train_csv'] = join(path, 'train', 'train.csv')
|
||||
TRAIN_CONFIG['validation_csv'] = join(path, 'train', 'train.csv')
|
||||
TRAIN_CONFIG['model_dir'] = join(path, 'model')
|
||||
TRAIN_CONFIG['training_cache'] = join(path, 'cache', 'training')
|
||||
TRAIN_CONFIG['validation_cache'] = join(path, 'cache', 'validation')
|
||||
with open('useless_config.json', 'w') as stream:
|
||||
json.dump(TRAIN_CONFIG, stream)
|
||||
# execute training
|
||||
result = runner.invoke(spleeter, [
|
||||
'train',
|
||||
'-p', 'useless_config.json',
|
||||
'-d', path
|
||||
])
|
||||
# assert that model checkpoint was created.
|
||||
assert os.path.exists(join(path, 'model', 'model.ckpt-10.index'))
|
||||
assert os.path.exists(join(path, 'model', 'checkpoint'))
|
||||
assert os.path.exists(join(path, 'model', 'model.ckpt-0.meta'))
|
||||
assert result.exit_code == 0
|
||||
for n_channels in [1,2]:
|
||||
TRAIN_CONFIG["n_channels"] = n_channels
|
||||
generate_fake_training_dataset(path,
|
||||
n_channels=n_channels,
|
||||
fs=TRAIN_CONFIG["sample_rate"]
|
||||
)
|
||||
# set training command arguments
|
||||
runner = CliRunner()
|
||||
|
||||
model_dir = join(path, f'model_{n_channels}')
|
||||
train_dir = join(path, f'train')
|
||||
cache_dir = join(path, f'cache_{n_channels}')
|
||||
|
||||
TRAIN_CONFIG['train_csv'] = join(train_dir, 'train.csv')
|
||||
TRAIN_CONFIG['validation_csv'] = join(train_dir, 'train.csv')
|
||||
TRAIN_CONFIG['model_dir'] = model_dir
|
||||
TRAIN_CONFIG['training_cache'] = join(cache_dir, 'training')
|
||||
TRAIN_CONFIG['validation_cache'] = join(cache_dir, 'validation')
|
||||
with open('useless_config.json', 'w') as stream:
|
||||
json.dump(TRAIN_CONFIG, stream)
|
||||
|
||||
# execute training
|
||||
result = runner.invoke(spleeter, [
|
||||
'train',
|
||||
'-p', 'useless_config.json',
|
||||
'-d', path,
|
||||
"--verbose"
|
||||
])
|
||||
|
||||
# assert that model checkpoint was created.
|
||||
assert os.path.exists(join(model_dir, 'model.ckpt-10.index'))
|
||||
assert os.path.exists(join(model_dir, 'checkpoint'))
|
||||
assert os.path.exists(join(model_dir, 'model.ckpt-0.meta'))
|
||||
assert result.exit_code == 0
|
||||
|
||||
Reference in New Issue
Block a user