AIlib2/segutils/core/models/model_store.py

69 lines
2.2 KiB
Python

"""Model store which provides pretrained models."""
from __future__ import print_function
import os
import zipfile
from ..utils.download import download, check_sha1
__all__ = ['get_model_file', 'get_resnet_file']
_model_sha1 = {name: checksum for checksum, name in [
('25c4b50959ef024fcc050213a06b614899f94b3d', 'resnet50'),
('2a57e44de9c853fa015b172309a1ee7e2d0e4e2a', 'resnet101'),
('0d43d698c66aceaa2bc0309f55efdd7ff4b143af', 'resnet152'),
]}
encoding_repo_url = 'https://hangzh.s3.amazonaws.com/'
_url_format = '{repo_url}encoding/models/{file_name}.zip'
def short_hash(name):
if name not in _model_sha1:
raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
return _model_sha1[name][:8]
def get_resnet_file(name, root='~/.torch/models'):
file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name))
root = os.path.expanduser(root)
file_path = os.path.join(root, file_name + '.pth')
sha1_hash = _model_sha1[name]
if os.path.exists(file_path):
if check_sha1(file_path, sha1_hash):
return file_path
else:
print('Mismatch in the content of model file {} detected.' +
' Downloading again.'.format(file_path))
else:
print('Model file {} is not found. Downloading.'.format(file_path))
if not os.path.exists(root):
os.makedirs(root)
zip_file_path = os.path.join(root, file_name + '.zip')
repo_url = os.environ.get('ENCODING_REPO', encoding_repo_url)
if repo_url[-1] != '/':
repo_url = repo_url + '/'
download(_url_format.format(repo_url=repo_url, file_name=file_name),
path=zip_file_path,
overwrite=True)
with zipfile.ZipFile(zip_file_path) as zf:
zf.extractall(root)
os.remove(zip_file_path)
if check_sha1(file_path, sha1_hash):
return file_path
else:
raise ValueError('Downloaded file has different hash. Please try again.')
def get_model_file(name, root='~/.torch/models'):
root = os.path.expanduser(root)
file_path = os.path.join(root, name + '.pth')
if os.path.exists(file_path):
return file_path
else:
raise ValueError('Model file is not found. Downloading or trainning.')