69 lines
2.2 KiB
Python
69 lines
2.2 KiB
Python
"""Model store which provides pretrained models."""
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import zipfile
|
|
|
|
from ..utils.download import download, check_sha1
|
|
|
|
__all__ = ['get_model_file', 'get_resnet_file']
|
|
|
|
_model_sha1 = {name: checksum for checksum, name in [
|
|
('25c4b50959ef024fcc050213a06b614899f94b3d', 'resnet50'),
|
|
('2a57e44de9c853fa015b172309a1ee7e2d0e4e2a', 'resnet101'),
|
|
('0d43d698c66aceaa2bc0309f55efdd7ff4b143af', 'resnet152'),
|
|
]}
|
|
|
|
encoding_repo_url = 'https://hangzh.s3.amazonaws.com/'
|
|
_url_format = '{repo_url}encoding/models/{file_name}.zip'
|
|
|
|
|
|
def short_hash(name):
|
|
if name not in _model_sha1:
|
|
raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
|
|
return _model_sha1[name][:8]
|
|
|
|
|
|
def get_resnet_file(name, root='~/.torch/models'):
|
|
file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name))
|
|
root = os.path.expanduser(root)
|
|
|
|
file_path = os.path.join(root, file_name + '.pth')
|
|
sha1_hash = _model_sha1[name]
|
|
if os.path.exists(file_path):
|
|
if check_sha1(file_path, sha1_hash):
|
|
return file_path
|
|
else:
|
|
print('Mismatch in the content of model file {} detected.' +
|
|
' Downloading again.'.format(file_path))
|
|
else:
|
|
print('Model file {} is not found. Downloading.'.format(file_path))
|
|
|
|
if not os.path.exists(root):
|
|
os.makedirs(root)
|
|
|
|
zip_file_path = os.path.join(root, file_name + '.zip')
|
|
repo_url = os.environ.get('ENCODING_REPO', encoding_repo_url)
|
|
if repo_url[-1] != '/':
|
|
repo_url = repo_url + '/'
|
|
download(_url_format.format(repo_url=repo_url, file_name=file_name),
|
|
path=zip_file_path,
|
|
overwrite=True)
|
|
with zipfile.ZipFile(zip_file_path) as zf:
|
|
zf.extractall(root)
|
|
os.remove(zip_file_path)
|
|
|
|
if check_sha1(file_path, sha1_hash):
|
|
return file_path
|
|
else:
|
|
raise ValueError('Downloaded file has different hash. Please try again.')
|
|
|
|
|
|
def get_model_file(name, root='~/.torch/models'):
|
|
root = os.path.expanduser(root)
|
|
file_path = os.path.join(root, name + '.pth')
|
|
if os.path.exists(file_path):
|
|
return file_path
|
|
else:
|
|
raise ValueError('Model file is not found. Downloading or trainning.')
|