algN/util/ModelUtils2.py

443 lines
20 KiB
Python
Raw Normal View History

2025-08-23 10:12:26 +08:00
# -*- coding: utf-8 -*-
import sys
from json import dumps, loads
from traceback import format_exc
import cv2
from loguru import logger
from common.Constant import COLOR
from enums.BaiduSdkEnum import VehicleEnum
from enums.ExceptionEnum import ExceptionType
from enums.ModelTypeEnum2 import ModelType2, BAIDU_MODEL_TARGET_CONFIG2
from exception.CustomerException import ServiceException
from util.ImgBaiduSdk import AipBodyAnalysisClient, AipImageClassifyClient
from util.PlotsUtils import get_label_arrays
from util.TorchUtils import select_device
import time
import torch
import tensorrt as trt
sys.path.extend(['..', '../AIlib2'])
from AI import AI_process, get_postProcess_para, get_postProcess_para_dic, AI_det_track, AI_det_track_batch, AI_det_track_batch_N
from stdc import stdcModel
from utilsK.jkmUtils import pre_process, post_process, get_return_data
from obbUtils.shipUtils import OBB_infer, OBB_tracker, draw_obb, OBB_tracker_batch
from obbUtils.load_obb_model import load_model_decoder_OBB
from trackUtils.sort import Sort
from trackUtils.sort_obb import OBB_Sort
from DMPR import DMPRModel
FONT_PATH = "../AIlib2/conf/platech.ttf"
class Model:
__slots__ = "model_conf"
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
try:
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
requestId)
par = modeType.value[4](str(device), gpu_name)
trackPar = par['trackPar']
names = par['labelnames']
detPostPar = par['postFile']
rainbows = detPostPar["rainbows"]
#第一步加载模型
modelList=[ modelPar['model'](weights=modelPar['weight'],par=modelPar['par']) for modelPar in par['models'] ]
#第二步准备跟踪参数
trackPar=par['trackPar']
sort_tracker = Sort(max_age=trackPar['sort_max_age'],
min_hits=trackPar['sort_min_hits'],
iou_threshold=trackPar['sort_iou_thresh'])
postProcess = par['postProcess']
model_param = {
"modelList": modelList,
"postProcess": postProcess,
"sort_tracker": sort_tracker,
"trackPar": trackPar,
}
self.model_conf = (modeType, model_param, allowedList, names, rainbows)
except Exception:
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
def get_label_arraylist(*args):
width, height, names, rainbows = args
# line = int(round(0.002 * (height + width) / 2) + 1)
line = max(1, int(round(width / 1920 * 3)))
tf = max(line, 1)
fontScale = line * 0.33
text_width, text_height = cv2.getTextSize(' 0.95', 0, fontScale=fontScale, thickness=tf)[0]
label_arraylist = get_label_arrays(names, rainbows, fontSize=text_height, fontPath=FONT_PATH)
return label_arraylist, (line, text_width, text_height, fontScale, tf)
"""
输入
imgarray_list--图像列表
iframe_list -- 帧号列表
modelPar--模型参数,字典,modelPar={'det_Model':,'seg_Model':}
processPar--字典存放检测相关参数'half', 'device', 'conf_thres', 'iou_thres','trtFlag_det'
sort_tracker--对象初始化的跟踪对象为了保持一致即使是单帧也要有
trackPar--跟踪参数关键字包括det_cntwindowsize
segPar--None,分割模型相关参数如果用不到则为None
输入retResults,timeInfos
retResultslist
retResults[0]--imgarray_list
retResults[1]--所有结果用numpy格式所有的检测结果包括8类每列分别是x1, y1, x2, y2, conf, detclass,iframe,trackId
retResults[2]--所有结果用list表示,其中每一个元素为一个list表示每一帧的检测结果每一个结果是由多个list构成每个list表示一个框格式为[ cls , x0 ,y0 ,x1 ,y1 ,conf,ifrmae,trackId ] retResults[2][j][k]表示第j帧的第k个框
"""
def model_process(args):
# (modeType, model_param, allowedList, names, rainbows)
imgarray_list, iframe_list, model_param, request_id = args
try:
return AI_det_track_batch_N(imgarray_list, iframe_list,
model_param['modelList'],
model_param['postProcess'],
model_param['sort_tracker'],
model_param['trackPar'])
except ServiceException as s:
raise s
except Exception:
# self.num += 1
# cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame)
logger.error("算法模型分析异常: {}, requestId: {}", format_exc(), request_id)
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
# 船只模型
class ShipModel:
__slots__ = "model_conf"
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
s = time.time()
try:
logger.info("########################加载船只模型########################, requestId:{}", requestId)
par = modeType.value[4](str(device), gpu_name)
obbModelPar = par['obbModelPar']
model, decoder2 = load_model_decoder_OBB(obbModelPar)
obbModelPar['decoder'] = decoder2
names = par['labelnames']
rainbows = par['postFile']["rainbows"]
trackPar = par['trackPar']
sort_tracker = OBB_Sort(max_age=trackPar['sort_max_age'], min_hits=trackPar['sort_min_hits'],
iou_threshold=trackPar['sort_iou_thresh'])
modelPar = {'obbmodel': model}
segPar = None
model_param = {
"modelPar": modelPar,
"obbModelPar": obbModelPar,
"sort_tracker": sort_tracker,
"trackPar": trackPar,
"segPar": segPar
}
self.model_conf = (modeType, model_param, allowedList, names, rainbows)
except Exception:
logger.exception("模型加载异常:{}, requestId:{}", format_exc(), requestId)
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
logger.info("模型初始化时间:{}, requestId:{}", time.time() - s, requestId)
def obb_process(args):
imgarray_list, iframe_list, model_param, request_id = args
try:
return OBB_tracker_batch(imgarray_list, iframe_list, model_param['modelPar'], model_param['obbModelPar'],
model_param['sort_tracker'], model_param['trackPar'], model_param['segPar'])
except ServiceException as s:
raise s
except Exception:
# self.num += 1
# cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame)
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
# 车牌分割模型、健康码、行程码分割模型
class IMModel:
__slots__ = "model_conf"
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
try:
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
requestId)
img_type = 'code'
if ModelType2.PLATE_MODEL == modeType:
img_type = 'plate'
par = {
'code': {'weights': '../AIlib2/weights/conf/jkm/health_yolov5s_v3.jit', 'img_type': 'code', 'nc': 10},
'plate': {'weights': '../AIlib2/weights/conf/jkm/plate_yolov5s_v3.jit', 'img_type': 'plate', 'nc': 1},
'conf_thres': 0.4,
'iou_thres': 0.45,
'device': 'cuda:%s' % device,
'plate_dilate': (0.5, 0.3)
}
new_device = torch.device(par['device'])
model = torch.jit.load(par[img_type]['weights'])
model_param = {
"device": new_device,
"model": model,
"par": par,
"img_type": img_type
}
self.model_conf = (modeType, model_param, allowedList)
except Exception:
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
def im_process(args):
model_param, frame, request_id = args
device, par, img_type = model_param['device'], model_param['par'], model_param['img_type']
try:
img, padInfos = pre_process(frame, device)
pred = model_param['model'](img)
boxes = post_process(pred, padInfos, device, conf_thres=par['conf_thres'],
iou_thres=par['iou_thres'], nc=par[img_type]['nc']) # 后处理
dataBack = get_return_data(frame, boxes, modelType=img_type, plate_dilate=par['plate_dilate'])
return dataBack
except ServiceException as s:
raise s
except Exception:
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
# 百度AI图片识别模型
class BaiduAiImageModel:
__slots__ = "model_conf"
def __init__(self, device=None, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None,
env=None):
try:
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
requestId)
aipBodyAnalysisClient = AipBodyAnalysisClient(base_dir, env)
aipImageClassifyClient = AipImageClassifyClient(base_dir, env)
rainbows = COLOR
vehicle_names = [VehicleEnum.CAR.value[1], VehicleEnum.TRICYCLE.value[1], VehicleEnum.MOTORBIKE.value[1],
VehicleEnum.CARPLATE.value[1], VehicleEnum.TRUCK.value[1], VehicleEnum.BUS.value[1]]
person_names = ['']
model_param = {
"vehicle_client": aipImageClassifyClient,
"person_client": aipBodyAnalysisClient,
}
self.model_conf = (modeType, model_param, allowedList, (vehicle_names, person_names), rainbows)
except Exception:
logger.exception("模型加载异常:{}, requestId:{}", format_exc(), requestId)
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
def baidu_process(args):
model_param, target, url, request_id = args
try:
baiduEnum = BAIDU_MODEL_TARGET_CONFIG2.get(target)
if baiduEnum is None:
raise ServiceException(ExceptionType.DETECTION_TARGET_TYPES_ARE_NOT_SUPPORTED.value[0],
ExceptionType.DETECTION_TARGET_TYPES_ARE_NOT_SUPPORTED.value[1]
+ " target: " + target)
return baiduEnum.value[2](model_param['vehicle_client'], model_param['person_client'], url, request_id)
except ServiceException as s:
raise s
except Exception:
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
def get_baidu_label_arraylist(*args):
width, height, vehicle_names, person_names, rainbows = args
# line = int(round(0.002 * (height + width) / 2) + 1)
line = max(1, int(round(width / 1920 * 3) + 1))
label = ' 0.97'
tf = max(line, 1)
fontScale = line * 0.33
text_width, text_height = cv2.getTextSize(label, 0, fontScale=fontScale, thickness=tf)[0]
vehicle_label_arrays = get_label_arrays(vehicle_names, rainbows, fontSize=text_height, fontPath=FONT_PATH)
person_label_arrays = get_label_arrays(person_names, rainbows, fontSize=text_height, fontPath=FONT_PATH)
font_config = (line, text_width, text_height, fontScale, tf)
return vehicle_label_arrays, person_label_arrays, font_config
def one_label(width, height, model_config):
# (modeType, model_param, allowedList, names, rainbows)
names = model_config[3]
rainbows = model_config[4]
label_arraylist, font_config = get_label_arraylist(width, height, names, rainbows)
model_config[1]['label_arraylist'] = label_arraylist
model_config[1]['font_config'] = font_config
def baidu_label(width, height, model_config):
# modeType, model_param, allowedList, (vehicle_names, person_names), rainbows
vehicle_names = model_config[3][0]
person_names = model_config[3][1]
rainbows = model_config[4]
vehicle_label_arrays, person_label_arrays, font_config = get_baidu_label_arraylist(width, height, vehicle_names,
person_names, rainbows)
model_config[1]['vehicle_label_arrays'] = vehicle_label_arrays
model_config[1]['person_label_arrays'] = person_label_arrays
model_config[1]['font_config'] = font_config
def model_process1(args):
imgarray_list, iframe_list, model_param, request_id = args
model_conf, frame, request_id = args
model_param, names, rainbows = model_conf[1], model_conf[3], model_conf[4]
# modeType, model_param, allowedList, names, rainbows = model_conf
# segmodel, names, label_arraylist, rainbows, objectPar, font, segPar, mode, postPar, requestId = args
# model_param['digitFont'] = digitFont
# model_param['label_arraylist'] = label_arraylist
# model_param['font_config'] = font_config
try:
return AI_process([frame], model_param['model'], model_param['segmodel'], names, model_param['label_arraylist'],
rainbows, objectPar=model_param['objectPar'], font=model_param['digitFont'],
segPar=loads(dumps(model_param['segPar'])), mode=model_param['mode'],
postPar=model_param['postPar'])
except ServiceException as s:
raise s
except Exception:
# self.num += 1
# cv2.imwrite('/home/th/tuo_heng/dev/img%s.jpg' % str(self.num), frame)
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
MODEL_CONFIG2 = {
# 加载河道模型
ModelType2.WATER_SURFACE_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.WATER_SURFACE_MODEL, t, z, h),
ModelType2.WATER_SURFACE_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)
),
# 加载森林模型
ModelType2.FOREST_FARM_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.FOREST_FARM_MODEL, t, z, h),
ModelType2.FOREST_FARM_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)
),
# 加载交通模型
ModelType2.TRAFFIC_FARM_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.TRAFFIC_FARM_MODEL, t, z, h),
ModelType2.TRAFFIC_FARM_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)
),
# 加载防疫模型
ModelType2.EPIDEMIC_PREVENTION_MODEL.value[1]: (
lambda x, y, r, t, z, h: IMModel(x, y, r, ModelType2.EPIDEMIC_PREVENTION_MODEL, t, z, h),
ModelType2.EPIDEMIC_PREVENTION_MODEL,
None,
lambda x: im_process(x)),
# 加载车牌模型
ModelType2.PLATE_MODEL.value[1]: (
lambda x, y, r, t, z, h: IMModel(x, y, r, ModelType2.PLATE_MODEL, t, z, h),
ModelType2.PLATE_MODEL,
None,
lambda x: im_process(x)),
# 加载车辆模型
ModelType2.VEHICLE_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.VEHICLE_MODEL, t, z, h),
ModelType2.VEHICLE_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)
),
# 加载行人模型
ModelType2.PEDESTRIAN_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.PEDESTRIAN_MODEL, t, z, h),
ModelType2.PEDESTRIAN_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)),
# 加载烟火模型
ModelType2.SMOGFIRE_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.SMOGFIRE_MODEL, t, z, h),
ModelType2.SMOGFIRE_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)),
# 加载钓鱼游泳模型
ModelType2.ANGLERSWIMMER_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.ANGLERSWIMMER_MODEL, t, z, h),
ModelType2.ANGLERSWIMMER_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)),
# 加载乡村模型
ModelType2.COUNTRYROAD_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.COUNTRYROAD_MODEL, t, z, h),
ModelType2.COUNTRYROAD_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)),
# 加载船只模型
ModelType2.SHIP_MODEL.value[1]: (
lambda x, y, r, t, z, h: ShipModel(x, y, r, ModelType2.SHIP_MODEL, t, z, h),
ModelType2.SHIP_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: obb_process(x)),
# 百度AI图片识别模型
ModelType2.BAIDU_MODEL.value[1]: (
lambda x, y, r, t, z, h: BaiduAiImageModel(x, y, r, ModelType2.BAIDU_MODEL, t, z, h),
ModelType2.BAIDU_MODEL,
lambda x, y, z: baidu_label(x, y, z),
lambda x: baidu_process(x)),
# 航道模型
ModelType2.CHANNEL_EMERGENCY_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.CHANNEL_EMERGENCY_MODEL, t, z, h),
ModelType2.CHANNEL_EMERGENCY_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)),
# 河道检测模型
ModelType2.RIVER2_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.RIVER2_MODEL, t, z, h),
ModelType2.RIVER2_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)),
# 城管模型
ModelType2.CITY_MANGEMENT_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.CITY_MANGEMENT_MODEL, t, z, h),
ModelType2.CITY_MANGEMENT_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)
),
# 人员落水模型
ModelType2.DROWING_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.DROWING_MODEL, t, z, h),
ModelType2.DROWING_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)
),
# 城市违章模型
ModelType2.NOPARKING_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.NOPARKING_MODEL, t, z, h),
ModelType2.NOPARKING_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)
),
# 城市公路模型
ModelType2.CITYROAD_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.CITYROAD_MODEL, t, z, h),
ModelType2.CITYROAD_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)
),
# 加载坑槽模型
ModelType2.POTHOLE_MODEL.value[1]: (
lambda x, y, r, t, z, h: Model(x, y, r, ModelType2.POTHOLE_MODEL, t, z, h),
ModelType2.POTHOLE_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)
),
}