tuoheng_algN/util/ModelUtils2.py

443 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
import sys
from json import dumps, loads
from traceback import format_exc
import cv2
from loguru import logger
from common.Constant import COLOR
from enums.BaiduSdkEnum import VehicleEnum
from enums.ExceptionEnum import ExceptionType
from enums.ModelTypeEnum2 import ModelType2, BAIDU_MODEL_TARGET_CONFIG2
from exception.CustomerException import ServiceException
from util.ImgBaiduSdk import AipBodyAnalysisClient, AipImageClassifyClient
from util.PlotsUtils import get_label_arrays
from util.TorchUtils import select_device
import time
import torch
import tensorrt as trt
sys.path.extend(['..', '../AIlib2'])
from AI import AI_process, get_postProcess_para, get_postProcess_para_dic, AI_det_track, AI_det_track_batch, AI_det_track_batch_N
from stdc import stdcModel
from utilsK.jkmUtils import pre_process, post_process, get_return_data
from obbUtils.shipUtils import OBB_infer, OBB_tracker, draw_obb, OBB_tracker_batch
from obbUtils.load_obb_model import load_model_decoder_OBB
from trackUtils.sort import Sort
from trackUtils.sort_obb import OBB_Sort
from DMPR import DMPRModel
FONT_PATH = "../AIlib2/conf/platech.ttf"
class Model:
__slots__ = "model_conf"
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
try:
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
requestId)
par = modeType.value[4](str(device), gpu_name)
trackPar = par['trackPar']
names = par['labelnames']
detPostPar = par['postFile']
rainbows = detPostPar["rainbows"]
#第一步加载模型
modelList=[ modelPar['model'](weights=modelPar['weight'],par=modelPar['par']) for modelPar in par['models'] ]
#第二步准备跟踪参数
trackPar=par['trackPar']
sort_tracker = Sort(max_age=trackPar['sort_max_age'],
min_hits=trackPar['sort_min_hits'],
iou_threshold=trackPar['sort_iou_thresh'])
postProcess = par['postProcess']
model_param = {
"modelList": modelList,
"postProcess": postProcess,
"sort_tracker": sort_tracker,
"trackPar": trackPar,
}
self.model_conf = (modeType, model_param, allowedList, names, rainbows)
except Exception:
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
def get_label_arraylist(*args):
width, height, names, rainbows = args
# line = int(round(0.002 * (height + width) / 2) + 1)
line = max(1, int(round(width / 1920 * 3)))
tf = max(line, 1)
fontScale = line * 0.33
text_width, text_height = cv2.getTextSize(' 0.95', 0, fontScale=fontScale, thickness=tf)[0]
label_arraylist = get_label_arrays(names, rainbows, fontSize=text_height, fontPath=FONT_PATH)
return label_arraylist, (line, text_width, text_height, fontScale, tf)
"""
输入:
imgarray_list--图像列表
iframe_list -- 帧号列表
modelPar--模型参数,字典,modelPar={'det_Model':,'seg_Model':}
processPar--字典,存放检测相关参数,'half', 'device', 'conf_thres', 'iou_thres','trtFlag_det'
sort_tracker--对象,初始化的跟踪对象。为了保持一致,即使是单帧也要有。
trackPar--跟踪参数关键字包括det_cntwindowsize
segPar--None,分割模型相关参数。如果用不到则为None
输入retResults,timeInfos
retResultslist
retResults[0]--imgarray_list
retResults[1]--所有结果用numpy格式所有的检测结果包括8类每列分别是x1, y1, x2, y2, conf, detclass,iframe,trackId
retResults[2]--所有结果用list表示,其中每一个元素为一个list表示每一帧的检测结果每一个结果是由多个list构成每个list表示一个框格式为[ cls , x0 ,y0 ,x1 ,y1 ,conf,ifrmae,trackId ],如 retResults[2][j][k]表示第j帧的第k个框。
"""
def model_process(args):
# (modeType, model_param, allowedList, names, rainbows)
imgarray_list, iframe_list, model_param, request_id = args
try:
return AI_det_track_batch_N(imgarray_list, iframe_list,
model_param['modelList'],
model_param['postProcess'],
model_param['sort_tracker'],
model_param['trackPar'])
except ServiceException as s:
raise s
except Exception:
# self.num += 1
# cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame)
logger.error("算法模型分析异常: {}, requestId: {}", format_exc(), request_id)
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
# 船只模型
class ShipModel:
__slots__ = "model_conf"
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
s = time.time()
try:
logger.info("########################加载船只模型########################, requestId:{}", requestId)
par = modeType.value[4](str(device), gpu_name)
obbModelPar = par['obbModelPar']
model, decoder2 = load_model_decoder_OBB(obbModelPar)
obbModelPar['decoder'] = decoder2
names = par['labelnames']
rainbows = par['postFile']["rainbows"]
trackPar = par['trackPar']
sort_tracker = OBB_Sort(max_age=trackPar['sort_max_age'], min_hits=trackPar['sort_min_hits'],
iou_threshold=trackPar['sort_iou_thresh'])
modelPar = {'obbmodel': model}
segPar = None
model_param = {
"modelPar": modelPar,
"obbModelPar": obbModelPar,
"sort_tracker": sort_tracker,
"trackPar": trackPar,
"segPar": segPar
}
self.model_conf = (modeType, model_param, allowedList, names, rainbows)
except Exception:
logger.exception("模型加载异常:{}, requestId:{}", format_exc(), requestId)
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId)
def obb_process(args):
imgarray_list, iframe_list, model_param, request_id = args
try:
return OBB_tracker_batch(imgarray_list, iframe_list, model_param['modelPar'], model_param['obbModelPar'],
model_param['sort_tracker'], model_param['trackPar'], model_param['segPar'])
except ServiceException as s:
raise s
except Exception:
# self.num += 1
# cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame)
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
# 车牌分割模型、健康码、行程码分割模型
class IMModel:
__slots__ = "model_conf"
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
try:
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
requestId)
img_type = 'code'
if ModelType2.PLATE_MODEL == modeType:
img_type = 'plate'
par = {
'code': {'weights': '../AIlib2/weights/conf/jkm/health_yolov5s_v3.jit', 'img_type': 'code', 'nc': 10},
'plate': {'weights': '../AIlib2/weights/conf/jkm/plate_yolov5s_v3.jit', 'img_type': 'plate', 'nc': 1},
'conf_thres': 0.4,
'iou_thres': 0.45,
'device': 'cuda:%s' % device,
'plate_dilate': (0.5, 0.3)
}
new_device = torch.device(par['device'])
model = torch.jit.load(par[img_type]['weights'])
model_param = {
"device": new_device,
"model": model,
"par": par,
"img_type": img_type
}
self.model_conf = (modeType, model_param, allowedList)
except Exception:
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
def im_process(args):
model_param, frame, request_id = args
device, par, img_type = model_param['device'], model_param['par'], model_param['img_type']
try:
img, padInfos = pre_process(frame, device)
pred = model_param['model'](img)
boxes = post_process(pred, padInfos, device, conf_thres=par['conf_thres'],
iou_thres=par['iou_thres'], nc=par[img_type]['nc']) # 后处理
dataBack = get_return_data(frame, boxes, modelType=img_type, plate_dilate=par['plate_dilate'])
return dataBack
except ServiceException as s:
raise s
except Exception:
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
# 百度AI图片识别模型
class BaiduAiImageModel:
__slots__ = "model_conf"
def __init__(self, device=None, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None,
env=None):
try:
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
requestId)
aipBodyAnalysisClient = AipBodyAnalysisClient(base_dir, env)
aipImageClassifyClient = AipImageClassifyClient(base_dir, env)
rainbows = COLOR
vehicle_names = [VehicleEnum.CAR.value[1], VehicleEnum.TRICYCLE.value[1], VehicleEnum.MOTORBIKE.value[1],
VehicleEnum.CARPLATE.value[1], VehicleEnum.TRUCK.value[1], VehicleEnum.BUS.value[1]]
person_names = ['']
model_param = {
"vehicle_client": aipImageClassifyClient,
"person_client": aipBodyAnalysisClient,
}
self.model_conf = (modeType, model_param, allowedList, (vehicle_names, person_names), rainbows)
except Exception:
logger.exception("模型加载异常:{}, requestId:{}", format_exc(), requestId)
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
def baidu_process(args):
model_param, target, url, request_id = args
try:
baiduEnum = BAIDU_MODEL_TARGET_CONFIG2.get(target)
if baiduEnum is None:
raise ServiceException(ExceptionType.DETECTION_TARGET_TYPES_ARE_NOT_SUPPORTED.value[0],
ExceptionType.DETECTION_TARGET_TYPES_ARE_NOT_SUPPORTED.value[1]
+ " target: " + target)
return baiduEnum.value[2](model_param['vehicle_client'], model_param['person_client'], url, request_id)
except ServiceException as s:
raise s
except Exception:
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
def get_baidu_label_arraylist(*args):
width, height, vehicle_names, person_names, rainbows = args
# line = int(round(0.002 * (height + width) / 2) + 1)
line = max(1, int(round(width / 1920 * 3) + 1))
label = ' 0.97'
tf = max(line, 1)
fontScale = line * 0.33
text_width, text_height = cv2.getTextSize(label, 0, fontScale=fontScale, thickness=tf)[0]
vehicle_label_arrays = get_label_arrays(vehicle_names, rainbows, fontSize=text_height, fontPath=FONT_PATH)
person_label_arrays = get_label_arrays(person_names, rainbows, fontSize=text_height, fontPath=FONT_PATH)
font_config = (line, text_width, text_height, fontScale, tf)
return vehicle_label_arrays, person_label_arrays, font_config
def one_label(width, height, model_config):
# (modeType, model_param, allowedList, names, rainbows)
names = model_config[3]
rainbows = model_config[4]
label_arraylist, font_config = get_label_arraylist(width, height, names, rainbows)
model_config[1]['label_arraylist'] = label_arraylist
model_config[1]['font_config'] = font_config
def baidu_label(width, height, model_config):
# modeType, model_param, allowedList, (vehicle_names, person_names), rainbows
vehicle_names = model_config[3][0]
person_names = model_config[3][1]
rainbows = model_config[4]
vehicle_label_arrays, person_label_arrays, font_config = get_baidu_label_arraylist(width, height, vehicle_names,
person_names, rainbows)
model_config[1]['vehicle_label_arrays'] = vehicle_label_arrays
model_config[1]['person_label_arrays'] = person_label_arrays
model_config[1]['font_config'] = font_config
def model_process1(args):
imgarray_list, iframe_list, model_param, request_id = args
model_conf, frame, request_id = args
model_param, names, rainbows = model_conf[1], model_conf[3], model_conf[4]
# modeType, model_param, allowedList, names, rainbows = model_conf
# segmodel, names, label_arraylist, rainbows, objectPar, font, segPar, mode, postPar, requestId = args
# model_param['digitFont'] = digitFont
# model_param['label_arraylist'] = label_arraylist
# model_param['font_config'] = font_config
try:
return AI_process([frame], model_param['model'], model_param['segmodel'], names, model_param['label_arraylist'],
rainbows, objectPar=model_param['objectPar'], font=model_param['digitFont'],
segPar=loads(dumps(model_param['segPar'])), mode=model_param['mode'],
postPar=model_param['postPar'])
except ServiceException as s:
raise s
except Exception:
# self.num += 1
# cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame)
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
MODEL_CONFIG2 = {
# 加载河道模型
ModelType2.WATER_SURFACE_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.WATER_SURFACE_MODEL, t, z, h),
ModelType2.WATER_SURFACE_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)
),
# 加载森林模型
ModelType2.FOREST_FARM_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.FOREST_FARM_MODEL, t, z, h),
ModelType2.FOREST_FARM_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)
),
# 加载交通模型
ModelType2.TRAFFIC_FARM_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.TRAFFIC_FARM_MODEL, t, z, h),
ModelType2.TRAFFIC_FARM_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)
),
# 加载防疫模型
ModelType2.EPIDEMIC_PREVENTION_MODEL.value[1]: (
lambda x, y, r, t, z, h: IMModel(x, y, r, ModelType2.EPIDEMIC_PREVENTION_MODEL, t, z, h),
ModelType2.EPIDEMIC_PREVENTION_MODEL,
None,
lambda x: im_process(x)),
# 加载车牌模型
ModelType2.PLATE_MODEL.value[1]: (
lambda x, y, r, t, z, h: IMModel(x, y, r, ModelType2.PLATE_MODEL, t, z, h),
ModelType2.PLATE_MODEL,
None,
lambda x: im_process(x)),
# 加载车辆模型
ModelType2.VEHICLE_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.VEHICLE_MODEL, t, z, h),
ModelType2.VEHICLE_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)
),
# 加载行人模型
ModelType2.PEDESTRIAN_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.PEDESTRIAN_MODEL, t, z, h),
ModelType2.PEDESTRIAN_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)),
# 加载烟火模型
ModelType2.SMOGFIRE_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.SMOGFIRE_MODEL, t, z, h),
ModelType2.SMOGFIRE_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)),
# 加载钓鱼游泳模型
ModelType2.ANGLERSWIMMER_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.ANGLERSWIMMER_MODEL, t, z, h),
ModelType2.ANGLERSWIMMER_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)),
# 加载乡村模型
ModelType2.COUNTRYROAD_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.COUNTRYROAD_MODEL, t, z, h),
ModelType2.COUNTRYROAD_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)),
# 加载船只模型
ModelType2.SHIP_MODEL.value[1]: (
lambda x, y, r, t, z, h: ShipModel(x, y, r, ModelType2.SHIP_MODEL, t, z, h),
ModelType2.SHIP_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: obb_process(x)),
# 百度AI图片识别模型
ModelType2.BAIDU_MODEL.value[1]: (
lambda x, y, r, t, z, h: BaiduAiImageModel(x, y, r, ModelType2.BAIDU_MODEL, t, z, h),
ModelType2.BAIDU_MODEL,
lambda x, y, z: baidu_label(x, y, z),
lambda x: baidu_process(x)),
# 航道模型
ModelType2.CHANNEL_EMERGENCY_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.CHANNEL_EMERGENCY_MODEL, t, z, h),
ModelType2.CHANNEL_EMERGENCY_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)),
# 河道检测模型
ModelType2.RIVER2_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.RIVER2_MODEL, t, z, h),
ModelType2.RIVER2_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)),
# 城管模型
ModelType2.CITY_MANGEMENT_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.CITY_MANGEMENT_MODEL, t, z, h),
ModelType2.CITY_MANGEMENT_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)
),
# 人员落水模型
ModelType2.DROWING_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.DROWING_MODEL, t, z, h),
ModelType2.DROWING_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)
),
# 城市违章模型
ModelType2.NOPARKING_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.NOPARKING_MODEL, t, z, h),
ModelType2.NOPARKING_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)
),
# 城市公路模型
ModelType2.CITYROAD_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.CITYROAD_MODEL, t, z, h),
ModelType2.CITYROAD_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)
),
# 加载坑槽模型
ModelType2.POTHOLE_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.POTHOLE_MODEL, t, z, h),
ModelType2.POTHOLE_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)
),
}