mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-30 20:24:31 +00:00
Fixed channels setting in dataset + added dimension check
This commit is contained in:
@@ -241,7 +241,8 @@ class InstrumentDatasetBuilder(object):
|
|||||||
def filter_shape(self, sample):
|
def filter_shape(self, sample):
|
||||||
""" Filter badly shaped sample. """
|
""" Filter badly shaped sample. """
|
||||||
return check_tensor_shape(
|
return check_tensor_shape(
|
||||||
sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2)
|
sample[self._spectrogram_key],
|
||||||
|
(self._parent._T, self._parent._F, self._parent._n_channels),
|
||||||
)
|
)
|
||||||
|
|
||||||
def reshape_spectrogram(self, sample):
|
def reshape_spectrogram(self, sample):
|
||||||
@@ -250,7 +251,8 @@ class InstrumentDatasetBuilder(object):
|
|||||||
sample,
|
sample,
|
||||||
**{
|
**{
|
||||||
self._spectrogram_key: set_tensor_shape(
|
self._spectrogram_key: set_tensor_shape(
|
||||||
sample[self._spectrogram_key], (self._parent._T, self._parent._F, 2)
|
sample[self._spectrogram_key],
|
||||||
|
(self._parent._T, self._parent._F, self._parent._n_channels),
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -299,6 +301,7 @@ class DatasetBuilder(object):
|
|||||||
self._frame_length = audio_params["frame_length"]
|
self._frame_length = audio_params["frame_length"]
|
||||||
self._frame_step = audio_params["frame_step"]
|
self._frame_step = audio_params["frame_step"]
|
||||||
self._mix_name = audio_params["mix_name"]
|
self._mix_name = audio_params["mix_name"]
|
||||||
|
self._n_channels = audio_params["n_channels"]
|
||||||
self._instruments = [self._mix_name] + audio_params["instrument_list"]
|
self._instruments = [self._mix_name] + audio_params["instrument_list"]
|
||||||
self._instrument_builders = None
|
self._instrument_builders = None
|
||||||
self._chunk_duration = chunk_duration
|
self._chunk_duration = chunk_duration
|
||||||
@@ -307,6 +310,21 @@ class DatasetBuilder(object):
|
|||||||
self._audio_path = audio_path
|
self._audio_path = audio_path
|
||||||
self._random_seed = random_seed
|
self._random_seed = random_seed
|
||||||
|
|
||||||
|
self.check_parameters_compatibility()
|
||||||
|
|
||||||
|
def check_parameters_compatibility(self):
|
||||||
|
if self._frame_length / 2 + 1 < self._F:
|
||||||
|
raise ValueError(
|
||||||
|
"F is too large and must be set to at most frame_length/2+1. Decrease F or increase frame_length to fix."
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self._chunk_duration * self._sample_rate - self._frame_length
|
||||||
|
) / self._frame_step < self._T:
|
||||||
|
raise ValueError(
|
||||||
|
"T is too large considering STFT parameters and chunk duratoin. Make sure spectrogram time dimension of chunks is larger than T (for instance reducing T or frame_step or increasing chunk duration)."
|
||||||
|
)
|
||||||
|
|
||||||
def expand_path(self, sample):
|
def expand_path(self, sample):
|
||||||
""" Expands audio paths for the given sample. """
|
""" Expands audio paths for the given sample. """
|
||||||
return dict(
|
return dict(
|
||||||
@@ -368,7 +386,7 @@ class DatasetBuilder(object):
|
|||||||
},
|
},
|
||||||
lambda x: tf.image.random_crop(
|
lambda x: tf.image.random_crop(
|
||||||
x,
|
x,
|
||||||
(self._T, len(self._instruments) * self._F, 2),
|
(self._T, len(self._instruments) * self._F, self._n_channels),
|
||||||
seed=self._random_seed,
|
seed=self._random_seed,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -52,17 +52,19 @@ TRAIN_CONFIG = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def generate_fake_training_dataset(path, instrument_list=['vocals', 'other']):
|
def generate_fake_training_dataset(path,
|
||||||
|
instrument_list=['vocals', 'other'],
|
||||||
|
n_channels=2,
|
||||||
|
n_songs = 2,
|
||||||
|
fs = 44100,
|
||||||
|
duration = 6,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
generates a fake training dataset in path:
|
generates a fake training dataset in path:
|
||||||
- generates audio files
|
- generates audio files
|
||||||
- generates a csv file describing the dataset
|
- generates a csv file describing the dataset
|
||||||
"""
|
"""
|
||||||
aa = AudioAdapter.default()
|
aa = AudioAdapter.default()
|
||||||
n_songs = 2
|
|
||||||
fs = 44100
|
|
||||||
duration = 6
|
|
||||||
n_channels = 2
|
|
||||||
rng = np.random.RandomState(seed=0)
|
rng = np.random.RandomState(seed=0)
|
||||||
dataset_df = pd.DataFrame(
|
dataset_df = pd.DataFrame(
|
||||||
columns=['mix_path'] + [
|
columns=['mix_path'] + [
|
||||||
@@ -83,26 +85,40 @@ def generate_fake_training_dataset(path, instrument_list=['vocals', 'other']):
|
|||||||
|
|
||||||
|
|
||||||
def test_train():
|
def test_train():
|
||||||
|
|
||||||
with TemporaryDirectory() as path:
|
with TemporaryDirectory() as path:
|
||||||
# generate training dataset
|
# generate training dataset
|
||||||
generate_fake_training_dataset(path)
|
for n_channels in [1,2]:
|
||||||
# set training command aruments
|
TRAIN_CONFIG["n_channels"] = n_channels
|
||||||
runner = CliRunner()
|
generate_fake_training_dataset(path,
|
||||||
TRAIN_CONFIG['train_csv'] = join(path, 'train', 'train.csv')
|
n_channels=n_channels,
|
||||||
TRAIN_CONFIG['validation_csv'] = join(path, 'train', 'train.csv')
|
fs=TRAIN_CONFIG["sample_rate"]
|
||||||
TRAIN_CONFIG['model_dir'] = join(path, 'model')
|
)
|
||||||
TRAIN_CONFIG['training_cache'] = join(path, 'cache', 'training')
|
# set training command arguments
|
||||||
TRAIN_CONFIG['validation_cache'] = join(path, 'cache', 'validation')
|
runner = CliRunner()
|
||||||
with open('useless_config.json', 'w') as stream:
|
|
||||||
json.dump(TRAIN_CONFIG, stream)
|
model_dir = join(path, f'model_{n_channels}')
|
||||||
# execute training
|
train_dir = join(path, f'train')
|
||||||
result = runner.invoke(spleeter, [
|
cache_dir = join(path, f'cache_{n_channels}')
|
||||||
'train',
|
|
||||||
'-p', 'useless_config.json',
|
TRAIN_CONFIG['train_csv'] = join(train_dir, 'train.csv')
|
||||||
'-d', path
|
TRAIN_CONFIG['validation_csv'] = join(train_dir, 'train.csv')
|
||||||
])
|
TRAIN_CONFIG['model_dir'] = model_dir
|
||||||
# assert that model checkpoint was created.
|
TRAIN_CONFIG['training_cache'] = join(cache_dir, 'training')
|
||||||
assert os.path.exists(join(path, 'model', 'model.ckpt-10.index'))
|
TRAIN_CONFIG['validation_cache'] = join(cache_dir, 'validation')
|
||||||
assert os.path.exists(join(path, 'model', 'checkpoint'))
|
with open('useless_config.json', 'w') as stream:
|
||||||
assert os.path.exists(join(path, 'model', 'model.ckpt-0.meta'))
|
json.dump(TRAIN_CONFIG, stream)
|
||||||
assert result.exit_code == 0
|
|
||||||
|
# execute training
|
||||||
|
result = runner.invoke(spleeter, [
|
||||||
|
'train',
|
||||||
|
'-p', 'useless_config.json',
|
||||||
|
'-d', path,
|
||||||
|
"--verbose"
|
||||||
|
])
|
||||||
|
|
||||||
|
# assert that model checkpoint was created.
|
||||||
|
assert os.path.exists(join(model_dir, 'model.ckpt-10.index'))
|
||||||
|
assert os.path.exists(join(model_dir, 'checkpoint'))
|
||||||
|
assert os.path.exists(join(model_dir, 'model.ckpt-0.meta'))
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|||||||
Reference in New Issue
Block a user