diff --git a/spleeter/commands/separate.py b/spleeter/commands/separate.py index df88229..2fbbfee 100644 --- a/spleeter/commands/separate.py +++ b/spleeter/commands/separate.py @@ -11,20 +11,8 @@ -i /path/to/audio1.wav /path/to/audio2.mp3 """ -from multiprocessing import Pool -from os.path import isabs, join, split, splitext -from tempfile import gettempdir - -# pylint: disable=import-error -import tensorflow as tf -import numpy as np -# pylint: enable=import-error - from ..audio.adapter import get_audio_adapter -from ..audio.convertor import to_n_channels from ..separator import Separator -from ..utils.estimator import create_estimator -from ..utils.tensor import set_tensor_shape __email__ = 'research@deezer.com' __author__ = 'Deezer Research' diff --git a/spleeter/model/provider/github.py b/spleeter/model/provider/github.py index 3afecd1..2f6d7a3 100644 --- a/spleeter/model/provider/github.py +++ b/spleeter/model/provider/github.py @@ -14,10 +14,10 @@ >>> provider.download('2stems', '/path/to/local/storage') """ +import hashlib import tarfile -from os import environ -from tempfile import TemporaryFile +from tempfile import NamedTemporaryFile import requests @@ -34,6 +34,7 @@ class GithubModelProvider(ModelProvider): LATEST_RELEASE = 'v1.4.0' RELEASE_PATH = 'releases/download' + CHECKSUM_INDEX = 'checksum.json' def __init__(self, host, repository, release): """ Default constructor. @@ -46,6 +47,33 @@ class GithubModelProvider(ModelProvider): self._repository = repository self._release = release + def checksum(self, name, path): + """ Computes given path file sha256 and compares it to reference index + from release. Raise an exception if not matching. + + :param name: Name of the model to compute checksum for. + :param path: Path of the file to compare checksum with. + :raise ValueError: If the given model name is not indexed. + :raise IOerror: if checksum is not valid or index cannot be downloaded. + """ + url = '{}/{}/{}/{}/{}'.format( + self._host, + self._repository, + self.RELEASE_PATH, + self._release, + self.CHECKSUM_INDEX) + response = requests.get(url) + response.raise_for_status() + index = response.json() + if name not in index: + raise ValueError('No checksum for model {}'.format(name)) + sha256 = hashlib.sha256() + with open(path, 'rb') as stream: + for chunk in iter(lambda: stream.read(4096), b''): + sha256.update(chunk) + if sha256.hexdigest() != index[name]: + raise IOError('Downloaded file is corrupted, please retry') + def download(self, name, path): """ Download model denoted by the given name to disk. @@ -60,17 +88,17 @@ class GithubModelProvider(ModelProvider): name) get_logger().info('Downloading model archive %s', url) with requests.get(url, stream=True) as response: - # Note: check for error logging here or upstream ? response.raise_for_status() - with TemporaryFile() as stream: + archive = NamedTemporaryFile(delete=False) + with archive as stream: # Note: check for chunk size parameters ? for chunk in response.iter_content(chunk_size=8192): if chunk: stream.write(chunk) - get_logger().info('Extracting downloaded %s archive', name) - stream.seek(0) - tar = tarfile.open(fileobj=stream) - tar.extractall(path=path) - tar.close() - # TODO: perform checksum control + get_logger().info('Validating archive checksum') + self.checksum(name, archive.name) + get_logger().info('Extracting downloaded %s archive', name) + tar = tarfile.open(name=archive.name) + tar.extractall(path=path) + tar.close() get_logger().info('%s model file(s) extracted', name)