feat: add checksum control

This commit is contained in:
Félix Voituret
2019-11-19 14:56:12 +01:00
parent b2ebbe5db8
commit 52cefd1dae
2 changed files with 37 additions and 8 deletions

View File

@@ -47,14 +47,12 @@ class GithubModelProvider(ModelProvider):
self._repository = repository
self._release = release
def checksum(self, name, path):
""" Computes given path file sha256 and compares it to reference index
from release. Raise an exception if not matching.
def checksum(self, name):
""" Downloads and returns reference checksum for the given model name.
:param name: Name of the model to compute checksum for.
:param path: Path of the file to compare checksum with.
:param name: Name of the model to get checksum for.
:returns: Checksum of the required model.
:raise ValueError: If the given model name is not indexed.
:raise IOerror: if checksum is not valid or index cannot be downloaded.
"""
url = '{}/{}/{}/{}/{}'.format(
self._host,
@@ -67,11 +65,21 @@ class GithubModelProvider(ModelProvider):
index = response.json()
if name not in index:
raise ValueError('No checksum for model {}'.format(name))
return index[name]
def check_integrity(self, name, path):
""" Computes given path file sha256 and compares it to reference index
from release. Raise an exception if not matching.
:param name: Name of the model to compute checksum for.
:param path: Path of the file to compare checksum with.
:raise IOerror: if checksum is not valid or index cannot be downloaded.
"""
sha256 = hashlib.sha256()
with open(path, 'rb') as stream:
for chunk in iter(lambda: stream.read(4096), b''):
sha256.update(chunk)
if sha256.hexdigest() != index[name]:
if sha256.hexdigest() != self.checksum(name):
raise IOError('Downloaded file is corrupted, please retry')
def download(self, name, path):
@@ -96,7 +104,7 @@ class GithubModelProvider(ModelProvider):
if chunk:
stream.write(chunk)
get_logger().info('Validating archive checksum')
self.checksum(name, archive.name)
self.check_integrity(name, archive.name)
get_logger().info('Extracting downloaded %s archive', name)
tar = tarfile.open(name=archive.name)
tar.extractall(path=path)