用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123 lines
4.6KB

  1. """Model store which handles pretrained models """
  2. from .fcn import *
  3. from .fcnv2 import *
  4. from .pspnet import *
  5. from .deeplabv3 import *
  6. from .deeplabv3_plus import *
  7. from .danet import *
  8. from .denseaspp import *
  9. from .bisenet import *
  10. from .encnet import *
  11. from .dunet import *
  12. from .icnet import *
  13. from .enet import *
  14. from .ocnet import *
  15. from .ccnet import *
  16. from .psanet import *
  17. from .cgnet import *
  18. from .espnet import *
  19. from .lednet import *
  20. from .dfanet import *
  21. __all__ = ['get_model', 'get_model_list', 'get_segmentation_model']
  22. _models = {
  23. 'fcn32s_vgg16_voc': get_fcn32s_vgg16_voc,
  24. 'fcn16s_vgg16_voc': get_fcn16s_vgg16_voc,
  25. 'fcn8s_vgg16_voc': get_fcn8s_vgg16_voc,
  26. 'fcn_resnet50_voc': get_fcn_resnet50_voc,
  27. 'fcn_resnet101_voc': get_fcn_resnet101_voc,
  28. 'fcn_resnet152_voc': get_fcn_resnet152_voc,
  29. 'psp_resnet50_voc': get_psp_resnet50_voc,
  30. 'psp_resnet50_ade': get_psp_resnet50_ade,
  31. 'psp_resnet101_voc': get_psp_resnet101_voc,
  32. 'psp_resnet101_ade': get_psp_resnet101_ade,
  33. 'psp_resnet101_citys': get_psp_resnet101_citys,
  34. 'psp_resnet101_coco': get_psp_resnet101_coco,
  35. 'deeplabv3_resnet50_voc': get_deeplabv3_resnet50_voc,
  36. 'deeplabv3_resnet101_voc': get_deeplabv3_resnet101_voc,
  37. 'deeplabv3_resnet152_voc': get_deeplabv3_resnet152_voc,
  38. 'deeplabv3_resnet50_ade': get_deeplabv3_resnet50_ade,
  39. 'deeplabv3_resnet101_ade': get_deeplabv3_resnet101_ade,
  40. 'deeplabv3_resnet152_ade': get_deeplabv3_resnet152_ade,
  41. 'deeplabv3_plus_xception_voc': get_deeplabv3_plus_xception_voc,
  42. 'danet_resnet50_ciyts': get_danet_resnet50_citys,
  43. 'danet_resnet101_citys': get_danet_resnet101_citys,
  44. 'danet_resnet152_citys': get_danet_resnet152_citys,
  45. 'denseaspp_densenet121_citys': get_denseaspp_densenet121_citys,
  46. 'denseaspp_densenet161_citys': get_denseaspp_densenet161_citys,
  47. 'denseaspp_densenet169_citys': get_denseaspp_densenet169_citys,
  48. 'denseaspp_densenet201_citys': get_denseaspp_densenet201_citys,
  49. 'bisenet_resnet18_citys': get_bisenet_resnet18_citys,
  50. 'encnet_resnet50_ade': get_encnet_resnet50_ade,
  51. 'encnet_resnet101_ade': get_encnet_resnet101_ade,
  52. 'encnet_resnet152_ade': get_encnet_resnet152_ade,
  53. 'dunet_resnet50_pascal_voc': get_dunet_resnet50_pascal_voc,
  54. 'dunet_resnet101_pascal_voc': get_dunet_resnet101_pascal_voc,
  55. 'dunet_resnet152_pascal_voc': get_dunet_resnet152_pascal_voc,
  56. 'icnet_resnet50_citys': get_icnet_resnet50_citys,
  57. 'icnet_resnet101_citys': get_icnet_resnet101_citys,
  58. 'icnet_resnet152_citys': get_icnet_resnet152_citys,
  59. 'enet_citys': get_enet_citys,
  60. 'base_ocnet_resnet101_citys': get_base_ocnet_resnet101_citys,
  61. 'pyramid_ocnet_resnet101_citys': get_pyramid_ocnet_resnet101_citys,
  62. 'asp_ocnet_resnet101_citys': get_asp_ocnet_resnet101_citys,
  63. 'ccnet_resnet50_citys': get_ccnet_resnet50_citys,
  64. 'ccnet_resnet101_citys': get_ccnet_resnet101_citys,
  65. 'ccnet_resnet152_citys': get_ccnet_resnet152_citys,
  66. 'ccnet_resnet50_ade': get_ccnet_resnet50_ade,
  67. 'ccnet_resnet101_ade': get_ccnet_resnet101_ade,
  68. 'ccnet_resnet152_ade': get_ccnet_resnet152_ade,
  69. 'psanet_resnet50_voc': get_psanet_resnet50_voc,
  70. 'psanet_resnet101_voc': get_psanet_resnet101_voc,
  71. 'psanet_resnet152_voc': get_psanet_resnet152_voc,
  72. 'psanet_resnet50_citys': get_psanet_resnet50_citys,
  73. 'psanet_resnet101_citys': get_psanet_resnet101_citys,
  74. 'psanet_resnet152_citys': get_psanet_resnet152_citys,
  75. 'cgnet_citys': get_cgnet_citys,
  76. 'espnet_citys': get_espnet_citys,
  77. 'lednet_citys': get_lednet_citys,
  78. 'dfanet_citys': get_dfanet_citys,
  79. }
  80. def get_model(name, **kwargs):
  81. name = name.lower()
  82. if name not in _models:
  83. err_str = '"%s" is not among the following model list:\n\t' % (name)
  84. err_str += '%s' % ('\n\t'.join(sorted(_models.keys())))
  85. raise ValueError(err_str)
  86. net = _models[name](**kwargs)
  87. return net
  88. def get_model_list():
  89. return _models.keys()
  90. def get_segmentation_model(model, **kwargs):
  91. models = {
  92. 'fcn32s': get_fcn32s,
  93. 'fcn16s': get_fcn16s,
  94. 'fcn8s': get_fcn8s,
  95. 'fcn': get_fcn,
  96. 'psp': get_psp,
  97. 'deeplabv3': get_deeplabv3,
  98. 'deeplabv3_plus': get_deeplabv3_plus,
  99. 'danet': get_danet,
  100. 'denseaspp': get_denseaspp,
  101. 'bisenet': get_bisenet,
  102. 'encnet': get_encnet,
  103. 'dunet': get_dunet,
  104. 'icnet': get_icnet,
  105. 'enet': get_enet,
  106. 'ocnet': get_ocnet,
  107. 'ccnet': get_ccnet,
  108. 'psanet': get_psanet,
  109. 'cgnet': get_cgnet,
  110. 'espnet': get_espnet,
  111. 'lednet': get_lednet,
  112. 'dfanet': get_dfanet,
  113. }
  114. return models[model](**kwargs)