You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

443 lines
20KB

  1. # -*- coding: utf-8 -*-
  2. import sys
  3. from json import dumps, loads
  4. from traceback import format_exc
  5. import cv2
  6. from loguru import logger
  7. from common.Constant import COLOR
  8. from enums.BaiduSdkEnum import VehicleEnum
  9. from enums.ExceptionEnum import ExceptionType
  10. from enums.ModelTypeEnum2 import ModelType2, BAIDU_MODEL_TARGET_CONFIG2
  11. from exception.CustomerException import ServiceException
  12. from util.ImgBaiduSdk import AipBodyAnalysisClient, AipImageClassifyClient
  13. from util.PlotsUtils import get_label_arrays
  14. from util.TorchUtils import select_device
  15. import time
  16. import torch
  17. import tensorrt as trt
  18. sys.path.extend(['..', '../AIlib2'])
  19. from AI import AI_process, get_postProcess_para, get_postProcess_para_dic, AI_det_track, AI_det_track_batch, AI_det_track_batch_N
  20. from stdc import stdcModel
  21. from utilsK.jkmUtils import pre_process, post_process, get_return_data
  22. from obbUtils.shipUtils import OBB_infer, OBB_tracker, draw_obb, OBB_tracker_batch
  23. from obbUtils.load_obb_model import load_model_decoder_OBB
  24. from trackUtils.sort import Sort
  25. from trackUtils.sort_obb import OBB_Sort
  26. from DMPR import DMPRModel
  27. FONT_PATH = "../AIlib2/conf/platech.ttf"
  28. class Model:
  29. __slots__ = "model_conf"
  30. def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
  31. try:
  32. logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
  33. requestId)
  34. par = modeType.value[4](str(device), gpu_name)
  35. trackPar = par['trackPar']
  36. names = par['labelnames']
  37. detPostPar = par['postFile']
  38. rainbows = detPostPar["rainbows"]
  39. #第一步加载模型
  40. modelList=[ modelPar['model'](weights=modelPar['weight'],par=modelPar['par']) for modelPar in par['models'] ]
  41. #第二步准备跟踪参数
  42. trackPar=par['trackPar']
  43. sort_tracker = Sort(max_age=trackPar['sort_max_age'],
  44. min_hits=trackPar['sort_min_hits'],
  45. iou_threshold=trackPar['sort_iou_thresh'])
  46. postProcess = par['postProcess']
  47. model_param = {
  48. "modelList": modelList,
  49. "postProcess": postProcess,
  50. "sort_tracker": sort_tracker,
  51. "trackPar": trackPar,
  52. }
  53. self.model_conf = (modeType, model_param, allowedList, names, rainbows)
  54. except Exception:
  55. logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  56. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  57. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  58. def get_label_arraylist(*args):
  59. width, height, names, rainbows = args
  60. # line = int(round(0.002 * (height + width) / 2) + 1)
  61. line = max(1, int(round(width / 1920 * 3)))
  62. tf = max(line, 1)
  63. fontScale = line * 0.33
  64. text_width, text_height = cv2.getTextSize(' 0.95', 0, fontScale=fontScale, thickness=tf)[0]
  65. label_arraylist = get_label_arrays(names, rainbows, fontSize=text_height, fontPath=FONT_PATH)
  66. return label_arraylist, (line, text_width, text_height, fontScale, tf)
  67. """
  68. 输入:
  69. imgarray_list--图像列表
  70. iframe_list -- 帧号列表
  71. modelPar--模型参数,字典,modelPar={'det_Model':,'seg_Model':}
  72. processPar--字典,存放检测相关参数,'half', 'device', 'conf_thres', 'iou_thres','trtFlag_det'
  73. sort_tracker--对象,初始化的跟踪对象。为了保持一致,即使是单帧也要有。
  74. trackPar--跟踪参数,关键字包括:det_cnt,windowsize
  75. segPar--None,分割模型相关参数。如果用不到,则为None
  76. 输入:retResults,timeInfos
  77. retResults:list
  78. retResults[0]--imgarray_list
  79. retResults[1]--所有结果用numpy格式,所有的检测结果,包括8类,每列分别是x1, y1, x2, y2, conf, detclass,iframe,trackId
  80. retResults[2]--所有结果用list表示,其中每一个元素为一个list,表示每一帧的检测结果,每一个结果是由多个list构成,每个list表示一个框,格式为[ cls , x0 ,y0 ,x1 ,y1 ,conf,ifrmae,trackId ],如 retResults[2][j][k]表示第j帧的第k个框。
  81. """
  82. def model_process(args):
  83. # (modeType, model_param, allowedList, names, rainbows)
  84. imgarray_list, iframe_list, model_param, request_id = args
  85. try:
  86. return AI_det_track_batch_N(imgarray_list, iframe_list,
  87. model_param['modelList'],
  88. model_param['postProcess'],
  89. model_param['sort_tracker'],
  90. model_param['trackPar'])
  91. except ServiceException as s:
  92. raise s
  93. except Exception:
  94. # self.num += 1
  95. # cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame)
  96. logger.error("算法模型分析异常: {}, requestId: {}", format_exc(), request_id)
  97. raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
  98. ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
  99. # 船只模型
  100. class ShipModel:
  101. __slots__ = "model_conf"
  102. def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
  103. s = time.time()
  104. try:
  105. logger.info("########################加载船只模型########################, requestId:{}", requestId)
  106. par = modeType.value[4](str(device), gpu_name)
  107. obbModelPar = par['obbModelPar']
  108. model, decoder2 = load_model_decoder_OBB(obbModelPar)
  109. obbModelPar['decoder'] = decoder2
  110. names = par['labelnames']
  111. rainbows = par['postFile']["rainbows"]
  112. trackPar = par['trackPar']
  113. sort_tracker = OBB_Sort(max_age=trackPar['sort_max_age'], min_hits=trackPar['sort_min_hits'],
  114. iou_threshold=trackPar['sort_iou_thresh'])
  115. modelPar = {'obbmodel': model}
  116. segPar = None
  117. model_param = {
  118. "modelPar": modelPar,
  119. "obbModelPar": obbModelPar,
  120. "sort_tracker": sort_tracker,
  121. "trackPar": trackPar,
  122. "segPar": segPar
  123. }
  124. self.model_conf = (modeType, model_param, allowedList, names, rainbows)
  125. except Exception:
  126. logger.exception("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  127. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  128. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  129. logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId)
  130. def obb_process(args):
  131. imgarray_list, iframe_list, model_param, request_id = args
  132. try:
  133. return OBB_tracker_batch(imgarray_list, iframe_list, model_param['modelPar'], model_param['obbModelPar'],
  134. model_param['sort_tracker'], model_param['trackPar'], model_param['segPar'])
  135. except ServiceException as s:
  136. raise s
  137. except Exception:
  138. # self.num += 1
  139. # cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame)
  140. logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
  141. raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
  142. ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
  143. # 车牌分割模型、健康码、行程码分割模型
  144. class IMModel:
  145. __slots__ = "model_conf"
  146. def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
  147. try:
  148. logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
  149. requestId)
  150. img_type = 'code'
  151. if ModelType2.PLATE_MODEL == modeType:
  152. img_type = 'plate'
  153. par = {
  154. 'code': {'weights': '../AIlib2/weights/conf/jkm/health_yolov5s_v3.jit', 'img_type': 'code', 'nc': 10},
  155. 'plate': {'weights': '../AIlib2/weights/conf/jkm/plate_yolov5s_v3.jit', 'img_type': 'plate', 'nc': 1},
  156. 'conf_thres': 0.4,
  157. 'iou_thres': 0.45,
  158. 'device': 'cuda:%s' % device,
  159. 'plate_dilate': (0.5, 0.3)
  160. }
  161. new_device = torch.device(par['device'])
  162. model = torch.jit.load(par[img_type]['weights'])
  163. model_param = {
  164. "device": new_device,
  165. "model": model,
  166. "par": par,
  167. "img_type": img_type
  168. }
  169. self.model_conf = (modeType, model_param, allowedList)
  170. except Exception:
  171. logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  172. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  173. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  174. def im_process(args):
  175. model_param, frame, request_id = args
  176. device, par, img_type = model_param['device'], model_param['par'], model_param['img_type']
  177. try:
  178. img, padInfos = pre_process(frame, device)
  179. pred = model_param['model'](img)
  180. boxes = post_process(pred, padInfos, device, conf_thres=par['conf_thres'],
  181. iou_thres=par['iou_thres'], nc=par[img_type]['nc']) # 后处理
  182. dataBack = get_return_data(frame, boxes, modelType=img_type, plate_dilate=par['plate_dilate'])
  183. return dataBack
  184. except ServiceException as s:
  185. raise s
  186. except Exception:
  187. logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
  188. raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
  189. ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
  190. # 百度AI图片识别模型
  191. class BaiduAiImageModel:
  192. __slots__ = "model_conf"
  193. def __init__(self, device=None, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None,
  194. env=None):
  195. try:
  196. logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
  197. requestId)
  198. aipBodyAnalysisClient = AipBodyAnalysisClient(base_dir, env)
  199. aipImageClassifyClient = AipImageClassifyClient(base_dir, env)
  200. rainbows = COLOR
  201. vehicle_names = [VehicleEnum.CAR.value[1], VehicleEnum.TRICYCLE.value[1], VehicleEnum.MOTORBIKE.value[1],
  202. VehicleEnum.CARPLATE.value[1], VehicleEnum.TRUCK.value[1], VehicleEnum.BUS.value[1]]
  203. person_names = ['人']
  204. model_param = {
  205. "vehicle_client": aipImageClassifyClient,
  206. "person_client": aipBodyAnalysisClient,
  207. }
  208. self.model_conf = (modeType, model_param, allowedList, (vehicle_names, person_names), rainbows)
  209. except Exception:
  210. logger.exception("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  211. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  212. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  213. def baidu_process(args):
  214. model_param, target, url, request_id = args
  215. try:
  216. baiduEnum = BAIDU_MODEL_TARGET_CONFIG2.get(target)
  217. if baiduEnum is None:
  218. raise ServiceException(ExceptionType.DETECTION_TARGET_TYPES_ARE_NOT_SUPPORTED.value[0],
  219. ExceptionType.DETECTION_TARGET_TYPES_ARE_NOT_SUPPORTED.value[1]
  220. + " target: " + target)
  221. return baiduEnum.value[2](model_param['vehicle_client'], model_param['person_client'], url, request_id)
  222. except ServiceException as s:
  223. raise s
  224. except Exception:
  225. logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
  226. raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
  227. ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
  228. def get_baidu_label_arraylist(*args):
  229. width, height, vehicle_names, person_names, rainbows = args
  230. # line = int(round(0.002 * (height + width) / 2) + 1)
  231. line = max(1, int(round(width / 1920 * 3) + 1))
  232. label = ' 0.97'
  233. tf = max(line, 1)
  234. fontScale = line * 0.33
  235. text_width, text_height = cv2.getTextSize(label, 0, fontScale=fontScale, thickness=tf)[0]
  236. vehicle_label_arrays = get_label_arrays(vehicle_names, rainbows, fontSize=text_height, fontPath=FONT_PATH)
  237. person_label_arrays = get_label_arrays(person_names, rainbows, fontSize=text_height, fontPath=FONT_PATH)
  238. font_config = (line, text_width, text_height, fontScale, tf)
  239. return vehicle_label_arrays, person_label_arrays, font_config
  240. def one_label(width, height, model_config):
  241. # (modeType, model_param, allowedList, names, rainbows)
  242. names = model_config[3]
  243. rainbows = model_config[4]
  244. label_arraylist, font_config = get_label_arraylist(width, height, names, rainbows)
  245. model_config[1]['label_arraylist'] = label_arraylist
  246. model_config[1]['font_config'] = font_config
  247. def baidu_label(width, height, model_config):
  248. # modeType, model_param, allowedList, (vehicle_names, person_names), rainbows
  249. vehicle_names = model_config[3][0]
  250. person_names = model_config[3][1]
  251. rainbows = model_config[4]
  252. vehicle_label_arrays, person_label_arrays, font_config = get_baidu_label_arraylist(width, height, vehicle_names,
  253. person_names, rainbows)
  254. model_config[1]['vehicle_label_arrays'] = vehicle_label_arrays
  255. model_config[1]['person_label_arrays'] = person_label_arrays
  256. model_config[1]['font_config'] = font_config
  257. def model_process1(args):
  258. imgarray_list, iframe_list, model_param, request_id = args
  259. model_conf, frame, request_id = args
  260. model_param, names, rainbows = model_conf[1], model_conf[3], model_conf[4]
  261. # modeType, model_param, allowedList, names, rainbows = model_conf
  262. # segmodel, names, label_arraylist, rainbows, objectPar, font, segPar, mode, postPar, requestId = args
  263. # model_param['digitFont'] = digitFont
  264. # model_param['label_arraylist'] = label_arraylist
  265. # model_param['font_config'] = font_config
  266. try:
  267. return AI_process([frame], model_param['model'], model_param['segmodel'], names, model_param['label_arraylist'],
  268. rainbows, objectPar=model_param['objectPar'], font=model_param['digitFont'],
  269. segPar=loads(dumps(model_param['segPar'])), mode=model_param['mode'],
  270. postPar=model_param['postPar'])
  271. except ServiceException as s:
  272. raise s
  273. except Exception:
  274. # self.num += 1
  275. # cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame)
  276. logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
  277. raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
  278. ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
  279. MODEL_CONFIG2 = {
  280. # 加载河道模型
  281. ModelType2.WATER_SURFACE_MODEL.value[1]: (
  282. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.WATER_SURFACE_MODEL, t, z, h),
  283. ModelType2.WATER_SURFACE_MODEL,
  284. lambda x, y, z: one_label(x, y, z),
  285. lambda x: model_process(x)
  286. ),
  287. # 加载森林模型
  288. ModelType2.FOREST_FARM_MODEL.value[1]: (
  289. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.FOREST_FARM_MODEL, t, z, h),
  290. ModelType2.FOREST_FARM_MODEL,
  291. lambda x, y, z: one_label(x, y, z),
  292. lambda x: model_process(x)
  293. ),
  294. # 加载交通模型
  295. ModelType2.TRAFFIC_FARM_MODEL.value[1]: (
  296. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.TRAFFIC_FARM_MODEL, t, z, h),
  297. ModelType2.TRAFFIC_FARM_MODEL,
  298. lambda x, y, z: one_label(x, y, z),
  299. lambda x: model_process(x)
  300. ),
  301. # 加载防疫模型
  302. ModelType2.EPIDEMIC_PREVENTION_MODEL.value[1]: (
  303. lambda x, y, r, t, z, h: IMModel(x, y, r, ModelType2.EPIDEMIC_PREVENTION_MODEL, t, z, h),
  304. ModelType2.EPIDEMIC_PREVENTION_MODEL,
  305. None,
  306. lambda x: im_process(x)),
  307. # 加载车牌模型
  308. ModelType2.PLATE_MODEL.value[1]: (
  309. lambda x, y, r, t, z, h: IMModel(x, y, r, ModelType2.PLATE_MODEL, t, z, h),
  310. ModelType2.PLATE_MODEL,
  311. None,
  312. lambda x: im_process(x)),
  313. # 加载车辆模型
  314. ModelType2.VEHICLE_MODEL.value[1]: (
  315. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.VEHICLE_MODEL, t, z, h),
  316. ModelType2.VEHICLE_MODEL,
  317. lambda x, y, z: one_label(x, y, z),
  318. lambda x: model_process(x)
  319. ),
  320. # 加载行人模型
  321. ModelType2.PEDESTRIAN_MODEL.value[1]: (
  322. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.PEDESTRIAN_MODEL, t, z, h),
  323. ModelType2.PEDESTRIAN_MODEL,
  324. lambda x, y, z: one_label(x, y, z),
  325. lambda x: model_process(x)),
  326. # 加载烟火模型
  327. ModelType2.SMOGFIRE_MODEL.value[1]: (
  328. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.SMOGFIRE_MODEL, t, z, h),
  329. ModelType2.SMOGFIRE_MODEL,
  330. lambda x, y, z: one_label(x, y, z),
  331. lambda x: model_process(x)),
  332. # 加载钓鱼游泳模型
  333. ModelType2.ANGLERSWIMMER_MODEL.value[1]: (
  334. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.ANGLERSWIMMER_MODEL, t, z, h),
  335. ModelType2.ANGLERSWIMMER_MODEL,
  336. lambda x, y, z: one_label(x, y, z),
  337. lambda x: model_process(x)),
  338. # 加载乡村模型
  339. ModelType2.COUNTRYROAD_MODEL.value[1]: (
  340. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.COUNTRYROAD_MODEL, t, z, h),
  341. ModelType2.COUNTRYROAD_MODEL,
  342. lambda x, y, z: one_label(x, y, z),
  343. lambda x: model_process(x)),
  344. # 加载船只模型
  345. ModelType2.SHIP_MODEL.value[1]: (
  346. lambda x, y, r, t, z, h: ShipModel(x, y, r, ModelType2.SHIP_MODEL, t, z, h),
  347. ModelType2.SHIP_MODEL,
  348. lambda x, y, z: one_label(x, y, z),
  349. lambda x: obb_process(x)),
  350. # 百度AI图片识别模型
  351. ModelType2.BAIDU_MODEL.value[1]: (
  352. lambda x, y, r, t, z, h: BaiduAiImageModel(x, y, r, ModelType2.BAIDU_MODEL, t, z, h),
  353. ModelType2.BAIDU_MODEL,
  354. lambda x, y, z: baidu_label(x, y, z),
  355. lambda x: baidu_process(x)),
  356. # 航道模型
  357. ModelType2.CHANNEL_EMERGENCY_MODEL.value[1]: (
  358. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.CHANNEL_EMERGENCY_MODEL, t, z, h),
  359. ModelType2.CHANNEL_EMERGENCY_MODEL,
  360. lambda x, y, z: one_label(x, y, z),
  361. lambda x: model_process(x)),
  362. # 河道检测模型
  363. ModelType2.RIVER2_MODEL.value[1]: (
  364. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.RIVER2_MODEL, t, z, h),
  365. ModelType2.RIVER2_MODEL,
  366. lambda x, y, z: one_label(x, y, z),
  367. lambda x: model_process(x)),
  368. # 城管模型
  369. ModelType2.CITY_MANGEMENT_MODEL.value[1]: (
  370. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.CITY_MANGEMENT_MODEL, t, z, h),
  371. ModelType2.CITY_MANGEMENT_MODEL,
  372. lambda x, y, z: one_label(x, y, z),
  373. lambda x: model_process(x)
  374. ),
  375. # 人员落水模型
  376. ModelType2.DROWING_MODEL.value[1]: (
  377. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.DROWING_MODEL, t, z, h),
  378. ModelType2.DROWING_MODEL,
  379. lambda x, y, z: one_label(x, y, z),
  380. lambda x: model_process(x)
  381. ),
  382. # 城市违章模型
  383. ModelType2.NOPARKING_MODEL.value[1]: (
  384. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.NOPARKING_MODEL, t, z, h),
  385. ModelType2.NOPARKING_MODEL,
  386. lambda x, y, z: one_label(x, y, z),
  387. lambda x: model_process(x)
  388. ),
  389. # 城市公路模型
  390. ModelType2.CITYROAD_MODEL.value[1]: (
  391. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.CITYROAD_MODEL, t, z, h),
  392. ModelType2.CITYROAD_MODEL,
  393. lambda x, y, z: one_label(x, y, z),
  394. lambda x: model_process(x)
  395. ),
  396. # 加载坑槽模型
  397. ModelType2.POTHOLE_MODEL.value[1]: (
  398. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.POTHOLE_MODEL, t, z, h),
  399. ModelType2.POTHOLE_MODEL,
  400. lambda x, y, z: one_label(x, y, z),
  401. lambda x: model_process(x)
  402. ),
  403. }