mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-06 04:32:43 +00:00
feat: add checksum control
This commit is contained in:
@@ -47,14 +47,12 @@ class GithubModelProvider(ModelProvider):
|
||||
self._repository = repository
|
||||
self._release = release
|
||||
|
||||
def checksum(self, name, path):
|
||||
""" Computes given path file sha256 and compares it to reference index
|
||||
from release. Raise an exception if not matching.
|
||||
def checksum(self, name):
|
||||
""" Downloads and returns reference checksum for the given model name.
|
||||
|
||||
:param name: Name of the model to compute checksum for.
|
||||
:param path: Path of the file to compare checksum with.
|
||||
:param name: Name of the model to get checksum for.
|
||||
:returns: Checksum of the required model.
|
||||
:raise ValueError: If the given model name is not indexed.
|
||||
:raise IOerror: if checksum is not valid or index cannot be downloaded.
|
||||
"""
|
||||
url = '{}/{}/{}/{}/{}'.format(
|
||||
self._host,
|
||||
@@ -67,11 +65,21 @@ class GithubModelProvider(ModelProvider):
|
||||
index = response.json()
|
||||
if name not in index:
|
||||
raise ValueError('No checksum for model {}'.format(name))
|
||||
return index[name]
|
||||
|
||||
def check_integrity(self, name, path):
|
||||
""" Computes given path file sha256 and compares it to reference index
|
||||
from release. Raise an exception if not matching.
|
||||
|
||||
:param name: Name of the model to compute checksum for.
|
||||
:param path: Path of the file to compare checksum with.
|
||||
:raise IOerror: if checksum is not valid or index cannot be downloaded.
|
||||
"""
|
||||
sha256 = hashlib.sha256()
|
||||
with open(path, 'rb') as stream:
|
||||
for chunk in iter(lambda: stream.read(4096), b''):
|
||||
sha256.update(chunk)
|
||||
if sha256.hexdigest() != index[name]:
|
||||
if sha256.hexdigest() != self.checksum(name):
|
||||
raise IOError('Downloaded file is corrupted, please retry')
|
||||
|
||||
def download(self, name, path):
|
||||
@@ -96,7 +104,7 @@ class GithubModelProvider(ModelProvider):
|
||||
if chunk:
|
||||
stream.write(chunk)
|
||||
get_logger().info('Validating archive checksum')
|
||||
self.checksum(name, archive.name)
|
||||
self.check_integrity(name, archive.name)
|
||||
get_logger().info('Extracting downloaded %s archive', name)
|
||||
tar = tarfile.open(name=archive.name)
|
||||
tar.extractall(path=path)
|
||||
|
||||
21
tests/test_github_model_provider.py
Normal file
21
tests/test_github_model_provider.py
Normal file
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf8
|
||||
|
||||
""" TO DOCUMENT """
|
||||
|
||||
from pytest import raises
|
||||
|
||||
from spleeter.model.provider import get_default_model_provider
|
||||
|
||||
|
||||
def test_checksum():
|
||||
""" Test archive checksum index retrieval. """
|
||||
provider = get_default_model_provider()
|
||||
assert provider.checksum('2stems') == \
|
||||
'f3a90b39dd2874269e8b05a48a86745df897b848c61f3958efc80a39152bd692'
|
||||
assert provider.checksum('4stems') == \
|
||||
'3adb4a50ad4eb18c7c4d65fcf4cf2367a07d48408a5eb7d03cd20067429dfaa8'
|
||||
assert provider.checksum('5stems') == \
|
||||
'25a1e87eb5f75cc72a4d2d5467a0a50ac75f05611f877c278793742513cc7218'
|
||||
with raises(ValueError):
|
||||
provider.checksum('laisse moi stems stems stems')
|
||||
Reference in New Issue
Block a user