更新 util/ModelUtils.py

This commit is contained in:
zhoushuliang 2025-07-10 17:24:54 +08:00
parent 919d15ec5f
commit 5cc22405a4
1 changed files with 89 additions and 8 deletions

View File

@ -17,7 +17,7 @@ from util.PlotsUtils import get_label_arrays, get_label_array_dict
from util.TorchUtils import select_device
sys.path.extend(['..', '../AIlib2'])
from AI import AI_process, AI_process_forest, get_postProcess_para, ocr_process, AI_process_N, AI_process_C
from AI import AI_process, AI_process_forest, get_postProcess_para, ocr_process, AI_process_N, AI_process_C,AI_process_Ocr,AI_process_Crowd
from stdc import stdcModel
from segutils.segmodel import SegModel
from models.experimental import attempt_load
@ -241,7 +241,7 @@ def channel2_process(args):
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
def get_label_arraylist(*args):
width, height, names, rainbows = args
# line = int(round(0.002 * (height + width) / 2) + 1)
@ -354,6 +354,73 @@ def im_process(args):
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
def immulti_process(args):
model_conf, frame, requestId = args
device, modelList, detpar = model_conf[1], model_conf[2], model_conf[3]
try:
# new_device = torch.device(device)
# img, padInfos = pre_process(frame, new_device)
# pred = model(img)
# boxes = post_process(pred, padInfos, device, conf_thres=pardet['conf_thres'],
# iou_thres=pardet['iou_thres'], nc=pardet['nc']) # 后处理
return AI_process_Ocr([frame], modelList, device, detpar)
except ServiceException as s:
raise s
except Exception:
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), requestId)
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
class CARPLATEModel:
__slots__ = "model_conf"
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None,
env=None):
try:
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
requestId)
par = modeType.value[4](str(device), gpu_name)
modelList=[ modelPar['model'](weights=modelPar['weight'],par=modelPar['par']) for modelPar in par['models'] ]
detpar = par['models'][0]['par']
# new_device = torch.device(par['device'])
# modelList=[ modelPar['model'](weights=modelPar['weight'],par=modelPar['par']) for modelPar in par['models'] ]
logger.info("########################加载 plate_yolov5s_v3.jit 成功 ########################, requestId:{}",
requestId)
self.model_conf = (modeType, device, modelList, detpar, par['rainbows'])
except Exception:
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
class DENSECROWDCOUNTModel:
__slots__ = "model_conf"
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
try:
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
requestId)
par = modeType.value[4](str(device), gpu_name)
rainbows = par["rainbows"]
models=[ modelPar['model'](weights=modelPar['weight'],par=modelPar['par']) for modelPar in par['models'] ]
postPar = par['models'][0]['par']
self.model_conf = (modeType, device, models[0], postPar, rainbows)
except Exception:
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
def cc_process(args):
model_conf, frame, requestId = args
device, model, postPar = model_conf[1], model_conf[2], model_conf[3]
try:
return AI_process_Crowd([frame], model, device, postPar)
except ServiceException as s:
raise s
except Exception:
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), requestId)
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
# 百度AI图片识别模型
class BaiduAiImageModel:
@ -470,7 +537,7 @@ MODEL_CONFIG = {
lambda x, y, z: one_label(x, y, z),
lambda x: detSeg_demo2(x)
),
# 加载交通模型
ModelType.TRAFFIC_FARM_MODEL.value[1]: (
lambda x, y, r, t, z, h: OneModel(x, y, r, ModelType.TRAFFIC_FARM_MODEL, t, z, h),
@ -613,22 +680,22 @@ MODEL_CONFIG = {
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)
),
# 加载智慧工地模型
# 加载智慧工地模型
ModelType.SMARTSITE_MODEL.value[1]: (
lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.SMARTSITE_MODEL, t, z, h),
ModelType.SMARTSITE_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: detSeg_demo2(x)
),
# 加载垃圾模型
# 加载垃圾模型
ModelType.RUBBISH_MODEL.value[1]: (
lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.RUBBISH_MODEL, t, z, h),
ModelType.RUBBISH_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: detSeg_demo2(x)
),
# 加载烟花模型
ModelType.FIREWORK_MODEL.value[1]: (
lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.FIREWORK_MODEL, t, z, h),
@ -657,6 +724,13 @@ MODEL_CONFIG = {
lambda x, y, z: one_label(x, y, z),
lambda x: detSeg_demo2(x)
),
# 加载自研车牌检测模型
ModelType.CITY_CARPLATE_MODEL.value[1]: (
lambda x, y, r, t, z, h: CARPLATEModel(x, y, r, ModelType.CITY_CARPLATE_MODEL, t, z, h),
ModelType.CITY_CARPLATE_MODEL,
None,
lambda x: immulti_process(x)
),
# 加载红外行人检测模型
ModelType.CITY_INFRAREDPERSON_MODEL.value[1]: (
lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.CITY_INFRAREDPERSON_MODEL, t, z, h),
@ -670,5 +744,12 @@ MODEL_CONFIG = {
ModelType.CITY_NIGHTFIRESMOKE_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: detSeg_demo2(x)
),
),
# 加载密集人群计数检测模型
ModelType.CITY_DENSECROWDCOUNT_MODEL.value[1]: (
lambda x, y, r, t, z, h: DENSECROWDCOUNTModel(x, y, r, ModelType.CITY_DENSECROWDCOUNT_MODEL, t, z, h),
ModelType.CITY_DENSECROWDCOUNT_MODEL,
None,
lambda x: cc_process(x)
),
}