mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-06 04:32:43 +00:00
fix: checksum testing
This commit is contained in:
@@ -29,6 +29,19 @@ __author__ = 'Deezer Research'
|
||||
__license__ = 'MIT License'
|
||||
|
||||
|
||||
def compute_file_checksum(path):
|
||||
""" Computes given path file sha256.
|
||||
|
||||
:param path: Path of the file to compute checksum for.
|
||||
:returns: File checksum.
|
||||
"""
|
||||
sha256 = hashlib.sha256()
|
||||
with open(path, 'rb') as stream:
|
||||
for chunk in iter(lambda: stream.read(4096), b''):
|
||||
sha256.update(chunk)
|
||||
return sha256.hexdigest()
|
||||
|
||||
|
||||
class GithubModelProvider(ModelProvider):
|
||||
""" A ModelProvider implementation backed on Github for remote storage. """
|
||||
|
||||
@@ -67,21 +80,6 @@ class GithubModelProvider(ModelProvider):
|
||||
raise ValueError('No checksum for model {}'.format(name))
|
||||
return index[name]
|
||||
|
||||
def check_integrity(self, name, path):
|
||||
""" Computes given path file sha256 and compares it to reference index
|
||||
from release. Raise an exception if not matching.
|
||||
|
||||
:param name: Name of the model to compute checksum for.
|
||||
:param path: Path of the file to compare checksum with.
|
||||
:raise IOerror: if checksum is not valid or index cannot be downloaded.
|
||||
"""
|
||||
sha256 = hashlib.sha256()
|
||||
with open(path, 'rb') as stream:
|
||||
for chunk in iter(lambda: stream.read(4096), b''):
|
||||
sha256.update(chunk)
|
||||
if sha256.hexdigest() != self.checksum(name):
|
||||
raise IOError('Downloaded file is corrupted, please retry')
|
||||
|
||||
def download(self, name, path):
|
||||
""" Download model denoted by the given name to disk.
|
||||
|
||||
@@ -104,7 +102,8 @@ class GithubModelProvider(ModelProvider):
|
||||
if chunk:
|
||||
stream.write(chunk)
|
||||
get_logger().info('Validating archive checksum')
|
||||
self.check_integrity(name, archive.name)
|
||||
if compute_file_checksum(archive.name) != self.checksum(name):
|
||||
raise IOError('Downloaded file is corrupted, please retry')
|
||||
get_logger().info('Extracting downloaded %s archive', name)
|
||||
tar = tarfile.open(name=archive.name)
|
||||
tar.extractall(path=path)
|
||||
|
||||
@@ -53,9 +53,3 @@ def test_separate_to_file(configuration, instruments):
|
||||
directory)
|
||||
for instrument in instruments:
|
||||
assert exists(join(directory, '{}.wav'.format(instrument)))
|
||||
for instrument in instruments:
|
||||
for compared in instruments:
|
||||
if instrument != compared:
|
||||
assert not filecmp.cmp(
|
||||
join(directory, '{}.wav'.format(instrument)),
|
||||
join(directory, '{}.wav'.format(compared)))
|
||||
|
||||
Reference in New Issue
Block a user