diff --git a/spleeter/model/provider/github.py b/spleeter/model/provider/github.py index aad0c44..3afecd1 100644 --- a/spleeter/model/provider/github.py +++ b/spleeter/model/provider/github.py @@ -18,7 +18,6 @@ import tarfile from os import environ from tempfile import TemporaryFile -from shutil import copyfileobj import requests @@ -60,14 +59,18 @@ class GithubModelProvider(ModelProvider): self._release, name) get_logger().info('Downloading model archive %s', url) - response = requests.get(url, stream=True) - if response.status_code != 200: - raise IOError(f'Resource {url} not found') - with TemporaryFile() as stream: - copyfileobj(response.raw, stream) - get_logger().info('Extracting downloaded %s archive', name) - stream.seek(0) - tar = tarfile.open(fileobj=stream) - tar.extractall(path=path) - tar.close() + with requests.get(url, stream=True) as response: + # Note: check for error logging here or upstream ? + response.raise_for_status() + with TemporaryFile() as stream: + # Note: check for chunk size parameters ? + for chunk in response.iter_content(chunk_size=8192): + if chunk: + stream.write(chunk) + get_logger().info('Extracting downloaded %s archive', name) + stream.seek(0) + tar = tarfile.open(fileobj=stream) + tar.extractall(path=path) + tar.close() + # TODO: perform checksum control get_logger().info('%s model file(s) extracted', name)