mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
feat: add checksum control
This commit is contained in:
@@ -47,14 +47,12 @@ class GithubModelProvider(ModelProvider):
|
||||
self._repository = repository
|
||||
self._release = release
|
||||
|
||||
def checksum(self, name, path):
|
||||
""" Computes given path file sha256 and compares it to reference index
|
||||
from release. Raise an exception if not matching.
|
||||
def checksum(self, name):
|
||||
""" Downloads and returns reference checksum for the given model name.
|
||||
|
||||
:param name: Name of the model to compute checksum for.
|
||||
:param path: Path of the file to compare checksum with.
|
||||
:param name: Name of the model to get checksum for.
|
||||
:returns: Checksum of the required model.
|
||||
:raise ValueError: If the given model name is not indexed.
|
||||
:raise IOerror: if checksum is not valid or index cannot be downloaded.
|
||||
"""
|
||||
url = '{}/{}/{}/{}/{}'.format(
|
||||
self._host,
|
||||
@@ -67,11 +65,21 @@ class GithubModelProvider(ModelProvider):
|
||||
index = response.json()
|
||||
if name not in index:
|
||||
raise ValueError('No checksum for model {}'.format(name))
|
||||
return index[name]
|
||||
|
||||
def check_integrity(self, name, path):
|
||||
""" Computes given path file sha256 and compares it to reference index
|
||||
from release. Raise an exception if not matching.
|
||||
|
||||
:param name: Name of the model to compute checksum for.
|
||||
:param path: Path of the file to compare checksum with.
|
||||
:raise IOerror: if checksum is not valid or index cannot be downloaded.
|
||||
"""
|
||||
sha256 = hashlib.sha256()
|
||||
with open(path, 'rb') as stream:
|
||||
for chunk in iter(lambda: stream.read(4096), b''):
|
||||
sha256.update(chunk)
|
||||
if sha256.hexdigest() != index[name]:
|
||||
if sha256.hexdigest() != self.checksum(name):
|
||||
raise IOError('Downloaded file is corrupted, please retry')
|
||||
|
||||
def download(self, name, path):
|
||||
@@ -96,7 +104,7 @@ class GithubModelProvider(ModelProvider):
|
||||
if chunk:
|
||||
stream.write(chunk)
|
||||
get_logger().info('Validating archive checksum')
|
||||
self.checksum(name, archive.name)
|
||||
self.check_integrity(name, archive.name)
|
||||
get_logger().info('Extracting downloaded %s archive', name)
|
||||
tar = tarfile.open(name=archive.name)
|
||||
tar.extractall(path=path)
|
||||
|
||||
Reference in New Issue
Block a user