diff --git a/spleeter/model/provider/github.py b/spleeter/model/provider/github.py index fdd3d80..cd4d904 100644 --- a/spleeter/model/provider/github.py +++ b/spleeter/model/provider/github.py @@ -16,6 +16,7 @@ import hashlib import tarfile +import os from tempfile import NamedTemporaryFile @@ -96,16 +97,18 @@ class GithubModelProvider(ModelProvider): with requests.get(url, stream=True) as response: response.raise_for_status() archive = NamedTemporaryFile(delete=False) - with archive as stream: - # Note: check for chunk size parameters ? - for chunk in response.iter_content(chunk_size=8192): - if chunk: - stream.write(chunk) - get_logger().info('Validating archive checksum') - if compute_file_checksum(archive.name) != self.checksum(name): - raise IOError('Downloaded file is corrupted, please retry') - get_logger().info('Extracting downloaded %s archive', name) - tar = tarfile.open(name=archive.name) - tar.extractall(path=path) - tar.close() + try: + with archive as stream: + # Note: check for chunk size parameters ? + for chunk in response.iter_content(chunk_size=8192): + if chunk: + stream.write(chunk) + get_logger().info('Validating archive checksum') + if compute_file_checksum(archive.name) != self.checksum(name): + raise IOError('Downloaded file is corrupted, please retry') + get_logger().info('Extracting downloaded %s archive', name) + with tarfile.open(name=archive.name) as tar: + tar.extractall(path=path) + finally: + os.unlink(archive.name) get_logger().info('%s model file(s) extracted', name)