"""Model store which handles pretrained models """ from .fcn import * from .fcnv2 import * from .pspnet import * from .deeplabv3 import * from .deeplabv3_plus import * from .danet import * from .denseaspp import * from .bisenet import * from .encnet import * from .dunet import * from .icnet import * from .enet import * from .ocnet import * from .ccnet import * from .psanet import * from .cgnet import * from .espnet import * from .lednet import * from .dfanet import * __all__ = ['get_model', 'get_model_list', 'get_segmentation_model'] _models = { 'fcn32s_vgg16_voc': get_fcn32s_vgg16_voc, 'fcn16s_vgg16_voc': get_fcn16s_vgg16_voc, 'fcn8s_vgg16_voc': get_fcn8s_vgg16_voc, 'fcn_resnet50_voc': get_fcn_resnet50_voc, 'fcn_resnet101_voc': get_fcn_resnet101_voc, 'fcn_resnet152_voc': get_fcn_resnet152_voc, 'psp_resnet50_voc': get_psp_resnet50_voc, 'psp_resnet50_ade': get_psp_resnet50_ade, 'psp_resnet101_voc': get_psp_resnet101_voc, 'psp_resnet101_ade': get_psp_resnet101_ade, 'psp_resnet101_citys': get_psp_resnet101_citys, 'psp_resnet101_coco': get_psp_resnet101_coco, 'deeplabv3_resnet50_voc': get_deeplabv3_resnet50_voc, 'deeplabv3_resnet101_voc': get_deeplabv3_resnet101_voc, 'deeplabv3_resnet152_voc': get_deeplabv3_resnet152_voc, 'deeplabv3_resnet50_ade': get_deeplabv3_resnet50_ade, 'deeplabv3_resnet101_ade': get_deeplabv3_resnet101_ade, 'deeplabv3_resnet152_ade': get_deeplabv3_resnet152_ade, 'deeplabv3_plus_xception_voc': get_deeplabv3_plus_xception_voc, 'danet_resnet50_ciyts': get_danet_resnet50_citys, 'danet_resnet101_citys': get_danet_resnet101_citys, 'danet_resnet152_citys': get_danet_resnet152_citys, 'denseaspp_densenet121_citys': get_denseaspp_densenet121_citys, 'denseaspp_densenet161_citys': get_denseaspp_densenet161_citys, 'denseaspp_densenet169_citys': get_denseaspp_densenet169_citys, 'denseaspp_densenet201_citys': get_denseaspp_densenet201_citys, 'bisenet_resnet18_citys': get_bisenet_resnet18_citys, 'encnet_resnet50_ade': get_encnet_resnet50_ade, 'encnet_resnet101_ade': get_encnet_resnet101_ade, 'encnet_resnet152_ade': get_encnet_resnet152_ade, 'dunet_resnet50_pascal_voc': get_dunet_resnet50_pascal_voc, 'dunet_resnet101_pascal_voc': get_dunet_resnet101_pascal_voc, 'dunet_resnet152_pascal_voc': get_dunet_resnet152_pascal_voc, 'icnet_resnet50_citys': get_icnet_resnet50_citys, 'icnet_resnet101_citys': get_icnet_resnet101_citys, 'icnet_resnet152_citys': get_icnet_resnet152_citys, 'enet_citys': get_enet_citys, 'base_ocnet_resnet101_citys': get_base_ocnet_resnet101_citys, 'pyramid_ocnet_resnet101_citys': get_pyramid_ocnet_resnet101_citys, 'asp_ocnet_resnet101_citys': get_asp_ocnet_resnet101_citys, 'ccnet_resnet50_citys': get_ccnet_resnet50_citys, 'ccnet_resnet101_citys': get_ccnet_resnet101_citys, 'ccnet_resnet152_citys': get_ccnet_resnet152_citys, 'ccnet_resnet50_ade': get_ccnet_resnet50_ade, 'ccnet_resnet101_ade': get_ccnet_resnet101_ade, 'ccnet_resnet152_ade': get_ccnet_resnet152_ade, 'psanet_resnet50_voc': get_psanet_resnet50_voc, 'psanet_resnet101_voc': get_psanet_resnet101_voc, 'psanet_resnet152_voc': get_psanet_resnet152_voc, 'psanet_resnet50_citys': get_psanet_resnet50_citys, 'psanet_resnet101_citys': get_psanet_resnet101_citys, 'psanet_resnet152_citys': get_psanet_resnet152_citys, 'cgnet_citys': get_cgnet_citys, 'espnet_citys': get_espnet_citys, 'lednet_citys': get_lednet_citys, 'dfanet_citys': get_dfanet_citys, } def get_model(name, **kwargs): name = name.lower() if name not in _models: err_str = '"%s" is not among the following model list:\n\t' % (name) err_str += '%s' % ('\n\t'.join(sorted(_models.keys()))) raise ValueError(err_str) net = _models[name](**kwargs) return net def get_model_list(): return _models.keys() def get_segmentation_model(model, **kwargs): models = { 'fcn32s': get_fcn32s, 'fcn16s': get_fcn16s, 'fcn8s': get_fcn8s, 'fcn': get_fcn, 'psp': get_psp, 'deeplabv3': get_deeplabv3, 'deeplabv3_plus': get_deeplabv3_plus, 'danet': get_danet, 'denseaspp': get_denseaspp, 'bisenet': get_bisenet, 'encnet': get_encnet, 'dunet': get_dunet, 'icnet': get_icnet, 'enet': get_enet, 'ocnet': get_ocnet, 'ccnet': get_ccnet, 'psanet': get_psanet, 'cgnet': get_cgnet, 'espnet': get_espnet, 'lednet': get_lednet, 'dfanet': get_dfanet, } return models[model](**kwargs)