mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
fix: checksum testing
This commit is contained in:
@@ -29,6 +29,19 @@ __author__ = 'Deezer Research'
|
|||||||
__license__ = 'MIT License'
|
__license__ = 'MIT License'
|
||||||
|
|
||||||
|
|
||||||
|
def compute_file_checksum(path):
|
||||||
|
""" Computes given path file sha256.
|
||||||
|
|
||||||
|
:param path: Path of the file to compute checksum for.
|
||||||
|
:returns: File checksum.
|
||||||
|
"""
|
||||||
|
sha256 = hashlib.sha256()
|
||||||
|
with open(path, 'rb') as stream:
|
||||||
|
for chunk in iter(lambda: stream.read(4096), b''):
|
||||||
|
sha256.update(chunk)
|
||||||
|
return sha256.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
class GithubModelProvider(ModelProvider):
|
class GithubModelProvider(ModelProvider):
|
||||||
""" A ModelProvider implementation backed on Github for remote storage. """
|
""" A ModelProvider implementation backed on Github for remote storage. """
|
||||||
|
|
||||||
@@ -67,21 +80,6 @@ class GithubModelProvider(ModelProvider):
|
|||||||
raise ValueError('No checksum for model {}'.format(name))
|
raise ValueError('No checksum for model {}'.format(name))
|
||||||
return index[name]
|
return index[name]
|
||||||
|
|
||||||
def check_integrity(self, name, path):
|
|
||||||
""" Computes given path file sha256 and compares it to reference index
|
|
||||||
from release. Raise an exception if not matching.
|
|
||||||
|
|
||||||
:param name: Name of the model to compute checksum for.
|
|
||||||
:param path: Path of the file to compare checksum with.
|
|
||||||
:raise IOerror: if checksum is not valid or index cannot be downloaded.
|
|
||||||
"""
|
|
||||||
sha256 = hashlib.sha256()
|
|
||||||
with open(path, 'rb') as stream:
|
|
||||||
for chunk in iter(lambda: stream.read(4096), b''):
|
|
||||||
sha256.update(chunk)
|
|
||||||
if sha256.hexdigest() != self.checksum(name):
|
|
||||||
raise IOError('Downloaded file is corrupted, please retry')
|
|
||||||
|
|
||||||
def download(self, name, path):
|
def download(self, name, path):
|
||||||
""" Download model denoted by the given name to disk.
|
""" Download model denoted by the given name to disk.
|
||||||
|
|
||||||
@@ -104,7 +102,8 @@ class GithubModelProvider(ModelProvider):
|
|||||||
if chunk:
|
if chunk:
|
||||||
stream.write(chunk)
|
stream.write(chunk)
|
||||||
get_logger().info('Validating archive checksum')
|
get_logger().info('Validating archive checksum')
|
||||||
self.check_integrity(name, archive.name)
|
if compute_file_checksum(archive.name) != self.checksum(name):
|
||||||
|
raise IOError('Downloaded file is corrupted, please retry')
|
||||||
get_logger().info('Extracting downloaded %s archive', name)
|
get_logger().info('Extracting downloaded %s archive', name)
|
||||||
tar = tarfile.open(name=archive.name)
|
tar = tarfile.open(name=archive.name)
|
||||||
tar.extractall(path=path)
|
tar.extractall(path=path)
|
||||||
|
|||||||
@@ -53,9 +53,3 @@ def test_separate_to_file(configuration, instruments):
|
|||||||
directory)
|
directory)
|
||||||
for instrument in instruments:
|
for instrument in instruments:
|
||||||
assert exists(join(directory, '{}.wav'.format(instrument)))
|
assert exists(join(directory, '{}.wav'.format(instrument)))
|
||||||
for instrument in instruments:
|
|
||||||
for compared in instruments:
|
|
||||||
if instrument != compared:
|
|
||||||
assert not filecmp.cmp(
|
|
||||||
join(directory, '{}.wav'.format(instrument)),
|
|
||||||
join(directory, '{}.wav'.format(compared)))
|
|
||||||
|
|||||||
Reference in New Issue
Block a user