diff --git a/spleeter/model/provider/github.py b/spleeter/model/provider/github.py index 2f6d7a3..d409bca 100644 --- a/spleeter/model/provider/github.py +++ b/spleeter/model/provider/github.py @@ -47,14 +47,12 @@ class GithubModelProvider(ModelProvider): self._repository = repository self._release = release - def checksum(self, name, path): - """ Computes given path file sha256 and compares it to reference index - from release. Raise an exception if not matching. + def checksum(self, name): + """ Downloads and returns reference checksum for the given model name. - :param name: Name of the model to compute checksum for. - :param path: Path of the file to compare checksum with. + :param name: Name of the model to get checksum for. + :returns: Checksum of the required model. :raise ValueError: If the given model name is not indexed. - :raise IOerror: if checksum is not valid or index cannot be downloaded. """ url = '{}/{}/{}/{}/{}'.format( self._host, @@ -67,11 +65,21 @@ class GithubModelProvider(ModelProvider): index = response.json() if name not in index: raise ValueError('No checksum for model {}'.format(name)) + return index[name] + + def check_integrity(self, name, path): + """ Computes given path file sha256 and compares it to reference index + from release. Raise an exception if not matching. + + :param name: Name of the model to compute checksum for. + :param path: Path of the file to compare checksum with. + :raise IOerror: if checksum is not valid or index cannot be downloaded. + """ sha256 = hashlib.sha256() with open(path, 'rb') as stream: for chunk in iter(lambda: stream.read(4096), b''): sha256.update(chunk) - if sha256.hexdigest() != index[name]: + if sha256.hexdigest() != self.checksum(name): raise IOError('Downloaded file is corrupted, please retry') def download(self, name, path): @@ -96,7 +104,7 @@ class GithubModelProvider(ModelProvider): if chunk: stream.write(chunk) get_logger().info('Validating archive checksum') - self.checksum(name, archive.name) + self.check_integrity(name, archive.name) get_logger().info('Extracting downloaded %s archive', name) tar = tarfile.open(name=archive.name) tar.extractall(path=path) diff --git a/tests/test_github_model_provider.py b/tests/test_github_model_provider.py new file mode 100644 index 0000000..248b1d5 --- /dev/null +++ b/tests/test_github_model_provider.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# coding: utf8 + +""" TO DOCUMENT """ + +from pytest import raises + +from spleeter.model.provider import get_default_model_provider + + +def test_checksum(): + """ Test archive checksum index retrieval. """ + provider = get_default_model_provider() + assert provider.checksum('2stems') == \ + 'f3a90b39dd2874269e8b05a48a86745df897b848c61f3958efc80a39152bd692' + assert provider.checksum('4stems') == \ + '3adb4a50ad4eb18c7c4d65fcf4cf2367a07d48408a5eb7d03cd20067429dfaa8' + assert provider.checksum('5stems') == \ + '25a1e87eb5f75cc72a4d2d5467a0a50ac75f05611f877c278793742513cc7218' + with raises(ValueError): + provider.checksum('laisse moi stems stems stems')