diff --git a/spleeter/separator.py b/spleeter/separator.py index 554e0e0..8baa58c 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -61,6 +61,9 @@ class Separator(object): self._sample_rate = self._params['sample_rate'] self._MWF = MWF self._predictor = None + self._input_provider = None + self._builder = None + self._features = None self._pool = Pool() if multiprocess else None self._tasks = [] self._params["stft_backend"] = get_backend(stft_backend) @@ -128,19 +131,33 @@ class Separator(object): return out[0] return np.concatenate(out, axis=2-inverse) + + def _get_input_provider(self): + if self._input_provider is None: + self._input_provider = InputProviderFactory.get(self._params) + return self._input_provider + + def _get_features(self): + if self._features is None: + self._features = self._get_input_provider().get_input_dict_placeholders() + return self._features + + def _get_builder(self): + if self._builder is None: + self._builder = EstimatorSpecBuilder(self._get_features(), self._params) + return self._builder + def _separate_librosa(self, waveform, audio_id): """ Performs separation with librosa backend for STFT. """ out = {} - input_provider = InputProviderFactory.get(self._params) - features = input_provider.get_input_dict_placeholders() + features = self._get_features() - builder = EstimatorSpecBuilder(features, self._params) latest_checkpoint = tf.train.latest_checkpoint(get_default_model_dir(self._params['model_dir'])) # TODO: fix the logic, build sometimes return, sometimes set attribute - outputs = builder.outputs + outputs = self._get_builder().outputs stft = self._stft(waveform) if stft.shape[-1] == 1: stft = np.concatenate([stft, stft], axis=-1) @@ -150,8 +167,8 @@ class Separator(object): saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, latest_checkpoint) - outputs = sess.run(outputs, feed_dict=input_provider.get_feed_dict(features, stft, audio_id)) - for inst in builder.instruments: + outputs = sess.run(outputs, feed_dict=self._get_input_provider().get_feed_dict(features, stft, audio_id)) + for inst in self._get_builder().instruments: out[inst] = self._stft(outputs[inst], inverse=True, length=waveform.shape[0]) return out