|
|
@@ -82,7 +82,59 @@ class OneModel: |
|
|
|
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0], |
|
|
|
ExceptionType.MODEL_LOADING_EXCEPTION.value[1]) |
|
|
|
|
|
|
|
# 城管模型专用,多出一个score_byClass参数 |
|
|
|
class cityManagementModel: |
|
|
|
__slots__ = "model_conf" |
|
|
|
|
|
|
|
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None): |
|
|
|
try: |
|
|
|
logger.info("########################加载{}########################, requestId:{}", modeType.value[2], |
|
|
|
requestId) |
|
|
|
par = modeType.value[4](str(device), gpu_name) |
|
|
|
mode, postPar, segPar = par.get('mode', 'others'), par.get('postPar'), par.get('segPar') |
|
|
|
names = par['labelnames'] |
|
|
|
postFile = par['postFile'] |
|
|
|
rainbows = postFile["rainbows"] |
|
|
|
new_device = select_device(par.get('device')) |
|
|
|
half = new_device.type != 'cpu' |
|
|
|
Detweights = par['Detweights'] |
|
|
|
with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime: |
|
|
|
model = runtime.deserialize_cuda_engine(f.read()) |
|
|
|
par['segPar']['seg_nclass'] = par['seg_nclass'] |
|
|
|
Segweights = par['Segweights'] |
|
|
|
if Segweights: |
|
|
|
if modeType.value[3] == 'cityMangement2': |
|
|
|
segmodel = DMPRModel(weights=Segweights, par=par['segPar']) |
|
|
|
else: |
|
|
|
segmodel = stdcModel(weights=Segweights, par=par['segPar']) |
|
|
|
else: |
|
|
|
segmodel = None |
|
|
|
objectPar = { |
|
|
|
'half': half, |
|
|
|
'device': new_device, |
|
|
|
'conf_thres': postFile["conf_thres"], |
|
|
|
'ovlap_thres_crossCategory': postFile.get("ovlap_thres_crossCategory"), |
|
|
|
'iou_thres': postFile["iou_thres"], |
|
|
|
'allowedList': [], |
|
|
|
'segRegionCnt': par['segRegionCnt'], |
|
|
|
'trtFlag_det': par['trtFlag_det'], |
|
|
|
'trtFlag_seg': par['trtFlag_seg'], |
|
|
|
'score_byClass': par['score_byClass'] |
|
|
|
} |
|
|
|
model_param = { |
|
|
|
"model": model, |
|
|
|
"segmodel": segmodel, |
|
|
|
"objectPar": objectPar, |
|
|
|
"segPar": segPar, |
|
|
|
"mode": mode, |
|
|
|
"postPar": postPar |
|
|
|
} |
|
|
|
self.model_conf = (modeType, model_param, allowedList, names, rainbows) |
|
|
|
except Exception: |
|
|
|
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId) |
|
|
|
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0], |
|
|
|
ExceptionType.MODEL_LOADING_EXCEPTION.value[1]) |
|
|
|
|
|
|
|
def model_process(args): |
|
|
|
model_conf, frame, request_id = args |
|
|
|
model_param, names, rainbows = model_conf[1], model_conf[3], model_conf[4] |
|
|
@@ -447,7 +499,7 @@ MODEL_CONFIG = { |
|
|
|
lambda x: model_process(x)), |
|
|
|
# 城管模型 |
|
|
|
ModelType.CITY_MANGEMENT_MODEL.value[1]: ( |
|
|
|
lambda x, y, r, t, z, h: OneModel(x, y, r, ModelType.CITY_MANGEMENT_MODEL, t, z, h), |
|
|
|
lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.CITY_MANGEMENT_MODEL, t, z, h), |
|
|
|
ModelType.CITY_MANGEMENT_MODEL, |
|
|
|
lambda x, y, z: one_label(x, y, z), |
|
|
|
lambda x: model_process(x) |