diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 3233489..afb7d6d 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -7,7 +7,7 @@ jobs: strategy: matrix: platform: [cpu, gpu] - distribution: [3.6, 3.7, conda] + distribution: [3.6, 3.7, 3.8, conda] model: [modelless, 2stems, 4stems, 5stems] fail-fast: true steps: @@ -69,13 +69,13 @@ jobs: run: echo ${{ secrets.DOCKERHUB_PASSWORD }} | docker login -u ${{ secrets.DOCKERHUB_USERNAME }} --password-stdin - name: Push deezer/spleeter:${{ env.tag }} image run: docker push deezer/spleeter:${{ env.tag }} - - if: ${{ env.tag == 'spleeter:3.7' }} + - if: ${{ env.tag == 'spleeter:3.8' }} name: Push deezer/spleeter:latest image run: | - docker tag deezer/spleeter:3.7 deezer/spleeter:latest + docker tag deezer/spleeter:3.8 deezer/spleeter:latest docker push deezer/spleeter:latest - - if: ${{ env.tag == 'spleeter:3.7-gpu' }} + - if: ${{ env.tag == 'spleeter:3.8-gpu' }} name: Push deezer/spleeter:gpu image run: | - docker tag deezer/spleeter:3.7-gpu deezer/spleeter:gpu + docker tag deezer/spleeter:3.8-gpu deezer/spleeter:gpu docker push deezer/spleeter:gpu diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index d4b2965..36ed370 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.6, 3.7] + python-version: [3.6, 3.7, 3.8] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/setup.py b/setup.py index 22b6ec3..d32cd4a 100644 --- a/setup.py +++ b/setup.py @@ -14,9 +14,9 @@ __license__ = 'MIT License' # Default project values. project_name = 'spleeter' -project_version = '1.5.4' +project_version = '2.0' tensorflow_dependency = 'tensorflow' -tensorflow_version = '1.15.2' +tensorflow_version = '2.3.0' here = path.abspath(path.dirname(__file__)) readme_path = path.join(here, 'README.md') with open(readme_path, 'r') as stream: @@ -47,17 +47,16 @@ setup( 'spleeter.utils', ], package_data={'spleeter.resources': ['*.json']}, - python_requires='>=3.6, <3.8', + python_requires='>=3.6, <3.9', include_package_data=True, install_requires=[ 'ffmpeg-python', 'importlib_resources ; python_version<"3.7"', 'norbert==0.2.1', - 'pandas==0.25.1', + 'pandas==1.1.2', 'requests', 'setuptools>=41.0.0', - 'librosa==0.7.2', - 'numba==0.48.0', + 'librosa==0.8.0', '{}=={}'.format(tensorflow_dependency, tensorflow_version), ], extras_require={ diff --git a/spleeter/audio/adapter.py b/spleeter/audio/adapter.py index 2bfcc9b..994c8df 100644 --- a/spleeter/audio/adapter.py +++ b/spleeter/audio/adapter.py @@ -13,7 +13,7 @@ from os.path import exists import numpy as np import tensorflow as tf -from tensorflow.contrib.signal import stft, hann_window +from tensorflow.signal import stft, hann_window # pylint: enable=import-error from .. import SpleeterError diff --git a/spleeter/audio/spectrogram.py b/spleeter/audio/spectrogram.py index 80ff65a..a1a79b3 100644 --- a/spleeter/audio/spectrogram.py +++ b/spleeter/audio/spectrogram.py @@ -7,7 +7,7 @@ import numpy as np import tensorflow as tf -from tensorflow.contrib.signal import stft, hann_window +from tensorflow.signal import stft, hann_window # pylint: enable=import-error __email__ = 'spleeter@deezer.com' diff --git a/spleeter/dataset.py b/spleeter/dataset.py index d6d5d03..5b11969 100644 --- a/spleeter/dataset.py +++ b/spleeter/dataset.py @@ -238,7 +238,7 @@ class DatasetBuilder(object): def expand_path(self, sample): """ Expands audio paths for the given sample. """ - return dict(sample, **{f'{instrument}_path': tf.string_join( + return dict(sample, **{f'{instrument}_path': tf.strings.join( (self._audio_path, sample[f'{instrument}_path']), SEPARATOR) for instrument in self._instruments}) diff --git a/spleeter/model/__init__.py b/spleeter/model/__init__.py index b79e424..8b8f511 100644 --- a/spleeter/model/__init__.py +++ b/spleeter/model/__init__.py @@ -8,7 +8,7 @@ import importlib # pylint: disable=import-error import tensorflow as tf -from tensorflow.contrib.signal import stft, inverse_stft, hann_window +from tensorflow.signal import stft, inverse_stft, hann_window # pylint: enable=import-error from ..utils.tensor import pad_and_partition, pad_and_reshape diff --git a/spleeter/separator.py b/spleeter/separator.py index 2c636b4..131bdac 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -12,14 +12,18 @@ >>> separator.separate_to_file(...) """ +import atexit import os import logging -from time import time from multiprocessing import Pool from os.path import basename, join, splitext, dirname +from time import time +from typing import Container, NoReturn + import numpy as np import tensorflow as tf + from librosa.core import stft, istft from scipy.signal.windows import hann @@ -27,64 +31,114 @@ from . import SpleeterError from .audio.adapter import get_default_audio_adapter from .audio.convertor import to_stereo from .utils.configuration import load_configuration -from .utils.estimator import create_estimator, to_predictor, get_default_model_dir +from .utils.estimator import create_estimator, get_default_model_dir from .model import EstimatorSpecBuilder, InputProviderFactory - __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' - -logger = logging.getLogger("spleeter") +SUPPORTED_BACKEND: Container[str] = ('auto', 'tensorflow', 'librosa') +""" """ +class DataGenerator(): + """ + Generator object that store a sample and generate it once while called. + Used to feed a tensorflow estimator without knowing the whole data at + build time. + """ -def get_backend(backend): - assert backend in ["auto", "tensorflow", "librosa"] - if backend == "auto": - return "tensorflow" if tf.test.is_gpu_available() else "librosa" + def __init__(self): + """ Default constructor. """ + self._current_data = None + + def update_data(self, data): + """ Replace internal data. """ + self._current_data = data + + def __call__(self): + """ Generation process. """ + buffer = self._current_data + while buffer: + yield buffer + buffer = self._current_data + + +def get_backend(backend: str) -> str: + """ + """ + if backend not in SUPPORTED_BACKEND: + raise ValueError(f'Unsupported backend {backend}') + if backend == 'auto': + if len(tf.config.list_physical_devices('GPU')): + return 'tensorflow' + return 'librosa' return backend class Separator(object): """ A wrapper class for performing separation. """ - def __init__(self, params_descriptor, MWF=False, stft_backend="auto", multiprocess=True): + def __init__( + self, + params_descriptor, + MWF: bool = False, + stft_backend: str = 'auto', + multiprocess: bool = True): """ Default constructor. :param params_descriptor: Descriptor for TF params to be used. :param MWF: (Optional) True if MWF should be used, False otherwise. """ - self._params = load_configuration(params_descriptor) self._sample_rate = self._params['sample_rate'] self._MWF = MWF self._tf_graph = tf.Graph() - self._predictor = None + self._prediction_generator = None self._input_provider = None self._builder = None self._features = None self._session = None - self._pool = Pool() if multiprocess else None + if multiprocess: + self._pool = Pool() + atexit.register(self._pool.close) + else: + self._pool = None self._tasks = [] - self._params["stft_backend"] = get_backend(stft_backend) + self._params['stft_backend'] = get_backend(stft_backend) + self._data_generator = DataGenerator() def __del__(self): + """ """ if self._session: self._session.close() - def _get_predictor(self): - """ Lazy loading access method for internal predictor instance. + def _get_prediction_generator(self): + """ Lazy loading access method for internal prediction generator + returned by the predict method of a tensorflow estimator. - :returns: Predictor to use for source separation. + :returns: generator of prediction. """ - if self._predictor is None: + if self._prediction_generator is None: estimator = create_estimator(self._params, self._MWF) - self._predictor = to_predictor(estimator) - return self._predictor - def join(self, timeout=200): + def get_dataset(): + return tf.data.Dataset.from_generator( + self._data_generator, + output_types={ + 'waveform': tf.float32, + 'audio_id': tf.string}, + output_shapes={ + 'waveform': (None, 2), + 'audio_id': ()}) + + self._prediction_generator = estimator.predict( + get_dataset, + yield_single_examples=False) + return self._prediction_generator + + def join(self, timeout: int = 200) -> NoReturn: """ Wait for all pending tasks to be finished. :param timeout: (Optional) task waiting timeout. @@ -94,44 +148,52 @@ class Separator(object): task.get() task.wait(timeout=timeout) - def _separate_tensorflow(self, waveform, audio_descriptor): - """ - Performs source separation over the given waveform with tensorflow backend. + def _separate_tensorflow(self, waveform: np.ndarray, audio_descriptor): + """ Performs source separation over the given waveform with tensorflow + backend. :param waveform: Waveform to apply separation on. :returns: Separated waveforms. """ if not waveform.shape[-1] == 2: waveform = to_stereo(waveform) - predictor = self._get_predictor() - prediction = predictor({ + prediction_generator = self._get_prediction_generator() + # NOTE: update data in generator before performing separation. + self._data_generator.update_data({ 'waveform': waveform, - 'audio_id': audio_descriptor}) + 'audio_id': np.array(audio_descriptor)}) + # NOTE: perform separation. + prediction = next(prediction_generator) prediction.pop('audio_id') return prediction - def _stft(self, data, inverse=False, length=None): - """ - Single entrypoint for both stft and istft. This computes stft and istft with librosa on stereo data. The two - channels are processed separately and are concatenated together in the result. The expected input formats are: - (n_samples, 2) for stft and (T, F, 2) for istft. - :param data: np.array with either the waveform or the complex spectrogram depending on the parameter inverse + def _stft(self, data, inverse: bool = False, length=None): + """ Single entrypoint for both stft and istft. This computes stft and + istft with librosa on stereo data. The two channels are processed + separately and are concatenated together in the result. The expected + input formats are: (n_samples, 2) for stft and (T, F, 2) for istft. + + :param data: np.array with either the waveform or the complex + spectrogram depending on the parameter inverse :param inverse: should a stft or an istft be computed. - :return: Stereo data as numpy array for the transform. The channels are stored in the last dimension + :returns: Stereo data as numpy array for the transform. + The channels are stored in the last dimension. """ assert not (inverse and length is None) data = np.asfortranarray(data) - N = self._params["frame_length"] - H = self._params["frame_step"] - + N = self._params['frame_length'] + H = self._params['frame_step'] win = hann(N, sym=False) fstft = istft if inverse else stft - win_len_arg = {"win_length": None, - "length": None} if inverse else {"n_fft": N} + win_len_arg = { + 'win_length': None, + 'length': None} if inverse else {'n_fft': N} n_channels = data.shape[-1] out = [] for c in range(n_channels): - d = np.concatenate((np.zeros((N, )), data[:, c], np.zeros((N, )))) if not inverse else data[:, :, c].T + d = np.concatenate( + (np.zeros((N, )), data[:, c], np.zeros((N, ))) + ) if not inverse else data[:, :, c].T s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg) if inverse: s = s[N:N+length] @@ -141,7 +203,6 @@ class Separator(object): return out[0] return np.concatenate(out, axis=2-inverse) - def _get_input_provider(self): if self._input_provider is None: self._input_provider = InputProviderFactory.get(self._params) @@ -149,66 +210,83 @@ class Separator(object): def _get_features(self): if self._features is None: - self._features = self._get_input_provider().get_input_dict_placeholders() + provider = self._get_input_provider() + self._features = provider.get_input_dict_placeholders() return self._features def _get_builder(self): if self._builder is None: - self._builder = EstimatorSpecBuilder(self._get_features(), self._params) + self._builder = EstimatorSpecBuilder( + self._get_features(), + self._params) return self._builder def _get_session(self): if self._session is None: - saver = tf.train.Saver() - latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir'])) - self._session = tf.Session() + saver = tf.compat.v1.train.Saver() + latest_checkpoint = tf.train.latest_checkpoint( + get_default_model_dir(self._params['model_dir'])) + self._session = tf.compat.v1.Session() saver.restore(self._session, latest_checkpoint) return self._session - def _separate_librosa(self, waveform, audio_id): - """ - Performs separation with librosa backend for STFT. + def _separate_librosa(self, waveform: np.ndarray, audio_id): + """ Performs separation with librosa backend for STFT. """ with self._tf_graph.as_default(): out = {} features = self._get_features() - - # TODO: fix the logic, build sometimes return, sometimes set attribute + # TODO: fix the logic, build sometimes return, + # sometimes set attribute. outputs = self._get_builder().outputs stft = self._stft(waveform) if stft.shape[-1] == 1: stft = np.concatenate([stft, stft], axis=-1) elif stft.shape[-1] > 2: stft = stft[:, :2] - sess = self._get_session() - outputs = sess.run(outputs, feed_dict=self._get_input_provider().get_feed_dict(features, stft, audio_id)) + outputs = sess.run( + outputs, + feed_dict=self._get_input_provider().get_feed_dict( + features, + stft, + audio_id)) for inst in self._get_builder().instruments: - out[inst] = self._stft(outputs[inst], inverse=True, length=waveform.shape[0]) + out[inst] = self._stft( + outputs[inst], + inverse=True, + length=waveform.shape[0]) return out - def separate(self, waveform, audio_descriptor=""): + def separate(self, waveform: np.ndarray, audio_descriptor=''): """ Performs separation on a waveform. :param waveform: Waveform to be separated (as a numpy array) - :param audio_descriptor: (Optional) string describing the waveform (e.g. filename). + :param audio_descriptor: (Optional) string describing the waveform + (e.g. filename). """ - if self._params["stft_backend"] == "tensorflow": + if self._params['stft_backend'] == 'tensorflow': return self._separate_tensorflow(waveform, audio_descriptor) else: return self._separate_librosa(waveform, audio_descriptor) def separate_to_file( - self, audio_descriptor, destination, + self, + audio_descriptor, + destination, audio_adapter=get_default_audio_adapter(), - offset=0, duration=600., codec='wav', bitrate='128k', + offset=0, + duration=600., + codec='wav', + bitrate='128k', filename_format='{filename}/{instrument}.{codec}', synchronous=True): """ Performs source separation and export result to file using given audio adapter. Filename format should be a Python formattable string that could use - following parameters : {instrument}, {filename}, {foldername} and {codec}. + following parameters : {instrument}, {filename}, {foldername} and + {codec}. :param audio_descriptor: Describe song to separate, used by audio adapter to retrieve and load audio data, @@ -217,8 +295,8 @@ class Separator(object): :param destination: Target directory to write output to. :param audio_adapter: (Optional) Audio adapter to use for I/O. :param offset: (Optional) Offset of loaded song. - :param duration: (Optional) Duration of loaded song (default: - 600s). + :param duration: (Optional) Duration of loaded song + (default: 600s). :param codec: (Optional) Export codec. :param bitrate: (Optional) Export bitrate. :param filename_format: (Optional) Filename format. @@ -230,16 +308,27 @@ class Separator(object): duration=duration, sample_rate=self._sample_rate) sources = self.separate(waveform, audio_descriptor) - self.save_to_file( sources, audio_descriptor, destination, - filename_format, codec, audio_adapter, - bitrate, synchronous) + self.save_to_file( + sources, + audio_descriptor, + destination, + filename_format, + codec, + audio_adapter, + bitrate, + synchronous) def save_to_file( - self, sources, audio_descriptor, destination, + self, + sources, + audio_descriptor, + destination, filename_format='{filename}/{instrument}.{codec}', - codec='wav', audio_adapter=get_default_audio_adapter(), - bitrate='128k', synchronous=True): - """ export dictionary of sources to files. + codec='wav', + audio_adapter=get_default_audio_adapter(), + bitrate='128k', + synchronous=True): + """ Export dictionary of sources to files. :param sources: Dictionary of sources to be exported. The keys are the name of the instruments, and @@ -258,7 +347,6 @@ class Separator(object): :param synchronous: (Optional) True is should by synchronous. """ - foldername = basename(dirname(audio_descriptor)) filename = splitext(basename(audio_descriptor))[0] generated = [] @@ -286,6 +374,11 @@ class Separator(object): bitrate)) self._tasks.append(task) else: - audio_adapter.save(path, data, self._sample_rate, codec, bitrate) + audio_adapter.save( + path, + data, + self._sample_rate, + codec, + bitrate) if synchronous and self._pool: self.join() diff --git a/spleeter/utils/estimator.py b/spleeter/utils/estimator.py index a9aa736..f49fb56 100644 --- a/spleeter/utils/estimator.py +++ b/spleeter/utils/estimator.py @@ -5,20 +5,14 @@ from pathlib import Path from os.path import join -from tempfile import gettempdir # pylint: disable=import-error import tensorflow as tf -from tensorflow.contrib import predictor -# pylint: enable=import-error -from ..model import model_fn, InputProviderFactory +from ..model import model_fn from ..model.provider import get_default_model_provider -# Default exporting directory for predictor. -DEFAULT_EXPORT_DIRECTORY = join(gettempdir(), 'serving') - def get_default_model_dir(model_dir): @@ -57,24 +51,3 @@ def create_estimator(params, MWF): config=config ) return estimator - - -def to_predictor(estimator, directory=DEFAULT_EXPORT_DIRECTORY): - """ Exports given estimator as predictor into the given directory - and returns associated tf.predictor instance. - - :param estimator: Estimator to export. - :param directory: (Optional) path to write exported model into. - """ - - input_provider = InputProviderFactory.get(estimator.params) - def receiver(): - features = input_provider.get_input_dict_placeholders() - return tf.estimator.export.ServingInputReceiver(features, features) - - estimator.export_saved_model(directory, receiver) - versions = [ - model for model in Path(directory).iterdir() - if model.is_dir() and 'temp' not in str(model)] - latest = str(sorted(versions)[-1]) - return predictor.from_saved_model(latest) diff --git a/tests/test_eval.py b/tests/test_eval.py index 97540a9..f3764b6 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -56,6 +56,9 @@ res_4stems = { } def generate_fake_eval_dataset(path): + """ + generate fake evaluation dataset + """ aa = get_default_audio_adapter() n_songs = 2 fs = 44100 @@ -71,6 +74,7 @@ def generate_fake_eval_dataset(path): aa.save(filename, data, fs) + @pytest.mark.parametrize('backend', TEST_CONFIGURATIONS) def test_evaluate(backend): with TemporaryDirectory() as directory: @@ -81,4 +85,4 @@ def test_evaluate(backend): metrics = evaluate.entrypoint(arguments, params) for instrument, metric in metrics.items(): for m, value in metric.items(): - assert np.allclose(np.median(value), res_4stems[instrument][m], atol=1e-3) + assert np.allclose(np.median(value), res_4stems[instrument][m], atol=1e-3) \ No newline at end of file diff --git a/tests/test_separator.py b/tests/test_separator.py index 3094900..e757abf 100644 --- a/tests/test_separator.py +++ b/tests/test_separator.py @@ -7,7 +7,6 @@ __email__ = 'spleeter@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' -import filecmp import itertools from os.path import splitext, basename, exists, join from tempfile import TemporaryDirectory @@ -33,7 +32,8 @@ MODEL_TO_INST = { MODELS_AND_TEST_FILES = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS)) -TEST_CONFIGURATIONS = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS, BACKENDS)) +TEST_CONFIGURATIONS = list(itertools.product( + TEST_AUDIO_DESCRIPTORS, MODELS, BACKENDS)) print("RUNNING TESTS WITH TF VERSION {}".format(tf.__version__)) @@ -44,8 +44,10 @@ def test_separator_backends(test_file): adapter = get_default_audio_adapter() waveform, _ = adapter.load(test_file) - separator_lib = Separator("spleeter:2stems", stft_backend="librosa") - separator_tf = Separator("spleeter:2stems", stft_backend="tensorflow") + separator_lib = Separator( + "spleeter:2stems", stft_backend="librosa", multiprocess=False) + separator_tf = Separator( + "spleeter:2stems", stft_backend="tensorflow", multiprocess=False) # Test the stft and inverse stft provides exact reconstruction stft_matrix = separator_lib._stft(waveform) @@ -68,7 +70,8 @@ def test_separate(test_file, configuration, backend): instruments = MODEL_TO_INST[configuration] adapter = get_default_audio_adapter() waveform, _ = adapter.load(test_file) - separator = Separator(configuration, stft_backend=backend, multiprocess=False) + separator = Separator( + configuration, stft_backend=backend, multiprocess=False) prediction = separator.separate(waveform, test_file) assert len(prediction) == len(instruments) for instrument in instruments: @@ -80,14 +83,14 @@ def test_separate(test_file, configuration, backend): for compared in instruments: if instrument != compared: assert not np.allclose(track, prediction[compared]) - @pytest.mark.parametrize('test_file, configuration, backend', TEST_CONFIGURATIONS) def test_separate_to_file(test_file, configuration, backend): """ Test file based separation. """ instruments = MODEL_TO_INST[configuration] - separator = Separator(configuration, stft_backend=backend, multiprocess=False) + separator = Separator( + configuration, stft_backend=backend, multiprocess=False) name = splitext(basename(test_file))[0] with TemporaryDirectory() as directory: separator.separate_to_file( @@ -103,7 +106,8 @@ def test_separate_to_file(test_file, configuration, backend): def test_filename_format(test_file, configuration, backend): """ Test custom filename format. """ instruments = MODEL_TO_INST[configuration] - separator = Separator(configuration, stft_backend=backend, multiprocess=False) + separator = Separator( + configuration, stft_backend=backend, multiprocess=False) name = splitext(basename(test_file))[0] with TemporaryDirectory() as directory: separator.separate_to_file( diff --git a/tests/test_train.py b/tests/test_train.py new file mode 100644 index 0000000..8d9533a --- /dev/null +++ b/tests/test_train.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python +# coding: utf8 + +""" Unit testing for Separator class. """ + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + +import filecmp +import itertools +import os +from os import makedirs +from os.path import splitext, basename, exists, join +from tempfile import TemporaryDirectory + +import numpy as np +import pandas as pd +import json + +import tensorflow as tf + +from spleeter.audio.adapter import get_default_audio_adapter +from spleeter.commands import create_argument_parser + +from spleeter.commands import train + +from spleeter.utils.configuration import load_configuration + +TRAIN_CONFIG = { + "mix_name": "mix", + "instrument_list": ["vocals", "other"], + "sample_rate":44100, + "frame_length":4096, + "frame_step":1024, + "T":128, + "F":128, + "n_channels":2, + "chunk_duration":4, + "n_chunks_per_song":1, + "separation_exponent":2, + "mask_extension":"zeros", + "learning_rate": 1e-4, + "batch_size":2, + "train_max_steps": 10, + "throttle_secs":20, + "save_checkpoints_steps":100, + "save_summary_steps":5, + "random_seed":0, + "model":{ + "type":"unet.unet", + "params":{ + "conv_activation":"ELU", + "deconv_activation":"ELU" + } + } +} + + +def generate_fake_training_dataset(path, instrument_list=["vocals", "other"]): + """ + generates a fake training dataset in path: + - generates audio files + - generates a csv file describing the dataset + """ + aa = get_default_audio_adapter() + n_songs = 2 + fs = 44100 + duration = 6 + n_channels = 2 + rng = np.random.RandomState(seed=0) + dataset_df = pd.DataFrame(columns=["mix_path"]+[f"{instr}_path" for instr in instrument_list]+["duration"]) + for song in range(n_songs): + song_path = join(path, "train", f"song{song}") + makedirs(song_path, exist_ok=True) + dataset_df.loc[song, f"duration"] = duration + for instr in instrument_list+["mix"]: + filename = join(song_path, f"{instr}.wav") + data = rng.rand(duration*fs, n_channels)-0.5 + aa.save(filename, data, fs) + dataset_df.loc[song, f"{instr}_path"] = join("train", f"song{song}", f"{instr}.wav") + + dataset_df.to_csv(join(path, "train", "train.csv"), index=False) + + + +def test_train(): + + + with TemporaryDirectory() as path: + + # generate training dataset + generate_fake_training_dataset(path) + + # set training command aruments + p = create_argument_parser() + arguments = p.parse_args(["train", "-p", "useless_config.json", "-d", path]) + TRAIN_CONFIG["train_csv"] = join(path, "train", "train.csv") + TRAIN_CONFIG["validation_csv"] = join(path, "train", "train.csv") + TRAIN_CONFIG["model_dir"] = join(path, "model") + TRAIN_CONFIG["training_cache"] = join(path, "cache", "training") + TRAIN_CONFIG["validation_cache"] = join(path, "cache", "validation") + + # execute training + res = train.entrypoint(arguments, TRAIN_CONFIG) + + # assert that model checkpoint was created. + assert os.path.exists(join(path,'model','model.ckpt-10.index')) + assert os.path.exists(join(path,'model','checkpoint')) + assert os.path.exists(join(path,'model','model.ckpt-0.meta')) + +if __name__=="__main__": + test_train() \ No newline at end of file