You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1328 lines
60KB

  1. # -*- coding: utf-8 -*-
  2. import sys
  3. from pickle import dumps, loads
  4. from traceback import format_exc
  5. from loguru import logger
  6. from common.Constant import COLOR
  7. from enums.ExceptionEnum import ExceptionType
  8. from enums.ModelTypeEnum import ModelType, BAIDU_MODEL_TARGET_CONFIG
  9. from exception.CustomerException import ServiceException
  10. from util.ImgBaiduSdk import AipBodyAnalysisClient, AipImageClassifyClient
  11. from util.PlotsUtils import get_label_arrays
  12. from util.TorchUtils import select_device
  13. sys.path.extend(['..', '../AIlib2'])
  14. from AI import AI_process, AI_process_forest, get_postProcess_para, AI_Seg_process, ocr_process
  15. import time
  16. from segutils.segmodel import SegModel
  17. from models.experimental import attempt_load
  18. from utils.torch_utils import select_device
  19. from utilsK.queRiver import get_labelnames, get_label_arrays, save_problem_images
  20. from obbUtils.shipUtils import OBB_infer
  21. from obbUtils.load_obb_model import load_model_decoder_OBB
  22. import torch
  23. import tensorrt as trt
  24. from utilsK.jkmUtils import pre_process, post_process, get_return_data
  25. FONTPATH = "../AIlib2/conf/platech.ttf"
  26. # 河道模型
  27. class RiverModel:
  28. __slots__ = "model_conf"
  29. def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None):
  30. try:
  31. logger.info("########################加载河道模型########################, requestId:{}", requestId)
  32. trtFlag_det = True
  33. trtFlag_seg = True
  34. # 公共变量
  35. par = {
  36. 'device': str(device),
  37. 'labelnames': ["排口", "水生植被", "其它", "漂浮物", "污染排口", "菜地", "违建", "岸坡垃圾"],
  38. 'detModelpara': [],
  39. 'seg_nclass': 2,
  40. 'segRegionCnt': 1,
  41. 'slopeIndex': [5, 6, 7],
  42. 'segPar': {
  43. 'modelSize': (640, 360),
  44. 'mean': (0.485, 0.456, 0.406),
  45. 'std': (0.229, 0.224, 0.225),
  46. 'numpy': False,
  47. 'RGB_convert_first': True
  48. },
  49. 'postFile': {
  50. "name": "post_process",
  51. "conf_thres": 0.25,
  52. "iou_thres": 0.45,
  53. "classes": 5,
  54. "rainbows": COLOR
  55. }
  56. }
  57. if trtFlag_det:
  58. par['detweights'] = "../AIlib2/weights/%s/yolov5_%s_fp16.engine" % (modeType.value[3], gpu_name)
  59. else:
  60. par['detweights'] = "../AIlib2/weights/conf/%s/yolov5.pt" % modeType.value[3]
  61. if trtFlag_seg:
  62. par['segweights'] = '../AIlib2/weights/%s/stdc_360X640_%s_fp16.engine' % (modeType.value[3], gpu_name)
  63. else:
  64. par['segweights'] = '../AIlib2/weights/conf/%s/stdc_360X640.pth' % modeType.value[3]
  65. mode = 'others'
  66. postPar = None
  67. segPar = par.get('segPar')
  68. new_device = select_device(par.get('device'))
  69. names = par.get('labelnames')
  70. half = new_device.type != 'cpu'
  71. Detweights = par.get('detweights') # 升级后的检测模型
  72. if trtFlag_det:
  73. with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
  74. model = runtime.deserialize_cuda_engine(f.read())
  75. else:
  76. model = attempt_load(Detweights, map_location=new_device)
  77. if half:
  78. model.half()
  79. Segweights = par.get('segweights')
  80. if trtFlag_seg:
  81. with open(Segweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
  82. segmodel = runtime.deserialize_cuda_engine(f.read())
  83. else:
  84. segmodel = SegModel(nclass=par.get('seg_nclass'), weights=Segweights, device=new_device)
  85. conf_thres = par.get('postFile').get("conf_thres")
  86. iou_thres = par.get('postFile').get("iou_thres")
  87. # classes = par.get('postFile').get("classes")
  88. rainbows = par.get('postFile').get("rainbows")
  89. objectPar = {
  90. 'half': half,
  91. 'device': new_device,
  92. 'conf_thres': conf_thres,
  93. 'iou_thres': iou_thres,
  94. 'allowedList': [],
  95. 'slopeIndex': par.get('slopeIndex'),
  96. 'segRegionCnt': par.get('segRegionCnt'),
  97. 'trtFlag_det': trtFlag_det,
  98. 'trtFlag_seg': trtFlag_seg
  99. }
  100. """
  101. frame, model, segmodel, names, label_arraylist, rainbows, objectPar, font, segPar, mode, postPar, requestId
  102. """
  103. model_param = [None, model, segmodel, names, None, rainbows, objectPar, None, segPar, mode, postPar,
  104. requestId]
  105. self.model_conf = (modeType, allowedList, model_param)
  106. except Exception:
  107. logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  108. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  109. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  110. def get_label_arraylist(width, names, rainbows):
  111. fontsize = int(width / 1920 * 40)
  112. line_thickness = 1
  113. boxLine_thickness = 1
  114. waterLineWidth = 1
  115. if width >= 960:
  116. line_thickness = int(round(width / 1920 * 3) - 1)
  117. boxLine_thickness = int(round(width / 1920 * 3))
  118. waterLineWidth = int(round(width / 1920 * 3))
  119. numFontSize = float(format(width / 1920 * 1.1, '.1f')) # 文字大小
  120. digitFont = {'line_thickness': line_thickness,
  121. 'boxLine_thickness': boxLine_thickness,
  122. 'fontSize': numFontSize,
  123. 'waterLineColor': (0, 255, 255),
  124. 'segLineShow': False,
  125. 'waterLineWidth': waterLineWidth}
  126. label_arraylist = get_label_arrays(names, rainbows, outfontsize=fontsize,
  127. fontpath=FONTPATH)
  128. return digitFont, label_arraylist
  129. '''
  130. #输入参数
  131. # im0s---原始图像列表
  132. # model---检测模型,segmodel---分割模型(如若没有用到,则为None)
  133. #输出:两个元素(列表,字符)构成的元组,[im0s[0],im0,det_xywh,iframe],strout
  134. # [im0s[0],im0,det_xywh,iframe]中,
  135. # im0s[0]--原始图像,im0--AI处理后的图像,iframe--帧号/暂时不需用到。
  136. # det_xywh--检测结果,是一个列表。
  137. # 其中每一个元素表示一个目标构成如:[float(cls_c), xc,yc,w,h, float(conf_c)]
  138. # #cls_c--类别,如0,1,2,3; xc,yc,w,h--中心点坐标及宽;conf_c--得分, 取值范围在0-1之间
  139. # #strout---统计AI处理个环节的时间
  140. '''
  141. def model_process(param):
  142. # frame, model, segmodel, names, label_arraylist, rainbows, objectPar, font, segPar, mode, postPar, requestId
  143. try:
  144. return AI_process([param[0]], param[1], param[2], param[3], param[4], param[5], objectPar=param[6],
  145. font=param[7], segPar=param[8], mode=param[9], postPar=param[10])
  146. except ServiceException as s:
  147. raise s
  148. except Exception:
  149. # self.num += 1
  150. # cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame)
  151. logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), param[11])
  152. raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
  153. ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
  154. # 河道检测模型
  155. class River2Model:
  156. __slots__ = "model_conf"
  157. def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None):
  158. try:
  159. logger.info("########################加载河道检测模型########################, requestId:{}", requestId)
  160. trtFlag_det = True, # 检测模型是否采用TRT
  161. trtFlag_seg = True, # 分割模型是否采用TRT
  162. # 公共变量
  163. par = {
  164. 'device': str(device),
  165. 'labelnames': [
  166. "漂浮物", "岸坡垃圾", "排口", "违建", "菜地", "水生植物", "河湖人员", "钓鱼人员", "船只", "蓝藻"],
  167. 'detModelpara': [],
  168. 'seg_nclass': 2,
  169. 'segRegionCnt': 1,
  170. 'slopeIndex': [1, 3, 4, 7],
  171. 'segPar': {
  172. 'modelSize': (640, 360),
  173. 'mean': (0.485, 0.456, 0.406),
  174. 'std': (0.229, 0.224, 0.225),
  175. 'numpy': False,
  176. 'RGB_convert_first': True
  177. },
  178. 'postFile': {
  179. "name": "post_process",
  180. "conf_thres": 0.25,
  181. "iou_thres": 0.45,
  182. "ovlap_thres_crossCategory": 0.65,
  183. "classes": 5,
  184. "rainbows": COLOR
  185. }
  186. }
  187. if trtFlag_det:
  188. par['detweights'] = "../AIlib2/weights/%s/yolov5_%s_fp16.engine" % (modeType.value[3], gpu_name)
  189. else:
  190. par['detweights'] = "../AIlib2/weights/conf/%s/yolov5.pt" % modeType.value[3]
  191. if trtFlag_seg:
  192. par['segweights'] = '../AIlib2/weights/%s/stdc_360X640_%s_fp16.engine' % (modeType.value[3], gpu_name)
  193. else:
  194. par['segweights'] = '../AIlib2/weights/conf/%s/stdc_360X640.pth' % modeType.value[3]
  195. mode = 'others'
  196. postPar = None
  197. new_device = select_device(par.get('device')) # 指定GPU
  198. names = par.get('labelnames')
  199. half = new_device.type != 'cpu'
  200. Detweights = par.get('detweights') # 升级后的检测模型
  201. if trtFlag_det:
  202. with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
  203. model = runtime.deserialize_cuda_engine(f.read())
  204. else:
  205. model = attempt_load(Detweights, map_location=new_device)
  206. if half:
  207. model.half()
  208. segPar = par.get('segPar')
  209. Segweights = par.get('segweights')
  210. if trtFlag_seg:
  211. with open(Segweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
  212. segmodel = runtime.deserialize_cuda_engine(f.read())
  213. else:
  214. segmodel = SegModel(nclass=par.get('seg_nclass'), weights=Segweights, device=new_device)
  215. conf_thres = par.get('postFile').get("conf_thres")
  216. iou_thres = par.get('postFile').get("iou_thres")
  217. # classes = par.get('postFile').get("classes")
  218. rainbows = par.get('postFile').get("rainbows")
  219. postFile = par.get('postFile')
  220. ovlap_thres_crossCategory = postFile.get('ovlap_thres_crossCategory')
  221. objectPar = {
  222. 'half': half,
  223. 'device': new_device,
  224. 'conf_thres': conf_thres,
  225. 'ovlap_thres_crossCategory': ovlap_thres_crossCategory,
  226. 'iou_thres': iou_thres,
  227. 'allowedList': [],
  228. 'slopeIndex': par.get('slopeIndex'),
  229. 'segRegionCnt': par.get('segRegionCnt'),
  230. 'trtFlag_det': trtFlag_det,
  231. 'trtFlag_seg': trtFlag_seg
  232. }
  233. """
  234. frame, model, segmodel, names, label_arraylist, rainbows, objectPar, font, segPar, mode, postPar,
  235. requestId
  236. """
  237. model_param = [None, model, segmodel, names, None, rainbows, objectPar, None, segPar, mode, postPar,
  238. requestId]
  239. self.model_conf = (modeType, allowedList, model_param)
  240. except Exception:
  241. logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  242. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  243. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  244. # 高速模型
  245. class HighWayModel:
  246. __slots__ = "model_conf"
  247. def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None):
  248. s = time.time()
  249. try:
  250. logger.info("########################加载高速模型########################, requestId:{}", requestId)
  251. trtFlag_det = True, # 检测模型是否采用TRT
  252. trtFlag_seg = True, # 分割模型是否采用TRT
  253. # 公共变量
  254. par = {
  255. 'device': str(device),
  256. 'gpu_name': gpu_name,
  257. 'labelnames': ["行人", "车辆", "纵向裂缝", "横向裂缝", "修补", "网状裂纹", "坑槽", "块状裂纹", "积水",
  258. "事故"],
  259. 'slopeIndex': [],
  260. 'seg_nclass': 3,
  261. 'segRegionCnt': 2,
  262. 'segPar': {
  263. 'modelSize': (640, 360),
  264. 'mean': (0.485, 0.456, 0.406),
  265. 'std': (0.229, 0.224, 0.225),
  266. 'predResize': True,
  267. 'numpy': False,
  268. 'RGB_convert_first': True
  269. },
  270. 'postPar': {
  271. 'label_csv': '../AIlib2/weights/conf/%s/class_dict.csv' % modeType.value[3],
  272. 'speedRoadArea': 16000,
  273. 'vehicleArea': 10,
  274. 'speedRoadVehicleAngleMin': 15,
  275. 'speedRoadVehicleAngleMax': 75,
  276. 'roundness': 0.7,
  277. 'cls': 9,
  278. 'vehicleFactor': 0.1
  279. },
  280. 'mode': 'highWay3.0',
  281. 'postFile': {
  282. "name": "post_process",
  283. "conf_thres": 0.25,
  284. "iou_thres": 0.25,
  285. "classes": 9,
  286. "rainbows": COLOR
  287. }
  288. }
  289. if trtFlag_det:
  290. par['detweights'] = "../AIlib2/weights/%s/yolov5_%s_fp16.engine" % (modeType.value[3], par['gpu_name'])
  291. else:
  292. par['detweights'] = "../AIlib2/weights/conf/%s/yolov5.pt" % modeType.value[3]
  293. if trtFlag_seg:
  294. par['segweights'] = '../AIlib2/weights/%s/stdc_360X640_%s_fp16.engine' % (
  295. modeType.value[3], par['gpu_name'])
  296. else:
  297. par['segweights'] = '../AIlib2/weights/conf/%s/stdc_360X640.pth' % modeType.value[3]
  298. mode = par.get('mode', 'others')
  299. postPar = par.get('postPar', None)
  300. new_device = select_device(par.get('device')) # 指定GPU
  301. names = par.get('labelnames')
  302. half = new_device.type != 'cpu'
  303. Detweights = par.get('detweights') # 升级后的检测模型
  304. if trtFlag_det:
  305. with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
  306. model = runtime.deserialize_cuda_engine(f.read())
  307. else:
  308. model = attempt_load(Detweights, map_location=new_device)
  309. if half:
  310. model.half()
  311. segPar = par.get('segPar')
  312. Segweights = par.get('segweights')
  313. if trtFlag_seg:
  314. with open(Segweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
  315. segmodel = runtime.deserialize_cuda_engine(f.read())
  316. else:
  317. seg_nclass = par.get('seg_nclass')
  318. segmodel = SegModel(nclass=seg_nclass, weights=Segweights, device=new_device)
  319. conf_thres = par.get('postFile').get("conf_thres")
  320. iou_thres = par.get('postFile').get("iou_thres")
  321. # classes = par.get('postFile').get("classes")
  322. rainbows = par.get('postFile').get("rainbows")
  323. objectPar = {
  324. 'half': half,
  325. 'device': new_device,
  326. 'conf_thres': conf_thres,
  327. 'iou_thres': iou_thres,
  328. 'allowedList': [],
  329. 'slopeIndex': par.get('slopeIndex'),
  330. 'segRegionCnt': par.get('segRegionCnt'),
  331. 'trtFlag_det': trtFlag_det,
  332. 'trtFlag_seg': trtFlag_seg
  333. }
  334. """
  335. frame, model, segmodel, names, label_arraylist, rainbows, objectPar, font, segPar, mode, postPar, requestId
  336. """
  337. model_param = [None, model, segmodel, names, None, rainbows, objectPar, None, segPar, mode, postPar,
  338. requestId]
  339. self.model_conf = (modeType, allowedList, model_param)
  340. # self.num = 0
  341. except Exception:
  342. logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  343. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  344. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  345. logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId)
  346. '''
  347. #输入参数
  348. # im0s---原始图像列表
  349. # model---检测模型,segmodel---分割模型(如若没有用到,则为None)
  350. #输出:两个元素(列表,字符)构成的元组,[im0s[0],im0,det_xywh,iframe],strout
  351. # [im0s[0],im0,det_xywh,iframe]中,
  352. # im0s[0]--原始图像,im0--AI处理后的图像,iframe--帧号/暂时不需用到。
  353. # det_xywh--检测结果,是一个列表。
  354. # 其中每一个元素表示一个目标构成如:[float(cls_c), xc,yc,w,h, float(conf_c)]
  355. # #cls_c--类别,如0,1,2,3; xc,yc,w,h--中心点坐标及宽;conf_c--得分, 取值范围在0-1之间
  356. # #strout---统计AI处理个环节的时间
  357. '''
  358. def high_process(param):
  359. try:
  360. # frame, model, segmodel, names, label_arraylist, rainbows, objectPar, font, segPar, mode, postPar, requestId
  361. return AI_process([param[0]], param[1], param[2], param[3], param[4], param[5], objectPar=param[6],
  362. font=param[7], segPar=param[8], mode=param[9], postPar=loads(dumps(param[10])))
  363. except ServiceException as s:
  364. raise s
  365. except Exception:
  366. # self.num += 1
  367. # cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame)
  368. logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), param[11])
  369. raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
  370. ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
  371. # 森林模型
  372. class ForestModel:
  373. __slots__ = "model_conf"
  374. def __init__(self, device1, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None):
  375. s = time.time()
  376. try:
  377. logger.info("########################加载森林模型########################, requestId:{}", requestId)
  378. trtFlag_det = True, # 检测模型是否采用TRT
  379. # trtFlag_seg = False, # 分割模型是否采用TRT
  380. # 公共变量
  381. par = {
  382. 'device': str(device1),
  383. 'gpu_name': gpu_name,
  384. 'labelnames': ("林斑", "病死树", "行人", "火焰", "烟雾"),
  385. 'seg_nclass': 2,
  386. 'segRegionCnt': 0,
  387. 'slopeIndex': [],
  388. 'segPar': None,
  389. 'postFile': {
  390. "name": "post_process",
  391. "conf_thres": 0.25,
  392. "iou_thres": 0.45,
  393. "classes": 5,
  394. "rainbows": COLOR
  395. },
  396. 'segweights': None
  397. }
  398. if trtFlag_det:
  399. par['detweights'] = "../AIlib2/weights/%s/yolov5_%s_fp16.engine" % (modeType.value[3], par['gpu_name'])
  400. else:
  401. par['detweights'] = "../AIlib2/weights/conf/%s/yolov5.pt" % modeType.value[3]
  402. device = select_device(par.get('device'))
  403. names = par.get('labelnames')
  404. half = device.type != 'cpu'
  405. Detweights = par.get('detweights')
  406. if trtFlag_det:
  407. with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
  408. model = runtime.deserialize_cuda_engine(f.read())
  409. else:
  410. model = attempt_load(Detweights, map_location=device)
  411. if half:
  412. model.half()
  413. segmodel = None
  414. conf_thres = par.get('postFile').get("conf_thres")
  415. iou_thres = par.get('postFile').get("iou_thres")
  416. # classes = par.get('postFile').get("classes")
  417. rainbows = par.get('postFile').get("rainbows")
  418. """
  419. frame, model, segmodel, names, label_arraylist, rainbows, half, device, conf_thres, iou_thres,
  420. allowedList, digitFont, trtFlag_det, requestId
  421. """
  422. model_param = [None, model, segmodel, names, None, rainbows, half, device, conf_thres, iou_thres, [], None,
  423. trtFlag_det, requestId]
  424. self.model_conf = (modeType, allowedList, model_param)
  425. except Exception:
  426. logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  427. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  428. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  429. logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId)
  430. def forest_process(param):
  431. try:
  432. """
  433. # frame, model, segmodel, names, label_arraylist, rainbows, half, device, conf_thres, iou_thres, allowedList,
  434. digitFont, trtFlag_det, requestId
  435. """
  436. return AI_process_forest([param[0]], param[1], param[2], param[3], param[4], param[5], param[6], param[7],
  437. param[8], param[9], param[10], font=param[11], trtFlag_det=param[12])
  438. except ServiceException as s:
  439. raise s
  440. except Exception:
  441. # self.num += 1
  442. # cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame)
  443. logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), param[13])
  444. raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
  445. ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
  446. # 车辆模型
  447. class VehicleModel:
  448. __slots__ = "model_conf"
  449. def __init__(self, device1, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None):
  450. s = time.time()
  451. try:
  452. logger.info("########################加载车辆模型########################, requestId:{}", requestId)
  453. trtFlag_det = True, # 检测模型是否采用TRT
  454. # trtFlag_seg = False, # 分割模型是否采用TRT
  455. # 公共变量
  456. par = {
  457. 'device': str(device1),
  458. 'gpu_name': gpu_name,
  459. 'labelnames': ["车辆"],
  460. 'seg_nclass': 2, # 分割模型类别数目,默认2类
  461. 'segRegionCnt': 0,
  462. 'slopeIndex': [],
  463. 'segPar': None,
  464. 'postFile': {
  465. "name": "post_process",
  466. "conf_thres": 0.25,
  467. "iou_thres": 0.45,
  468. "classes": 5,
  469. "rainbows": COLOR
  470. },
  471. 'segweights': None
  472. }
  473. if trtFlag_det:
  474. par['detweights'] = "../AIlib2/weights/%s/yolov5_%s_fp16.engine" % (modeType.value[3], par['gpu_name'])
  475. else:
  476. par['detweights'] = "../AIlib2/weights/conf/%s/yolov5.pt" % modeType.value[3]
  477. device = select_device(par.get('device'))
  478. names = par.get('labelnames')
  479. half = device.type != 'cpu'
  480. Detweights = par.get('detweights')
  481. if trtFlag_det:
  482. with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
  483. model = runtime.deserialize_cuda_engine(f.read())
  484. else:
  485. model = attempt_load(Detweights, map_location=device)
  486. if half:
  487. model.half()
  488. segmodel = None
  489. conf_thres = par.get('postFile').get("conf_thres")
  490. iou_thres = par.get('postFile').get("iou_thres")
  491. # classes = par.get('postFile').get("classes")
  492. rainbows = par.get('postFile').get("rainbows")
  493. """
  494. frame, model, segmodel, names, label_arraylist, rainbows, half, device, conf_thres, iou_thres,
  495. allowedList, digitFont, trtFlag_det, requestId
  496. """
  497. model_param = [None, model, segmodel, names, None, rainbows, half, device, conf_thres, iou_thres, [], None,
  498. trtFlag_det, requestId]
  499. self.model_conf = (modeType, allowedList, model_param)
  500. except Exception:
  501. logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  502. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  503. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  504. logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId)
  505. def get_label_arraylist_1(width, names, rainbows):
  506. fontsize = int(width / 1920 * 40)
  507. line_thickness = 1
  508. boxLine_thickness = 1
  509. waterLineWidth = 1
  510. if width >= 960:
  511. line_thickness = int(round(width / 1920 * 3) - 1)
  512. boxLine_thickness = int(round(width / 1920 * 3))
  513. waterLineWidth = int(round(width / 1920 * 3))
  514. numFontSize = float(format(width / 1920 * 1.1, '.1f')) # 文字大小
  515. digitFont = {'line_thickness': line_thickness,
  516. 'boxLine_thickness': boxLine_thickness,
  517. 'fontSize': numFontSize,
  518. 'segLineShow': False,
  519. 'waterLineColor': (0, 255, 255),
  520. 'waterLineWidth': waterLineWidth}
  521. label_arraylist = get_label_arrays(names, rainbows, outfontsize=fontsize, fontpath=FONTPATH)
  522. return digitFont, label_arraylist
  523. # 行人模型
  524. class PedestrianModel:
  525. __slots__ = "model_conf"
  526. def __init__(self, device1, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None):
  527. s = time.time()
  528. try:
  529. logger.info("########################加载行人模型########################, requestId:{}", requestId)
  530. trtFlag_det = True, # 检测模型是否采用TRT
  531. # trtFlag_seg = False, # 分割模型是否采用TRT
  532. # 公共变量
  533. par = {
  534. 'device': str(device1),
  535. 'gpu_name': gpu_name,
  536. 'labelnames': ["行人"],
  537. 'seg_nclass': 2,
  538. 'segRegionCnt': 0,
  539. 'slopeIndex': [],
  540. 'segPar': None,
  541. 'postFile': {
  542. "name": "post_process",
  543. "conf_thres": 0.25,
  544. "iou_thres": 0.45,
  545. "classes": 5,
  546. "rainbows": COLOR
  547. },
  548. 'segweights': None
  549. }
  550. if trtFlag_det:
  551. par['detweights'] = "../AIlib2/weights/%s/yolov5_%s_fp16.engine" % (modeType.value[3], par['gpu_name'])
  552. else:
  553. par['detweights'] = "../AIlib2/weights/conf/%s/yolov5.pt" % modeType.value[3]
  554. device = select_device(par.get('device')) # 指定GPU
  555. names = par.get('labelnames')
  556. half = device.type != 'cpu'
  557. Detweights = par.get('detweights') # 升级后的检测模型
  558. if trtFlag_det:
  559. with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
  560. model = runtime.deserialize_cuda_engine(f.read())
  561. else:
  562. model = attempt_load(Detweights, map_location=device)
  563. if half:
  564. model.half()
  565. segmodel = None
  566. conf_thres = par.get('postFile').get("conf_thres")
  567. iou_thres = par.get('postFile').get("iou_thres")
  568. # classes = par.get('postFile').get("classes")
  569. rainbows = par.get('postFile').get("rainbows")
  570. """
  571. frame, model, segmodel, names, label_arraylist, rainbows, half, device, conf_thres, iou_thres,
  572. allowedList, digitFont, trtFlag_det, requestId
  573. """
  574. model_param = [None, model, segmodel, names, None, rainbows, half, device, conf_thres, iou_thres, [], None,
  575. trtFlag_det, requestId]
  576. self.model_conf = (modeType, allowedList, model_param)
  577. except Exception:
  578. logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  579. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  580. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  581. logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId)
  582. # 烟火模型
  583. class SmogfireModel:
  584. __slots__ = "model_conf"
  585. def __init__(self, device1, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None):
  586. s = time.time()
  587. try:
  588. logger.info("########################加载烟火模型########################, requestId:{}", requestId)
  589. trtFlag_det = True, # 检测模型是否采用TRT
  590. # trtFlag_seg = False, # 分割模型是否采用TRT
  591. # 公共变量
  592. par = {
  593. 'device': str(device1),
  594. 'gpu_name': gpu_name,
  595. 'labelnames': ["烟雾", "火焰"],
  596. 'seg_nclass': 2, # 分割模型类别数目,默认2类
  597. 'segRegionCnt': 0,
  598. 'slopeIndex': [],
  599. 'segPar': None,
  600. 'postFile': {
  601. "name": "post_process",
  602. "conf_thres": 0.25,
  603. "iou_thres": 0.45,
  604. "classes": 5,
  605. "rainbows": COLOR
  606. },
  607. 'segweights': None
  608. }
  609. if trtFlag_det:
  610. par['detweights'] = "../AIlib2/weights/%s/yolov5_%s_fp16.engine" % (modeType.value[3], par['gpu_name'])
  611. else:
  612. par['detweights'] = "../AIlib2/weights/conf/%s/yolov5.pt" % modeType.value[3]
  613. device = select_device(par.get('device')) # 指定GPU
  614. names = par.get('labelnames')
  615. half = device.type != 'cpu'
  616. Detweights = par.get('detweights') # 升级后的检测模型
  617. if trtFlag_det:
  618. with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
  619. model = runtime.deserialize_cuda_engine(f.read())
  620. else:
  621. model = attempt_load(Detweights, map_location=device)
  622. if half:
  623. model.half()
  624. segmodel = None
  625. conf_thres = par.get('postFile').get("conf_thres")
  626. iou_thres = par.get('postFile').get("iou_thres")
  627. # classes = par.get('postFile').get("classes")
  628. rainbows = par.get('postFile').get("rainbows")
  629. """
  630. frame, model, segmodel, names, label_arraylist, rainbows, half, device, conf_thres, iou_thres,
  631. allowedList, digitFont, trtFlag_det, requestId
  632. """
  633. model_param = [None, model, segmodel, names, None, rainbows, half, device, conf_thres, iou_thres, [], None,
  634. trtFlag_det, requestId]
  635. self.model_conf = (modeType, allowedList, model_param)
  636. except Exception:
  637. logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  638. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  639. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  640. logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId)
  641. # 钓鱼模型
  642. class AnglerSwimmerModel:
  643. __slots__ = "model_conf"
  644. def __init__(self, device1, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None):
  645. s = time.time()
  646. try:
  647. logger.info("########################加载钓鱼模型########################, requestId:{}", requestId)
  648. trtFlag_det = True, # 检测模型是否采用TRT
  649. # trtFlag_seg = False, # 分割模型是否采用TRT
  650. # 公共变量
  651. par = {
  652. 'device': str(device1),
  653. 'gpu_name': gpu_name,
  654. 'labelnames': ["钓鱼", "游泳"],
  655. 'seg_nclass': 2, # 分割模型类别数目,默认2类
  656. 'segRegionCnt': 0,
  657. 'slopeIndex': [],
  658. 'segPar': None,
  659. 'postFile': {
  660. "name": "post_process",
  661. "conf_thres": 0.25,
  662. "iou_thres": 0.45,
  663. "classes": 5,
  664. "rainbows": COLOR
  665. },
  666. 'segweights': None
  667. }
  668. if trtFlag_det:
  669. par['detweights'] = "../AIlib2/weights/%s/yolov5_%s_fp16.engine" % (modeType.value[3], par['gpu_name'])
  670. else:
  671. par['detweights'] = "../AIlib2/weights/conf/%s/yolov5.pt" % modeType.value[3]
  672. device = select_device(par.get('device')) # 指定GPU
  673. names = par.get('labelnames')
  674. half = device.type != 'cpu'
  675. Detweights = par.get('detweights') # 升级后的检测模型
  676. if trtFlag_det:
  677. with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
  678. model = runtime.deserialize_cuda_engine(f.read())
  679. else:
  680. model = attempt_load(Detweights, map_location=device)
  681. if half:
  682. model.half()
  683. segmodel = None
  684. conf_thres = par.get('postFile').get("conf_thres")
  685. iou_thres = par.get('postFile').get("iou_thres")
  686. # classes = par.get('postFile').get("classes")
  687. rainbows = par.get('postFile').get("rainbows")
  688. """
  689. frame, model, segmodel, names, label_arraylist, rainbows, half, device, conf_thres, iou_thres,
  690. allowedList, digitFont, trtFlag_det, requestId
  691. """
  692. model_param = [None, model, segmodel, names, None, rainbows, half, device, conf_thres, iou_thres, [], None,
  693. trtFlag_det, requestId]
  694. self.model_conf = (modeType, allowedList, model_param)
  695. except Exception:
  696. logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  697. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  698. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  699. logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId)
  700. # 乡村模型
  701. class CountryRoadModel:
  702. __slots__ = "model_conf"
  703. def __init__(self, device1, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None):
  704. s = time.time()
  705. try:
  706. logger.info("########################加载乡村模型########################, requestId:{}", requestId)
  707. trtFlag_det = True, # 检测模型是否采用TRT
  708. # trtFlag_seg = False, # 分割模型是否采用TRT
  709. # 公共变量
  710. par = {
  711. 'device': str(device1),
  712. 'gpu_name': gpu_name,
  713. 'labelnames': ["违法种植"],
  714. 'seg_nclass': 2, # 分割模型类别数目,默认2类
  715. 'segRegionCnt': 0,
  716. 'slopeIndex': [],
  717. 'segPar': None,
  718. 'postFile': {
  719. "name": "post_process",
  720. "conf_thres": 0.25,
  721. "iou_thres": 0.45,
  722. "classes": 5,
  723. "rainbows": COLOR
  724. },
  725. 'segweights': None
  726. }
  727. if trtFlag_det:
  728. par['detweights'] = "../AIlib2/weights/%s/yolov5_%s_fp16.engine" % (modeType.value[3], par['gpu_name'])
  729. else:
  730. par['detweights'] = "../AIlib2/weights/conf/%s/yolov5.pt" % modeType.value[3]
  731. device = select_device(par.get('device')) # 指定GPU
  732. names = par.get('labelnames')
  733. half = device.type != 'cpu'
  734. Detweights = par.get('detweights') # 升级后的检测模型
  735. if trtFlag_det:
  736. with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
  737. model = runtime.deserialize_cuda_engine(f.read())
  738. else:
  739. model = attempt_load(Detweights, map_location=device)
  740. if half:
  741. model.half()
  742. segmodel = None
  743. conf_thres = par.get('postFile').get("conf_thres")
  744. iou_thres = par.get('postFile').get("iou_thres")
  745. # classes = par.get('postFile').get("classes")
  746. rainbows = par.get('postFile').get("rainbows")
  747. """
  748. frame, model, segmodel, names, label_arraylist, rainbows, half, device, conf_thres, iou_thres,
  749. allowedList, digitFont, trtFlag_det, requestId
  750. """
  751. model_param = [None, model, segmodel, names, None, rainbows, half, device, conf_thres, iou_thres, [], None,
  752. trtFlag_det, requestId]
  753. self.model_conf = (modeType, allowedList, model_param)
  754. except Exception:
  755. logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  756. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  757. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  758. logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId)
  759. # 航道模型
  760. class ChannelEmergencyModel:
  761. __slots__ = "model_conf"
  762. def __init__(self, device1, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None):
  763. s = time.time()
  764. try:
  765. logger.info("########################加载航道模型########################, requestId:{}", requestId)
  766. trtFlag_det = True, # 检测模型是否采用TRT
  767. # trtFlag_seg = False, # 分割模型是否采用TRT
  768. # 公共变量
  769. par = {
  770. 'device': str(device1),
  771. 'gpu_name': gpu_name,
  772. 'labelnames': ["人"],
  773. 'seg_nclass': 2, # 分割模型类别数目,默认2类
  774. 'segRegionCnt': 0,
  775. 'slopeIndex': [],
  776. 'segPar': None,
  777. 'postFile': {
  778. "name": "post_process",
  779. "conf_thres": 0.25,
  780. "iou_thres": 0.45,
  781. "classes": 5,
  782. "rainbows": COLOR
  783. },
  784. 'segweights': None
  785. }
  786. if trtFlag_det:
  787. par['detweights'] = "../AIlib2/weights/%s/yolov5_%s_fp16.engine" % (modeType.value[3], par['gpu_name'])
  788. else:
  789. par['detweights'] = "../AIlib2/weights/conf/%s/yolov5.pt" % modeType.value[3]
  790. device = select_device(par.get('device')) # 指定GPU
  791. names = par.get('labelnames')
  792. half = device.type != 'cpu'
  793. Detweights = par.get('detweights') # 升级后的检测模型
  794. if trtFlag_det:
  795. with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
  796. model = runtime.deserialize_cuda_engine(f.read())
  797. else:
  798. model = attempt_load(Detweights, map_location=device)
  799. if half:
  800. model.half()
  801. segmodel = None
  802. conf_thres = par.get('postFile').get("conf_thres")
  803. iou_thres = par.get('postFile').get("iou_thres")
  804. # classes = par.get('postFile').get("classes")
  805. rainbows = par.get('postFile').get("rainbows")
  806. """
  807. frame, model, segmodel, names, label_arraylist, rainbows, half, device, conf_thres, iou_thres,
  808. allowedList, digitFont, trtFlag_det, requestId
  809. """
  810. model_param = [None, model, segmodel, names, None, rainbows, half, device, conf_thres, iou_thres, [], None,
  811. trtFlag_det, requestId]
  812. self.model_conf = (modeType, allowedList, model_param)
  813. except Exception:
  814. logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  815. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  816. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  817. logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId)
  818. # 船只模型
  819. class ShipModel:
  820. __slots__ = "model_conf"
  821. def __init__(self, device1, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None):
  822. s = time.time()
  823. try:
  824. logger.info("########################加载船只模型########################, requestId:{}", requestId)
  825. # 公共变量
  826. par = {
  827. 'model_size': (608, 608),
  828. 'K': 100,
  829. 'conf_thresh': 0.18,
  830. 'device': 'cuda:%s' % str(device1),
  831. 'down_ratio': 4,
  832. 'num_classes': 15,
  833. 'weights': '../AIlib2/weights/%s/obb_608X608_%s_fp16.engine' % (modeType.value[3], gpu_name),
  834. 'dataset': 'dota',
  835. 'half': False,
  836. 'mean': (0.5, 0.5, 0.5),
  837. 'std': (1, 1, 1),
  838. 'heads': {'hm': None, 'wh': 10, 'reg': 2, 'cls_theta': 1},
  839. 'decoder': None,
  840. 'test_flag': True,
  841. "rainbows": COLOR,
  842. 'postFile': {
  843. "name": "post_process",
  844. "conf_thres": 0.25,
  845. "iou_thres": 0.45,
  846. "classes": 5,
  847. "rainbows": COLOR
  848. },
  849. 'drawBox': False,
  850. 'label_array': None,
  851. 'labelnames': ("0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "船只"),
  852. }
  853. model, decoder2 = load_model_decoder_OBB(par)
  854. par['decoder'] = decoder2
  855. names = par.get('labelnames')
  856. # conf_thres = par.get('postFile').get("conf_thres")
  857. # iou_thres = par.get('postFile').get("iou_thres")
  858. # classes = par.get('postFile').get("classes")
  859. rainbows = par.get('postFile').get("rainbows")
  860. """
  861. [frame, par, model, requestId]
  862. """
  863. model_param = [None, par, model, names, None, rainbows, requestId]
  864. self.model_conf = (modeType, allowedList, model_param)
  865. except Exception:
  866. logger.exception("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  867. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  868. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  869. logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId)
  870. def get_label_arraylist_2(width, par, names, rainbows):
  871. fontsize = int(width / 1920 * 40)
  872. line_thickness = 1
  873. boxLine_thickness = 1
  874. if width >= 960:
  875. line_thickness = int(round(width / 1920 * 3) - 1)
  876. boxLine_thickness = int(round(width / 1920 * 3))
  877. numFontSize = float(format(width / 1920 * 1.1, '.1f')) # 文字大小
  878. par["digitWordFont"] = {
  879. 'line_thickness': line_thickness,
  880. 'boxLine_thickness': boxLine_thickness,
  881. 'wordSize': fontsize,
  882. 'fontSize': numFontSize,
  883. 'label_location': 'leftTop'
  884. }
  885. par["label_array"] = get_label_arrays(names, rainbows, outfontsize=fontsize, fontpath=FONTPATH)
  886. def obb_process(param):
  887. try:
  888. # [frame, par, model, requestId]
  889. return OBB_infer(param[2], param[0], param[1])
  890. except ServiceException as s:
  891. raise s
  892. except Exception:
  893. # self.num += 1
  894. # cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame)
  895. logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), param[3])
  896. raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
  897. ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
  898. # 车牌分割模型、健康码、行程码分割模型
  899. class IMModel:
  900. __slots__ = "model_conf"
  901. def __init__(self, device1, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None):
  902. try:
  903. logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
  904. requestId)
  905. img_type = 'code'
  906. if ModelType.PLATE_MODEL == modeType:
  907. img_type = 'plate'
  908. par = {
  909. 'code': {'weights': '../AIlib2/weights/conf/jkm/health_yolov5s_v3.jit', 'img_type': 'code', 'nc': 10},
  910. 'plate': {'weights': '../AIlib2/weights/conf/jkm/plate_yolov5s_v3.jit', 'img_type': 'plate', 'nc': 1},
  911. 'conf_thres': 0.4,
  912. 'iou_thres': 0.45,
  913. 'device': 'cuda:%s' % device1,
  914. 'plate_dilate': (0.5, 0.3)
  915. }
  916. device = torch.device(par['device'])
  917. model = torch.jit.load(par[img_type]['weights'])
  918. """
  919. [frame, device, model, par, img_type, requestId]
  920. """
  921. model_param = [None, device, model, par, img_type, requestId]
  922. self.model_conf = (modeType, allowedList, model_param)
  923. except Exception:
  924. logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  925. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  926. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  927. def im_process(param):
  928. try:
  929. # [frame, device, model, par, img_type, requestId]
  930. img, padInfos = pre_process(param[0], param[1])
  931. pred = param[2](img)
  932. boxes = post_process(pred, padInfos, param[1], conf_thres=param[3]['conf_thres'],
  933. iou_thres=param[3]['iou_thres'], nc=param[3][param[4]]['nc']) # 后处理
  934. dataBack = get_return_data(param[0], boxes, modelType=param[4], plate_dilate=param[3]['plate_dilate'])
  935. return dataBack
  936. except ServiceException as s:
  937. raise s
  938. except Exception:
  939. logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), param[5])
  940. raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
  941. ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
  942. # class OCR_Model:
  943. #
  944. # def __init__(self, device=None, logger=None, requestId=None):
  945. # try:
  946. # logger.info("######################## 加载OCR模型 ########################, requestId:{}", requestId)
  947. # self.__requestId = requestId
  948. # self.__logger = logger
  949. # self.__trtFlag_det = True
  950. # if self.__trtFlag_det:
  951. # gpu = get_all_gpu_ids()[int(device)]
  952. # if '3090' in gpu.name:
  953. # TRTfile = "../AIlib2/weights/ocr_en/english_g2_3090_fp16_448X32.engine"
  954. # elif '2080' in gpu.name:
  955. # TRTfile = "../AIlib2/weights/ocr_en/english_2080Ti_g2_h64_fp16.engine"
  956. # elif '4090' in gpu.name:
  957. # TRTfile = "../AIlib2/weights/ocr_en/english_g2_4090_fp16_448X32.engine"
  958. # elif 'A10' in gpu.name:
  959. # TRTfile = "../AIlib2/weights/ocr_en/english_g2_A10_fp16_448X32.engine"
  960. # else:
  961. # raise Exception("未匹配到该GPU名称的模型, GPU: " + gpu.name)
  962. # else:
  963. # TRTfile = "../AIlib2/weights/conf/ocr_en/english_g2.pth"
  964. # par = {
  965. # 'TRTfile': TRTfile,
  966. # 'device': 'cuda: %s' % device,
  967. # 'dict_list': {'en': '../AIlib2/weights/conf/ocr_en/en.txt'},
  968. # 'char_file': '../AIlib2/weights/conf/ocr_en/en_character.csv',
  969. # 'imgH': 100,
  970. # 'imgW': 400
  971. # }
  972. # TRTfile = par['TRTfile']
  973. # self.__device = par['device']
  974. # dict_list = par['dict_list']
  975. # char_file = par['char_file']
  976. # imgH = par['imgH']
  977. # imgW = par['imgW']
  978. # logger = trt.Logger(trt.Logger.ERROR)
  979. # with open(TRTfile, "rb") as f, trt.Runtime(logger) as runtime:
  980. # self.engine = runtime.deserialize_cuda_engine(f.read()) # 输入trt本地文件,返回ICudaEngine对象
  981. # print('#####load TRT file:', TRTfile, 'success #####')
  982. # self.context = self.engine.create_execution_context()
  983. #
  984. # with open(char_file, 'r') as fp:
  985. # characters = fp.readlines()[0].strip()
  986. # self.converter = CTCLabelConverter(characters, {}, dict_list)
  987. # self.AlignCollate_normal = AlignCollate(imgH=imgH, imgW=imgW, keep_ratio_with_pad=True)
  988. # except Exception as ee:
  989. # self.__logger.exception("模型加载异常:{}, requestId:{}", ee, requestId)
  990. # raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  991. # ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  992. #
  993. # def process(self, frame):
  994. # try:
  995. # gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
  996. # par = [gray_frame, self.engine, self.context, self.converter, self.AlignCollate_normal, self.__device]
  997. # return ocr_process(par)
  998. # except ServiceException as s:
  999. # raise s
  1000. # except Exception as ee:
  1001. # self.__logger.exception("ocr坐标识别异常:{}, requestId:{}", ee, self.__requestId)
  1002. # raise ServiceException(ExceptionType.COORDINATE_ACQUISITION_FAILED.value[0],
  1003. # ExceptionType.COORDINATE_ACQUISITION_FAILED.value[1])
  1004. # 百度AI图片识别模型
  1005. class BaiduAiImageModel:
  1006. __slots__ = "model_conf"
  1007. def __init__(self, device=None, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None):
  1008. try:
  1009. logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
  1010. requestId)
  1011. aipBodyAnalysisClient = AipBodyAnalysisClient(base_dir)
  1012. aipImageClassifyClient = AipImageClassifyClient(base_dir)
  1013. rainbows = COLOR
  1014. """
  1015. [target, url, aipImageClassifyClient, aipBodyAnalysisClient, requestId]
  1016. """
  1017. model_param = [None, None, aipImageClassifyClient, aipBodyAnalysisClient, requestId]
  1018. self.model_conf = (modeType, allowedList, model_param, rainbows)
  1019. except Exception:
  1020. logger.exception("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  1021. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  1022. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  1023. def baidu_process(param):
  1024. try:
  1025. # [target, url, aipImageClassifyClient, aipBodyAnalysisClient, requestId]
  1026. baiduEnum = BAIDU_MODEL_TARGET_CONFIG.get(param[0])
  1027. if baiduEnum is None:
  1028. raise ServiceException(ExceptionType.DETECTION_TARGET_TYPES_ARE_NOT_SUPPORTED.value[0],
  1029. ExceptionType.DETECTION_TARGET_TYPES_ARE_NOT_SUPPORTED.value[1]
  1030. + " target: " + param[0])
  1031. return baiduEnum.value[2](param[2], param[3], param[1], param[4])
  1032. except ServiceException as s:
  1033. raise s
  1034. except Exception:
  1035. logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), param[4])
  1036. raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
  1037. ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
  1038. def river_label(width, model_param):
  1039. names = model_param[3]
  1040. rainbows = model_param[5]
  1041. digitFont, label_arraylist = get_label_arraylist(width, names, rainbows)
  1042. """
  1043. frame, model, segmodel, names, label_arraylist, rainbows, objectPar, font, segPar, mode, postPar, requestId
  1044. """
  1045. model_param[4] = label_arraylist
  1046. model_param[7] = digitFont
  1047. def river2_label(width, model_param):
  1048. names = model_param[3]
  1049. rainbows = model_param[5]
  1050. digitFont, label_arraylist = get_label_arraylist(width, names, rainbows)
  1051. """
  1052. frame, model, segmodel, names, label_arraylist, rainbows, objectPar, font, segPar, mode, postPar, requestId
  1053. """
  1054. model_param[4] = label_arraylist
  1055. model_param[7] = digitFont
  1056. def high_label(width, model_param):
  1057. names = model_param[3]
  1058. rainbows = model_param[5]
  1059. digitFont, label_arraylist = get_label_arraylist(width, names, rainbows)
  1060. """
  1061. frame, model, segmodel, names, label_arraylist, rainbows, objectPar, font, segPar, mode, postPar, requestId
  1062. """
  1063. model_param[4] = label_arraylist
  1064. model_param[7] = digitFont
  1065. def forest_label(width, model_param):
  1066. names = model_param[3]
  1067. rainbows = model_param[5]
  1068. digitFont, label_arraylist = get_label_arraylist(width, names, rainbows)
  1069. """
  1070. frame, model, segmodel, names, label_arraylist, rainbows, half, device, conf_thres, iou_thres,
  1071. allowedList, digitFont, trtFlag_det, requestId
  1072. """
  1073. model_param[4] = label_arraylist
  1074. model_param[11] = digitFont
  1075. def vehicle_label(width, model_param):
  1076. names = model_param[3]
  1077. rainbows = model_param[5]
  1078. digitFont, label_arraylist = get_label_arraylist_1(width, names, rainbows)
  1079. """
  1080. frame, model, segmodel, names, label_arraylist, rainbows, half, device, conf_thres, iou_thres,
  1081. allowedList, digitFont, trtFlag_det, requestId
  1082. """
  1083. model_param[4] = label_arraylist
  1084. model_param[11] = digitFont
  1085. def pedestrian_label(width, model_param):
  1086. names = model_param[3]
  1087. rainbows = model_param[5]
  1088. digitFont, label_arraylist = get_label_arraylist_1(width, names, rainbows)
  1089. """
  1090. frame, model, segmodel, names, label_arraylist, rainbows, half, device, conf_thres, iou_thres,
  1091. allowedList, digitFont, trtFlag_det, requestId
  1092. """
  1093. model_param[4] = label_arraylist
  1094. model_param[11] = digitFont
  1095. def channel_emergency_label(width, model_param):
  1096. names = model_param[3]
  1097. rainbows = model_param[5]
  1098. digitFont, label_arraylist = get_label_arraylist_1(width, names, rainbows)
  1099. """
  1100. frame, model, segmodel, names, label_arraylist, rainbows, half, device, conf_thres, iou_thres,
  1101. allowedList, digitFont, trtFlag_det, requestId
  1102. """
  1103. model_param[4] = label_arraylist
  1104. model_param[11] = digitFont
  1105. def ship_label(width, model_param):
  1106. # [None, par, model, names, None, rainbows, requestId]
  1107. par = model_param[1]
  1108. names = model_param[3]
  1109. rainbows = model_param[5]
  1110. get_label_arraylist_2(width, par, names, rainbows)
  1111. model_param[4] = par["label_array"]
  1112. def countryroad_label(width, model_param):
  1113. names = model_param[3]
  1114. rainbows = model_param[5]
  1115. digitFont, label_arraylist = get_label_arraylist_1(width, names, rainbows)
  1116. """
  1117. frame, model, segmodel, names, label_arraylist, rainbows, half, device, conf_thres, iou_thres,
  1118. allowedList, digitFont, trtFlag_det, requestId
  1119. """
  1120. model_param[4] = label_arraylist
  1121. model_param[11] = digitFont
  1122. def anglerswimmer_label(width, model_param):
  1123. names = model_param[3]
  1124. rainbows = model_param[5]
  1125. digitFont, label_arraylist = get_label_arraylist_1(width, names, rainbows)
  1126. """
  1127. frame, model, segmodel, names, label_arraylist, rainbows, half, device, conf_thres, iou_thres,
  1128. allowedList, digitFont, trtFlag_det, requestId
  1129. """
  1130. model_param[4] = label_arraylist
  1131. model_param[11] = digitFont
  1132. def smogfire_label(width, model_param):
  1133. names = model_param[3]
  1134. rainbows = model_param[5]
  1135. digitFont, label_arraylist = get_label_arraylist_1(width, names, rainbows)
  1136. """
  1137. frame, model, segmodel, names, label_arraylist, rainbows, half, device, conf_thres, iou_thres,
  1138. allowedList, digitFont, trtFlag_det, requestId
  1139. """
  1140. model_param[4] = label_arraylist
  1141. model_param[11] = digitFont
  1142. MODEL_CONFIG = {
  1143. # 加载河道模型
  1144. ModelType.WATER_SURFACE_MODEL.value[1]: (
  1145. lambda x, y, r, t, z : RiverModel(x, y, r, ModelType.WATER_SURFACE_MODEL, t, z),
  1146. ModelType.WATER_SURFACE_MODEL,
  1147. lambda x, y: river_label(x, y),
  1148. lambda x: model_process(x)
  1149. ),
  1150. # 加载森林模型
  1151. ModelType.FOREST_FARM_MODEL.value[1]: (
  1152. lambda x, y, r, t, z: ForestModel(x, y, r, ModelType.FOREST_FARM_MODEL, t, z),
  1153. ModelType.FOREST_FARM_MODEL,
  1154. lambda x, y: forest_label(x, y),
  1155. lambda x: forest_process(x)
  1156. ),
  1157. # 加载交通模型
  1158. ModelType.TRAFFIC_FARM_MODEL.value[1]: (
  1159. lambda x, y, r, t, z: HighWayModel(x, y, r, ModelType.TRAFFIC_FARM_MODEL, t, z),
  1160. ModelType.TRAFFIC_FARM_MODEL,
  1161. lambda x, y: high_label(x, y),
  1162. lambda x: high_process(x)
  1163. ),
  1164. # 加载防疫模型
  1165. ModelType.EPIDEMIC_PREVENTION_MODEL.value[1]: (
  1166. lambda x, y, r, t, z: IMModel(x, y, r, ModelType.EPIDEMIC_PREVENTION_MODEL, t, z),
  1167. ModelType.EPIDEMIC_PREVENTION_MODEL,
  1168. None,
  1169. lambda x: im_process(x)),
  1170. # 加载车牌模型
  1171. ModelType.PLATE_MODEL.value[1]: (
  1172. lambda x, y, r, t, z: IMModel(x, y, r, ModelType.PLATE_MODEL, t, z),
  1173. ModelType.PLATE_MODEL,
  1174. None,
  1175. lambda x: im_process(x)),
  1176. # 加载车辆模型
  1177. ModelType.VEHICLE_MODEL.value[1]: (
  1178. lambda x, y, r, t, z: VehicleModel(x, y, r, ModelType.VEHICLE_MODEL, t, z),
  1179. ModelType.VEHICLE_MODEL,
  1180. lambda x, y: vehicle_label(x, y),
  1181. lambda x: forest_process(x)
  1182. ),
  1183. # 加载行人模型
  1184. ModelType.PEDESTRIAN_MODEL.value[1]: (
  1185. lambda x, y, r, t, z: PedestrianModel(x, y, r, ModelType.PEDESTRIAN_MODEL, t, z),
  1186. ModelType.PEDESTRIAN_MODEL,
  1187. lambda x, y: pedestrian_label(x, y),
  1188. lambda x: forest_process(x)),
  1189. # 加载烟火模型
  1190. ModelType.SMOGFIRE_MODEL.value[1]: (
  1191. lambda x, y, r, t, z: SmogfireModel(x, y, r, ModelType.SMOGFIRE_MODEL, t, z),
  1192. ModelType.SMOGFIRE_MODEL,
  1193. lambda x, y: smogfire_label(x, y),
  1194. lambda x: forest_process(x)),
  1195. # 加载钓鱼游泳模型
  1196. ModelType.ANGLERSWIMMER_MODEL.value[1]: (
  1197. lambda x, y, r, t, z: AnglerSwimmerModel(x, y, r, ModelType.ANGLERSWIMMER_MODEL, t, z),
  1198. ModelType.ANGLERSWIMMER_MODEL,
  1199. lambda x, y: anglerswimmer_label(x, y),
  1200. lambda x: forest_process(x)),
  1201. # 加载乡村模型
  1202. ModelType.COUNTRYROAD_MODEL.value[1]: (
  1203. lambda x, y, r, t, z: CountryRoadModel(x, y, r, ModelType.COUNTRYROAD_MODEL, t, z),
  1204. ModelType.COUNTRYROAD_MODEL,
  1205. lambda x, y: countryroad_label(x, y),
  1206. lambda x: forest_process(x)),
  1207. # 加载船只模型
  1208. ModelType.SHIP_MODEL.value[1]: (
  1209. lambda x, y, r, t, z: ShipModel(x, y, r, ModelType.SHIP_MODEL, t, z),
  1210. ModelType.SHIP_MODEL,
  1211. lambda x, y: ship_label(x, y),
  1212. lambda x: obb_process(x)),
  1213. # 百度AI图片识别模型
  1214. ModelType.BAIDU_MODEL.value[1]: (
  1215. lambda x, y, r, t, z: BaiduAiImageModel(x, y, r, ModelType.BAIDU_MODEL, t, z),
  1216. ModelType.BAIDU_MODEL,
  1217. None,
  1218. lambda x: baidu_process(x)),
  1219. # 航道模型
  1220. ModelType.CHANNEL_EMERGENCY_MODEL.value[1]: (
  1221. lambda x, y, r, t, z: ChannelEmergencyModel(x, y, r, ModelType.CHANNEL_EMERGENCY_MODEL, t, z),
  1222. ModelType.CHANNEL_EMERGENCY_MODEL,
  1223. lambda x, y: channel_emergency_label(x, y),
  1224. lambda x: forest_process(x)),
  1225. # 河道检测模型
  1226. ModelType.RIVER2_MODEL.value[1]: (
  1227. lambda x, y, r, t, z: River2Model(x, y, r, ModelType.RIVER2_MODEL, t, z),
  1228. ModelType.RIVER2_MODEL,
  1229. lambda x, y: river2_label(x, y),
  1230. lambda x: model_process(x)
  1231. )
  1232. }
  1233. # ModelConfig = namedtuple('ModelConfig', ('x', 'y', 'z'))