From eb32348c348b965b345c28e64769a9df999658c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Voituret?= Date: Tue, 19 Nov 2019 15:04:32 +0100 Subject: [PATCH] fix: checksum testing --- spleeter/model/provider/github.py | 31 +++++++++++++++---------------- tests/test_separator.py | 6 ------ 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/spleeter/model/provider/github.py b/spleeter/model/provider/github.py index d409bca..fdd3d80 100644 --- a/spleeter/model/provider/github.py +++ b/spleeter/model/provider/github.py @@ -29,6 +29,19 @@ __author__ = 'Deezer Research' __license__ = 'MIT License' +def compute_file_checksum(path): + """ Computes given path file sha256. + + :param path: Path of the file to compute checksum for. + :returns: File checksum. + """ + sha256 = hashlib.sha256() + with open(path, 'rb') as stream: + for chunk in iter(lambda: stream.read(4096), b''): + sha256.update(chunk) + return sha256.hexdigest() + + class GithubModelProvider(ModelProvider): """ A ModelProvider implementation backed on Github for remote storage. """ @@ -67,21 +80,6 @@ class GithubModelProvider(ModelProvider): raise ValueError('No checksum for model {}'.format(name)) return index[name] - def check_integrity(self, name, path): - """ Computes given path file sha256 and compares it to reference index - from release. Raise an exception if not matching. - - :param name: Name of the model to compute checksum for. - :param path: Path of the file to compare checksum with. - :raise IOerror: if checksum is not valid or index cannot be downloaded. - """ - sha256 = hashlib.sha256() - with open(path, 'rb') as stream: - for chunk in iter(lambda: stream.read(4096), b''): - sha256.update(chunk) - if sha256.hexdigest() != self.checksum(name): - raise IOError('Downloaded file is corrupted, please retry') - def download(self, name, path): """ Download model denoted by the given name to disk. @@ -104,7 +102,8 @@ class GithubModelProvider(ModelProvider): if chunk: stream.write(chunk) get_logger().info('Validating archive checksum') - self.check_integrity(name, archive.name) + if compute_file_checksum(archive.name) != self.checksum(name): + raise IOError('Downloaded file is corrupted, please retry') get_logger().info('Extracting downloaded %s archive', name) tar = tarfile.open(name=archive.name) tar.extractall(path=path) diff --git a/tests/test_separator.py b/tests/test_separator.py index d2648c2..a95128b 100644 --- a/tests/test_separator.py +++ b/tests/test_separator.py @@ -53,9 +53,3 @@ def test_separate_to_file(configuration, instruments): directory) for instrument in instruments: assert exists(join(directory, '{}.wav'.format(instrument))) - for instrument in instruments: - for compared in instruments: - if instrument != compared: - assert not filecmp.cmp( - join(directory, '{}.wav'.format(instrument)), - join(directory, '{}.wav'.format(compared)))