更新 util/ModelUtils.py
This commit is contained in:
parent
919d15ec5f
commit
5cc22405a4
|
|
@ -17,7 +17,7 @@ from util.PlotsUtils import get_label_arrays, get_label_array_dict
|
|||
from util.TorchUtils import select_device
|
||||
|
||||
sys.path.extend(['..', '../AIlib2'])
|
||||
from AI import AI_process, AI_process_forest, get_postProcess_para, ocr_process, AI_process_N, AI_process_C
|
||||
from AI import AI_process, AI_process_forest, get_postProcess_para, ocr_process, AI_process_N, AI_process_C,AI_process_Ocr,AI_process_Crowd
|
||||
from stdc import stdcModel
|
||||
from segutils.segmodel import SegModel
|
||||
from models.experimental import attempt_load
|
||||
|
|
@ -241,7 +241,7 @@ def channel2_process(args):
|
|||
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
|
||||
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
|
||||
|
||||
|
||||
def get_label_arraylist(*args):
|
||||
width, height, names, rainbows = args
|
||||
# line = int(round(0.002 * (height + width) / 2) + 1)
|
||||
|
|
@ -354,6 +354,73 @@ def im_process(args):
|
|||
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
|
||||
|
||||
def immulti_process(args):
|
||||
model_conf, frame, requestId = args
|
||||
device, modelList, detpar = model_conf[1], model_conf[2], model_conf[3]
|
||||
try:
|
||||
# new_device = torch.device(device)
|
||||
# img, padInfos = pre_process(frame, new_device)
|
||||
# pred = model(img)
|
||||
# boxes = post_process(pred, padInfos, device, conf_thres=pardet['conf_thres'],
|
||||
# iou_thres=pardet['iou_thres'], nc=pardet['nc']) # 后处理
|
||||
return AI_process_Ocr([frame], modelList, device, detpar)
|
||||
except ServiceException as s:
|
||||
raise s
|
||||
except Exception:
|
||||
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), requestId)
|
||||
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
|
||||
|
||||
class CARPLATEModel:
|
||||
__slots__ = "model_conf"
|
||||
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None,
|
||||
env=None):
|
||||
try:
|
||||
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
|
||||
requestId)
|
||||
par = modeType.value[4](str(device), gpu_name)
|
||||
modelList=[ modelPar['model'](weights=modelPar['weight'],par=modelPar['par']) for modelPar in par['models'] ]
|
||||
detpar = par['models'][0]['par']
|
||||
# new_device = torch.device(par['device'])
|
||||
# modelList=[ modelPar['model'](weights=modelPar['weight'],par=modelPar['par']) for modelPar in par['models'] ]
|
||||
logger.info("########################加载 plate_yolov5s_v3.jit 成功 ########################, requestId:{}",
|
||||
requestId)
|
||||
self.model_conf = (modeType, device, modelList, detpar, par['rainbows'])
|
||||
except Exception:
|
||||
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
|
||||
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
|
||||
|
||||
|
||||
class DENSECROWDCOUNTModel:
|
||||
__slots__ = "model_conf"
|
||||
|
||||
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
|
||||
try:
|
||||
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
|
||||
requestId)
|
||||
par = modeType.value[4](str(device), gpu_name)
|
||||
rainbows = par["rainbows"]
|
||||
models=[ modelPar['model'](weights=modelPar['weight'],par=modelPar['par']) for modelPar in par['models'] ]
|
||||
postPar = par['models'][0]['par']
|
||||
self.model_conf = (modeType, device, models[0], postPar, rainbows)
|
||||
except Exception:
|
||||
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
|
||||
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
|
||||
|
||||
def cc_process(args):
|
||||
model_conf, frame, requestId = args
|
||||
device, model, postPar = model_conf[1], model_conf[2], model_conf[3]
|
||||
try:
|
||||
return AI_process_Crowd([frame], model, device, postPar)
|
||||
except ServiceException as s:
|
||||
raise s
|
||||
except Exception:
|
||||
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), requestId)
|
||||
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
|
||||
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
|
||||
|
||||
|
||||
# 百度AI图片识别模型
|
||||
class BaiduAiImageModel:
|
||||
|
|
@ -470,7 +537,7 @@ MODEL_CONFIG = {
|
|||
lambda x, y, z: one_label(x, y, z),
|
||||
lambda x: detSeg_demo2(x)
|
||||
),
|
||||
|
||||
|
||||
# 加载交通模型
|
||||
ModelType.TRAFFIC_FARM_MODEL.value[1]: (
|
||||
lambda x, y, r, t, z, h: OneModel(x, y, r, ModelType.TRAFFIC_FARM_MODEL, t, z, h),
|
||||
|
|
@ -613,22 +680,22 @@ MODEL_CONFIG = {
|
|||
lambda x, y, z: one_label(x, y, z),
|
||||
lambda x: model_process(x)
|
||||
),
|
||||
# 加载智慧工地模型
|
||||
# 加载智慧工地模型
|
||||
ModelType.SMARTSITE_MODEL.value[1]: (
|
||||
lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.SMARTSITE_MODEL, t, z, h),
|
||||
ModelType.SMARTSITE_MODEL,
|
||||
lambda x, y, z: one_label(x, y, z),
|
||||
lambda x: detSeg_demo2(x)
|
||||
),
|
||||
|
||||
# 加载垃圾模型
|
||||
|
||||
# 加载垃圾模型
|
||||
ModelType.RUBBISH_MODEL.value[1]: (
|
||||
lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.RUBBISH_MODEL, t, z, h),
|
||||
ModelType.RUBBISH_MODEL,
|
||||
lambda x, y, z: one_label(x, y, z),
|
||||
lambda x: detSeg_demo2(x)
|
||||
),
|
||||
|
||||
|
||||
# 加载烟花模型
|
||||
ModelType.FIREWORK_MODEL.value[1]: (
|
||||
lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.FIREWORK_MODEL, t, z, h),
|
||||
|
|
@ -657,6 +724,13 @@ MODEL_CONFIG = {
|
|||
lambda x, y, z: one_label(x, y, z),
|
||||
lambda x: detSeg_demo2(x)
|
||||
),
|
||||
# 加载自研车牌检测模型
|
||||
ModelType.CITY_CARPLATE_MODEL.value[1]: (
|
||||
lambda x, y, r, t, z, h: CARPLATEModel(x, y, r, ModelType.CITY_CARPLATE_MODEL, t, z, h),
|
||||
ModelType.CITY_CARPLATE_MODEL,
|
||||
None,
|
||||
lambda x: immulti_process(x)
|
||||
),
|
||||
# 加载红外行人检测模型
|
||||
ModelType.CITY_INFRAREDPERSON_MODEL.value[1]: (
|
||||
lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.CITY_INFRAREDPERSON_MODEL, t, z, h),
|
||||
|
|
@ -670,5 +744,12 @@ MODEL_CONFIG = {
|
|||
ModelType.CITY_NIGHTFIRESMOKE_MODEL,
|
||||
lambda x, y, z: one_label(x, y, z),
|
||||
lambda x: detSeg_demo2(x)
|
||||
),
|
||||
),
|
||||
# 加载密集人群计数检测模型
|
||||
ModelType.CITY_DENSECROWDCOUNT_MODEL.value[1]: (
|
||||
lambda x, y, r, t, z, h: DENSECROWDCOUNTModel(x, y, r, ModelType.CITY_DENSECROWDCOUNT_MODEL, t, z, h),
|
||||
ModelType.CITY_DENSECROWDCOUNT_MODEL,
|
||||
None,
|
||||
lambda x: cc_process(x)
|
||||
),
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue