mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-06 04:32:43 +00:00
Fixed channels setting in dataset + added dimension check
This commit is contained in:
@@ -241,7 +241,8 @@ class InstrumentDatasetBuilder(object):
|
||||
def filter_shape(self, sample):
|
||||
""" Filter badly shaped sample. """
|
||||
return check_tensor_shape(
|
||||
sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2)
|
||||
sample[self._spectrogram_key],
|
||||
(self._parent._T, self._parent._F, self._parent._n_channels),
|
||||
)
|
||||
|
||||
def reshape_spectrogram(self, sample):
|
||||
@@ -250,7 +251,8 @@ class InstrumentDatasetBuilder(object):
|
||||
sample,
|
||||
**{
|
||||
self._spectrogram_key: set_tensor_shape(
|
||||
sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2)
|
||||
sample[self._spectrogram_key],
|
||||
(self._parent._T, self._parent._F, self._parent._n_channels),
|
||||
)
|
||||
},
|
||||
)
|
||||
@@ -299,6 +301,7 @@ class DatasetBuilder(object):
|
||||
self._frame_length = audio_params["frame_length"]
|
||||
self._frame_step = audio_params["frame_step"]
|
||||
self._mix_name = audio_params["mix_name"]
|
||||
self._n_channels = audio_params["n_channels"]
|
||||
self._instruments = [self._mix_name] + audio_params["instrument_list"]
|
||||
self._instrument_builders = None
|
||||
self._chunk_duration = chunk_duration
|
||||
@@ -307,6 +310,21 @@ class DatasetBuilder(object):
|
||||
self._audio_path = audio_path
|
||||
self._random_seed = random_seed
|
||||
|
||||
self.check_parameters_compatibility()
|
||||
|
||||
def check_parameters_compatibility(self):
|
||||
if self._frame_length / 2 + 1 < self._F:
|
||||
raise ValueError(
|
||||
"F is too large and must be set to at most frame_length/2+1. Decrease F or increase frame_length to fix."
|
||||
)
|
||||
|
||||
if (
|
||||
self._chunk_duration * self._sample_rate - self._frame_length
|
||||
) / self._frame_step < self._T:
|
||||
raise ValueError(
|
||||
"T is too large considering STFT parameters and chunk duratoin. Make sure spectrogram time dimension of chunks is larger than T (for instance reducing T or frame_step or increasing chunk duration)."
|
||||
)
|
||||
|
||||
def expand_path(self, sample):
|
||||
""" Expands audio paths for the given sample. """
|
||||
return dict(
|
||||
@@ -368,7 +386,7 @@ class DatasetBuilder(object):
|
||||
},
|
||||
lambda x: tf.image.random_crop(
|
||||
x,
|
||||
(self._T, len(self._instruments) * self._F, 2),
|
||||
(self._T, len(self._instruments) * self._F, self._n_channels),
|
||||
seed=self._random_seed,
|
||||
),
|
||||
),
|
||||
|
||||
@@ -52,17 +52,19 @@ TRAIN_CONFIG = {
|
||||
}
|
||||
|
||||
|
||||
def generate_fake_training_dataset(path, instrument_list=['vocals', 'other']):
|
||||
def generate_fake_training_dataset(path,
|
||||
instrument_list=['vocals', 'other'],
|
||||
n_channels=2,
|
||||
n_songs = 2,
|
||||
fs = 44100,
|
||||
duration = 6,
|
||||
):
|
||||
"""
|
||||
generates a fake training dataset in path:
|
||||
- generates audio files
|
||||
- generates a csv file describing the dataset
|
||||
"""
|
||||
aa = AudioAdapter.default()
|
||||
n_songs = 2
|
||||
fs = 44100
|
||||
duration = 6
|
||||
n_channels = 2
|
||||
rng = np.random.RandomState(seed=0)
|
||||
dataset_df = pd.DataFrame(
|
||||
columns=['mix_path'] + [
|
||||
@@ -83,26 +85,40 @@ def generate_fake_training_dataset(path, instrument_list=['vocals', 'other']):
|
||||
|
||||
|
||||
def test_train():
|
||||
|
||||
with TemporaryDirectory() as path:
|
||||
# generate training dataset
|
||||
generate_fake_training_dataset(path)
|
||||
# set training command aruments
|
||||
runner = CliRunner()
|
||||
TRAIN_CONFIG['train_csv'] = join(path, 'train', 'train.csv')
|
||||
TRAIN_CONFIG['validation_csv'] = join(path, 'train', 'train.csv')
|
||||
TRAIN_CONFIG['model_dir'] = join(path, 'model')
|
||||
TRAIN_CONFIG['training_cache'] = join(path, 'cache', 'training')
|
||||
TRAIN_CONFIG['validation_cache'] = join(path, 'cache', 'validation')
|
||||
with open('useless_config.json', 'w') as stream:
|
||||
json.dump(TRAIN_CONFIG, stream)
|
||||
# execute training
|
||||
result = runner.invoke(spleeter, [
|
||||
'train',
|
||||
'-p', 'useless_config.json',
|
||||
'-d', path
|
||||
])
|
||||
# assert that model checkpoint was created.
|
||||
assert os.path.exists(join(path, 'model', 'model.ckpt-10.index'))
|
||||
assert os.path.exists(join(path, 'model', 'checkpoint'))
|
||||
assert os.path.exists(join(path, 'model', 'model.ckpt-0.meta'))
|
||||
assert result.exit_code == 0
|
||||
for n_channels in [1,2]:
|
||||
TRAIN_CONFIG["n_channels"] = n_channels
|
||||
generate_fake_training_dataset(path,
|
||||
n_channels=n_channels,
|
||||
fs=TRAIN_CONFIG["sample_rate"]
|
||||
)
|
||||
# set training command arguments
|
||||
runner = CliRunner()
|
||||
|
||||
model_dir = join(path, f'model_{n_channels}')
|
||||
train_dir = join(path, f'train')
|
||||
cache_dir = join(path, f'cache_{n_channels}')
|
||||
|
||||
TRAIN_CONFIG['train_csv'] = join(train_dir, 'train.csv')
|
||||
TRAIN_CONFIG['validation_csv'] = join(train_dir, 'train.csv')
|
||||
TRAIN_CONFIG['model_dir'] = model_dir
|
||||
TRAIN_CONFIG['training_cache'] = join(cache_dir, 'training')
|
||||
TRAIN_CONFIG['validation_cache'] = join(cache_dir, 'validation')
|
||||
with open('useless_config.json', 'w') as stream:
|
||||
json.dump(TRAIN_CONFIG, stream)
|
||||
|
||||
# execute training
|
||||
result = runner.invoke(spleeter, [
|
||||
'train',
|
||||
'-p', 'useless_config.json',
|
||||
'-d', path,
|
||||
"--verbose"
|
||||
])
|
||||
|
||||
# assert that model checkpoint was created.
|
||||
assert os.path.exists(join(model_dir, 'model.ckpt-10.index'))
|
||||
assert os.path.exists(join(model_dir, 'checkpoint'))
|
||||
assert os.path.exists(join(model_dir, 'model.ckpt-0.meta'))
|
||||
assert result.exit_code == 0
|
||||
|
||||
Reference in New Issue
Block a user