diff --git a/util/ModelUtils.py b/util/ModelUtils.py index 9154274..a5e6c9f 100644 --- a/util/ModelUtils.py +++ b/util/ModelUtils.py @@ -17,7 +17,7 @@ from util.PlotsUtils import get_label_arrays, get_label_array_dict from util.TorchUtils import select_device sys.path.extend(['..', '../AIlib2']) -from AI import AI_process, AI_process_forest, get_postProcess_para, ocr_process, AI_process_N, AI_process_C +from AI import AI_process, AI_process_forest, get_postProcess_para, ocr_process, AI_process_N, AI_process_C,AI_process_Ocr,AI_process_Crowd from stdc import stdcModel from segutils.segmodel import SegModel from models.experimental import attempt_load @@ -241,7 +241,7 @@ def channel2_process(args): logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id) raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0], ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1]) - + def get_label_arraylist(*args): width, height, names, rainbows = args # line = int(round(0.002 * (height + width) / 2) + 1) @@ -354,6 +354,73 @@ def im_process(args): raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0], ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1]) +def immulti_process(args): + model_conf, frame, requestId = args + device, modelList, detpar = model_conf[1], model_conf[2], model_conf[3] + try: + # new_device = torch.device(device) + # img, padInfos = pre_process(frame, new_device) + # pred = model(img) + # boxes = post_process(pred, padInfos, device, conf_thres=pardet['conf_thres'], + # iou_thres=pardet['iou_thres'], nc=pardet['nc']) # 后处理 + return AI_process_Ocr([frame], modelList, device, detpar) + except ServiceException as s: + raise s + except Exception: + logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), requestId) + raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0], + ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1]) + +class CARPLATEModel: + __slots__ = "model_conf" + def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, + env=None): + try: + logger.info("########################加载{}########################, requestId:{}", modeType.value[2], + requestId) + par = modeType.value[4](str(device), gpu_name) + modelList=[ modelPar['model'](weights=modelPar['weight'],par=modelPar['par']) for modelPar in par['models'] ] + detpar = par['models'][0]['par'] + # new_device = torch.device(par['device']) + # modelList=[ modelPar['model'](weights=modelPar['weight'],par=modelPar['par']) for modelPar in par['models'] ] + logger.info("########################加载 plate_yolov5s_v3.jit 成功 ########################, requestId:{}", + requestId) + self.model_conf = (modeType, device, modelList, detpar, par['rainbows']) + except Exception: + logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId) + raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0], + ExceptionType.MODEL_LOADING_EXCEPTION.value[1]) + + +class DENSECROWDCOUNTModel: + __slots__ = "model_conf" + + def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None): + try: + logger.info("########################加载{}########################, requestId:{}", modeType.value[2], + requestId) + par = modeType.value[4](str(device), gpu_name) + rainbows = par["rainbows"] + models=[ modelPar['model'](weights=modelPar['weight'],par=modelPar['par']) for modelPar in par['models'] ] + postPar = par['models'][0]['par'] + self.model_conf = (modeType, device, models[0], postPar, rainbows) + except Exception: + logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId) + raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0], + ExceptionType.MODEL_LOADING_EXCEPTION.value[1]) + +def cc_process(args): + model_conf, frame, requestId = args + device, model, postPar = model_conf[1], model_conf[2], model_conf[3] + try: + return AI_process_Crowd([frame], model, device, postPar) + except ServiceException as s: + raise s + except Exception: + logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), requestId) + raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0], + ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1]) + # 百度AI图片识别模型 class BaiduAiImageModel: @@ -470,7 +537,7 @@ MODEL_CONFIG = { lambda x, y, z: one_label(x, y, z), lambda x: detSeg_demo2(x) ), - + # 加载交通模型 ModelType.TRAFFIC_FARM_MODEL.value[1]: ( lambda x, y, r, t, z, h: OneModel(x, y, r, ModelType.TRAFFIC_FARM_MODEL, t, z, h), @@ -613,22 +680,22 @@ MODEL_CONFIG = { lambda x, y, z: one_label(x, y, z), lambda x: model_process(x) ), - # 加载智慧工地模型 + # 加载智慧工地模型 ModelType.SMARTSITE_MODEL.value[1]: ( lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.SMARTSITE_MODEL, t, z, h), ModelType.SMARTSITE_MODEL, lambda x, y, z: one_label(x, y, z), lambda x: detSeg_demo2(x) ), - - # 加载垃圾模型 + + # 加载垃圾模型 ModelType.RUBBISH_MODEL.value[1]: ( lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.RUBBISH_MODEL, t, z, h), ModelType.RUBBISH_MODEL, lambda x, y, z: one_label(x, y, z), lambda x: detSeg_demo2(x) ), - + # 加载烟花模型 ModelType.FIREWORK_MODEL.value[1]: ( lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.FIREWORK_MODEL, t, z, h), @@ -657,6 +724,13 @@ MODEL_CONFIG = { lambda x, y, z: one_label(x, y, z), lambda x: detSeg_demo2(x) ), + # 加载自研车牌检测模型 + ModelType.CITY_CARPLATE_MODEL.value[1]: ( + lambda x, y, r, t, z, h: CARPLATEModel(x, y, r, ModelType.CITY_CARPLATE_MODEL, t, z, h), + ModelType.CITY_CARPLATE_MODEL, + None, + lambda x: immulti_process(x) + ), # 加载红外行人检测模型 ModelType.CITY_INFRAREDPERSON_MODEL.value[1]: ( lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.CITY_INFRAREDPERSON_MODEL, t, z, h), @@ -670,5 +744,12 @@ MODEL_CONFIG = { ModelType.CITY_NIGHTFIRESMOKE_MODEL, lambda x, y, z: one_label(x, y, z), lambda x: detSeg_demo2(x) - ), + ), + # 加载密集人群计数检测模型 + ModelType.CITY_DENSECROWDCOUNT_MODEL.value[1]: ( + lambda x, y, r, t, z, h: DENSECROWDCOUNTModel(x, y, r, ModelType.CITY_DENSECROWDCOUNT_MODEL, t, z, h), + ModelType.CITY_DENSECROWDCOUNT_MODEL, + None, + lambda x: cc_process(x) + ), }