mirror of
https://github.com/YuzuZensai/spleeter.git
synced 2026-01-31 14:58:23 +00:00
fix: remove unused dep
feat: add checksum control
This commit is contained in:
@@ -11,20 +11,8 @@
|
|||||||
-i /path/to/audio1.wav /path/to/audio2.mp3
|
-i /path/to/audio1.wav /path/to/audio2.mp3
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from multiprocessing import Pool
|
|
||||||
from os.path import isabs, join, split, splitext
|
|
||||||
from tempfile import gettempdir
|
|
||||||
|
|
||||||
# pylint: disable=import-error
|
|
||||||
import tensorflow as tf
|
|
||||||
import numpy as np
|
|
||||||
# pylint: enable=import-error
|
|
||||||
|
|
||||||
from ..audio.adapter import get_audio_adapter
|
from ..audio.adapter import get_audio_adapter
|
||||||
from ..audio.convertor import to_n_channels
|
|
||||||
from ..separator import Separator
|
from ..separator import Separator
|
||||||
from ..utils.estimator import create_estimator
|
|
||||||
from ..utils.tensor import set_tensor_shape
|
|
||||||
|
|
||||||
__email__ = 'research@deezer.com'
|
__email__ = 'research@deezer.com'
|
||||||
__author__ = 'Deezer Research'
|
__author__ = 'Deezer Research'
|
||||||
|
|||||||
@@ -14,10 +14,10 @@
|
|||||||
>>> provider.download('2stems', '/path/to/local/storage')
|
>>> provider.download('2stems', '/path/to/local/storage')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import tarfile
|
import tarfile
|
||||||
|
|
||||||
from os import environ
|
from tempfile import NamedTemporaryFile
|
||||||
from tempfile import TemporaryFile
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -34,6 +34,7 @@ class GithubModelProvider(ModelProvider):
|
|||||||
|
|
||||||
LATEST_RELEASE = 'v1.4.0'
|
LATEST_RELEASE = 'v1.4.0'
|
||||||
RELEASE_PATH = 'releases/download'
|
RELEASE_PATH = 'releases/download'
|
||||||
|
CHECKSUM_INDEX = 'checksum.json'
|
||||||
|
|
||||||
def __init__(self, host, repository, release):
|
def __init__(self, host, repository, release):
|
||||||
""" Default constructor.
|
""" Default constructor.
|
||||||
@@ -46,6 +47,33 @@ class GithubModelProvider(ModelProvider):
|
|||||||
self._repository = repository
|
self._repository = repository
|
||||||
self._release = release
|
self._release = release
|
||||||
|
|
||||||
|
def checksum(self, name, path):
|
||||||
|
""" Computes given path file sha256 and compares it to reference index
|
||||||
|
from release. Raise an exception if not matching.
|
||||||
|
|
||||||
|
:param name: Name of the model to compute checksum for.
|
||||||
|
:param path: Path of the file to compare checksum with.
|
||||||
|
:raise ValueError: If the given model name is not indexed.
|
||||||
|
:raise IOerror: if checksum is not valid or index cannot be downloaded.
|
||||||
|
"""
|
||||||
|
url = '{}/{}/{}/{}/{}'.format(
|
||||||
|
self._host,
|
||||||
|
self._repository,
|
||||||
|
self.RELEASE_PATH,
|
||||||
|
self._release,
|
||||||
|
self.CHECKSUM_INDEX)
|
||||||
|
response = requests.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
index = response.json()
|
||||||
|
if name not in index:
|
||||||
|
raise ValueError('No checksum for model {}'.format(name))
|
||||||
|
sha256 = hashlib.sha256()
|
||||||
|
with open(path, 'rb') as stream:
|
||||||
|
for chunk in iter(lambda: stream.read(4096), b''):
|
||||||
|
sha256.update(chunk)
|
||||||
|
if sha256.hexdigest() != index[name]:
|
||||||
|
raise IOError('Downloaded file is corrupted, please retry')
|
||||||
|
|
||||||
def download(self, name, path):
|
def download(self, name, path):
|
||||||
""" Download model denoted by the given name to disk.
|
""" Download model denoted by the given name to disk.
|
||||||
|
|
||||||
@@ -60,17 +88,17 @@ class GithubModelProvider(ModelProvider):
|
|||||||
name)
|
name)
|
||||||
get_logger().info('Downloading model archive %s', url)
|
get_logger().info('Downloading model archive %s', url)
|
||||||
with requests.get(url, stream=True) as response:
|
with requests.get(url, stream=True) as response:
|
||||||
# Note: check for error logging here or upstream ?
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
with TemporaryFile() as stream:
|
archive = NamedTemporaryFile(delete=False)
|
||||||
|
with archive as stream:
|
||||||
# Note: check for chunk size parameters ?
|
# Note: check for chunk size parameters ?
|
||||||
for chunk in response.iter_content(chunk_size=8192):
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
if chunk:
|
if chunk:
|
||||||
stream.write(chunk)
|
stream.write(chunk)
|
||||||
get_logger().info('Extracting downloaded %s archive', name)
|
get_logger().info('Validating archive checksum')
|
||||||
stream.seek(0)
|
self.checksum(name, archive.name)
|
||||||
tar = tarfile.open(fileobj=stream)
|
get_logger().info('Extracting downloaded %s archive', name)
|
||||||
tar.extractall(path=path)
|
tar = tarfile.open(name=archive.name)
|
||||||
tar.close()
|
tar.extractall(path=path)
|
||||||
# TODO: perform checksum control
|
tar.close()
|
||||||
get_logger().info('%s model file(s) extracted', name)
|
get_logger().info('%s model file(s) extracted', name)
|
||||||
|
|||||||
Reference in New Issue
Block a user