"""Model store which provides pretrained models.""" from __future__ import print_function import os import zipfile from ..utils.download import download, check_sha1 __all__ = ['get_model_file', 'get_resnet_file'] _model_sha1 = {name: checksum for checksum, name in [ ('25c4b50959ef024fcc050213a06b614899f94b3d', 'resnet50'), ('2a57e44de9c853fa015b172309a1ee7e2d0e4e2a', 'resnet101'), ('0d43d698c66aceaa2bc0309f55efdd7ff4b143af', 'resnet152'), ]} encoding_repo_url = 'https://hangzh.s3.amazonaws.com/' _url_format = '{repo_url}encoding/models/{file_name}.zip' def short_hash(name): if name not in _model_sha1: raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) return _model_sha1[name][:8] def get_resnet_file(name, root='~/.torch/models'): file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name)) root = os.path.expanduser(root) file_path = os.path.join(root, file_name + '.pth') sha1_hash = _model_sha1[name] if os.path.exists(file_path): if check_sha1(file_path, sha1_hash): return file_path else: print('Mismatch in the content of model file {} detected.' + ' Downloading again.'.format(file_path)) else: print('Model file {} is not found. Downloading.'.format(file_path)) if not os.path.exists(root): os.makedirs(root) zip_file_path = os.path.join(root, file_name + '.zip') repo_url = os.environ.get('ENCODING_REPO', encoding_repo_url) if repo_url[-1] != '/': repo_url = repo_url + '/' download(_url_format.format(repo_url=repo_url, file_name=file_name), path=zip_file_path, overwrite=True) with zipfile.ZipFile(zip_file_path) as zf: zf.extractall(root) os.remove(zip_file_path) if check_sha1(file_path, sha1_hash): return file_path else: raise ValueError('Downloaded file has different hash. Please try again.') def get_model_file(name, root='~/.torch/models'): root = os.path.expanduser(root) file_path = os.path.join(root, name + '.pth') if os.path.exists(file_path): return file_path else: raise ValueError('Model file is not found. Downloading or trainning.')