74 lines
3.6 KiB
Python
74 lines
3.6 KiB
Python
from loguru import logger
|
|
import time
|
|
import tensorrt as trt
|
|
from DMPR import DMPRModel
|
|
from traceback import format_exc
|
|
from models.experimental import attempt_load
|
|
|
|
from DrGraph.util.drHelper import *
|
|
from DrGraph.util.Constant import *
|
|
from DrGraph.enums.ExceptionEnum import ExceptionType
|
|
from DrGraph.util.stdc import stdcModel
|
|
|
|
# 河道模型、河道检测模型、交通模型、人员落水模型、城市违章公共模型
|
|
class Model1:
|
|
__slots__ = "model_conf"
|
|
# 3090
|
|
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
|
|
try:
|
|
start = time.time()
|
|
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
|
|
requestId)
|
|
logger.info('__init__(device={}, allowedList={}, requestId={}, modeType={}, gpu_name={}, base_dir={}, env={})', \
|
|
device, allowedList, requestId, modeType, gpu_name, base_dir, env)
|
|
par = modeType.value[4](str(device), gpu_name)
|
|
mode, postPar, segPar = par.get('mode', 'others'), par.get('postPar'), par.get('segPar')
|
|
names = par['labelnames']
|
|
postFile = par['postFile']
|
|
rainbows = postFile["rainbows"]
|
|
new_device = torchHelper.select_device(par.get('device'))
|
|
half = new_device.type != 'cpu'
|
|
Detweights = par['Detweights']
|
|
if par['trtFlag_det']:
|
|
with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
|
|
model = runtime.deserialize_cuda_engine(f.read())
|
|
else:
|
|
model = attempt_load(Detweights, map_location=new_device) # load FP32 model
|
|
if half: model.half()
|
|
par['segPar']['seg_nclass'] = par['seg_nclass']
|
|
Segweights = par['Segweights']
|
|
if Segweights:
|
|
if modeType.value[3] == 'cityMangement3':
|
|
segmodel = DMPRModel(weights=Segweights, par=par['segPar'])
|
|
else:
|
|
segmodel = stdcModel(weights=Segweights, par=par['segPar'])
|
|
else:
|
|
segmodel = None
|
|
objectPar = {
|
|
'half': half,
|
|
'device': new_device,
|
|
'conf_thres': postFile["conf_thres"],
|
|
'ovlap_thres_crossCategory': postFile.get("ovlap_thres_crossCategory"),
|
|
'iou_thres': postFile["iou_thres"],
|
|
# 对高速模型进行过滤
|
|
'segRegionCnt': par['segRegionCnt'],
|
|
'trtFlag_det': par['trtFlag_det'],
|
|
'trtFlag_seg': par['trtFlag_seg'],
|
|
'score_byClass':par['score_byClass'] if 'score_byClass' in par.keys() else None,
|
|
'fiterList': par['fiterList'] if 'fiterList' in par.keys() else []
|
|
}
|
|
model_param = {
|
|
"model": model,
|
|
"segmodel": segmodel,
|
|
"objectPar": objectPar,
|
|
"segPar": segPar,
|
|
"mode": mode,
|
|
"postPar": postPar
|
|
}
|
|
self.model_conf = (modeType, model_param, allowedList, names, rainbows)
|
|
except Exception:
|
|
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
|
|
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
|
|
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
|
|
logger.info("模型初始化时间:{}, requestId:{}", time.time() - start, requestId)
|
|
|