feat: add checksum control

This commit is contained in:
Félix Voituret
2019-11-19 14:56:12 +01:00
parent b2ebbe5db8
commit 52cefd1dae
2 changed files with 37 additions and 8 deletions

View File

@@ -47,14 +47,12 @@ class GithubModelProvider(ModelProvider):
self._repository = repository
self._release = release
def checksum(self, name, path):
""" Computes given path file sha256 and compares it to reference index
from release. Raise an exception if not matching.
def checksum(self, name):
""" Downloads and returns reference checksum for the given model name.
:param name: Name of the model to compute checksum for.
:param path: Path of the file to compare checksum with.
:param name: Name of the model to get checksum for.
:returns: Checksum of the required model.
:raise ValueError: If the given model name is not indexed.
:raise IOerror: if checksum is not valid or index cannot be downloaded.
"""
url = '{}/{}/{}/{}/{}'.format(
self._host,
@@ -67,11 +65,21 @@ class GithubModelProvider(ModelProvider):
index = response.json()
if name not in index:
raise ValueError('No checksum for model {}'.format(name))
return index[name]
def check_integrity(self, name, path):
""" Computes given path file sha256 and compares it to reference index
from release. Raise an exception if not matching.
:param name: Name of the model to compute checksum for.
:param path: Path of the file to compare checksum with.
:raise IOerror: if checksum is not valid or index cannot be downloaded.
"""
sha256 = hashlib.sha256()
with open(path, 'rb') as stream:
for chunk in iter(lambda: stream.read(4096), b''):
sha256.update(chunk)
if sha256.hexdigest() != index[name]:
if sha256.hexdigest() != self.checksum(name):
raise IOError('Downloaded file is corrupted, please retry')
def download(self, name, path):
@@ -96,7 +104,7 @@ class GithubModelProvider(ModelProvider):
if chunk:
stream.write(chunk)
get_logger().info('Validating archive checksum')
self.checksum(name, archive.name)
self.check_integrity(name, archive.name)
get_logger().info('Extracting downloaded %s archive', name)
tar = tarfile.open(name=archive.name)
tar.extractall(path=path)

View File

@@ -0,0 +1,21 @@
#!/usr/bin/env python
# coding: utf8
""" TO DOCUMENT """
from pytest import raises
from spleeter.model.provider import get_default_model_provider
def test_checksum():
""" Test archive checksum index retrieval. """
provider = get_default_model_provider()
assert provider.checksum('2stems') == \
'f3a90b39dd2874269e8b05a48a86745df897b848c61f3958efc80a39152bd692'
assert provider.checksum('4stems') == \
'3adb4a50ad4eb18c7c4d65fcf4cf2367a07d48408a5eb7d03cd20067429dfaa8'
assert provider.checksum('5stems') == \
'25a1e87eb5f75cc72a4d2d5467a0a50ac75f05611f877c278793742513cc7218'
with raises(ValueError):
provider.checksum('laisse moi stems stems stems')