69 lines
2.1 KiB
Python
69 lines
2.1 KiB
Python
import os
|
|
import hashlib
|
|
import errno
|
|
import tarfile
|
|
from six.moves import urllib
|
|
from torch.utils.model_zoo import tqdm
|
|
|
|
def gen_bar_updater():
|
|
pbar = tqdm(total=None)
|
|
|
|
def bar_update(count, block_size, total_size):
|
|
if pbar.total is None and total_size:
|
|
pbar.total = total_size
|
|
progress_bytes = count * block_size
|
|
pbar.update(progress_bytes - pbar.n)
|
|
|
|
return bar_update
|
|
|
|
def check_integrity(fpath, md5=None):
|
|
if md5 is None:
|
|
return True
|
|
if not os.path.isfile(fpath):
|
|
return False
|
|
md5o = hashlib.md5()
|
|
with open(fpath, 'rb') as f:
|
|
# read in 1MB chunks
|
|
for chunk in iter(lambda: f.read(1024 * 1024), b''):
|
|
md5o.update(chunk)
|
|
md5c = md5o.hexdigest()
|
|
if md5c != md5:
|
|
return False
|
|
return True
|
|
|
|
def makedir_exist_ok(dirpath):
|
|
try:
|
|
os.makedirs(dirpath)
|
|
except OSError as e:
|
|
if e.errno == errno.EEXIST:
|
|
pass
|
|
else:
|
|
pass
|
|
|
|
def download_url(url, root, filename=None, md5=None):
|
|
"""Download a file from a url and place it in root."""
|
|
root = os.path.expanduser(root)
|
|
if not filename:
|
|
filename = os.path.basename(url)
|
|
fpath = os.path.join(root, filename)
|
|
|
|
makedir_exist_ok(root)
|
|
|
|
# downloads file
|
|
if os.path.isfile(fpath) and check_integrity(fpath, md5):
|
|
print('Using downloaded and verified file: ' + fpath)
|
|
else:
|
|
try:
|
|
print('Downloading ' + url + ' to ' + fpath)
|
|
urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater())
|
|
except OSError:
|
|
if url[:5] == 'https':
|
|
url = url.replace('https:', 'http:')
|
|
print('Failed download. Trying https -> http instead.'
|
|
' Downloading ' + url + ' to ' + fpath)
|
|
urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater())
|
|
|
|
def download_extract(url, root, filename, md5):
|
|
download_url(url, root, filename, md5)
|
|
with tarfile.open(os.path.join(root, filename), "r") as tar:
|
|
tar.extractall(path=root) |