123 lines
4.6 KiB
Python
123 lines
4.6 KiB
Python
"""Model store which handles pretrained models """
|
|
from .fcn import *
|
|
from .fcnv2 import *
|
|
from .pspnet import *
|
|
from .deeplabv3 import *
|
|
from .deeplabv3_plus import *
|
|
from .danet import *
|
|
from .denseaspp import *
|
|
from .bisenet import *
|
|
from .encnet import *
|
|
from .dunet import *
|
|
from .icnet import *
|
|
from .enet import *
|
|
from .ocnet import *
|
|
from .ccnet import *
|
|
from .psanet import *
|
|
from .cgnet import *
|
|
from .espnet import *
|
|
from .lednet import *
|
|
from .dfanet import *
|
|
|
|
__all__ = ['get_model', 'get_model_list', 'get_segmentation_model']
|
|
|
|
_models = {
|
|
'fcn32s_vgg16_voc': get_fcn32s_vgg16_voc,
|
|
'fcn16s_vgg16_voc': get_fcn16s_vgg16_voc,
|
|
'fcn8s_vgg16_voc': get_fcn8s_vgg16_voc,
|
|
'fcn_resnet50_voc': get_fcn_resnet50_voc,
|
|
'fcn_resnet101_voc': get_fcn_resnet101_voc,
|
|
'fcn_resnet152_voc': get_fcn_resnet152_voc,
|
|
'psp_resnet50_voc': get_psp_resnet50_voc,
|
|
'psp_resnet50_ade': get_psp_resnet50_ade,
|
|
'psp_resnet101_voc': get_psp_resnet101_voc,
|
|
'psp_resnet101_ade': get_psp_resnet101_ade,
|
|
'psp_resnet101_citys': get_psp_resnet101_citys,
|
|
'psp_resnet101_coco': get_psp_resnet101_coco,
|
|
'deeplabv3_resnet50_voc': get_deeplabv3_resnet50_voc,
|
|
'deeplabv3_resnet101_voc': get_deeplabv3_resnet101_voc,
|
|
'deeplabv3_resnet152_voc': get_deeplabv3_resnet152_voc,
|
|
'deeplabv3_resnet50_ade': get_deeplabv3_resnet50_ade,
|
|
'deeplabv3_resnet101_ade': get_deeplabv3_resnet101_ade,
|
|
'deeplabv3_resnet152_ade': get_deeplabv3_resnet152_ade,
|
|
'deeplabv3_plus_xception_voc': get_deeplabv3_plus_xception_voc,
|
|
'danet_resnet50_ciyts': get_danet_resnet50_citys,
|
|
'danet_resnet101_citys': get_danet_resnet101_citys,
|
|
'danet_resnet152_citys': get_danet_resnet152_citys,
|
|
'denseaspp_densenet121_citys': get_denseaspp_densenet121_citys,
|
|
'denseaspp_densenet161_citys': get_denseaspp_densenet161_citys,
|
|
'denseaspp_densenet169_citys': get_denseaspp_densenet169_citys,
|
|
'denseaspp_densenet201_citys': get_denseaspp_densenet201_citys,
|
|
'bisenet_resnet18_citys': get_bisenet_resnet18_citys,
|
|
'encnet_resnet50_ade': get_encnet_resnet50_ade,
|
|
'encnet_resnet101_ade': get_encnet_resnet101_ade,
|
|
'encnet_resnet152_ade': get_encnet_resnet152_ade,
|
|
'dunet_resnet50_pascal_voc': get_dunet_resnet50_pascal_voc,
|
|
'dunet_resnet101_pascal_voc': get_dunet_resnet101_pascal_voc,
|
|
'dunet_resnet152_pascal_voc': get_dunet_resnet152_pascal_voc,
|
|
'icnet_resnet50_citys': get_icnet_resnet50_citys,
|
|
'icnet_resnet101_citys': get_icnet_resnet101_citys,
|
|
'icnet_resnet152_citys': get_icnet_resnet152_citys,
|
|
'enet_citys': get_enet_citys,
|
|
'base_ocnet_resnet101_citys': get_base_ocnet_resnet101_citys,
|
|
'pyramid_ocnet_resnet101_citys': get_pyramid_ocnet_resnet101_citys,
|
|
'asp_ocnet_resnet101_citys': get_asp_ocnet_resnet101_citys,
|
|
'ccnet_resnet50_citys': get_ccnet_resnet50_citys,
|
|
'ccnet_resnet101_citys': get_ccnet_resnet101_citys,
|
|
'ccnet_resnet152_citys': get_ccnet_resnet152_citys,
|
|
'ccnet_resnet50_ade': get_ccnet_resnet50_ade,
|
|
'ccnet_resnet101_ade': get_ccnet_resnet101_ade,
|
|
'ccnet_resnet152_ade': get_ccnet_resnet152_ade,
|
|
'psanet_resnet50_voc': get_psanet_resnet50_voc,
|
|
'psanet_resnet101_voc': get_psanet_resnet101_voc,
|
|
'psanet_resnet152_voc': get_psanet_resnet152_voc,
|
|
'psanet_resnet50_citys': get_psanet_resnet50_citys,
|
|
'psanet_resnet101_citys': get_psanet_resnet101_citys,
|
|
'psanet_resnet152_citys': get_psanet_resnet152_citys,
|
|
'cgnet_citys': get_cgnet_citys,
|
|
'espnet_citys': get_espnet_citys,
|
|
'lednet_citys': get_lednet_citys,
|
|
'dfanet_citys': get_dfanet_citys,
|
|
}
|
|
|
|
|
|
def get_model(name, **kwargs):
|
|
name = name.lower()
|
|
if name not in _models:
|
|
err_str = '"%s" is not among the following model list:\n\t' % (name)
|
|
err_str += '%s' % ('\n\t'.join(sorted(_models.keys())))
|
|
raise ValueError(err_str)
|
|
net = _models[name](**kwargs)
|
|
return net
|
|
|
|
|
|
def get_model_list():
|
|
return _models.keys()
|
|
|
|
|
|
def get_segmentation_model(model, **kwargs):
|
|
models = {
|
|
'fcn32s': get_fcn32s,
|
|
'fcn16s': get_fcn16s,
|
|
'fcn8s': get_fcn8s,
|
|
'fcn': get_fcn,
|
|
'psp': get_psp,
|
|
'deeplabv3': get_deeplabv3,
|
|
'deeplabv3_plus': get_deeplabv3_plus,
|
|
'danet': get_danet,
|
|
'denseaspp': get_denseaspp,
|
|
'bisenet': get_bisenet,
|
|
'encnet': get_encnet,
|
|
'dunet': get_dunet,
|
|
'icnet': get_icnet,
|
|
'enet': get_enet,
|
|
'ocnet': get_ocnet,
|
|
'ccnet': get_ccnet,
|
|
'psanet': get_psanet,
|
|
'cgnet': get_cgnet,
|
|
'espnet': get_espnet,
|
|
'lednet': get_lednet,
|
|
'dfanet': get_dfanet,
|
|
}
|
|
return models[model](**kwargs)
|