Browse Source

城管模型新增按类别分数过滤参数

develop
YAO 1 year ago
parent
commit
9e865e8957
2 changed files with 54 additions and 1 deletions
  1. +1
    -0
      enums/ModelTypeEnum.py
  2. +53
    -1
      util/ModelUtils.py

+ 1
- 0
enums/ModelTypeEnum.py View File

@@ -347,6 +347,7 @@ class ModelType(Enum):
"classes": 5,
"rainbows": COLOR
},
"score_byClass": {'0':0.8, '1':0.5, '2':0.5},
'Segweights': '../AIlib2/weights/cityMangement2/dmpr_%s.engine' % gpuName
})


+ 53
- 1
util/ModelUtils.py View File

@@ -82,7 +82,59 @@ class OneModel:
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])

# 城管模型专用,多出一个score_byClass参数
class cityManagementModel:
__slots__ = "model_conf"

def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
try:
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
requestId)
par = modeType.value[4](str(device), gpu_name)
mode, postPar, segPar = par.get('mode', 'others'), par.get('postPar'), par.get('segPar')
names = par['labelnames']
postFile = par['postFile']
rainbows = postFile["rainbows"]
new_device = select_device(par.get('device'))
half = new_device.type != 'cpu'
Detweights = par['Detweights']
with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
model = runtime.deserialize_cuda_engine(f.read())
par['segPar']['seg_nclass'] = par['seg_nclass']
Segweights = par['Segweights']
if Segweights:
if modeType.value[3] == 'cityMangement2':
segmodel = DMPRModel(weights=Segweights, par=par['segPar'])
else:
segmodel = stdcModel(weights=Segweights, par=par['segPar'])
else:
segmodel = None
objectPar = {
'half': half,
'device': new_device,
'conf_thres': postFile["conf_thres"],
'ovlap_thres_crossCategory': postFile.get("ovlap_thres_crossCategory"),
'iou_thres': postFile["iou_thres"],
'allowedList': [],
'segRegionCnt': par['segRegionCnt'],
'trtFlag_det': par['trtFlag_det'],
'trtFlag_seg': par['trtFlag_seg'],
'score_byClass': par['score_byClass']
}
model_param = {
"model": model,
"segmodel": segmodel,
"objectPar": objectPar,
"segPar": segPar,
"mode": mode,
"postPar": postPar
}
self.model_conf = (modeType, model_param, allowedList, names, rainbows)
except Exception:
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
def model_process(args):
model_conf, frame, request_id = args
model_param, names, rainbows = model_conf[1], model_conf[3], model_conf[4]
@@ -447,7 +499,7 @@ MODEL_CONFIG = {
lambda x: model_process(x)),
# 城管模型
ModelType.CITY_MANGEMENT_MODEL.value[1]: (
lambda x, y, r, t, z, h: OneModel(x, y, r, ModelType.CITY_MANGEMENT_MODEL, t, z, h),
lambda x, y, r, t, z, h: cityManagementModel(x, y, r, ModelType.CITY_MANGEMENT_MODEL, t, z, h),
ModelType.CITY_MANGEMENT_MODEL,
lambda x, y, z: one_label(x, y, z),
lambda x: model_process(x)

Loading…
Cancel
Save