更新 util/ModelUtils.py
This commit is contained in:
parent
919d15ec5f
commit
5cc22405a4
|
|
@ -17,7 +17,7 @@ from util.PlotsUtils import get_label_arrays, get_label_array_dict
|
||||||
from util.TorchUtils import select_device
|
from util.TorchUtils import select_device
|
||||||
|
|
||||||
sys.path.extend(['..', '../AIlib2'])
|
sys.path.extend(['..', '../AIlib2'])
|
||||||
from AI import AI_process, AI_process_forest, get_postProcess_para, ocr_process, AI_process_N, AI_process_C
|
from AI import AI_process, AI_process_forest, get_postProcess_para, ocr_process, AI_process_N, AI_process_C,AI_process_Ocr,AI_process_Crowd
|
||||||
from stdc import stdcModel
|
from stdc import stdcModel
|
||||||
from segutils.segmodel import SegModel
|
from segutils.segmodel import SegModel
|
||||||
from models.experimental import attempt_load
|
from models.experimental import attempt_load
|
||||||
|
|
@ -241,7 +241,7 @@ def channel2_process(args):
|
||||||
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
|
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), request_id)
|
||||||
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
|
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
|
||||||
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
|
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
|
||||||
|
|
||||||
def get_label_arraylist(*args):
|
def get_label_arraylist(*args):
|
||||||
width, height, names, rainbows = args
|
width, height, names, rainbows = args
|
||||||
# line = int(round(0.002 * (height + width) / 2) + 1)
|
# line = int(round(0.002 * (height + width) / 2) + 1)
|
||||||
|
|
@ -354,6 +354,73 @@ def im_process(args):
|
||||||
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
|
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
|
||||||
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
|
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
|
||||||
|
|
||||||
|
def immulti_process(args):
|
||||||
|
model_conf, frame, requestId = args
|
||||||
|
device, modelList, detpar = model_conf[1], model_conf[2], model_conf[3]
|
||||||
|
try:
|
||||||
|
# new_device = torch.device(device)
|
||||||
|
# img, padInfos = pre_process(frame, new_device)
|
||||||
|
# pred = model(img)
|
||||||
|
# boxes = post_process(pred, padInfos, device, conf_thres=pardet['conf_thres'],
|
||||||
|
# iou_thres=pardet['iou_thres'], nc=pardet['nc']) # 后处理
|
||||||
|
return AI_process_Ocr([frame], modelList, device, detpar)
|
||||||
|
except ServiceException as s:
|
||||||
|
raise s
|
||||||
|
except Exception:
|
||||||
|
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), requestId)
|
||||||
|
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
|
||||||
|
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
|
||||||
|
|
||||||
|
class CARPLATEModel:
|
||||||
|
__slots__ = "model_conf"
|
||||||
|
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None,
|
||||||
|
env=None):
|
||||||
|
try:
|
||||||
|
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
|
||||||
|
requestId)
|
||||||
|
par = modeType.value[4](str(device), gpu_name)
|
||||||
|
modelList=[ modelPar['model'](weights=modelPar['weight'],par=modelPar['par']) for modelPar in par['models'] ]
|
||||||
|
detpar = par['models'][0]['par']
|
||||||
|
# new_device = torch.device(par['device'])
|
||||||
|
# modelList=[ modelPar['model'](weights=modelPar['weight'],par=modelPar['par']) for modelPar in par['models'] ]
|
||||||
|
logger.info("########################加载 plate_yolov5s_v3.jit 成功 ########################, requestId:{}",
|
||||||
|
requestId)
|
||||||
|
self.model_conf = (modeType, device, modelList, detpar, par['rainbows'])
|
||||||
|
except Exception:
|
||||||
|
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
|
||||||
|
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
|
||||||
|
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
|
||||||
|
|
||||||
|
|
||||||
|
class DENSECROWDCOUNTModel:
|
||||||
|
__slots__ = "model_conf"
|
||||||
|
|
||||||
|
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
|
||||||
|
try:
|
||||||
|
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
|
||||||
|
requestId)
|
||||||
|
par = modeType.value[4](str(device), gpu_name)
|
||||||
|
rainbows = par["rainbows"]
|
||||||
|
models=[ modelPar['model'](weights=modelPar['weight'],par=modelPar['par']) for modelPar in par['models'] ]
|
||||||
|
postPar = par['models'][0]['par']
|
||||||
|
self.model_conf = (modeType, device, models[0], postPar, rainbows)
|
||||||
|
except Exception:
|
||||||
|
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
|
||||||
|
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
|
||||||
|
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
|
||||||
|
|
||||||
|
def cc_process(args):
|
||||||
|
model_conf, frame, requestId = args
|
||||||
|
device, model, postPar = model_conf[1], model_conf[2], model_conf[3]
|
||||||
|
try:
|
||||||
|
return AI_process_Crowd([frame], model, device, postPar)
|
||||||
|
except ServiceException as s:
|
||||||
|
raise s
|
||||||
|
except Exception:
|
||||||
|
logger.error("算法模型分析异常:{}, requestId:{}", format_exc(), requestId)
|
||||||
|
raise ServiceException(ExceptionType.MODEL_ANALYSE_EXCEPTION.value[0],
|
||||||
|
ExceptionType.MODEL_ANALYSE_EXCEPTION.value[1])
|
||||||
|
|
||||||
|
|
||||||
# 百度AI图片识别模型
|
# 百度AI图片识别模型
|
||||||
class BaiduAiImageModel:
|
class BaiduAiImageModel:
|
||||||
|
|
@ -470,7 +537,7 @@ MODEL_CONFIG = {
|
||||||
lambda x, y, z: one_label(x, y, z),
|
lambda x, y, z: one_label(x, y, z),
|
||||||
lambda x: detSeg_demo2(x)
|
lambda x: detSeg_demo2(x)
|
||||||
),
|
),
|
||||||
|
|
||||||
# 加载交通模型
|
# 加载交通模型
|
||||||
ModelType.TRAFFIC_FARM_MODEL.value[1]: (
|
ModelType.TRAFFIC_FARM_MODEL.value[1]: (
|
||||||
lambda x, y, r, t, z, h: OneModel(x, y, r, ModelType.TRAFFIC_FARM_MODEL, t, z, h),
|
lambda x, y, r, t, z, h: OneModel(x, y, r, ModelType.TRAFFIC_FARM_MODEL, t, z, h),
|
||||||
|
|
@ -613,22 +680,22 @@ MODEL_CONFIG = {
|
||||||
lambda x, y, z: one_label(x, y, z),
|
lambda x, y, z: one_label(x, y, z),
|
||||||
lambda x: model_process(x)
|
lambda x: model_process(x)
|
||||||
),
|
),
|
||||||
# 加载智慧工地模型
|
# 加载智慧工地模型
|
||||||
ModelType.SMARTSITE_MODEL.value[1]: (
|
ModelType.SMARTSITE_MODEL.value[1]: (
|
||||||
lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.SMARTSITE_MODEL, t, z, h),
|
lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.SMARTSITE_MODEL, t, z, h),
|
||||||
ModelType.SMARTSITE_MODEL,
|
ModelType.SMARTSITE_MODEL,
|
||||||
lambda x, y, z: one_label(x, y, z),
|
lambda x, y, z: one_label(x, y, z),
|
||||||
lambda x: detSeg_demo2(x)
|
lambda x: detSeg_demo2(x)
|
||||||
),
|
),
|
||||||
|
|
||||||
# 加载垃圾模型
|
# 加载垃圾模型
|
||||||
ModelType.RUBBISH_MODEL.value[1]: (
|
ModelType.RUBBISH_MODEL.value[1]: (
|
||||||
lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.RUBBISH_MODEL, t, z, h),
|
lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.RUBBISH_MODEL, t, z, h),
|
||||||
ModelType.RUBBISH_MODEL,
|
ModelType.RUBBISH_MODEL,
|
||||||
lambda x, y, z: one_label(x, y, z),
|
lambda x, y, z: one_label(x, y, z),
|
||||||
lambda x: detSeg_demo2(x)
|
lambda x: detSeg_demo2(x)
|
||||||
),
|
),
|
||||||
|
|
||||||
# 加载烟花模型
|
# 加载烟花模型
|
||||||
ModelType.FIREWORK_MODEL.value[1]: (
|
ModelType.FIREWORK_MODEL.value[1]: (
|
||||||
lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.FIREWORK_MODEL, t, z, h),
|
lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.FIREWORK_MODEL, t, z, h),
|
||||||
|
|
@ -657,6 +724,13 @@ MODEL_CONFIG = {
|
||||||
lambda x, y, z: one_label(x, y, z),
|
lambda x, y, z: one_label(x, y, z),
|
||||||
lambda x: detSeg_demo2(x)
|
lambda x: detSeg_demo2(x)
|
||||||
),
|
),
|
||||||
|
# 加载自研车牌检测模型
|
||||||
|
ModelType.CITY_CARPLATE_MODEL.value[1]: (
|
||||||
|
lambda x, y, r, t, z, h: CARPLATEModel(x, y, r, ModelType.CITY_CARPLATE_MODEL, t, z, h),
|
||||||
|
ModelType.CITY_CARPLATE_MODEL,
|
||||||
|
None,
|
||||||
|
lambda x: immulti_process(x)
|
||||||
|
),
|
||||||
# 加载红外行人检测模型
|
# 加载红外行人检测模型
|
||||||
ModelType.CITY_INFRAREDPERSON_MODEL.value[1]: (
|
ModelType.CITY_INFRAREDPERSON_MODEL.value[1]: (
|
||||||
lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.CITY_INFRAREDPERSON_MODEL, t, z, h),
|
lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.CITY_INFRAREDPERSON_MODEL, t, z, h),
|
||||||
|
|
@ -670,5 +744,12 @@ MODEL_CONFIG = {
|
||||||
ModelType.CITY_NIGHTFIRESMOKE_MODEL,
|
ModelType.CITY_NIGHTFIRESMOKE_MODEL,
|
||||||
lambda x, y, z: one_label(x, y, z),
|
lambda x, y, z: one_label(x, y, z),
|
||||||
lambda x: detSeg_demo2(x)
|
lambda x: detSeg_demo2(x)
|
||||||
),
|
),
|
||||||
|
# 加载密集人群计数检测模型
|
||||||
|
ModelType.CITY_DENSECROWDCOUNT_MODEL.value[1]: (
|
||||||
|
lambda x, y, r, t, z, h: DENSECROWDCOUNTModel(x, y, r, ModelType.CITY_DENSECROWDCOUNT_MODEL, t, z, h),
|
||||||
|
ModelType.CITY_DENSECROWDCOUNT_MODEL,
|
||||||
|
None,
|
||||||
|
lambda x: cc_process(x)
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue