AIlib2/DrGraph/Bussiness/Models.py

74 lines
3.6 KiB
Python

from loguru import logger
import time
import tensorrt as trt
from DMPR import DMPRModel
from traceback import format_exc
from models.experimental import attempt_load
from DrGraph.util.drHelper import *
from DrGraph.util.Constant import *
from DrGraph.enums.ExceptionEnum import ExceptionType
from DrGraph.util.stdc import stdcModel
# 河道模型、河道检测模型、交通模型、人员落水模型、城市违章公共模型
class Model1:
__slots__ = "model_conf"
# 3090
def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None):
try:
start = time.time()
logger.info("########################加载{}########################, requestId:{}", modeType.value[2],
requestId)
logger.info('__init__(device={}, allowedList={}, requestId={}, modeType={}, gpu_name={}, base_dir={}, env={})', \
device, allowedList, requestId, modeType, gpu_name, base_dir, env)
par = modeType.value[4](str(device), gpu_name)
mode, postPar, segPar = par.get('mode', 'others'), par.get('postPar'), par.get('segPar')
names = par['labelnames']
postFile = par['postFile']
rainbows = postFile["rainbows"]
new_device = torchHelper.select_device(par.get('device'))
half = new_device.type != 'cpu'
Detweights = par['Detweights']
if par['trtFlag_det']:
with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime:
model = runtime.deserialize_cuda_engine(f.read())
else:
model = attempt_load(Detweights, map_location=new_device) # load FP32 model
if half: model.half()
par['segPar']['seg_nclass'] = par['seg_nclass']
Segweights = par['Segweights']
if Segweights:
if modeType.value[3] == 'cityMangement3':
segmodel = DMPRModel(weights=Segweights, par=par['segPar'])
else:
segmodel = stdcModel(weights=Segweights, par=par['segPar'])
else:
segmodel = None
objectPar = {
'half': half,
'device': new_device,
'conf_thres': postFile["conf_thres"],
'ovlap_thres_crossCategory': postFile.get("ovlap_thres_crossCategory"),
'iou_thres': postFile["iou_thres"],
# 对高速模型进行过滤
'segRegionCnt': par['segRegionCnt'],
'trtFlag_det': par['trtFlag_det'],
'trtFlag_seg': par['trtFlag_seg'],
'score_byClass':par['score_byClass'] if 'score_byClass' in par.keys() else None,
'fiterList': par['fiterList'] if 'fiterList' in par.keys() else []
}
model_param = {
"model": model,
"segmodel": segmodel,
"objectPar": objectPar,
"segPar": segPar,
"mode": mode,
"postPar": postPar
}
self.model_conf = (modeType, model_param, allowedList, names, rainbows)
except Exception:
logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId)
raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0],
ExceptionType.MODEL_LOADING_EXCEPTION.value[1])
logger.info("模型初始化时间:{}, requestId:{}", time.time() - start, requestId)