Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

448 lines
21KB

  1. # -*- coding: utf-8 -*-
  2. import sys
  3. from json import dumps, loads
  4. from traceback import format_exc
  5. import cv2
  6. from loguru import logger
  7. from common.Constant import COLOR
  8. from enums.BaiduSdkEnum import VehicleEnum
  9. from enums.ExceptionEnum import ExceptionType
  10. from enums.ModelTypeEnum2 import ModelType2, BAIDU_MODEL_TARGET_CONFIG2
  11. from exception.CustomerException import ServiceException
  12. from util.ImgBaiduSdk import AipBodyAnalysisClient, AipImageClassifyClient
  13. from util.PlotsUtils import get_label_arrays
  14. from util.TorchUtils import select_device
  15. import time
  16. import torch
  17. import tensorrt as trt
  18. sys.path.extend(['..', '../AIlib2'])
  19. from AI import AI_process, get_postProcess_para, get_postProcess_para_dic, AI_det_track, AI_det_track_batch
  20. from stdc import stdcModel
  21. from utilsK.jkmUtils import pre_process, post_process, get_return_data
  22. from obbUtils.shipUtils import OBB_infer, OBB_tracker, draw_obb, OBB_tracker_batch
  23. from obbUtils.load_obb_model import load_model_decoder_OBB
  24. from trackUtils.sort import Sort
  25. from trackUtils.sort_obb import OBB_Sort
  26. from DMPR import DMPRModel
  27. FONT_PATH = "../AIlib2/conf/platech.ttf"
  28. class Model:
  29. __slots__ = "model_conf"
  30. def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
  31. try:
  32. logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
  33. requestId)
  34. par = modeType.value[4](str(device), gpu_name)
  35. new_device = select_device(par['device'])
  36. Detweights = par['Detweights']
  37. with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
  38. model = runtime.deserialize_cuda_engine(f.read())
  39. Segweights = par['Segweights']
  40. if Segweights:
  41. if modeType.value[3] == 'cityMangement2':
  42. segmodel = DMPRModel(weights=par['Segweights'], par = par['segPar'])
  43. else:
  44. segmodel = stdcModel(weights=par['Segweights'], par = par['segPar'])
  45. else:
  46. segmodel = None
  47. trackPar = par['trackPar']
  48. sort_tracker = Sort(max_age=trackPar['sort_max_age'], min_hits=trackPar['sort_min_hits'],
  49. iou_threshold=trackPar['sort_iou_thresh'])
  50. names, segPar = par['labelnames'], par['segPar']
  51. detPostPar = par['postFile']
  52. rainbows = detPostPar["rainbows"]
  53. modelPar = {'det_Model': model, 'seg_Model': segmodel}
  54. processPar = {
  55. 'half': par['half'],
  56. 'device': new_device,
  57. 'conf_thres': detPostPar["conf_thres"],
  58. 'iou_thres': detPostPar["iou_thres"],
  59. 'trtFlag_det': par['trtFlag_det'],
  60. 'iou2nd': detPostPar.get("ovlap_thres_crossCategory")
  61. }
  62. model_param = {
  63. "modelPar": modelPar,
  64. "processPar": processPar,
  65. "sort_tracker": sort_tracker,
  66. "trackPar": trackPar,
  67. "segPar": segPar
  68. }
  69. self.model_conf = (modeType, model_param, allowedList, names, rainbows)
  70. except Exception:
  71. logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  72. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  73. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  74. def get_label_arraylist(*args):
  75. width, height, names, rainbows = args
  76. # line = int(round(0.002 * (height + width) / 2) + 1)
  77. line = max(1, int(round(width / 1920 * 3)))
  78. tf = max(line, 1)
  79. fontScale = line * 0.33
  80. text_width, text_height = cv2.getTextSize(' 0.95', 0, fontScale=fontScale, thickness=tf)[0]
  81. label_arraylist = get_label_arrays(names, rainbows, fontSize=text_height, fontPath=FONT_PATH)
  82. return label_arraylist, (line, text_width, text_height, fontScale, tf)
  83. """
  84. 输入:
  85. imgarray_list--图像列表
  86. iframe_list -- 帧号列表
  87. modelPar--模型参数,字典,modelPar={'det_Model':,'seg_Model':}
  88. processPar--字典,存放检测相关参数,'half', 'device', 'conf_thres', 'iou_thres','trtFlag_det'
  89. sort_tracker--对象,初始化的跟踪对象。为了保持一致,即使是单帧也要有。
  90. trackPar--跟踪参数,关键字包括:det_cnt,windowsize
  91. segPar--None,分割模型相关参数。如果用不到,则为None
  92. 输入:retResults,timeInfos
  93. retResults:list
  94. retResults[0]--imgarray_list
  95. retResults[1]--所有结果用numpy格式,所有的检测结果,包括8类,每列分别是x1, y1, x2, y2, conf, detclass,iframe,trackId
  96. retResults[2]--所有结果用list表示,其中每一个元素为一个list,表示每一帧的检测结果,每一个结果是由多个list构成,每个list表示一个框,格式为[ cls , x0 ,y0 ,x1 ,y1 ,conf,ifrmae,trackId ],如 retResults[2][j][k]表示第j帧的第k个框。
  97. """
  98. def model_process(args):
  99. # (modeType, model_param, allowedList, names, rainbows)
  100. imgarray_list, iframe_list, model_param, request_id = args
  101. try:
  102. return AI_det_track_batch(imgarray_list, iframe_list, model_param['modelPar'], model_param['processPar'],
  103. model_param['sort_tracker'], model_param['trackPar'], model_param['segPar'])
  104. except ServiceException as s:
  105. raise s
  106. except Exception:
  107. # self.num += 1
  108. # cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame)
  109. logger.error("算法模型分析异常: {}, requestId: {}", format_exc(), request_id)
  110. raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
  111. ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
  112. # 船只模型
  113. class ShipModel:
  114. __slots__ = "model_conf"
  115. def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
  116. s = time.time()
  117. try:
  118. logger.info("########################加载船只模型########################, requestId:{}", requestId)
  119. par = modeType.value[4](str(device), gpu_name)
  120. obbModelPar = par['obbModelPar']
  121. model, decoder2 = load_model_decoder_OBB(obbModelPar)
  122. obbModelPar['decoder'] = decoder2
  123. names = par['labelnames']
  124. rainbows = par['postFile']["rainbows"]
  125. trackPar = par['trackPar']
  126. sort_tracker = OBB_Sort(max_age=trackPar['sort_max_age'], min_hits=trackPar['sort_min_hits'],
  127. iou_threshold=trackPar['sort_iou_thresh'])
  128. modelPar = {'obbmodel': model}
  129. segPar = None
  130. model_param = {
  131. "modelPar": modelPar,
  132. "obbModelPar": obbModelPar,
  133. "sort_tracker": sort_tracker,
  134. "trackPar": trackPar,
  135. "segPar": segPar
  136. }
  137. self.model_conf = (modeType, model_param, allowedList, names, rainbows)
  138. except Exception:
  139. logger.exception("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  140. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  141. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  142. logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId)
  143. def obb_process(args):
  144. imgarray_list, iframe_list, model_param, request_id = args
  145. try:
  146. return OBB_tracker_batch(imgarray_list, iframe_list, model_param['modelPar'], model_param['obbModelPar'],
  147. model_param['sort_tracker'], model_param['trackPar'], model_param['segPar'])
  148. except ServiceException as s:
  149. raise s
  150. except Exception:
  151. # self.num += 1
  152. # cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame)
  153. logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
  154. raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
  155. ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
  156. # 车牌分割模型、健康码、行程码分割模型
  157. class IMModel:
  158. __slots__ = "model_conf"
  159. def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
  160. try:
  161. logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
  162. requestId)
  163. img_type = 'code'
  164. if ModelType2.PLATE_MODEL == modeType:
  165. img_type = 'plate'
  166. par = {
  167. 'code': {'weights': '../AIlib2/weights/conf/jkm/health_yolov5s_v3.jit', 'img_type': 'code', 'nc': 10},
  168. 'plate': {'weights': '../AIlib2/weights/conf/jkm/plate_yolov5s_v3.jit', 'img_type': 'plate', 'nc': 1},
  169. 'conf_thres': 0.4,
  170. 'iou_thres': 0.45,
  171. 'device': 'cuda:%s' % device,
  172. 'plate_dilate': (0.5, 0.3)
  173. }
  174. new_device = torch.device(par['device'])
  175. model = torch.jit.load(par[img_type]['weights'])
  176. model_param = {
  177. "device": new_device,
  178. "model": model,
  179. "par": par,
  180. "img_type": img_type
  181. }
  182. self.model_conf = (modeType, model_param, allowedList)
  183. except Exception:
  184. logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  185. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  186. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  187. def im_process(args):
  188. model_param, frame, request_id = args
  189. device, par, img_type = model_param['device'], model_param['par'], model_param['img_type']
  190. try:
  191. img, padInfos = pre_process(frame, device)
  192. pred = model_param['model'](img)
  193. boxes = post_process(pred, padInfos, device, conf_thres=par['conf_thres'],
  194. iou_thres=par['iou_thres'], nc=par[img_type]['nc']) # 后处理
  195. dataBack = get_return_data(frame, boxes, modelType=img_type, plate_dilate=par['plate_dilate'])
  196. return dataBack
  197. except ServiceException as s:
  198. raise s
  199. except Exception:
  200. logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
  201. raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
  202. ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
  203. # 百度AI图片识别模型
  204. class BaiduAiImageModel:
  205. __slots__ = "model_conf"
  206. def __init__(self, device=None, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None,
  207. env=None):
  208. try:
  209. logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
  210. requestId)
  211. aipBodyAnalysisClient = AipBodyAnalysisClient(base_dir, env)
  212. aipImageClassifyClient = AipImageClassifyClient(base_dir, env)
  213. rainbows = COLOR
  214. vehicle_names = [VehicleEnum.CAR.value[1], VehicleEnum.TRICYCLE.value[1], VehicleEnum.MOTORBIKE.value[1],
  215. VehicleEnum.CARPLATE.value[1], VehicleEnum.TRUCK.value[1], VehicleEnum.BUS.value[1]]
  216. person_names = ['人']
  217. model_param = {
  218. "vehicle_client": aipImageClassifyClient,
  219. "person_client": aipBodyAnalysisClient,
  220. }
  221. self.model_conf = (modeType, model_param, allowedList, (vehicle_names, person_names), rainbows)
  222. except Exception:
  223. logger.exception("模型加载异常:{}, requestId:{}", format_exc(), requestId)
  224. raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
  225. ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
  226. def baidu_process(args):
  227. model_param, target, url, request_id = args
  228. try:
  229. baiduEnum = BAIDU_MODEL_TARGET_CONFIG2.get(target)
  230. if baiduEnum is None:
  231. raise ServiceException(ExceptionType.DETECTION_TARGET_TYPES_ARE_NOT_SUPPORTED.value[0],
  232. ExceptionType.DETECTION_TARGET_TYPES_ARE_NOT_SUPPORTED.value[1]
  233. + " target: " + target)
  234. return baiduEnum.value[2](model_param['vehicle_client'], model_param['person_client'], url, request_id)
  235. except ServiceException as s:
  236. raise s
  237. except Exception:
  238. logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
  239. raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
  240. ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
  241. def get_baidu_label_arraylist(*args):
  242. width, height, vehicle_names, person_names, rainbows = args
  243. # line = int(round(0.002 * (height + width) / 2) + 1)
  244. line = max(1, int(round(width / 1920 * 3) + 1))
  245. label = ' 0.97'
  246. tf = max(line, 1)
  247. fontScale = line * 0.33
  248. text_width, text_height = cv2.getTextSize(label, 0, fontScale=fontScale, thickness=tf)[0]
  249. vehicle_label_arrays = get_label_arrays(vehicle_names, rainbows, fontSize=text_height, fontPath=FONT_PATH)
  250. person_label_arrays = get_label_arrays(person_names, rainbows, fontSize=text_height, fontPath=FONT_PATH)
  251. font_config = (line, text_width, text_height, fontScale, tf)
  252. return vehicle_label_arrays, person_label_arrays, font_config
  253. def one_label(width, height, model_config):
  254. # (modeType, model_param, allowedList, names, rainbows)
  255. names = model_config[3]
  256. rainbows = model_config[4]
  257. label_arraylist, font_config = get_label_arraylist(width, height, names, rainbows)
  258. model_config[1]['label_arraylist'] = label_arraylist
  259. model_config[1]['font_config'] = font_config
  260. def baidu_label(width, height, model_config):
  261. # modeType, model_param, allowedList, (vehicle_names, person_names), rainbows
  262. vehicle_names = model_config[3][0]
  263. person_names = model_config[3][1]
  264. rainbows = model_config[4]
  265. vehicle_label_arrays, person_label_arrays, font_config = get_baidu_label_arraylist(width, height, vehicle_names,
  266. person_names, rainbows)
  267. model_config[1]['vehicle_label_arrays'] = vehicle_label_arrays
  268. model_config[1]['person_label_arrays'] = person_label_arrays
  269. model_config[1]['font_config'] = font_config
  270. def model_process1(args):
  271. imgarray_list, iframe_list, model_param, request_id = args
  272. model_conf, frame, request_id = args
  273. model_param, names, rainbows = model_conf[1], model_conf[3], model_conf[4]
  274. # modeType, model_param, allowedList, names, rainbows = model_conf
  275. # segmodel, names, label_arraylist, rainbows, objectPar, font, segPar, mode, postPar, requestId = args
  276. # model_param['digitFont'] = digitFont
  277. # model_param['label_arraylist'] = label_arraylist
  278. # model_param['font_config'] = font_config
  279. try:
  280. return AI_process([frame], model_param['model'], model_param['segmodel'], names, model_param['label_arraylist'],
  281. rainbows, objectPar=model_param['objectPar'], font=model_param['digitFont'],
  282. segPar=loads(dumps(model_param['segPar'])), mode=model_param['mode'],
  283. postPar=model_param['postPar'])
  284. except ServiceException as s:
  285. raise s
  286. except Exception:
  287. # self.num += 1
  288. # cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame)
  289. logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
  290. raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
  291. ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
  292. MODEL_CONFIG2 = {
  293. # 加载河道模型
  294. ModelType2.WATER_SURFACE_MODEL.value[1]: (
  295. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.WATER_SURFACE_MODEL, t, z, h),
  296. ModelType2.WATER_SURFACE_MODEL,
  297. lambda x, y, z: one_label(x, y, z),
  298. lambda x: model_process(x)
  299. ),
  300. # 加载森林模型
  301. ModelType2.FOREST_FARM_MODEL.value[1]: (
  302. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.FOREST_FARM_MODEL, t, z, h),
  303. ModelType2.FOREST_FARM_MODEL,
  304. lambda x, y, z: one_label(x, y, z),
  305. lambda x: model_process(x)
  306. ),
  307. # 加载交通模型
  308. ModelType2.TRAFFIC_FARM_MODEL.value[1]: (
  309. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.TRAFFIC_FARM_MODEL, t, z, h),
  310. ModelType2.TRAFFIC_FARM_MODEL,
  311. lambda x, y, z: one_label(x, y, z),
  312. lambda x: model_process(x)
  313. ),
  314. # 加载防疫模型
  315. ModelType2.EPIDEMIC_PREVENTION_MODEL.value[1]: (
  316. lambda x, y, r, t, z, h: IMModel(x, y, r, ModelType2.EPIDEMIC_PREVENTION_MODEL, t, z, h),
  317. ModelType2.EPIDEMIC_PREVENTION_MODEL,
  318. None,
  319. lambda x: im_process(x)),
  320. # 加载车牌模型
  321. ModelType2.PLATE_MODEL.value[1]: (
  322. lambda x, y, r, t, z, h: IMModel(x, y, r, ModelType2.PLATE_MODEL, t, z, h),
  323. ModelType2.PLATE_MODEL,
  324. None,
  325. lambda x: im_process(x)),
  326. # 加载车辆模型
  327. ModelType2.VEHICLE_MODEL.value[1]: (
  328. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.VEHICLE_MODEL, t, z, h),
  329. ModelType2.VEHICLE_MODEL,
  330. lambda x, y, z: one_label(x, y, z),
  331. lambda x: model_process(x)
  332. ),
  333. # 加载行人模型
  334. ModelType2.PEDESTRIAN_MODEL.value[1]: (
  335. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.PEDESTRIAN_MODEL, t, z, h),
  336. ModelType2.PEDESTRIAN_MODEL,
  337. lambda x, y, z: one_label(x, y, z),
  338. lambda x: model_process(x)),
  339. # 加载烟火模型
  340. ModelType2.SMOGFIRE_MODEL.value[1]: (
  341. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.SMOGFIRE_MODEL, t, z, h),
  342. ModelType2.SMOGFIRE_MODEL,
  343. lambda x, y, z: one_label(x, y, z),
  344. lambda x: model_process(x)),
  345. # 加载钓鱼游泳模型
  346. ModelType2.ANGLERSWIMMER_MODEL.value[1]: (
  347. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.ANGLERSWIMMER_MODEL, t, z, h),
  348. ModelType2.ANGLERSWIMMER_MODEL,
  349. lambda x, y, z: one_label(x, y, z),
  350. lambda x: model_process(x)),
  351. # 加载乡村模型
  352. ModelType2.COUNTRYROAD_MODEL.value[1]: (
  353. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.COUNTRYROAD_MODEL, t, z, h),
  354. ModelType2.COUNTRYROAD_MODEL,
  355. lambda x, y, z: one_label(x, y, z),
  356. lambda x: model_process(x)),
  357. # 加载船只模型
  358. ModelType2.SHIP_MODEL.value[1]: (
  359. lambda x, y, r, t, z, h: ShipModel(x, y, r, ModelType2.SHIP_MODEL, t, z, h),
  360. ModelType2.SHIP_MODEL,
  361. lambda x, y, z: one_label(x, y, z),
  362. lambda x: obb_process(x)),
  363. # 百度AI图片识别模型
  364. ModelType2.BAIDU_MODEL.value[1]: (
  365. lambda x, y, r, t, z, h: BaiduAiImageModel(x, y, r, ModelType2.BAIDU_MODEL, t, z, h),
  366. ModelType2.BAIDU_MODEL,
  367. lambda x, y, z: baidu_label(x, y, z),
  368. lambda x: baidu_process(x)),
  369. # 航道模型
  370. ModelType2.CHANNEL_EMERGENCY_MODEL.value[1]: (
  371. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.CHANNEL_EMERGENCY_MODEL, t, z, h),
  372. ModelType2.CHANNEL_EMERGENCY_MODEL,
  373. lambda x, y, z: one_label(x, y, z),
  374. lambda x: model_process(x)),
  375. # 河道检测模型
  376. ModelType2.RIVER2_MODEL.value[1]: (
  377. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.RIVER2_MODEL, t, z, h),
  378. ModelType2.RIVER2_MODEL,
  379. lambda x, y, z: one_label(x, y, z),
  380. lambda x: model_process(x)),
  381. # 城管模型
  382. ModelType2.CITY_MANGEMENT_MODEL.value[1]: (
  383. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.CITY_MANGEMENT_MODEL, t, z, h),
  384. ModelType2.CITY_MANGEMENT_MODEL,
  385. lambda x, y, z: one_label(x, y, z),
  386. lambda x: model_process(x)
  387. ),
  388. # 人员落水模型
  389. ModelType2.DROWING_MODEL.value[1]: (
  390. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.DROWING_MODEL, t, z, h),
  391. ModelType2.DROWING_MODEL,
  392. lambda x, y, z: one_label(x, y, z),
  393. lambda x: model_process(x)
  394. ),
  395. # 城市违章模型
  396. ModelType2.NOPARKING_MODEL.value[1]: (
  397. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.NOPARKING_MODEL, t, z, h),
  398. ModelType2.NOPARKING_MODEL,
  399. lambda x, y, z: one_label(x, y, z),
  400. lambda x: model_process(x)
  401. ),
  402. ModelType2.CITYROAD_MODEL.value[1]: (
  403. lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.CITYROAD_MODEL, t, z, h),
  404. ModelType2.CITYROAD_MODEL,
  405. lambda x, y, z: one_label(x, y, z),
  406. lambda x: model_process(x))
  407. }