from loguru import logger import time import tensorrt as trt from DMPR import DMPRModel from traceback import format_exc from models.experimental import attempt_load from DrGraph.util.drHelper import * from DrGraph.util.Constant import * from DrGraph.enums.ExceptionEnum import ExceptionType from DrGraph.util.stdc import stdcModel # 河道模型、河道检测模型、交通模型、人员落水模型、城市违章公共模型 class Model1: __slots__ = "model_conf" # 3090 def __init__(self, device, allowedList=None, requestId=None, modeType=None, gpu_name=None, base_dir=None, env=None): try: start = time.time() logger.info("########################加载{}########################, requestId:{}", modeType.value[2], requestId) logger.info('__init__(device={}, allowedList={}, requestId={}, modeType={}, gpu_name={}, base_dir={}, env={})', \ device, allowedList, requestId, modeType, gpu_name, base_dir, env) par = modeType.value[4](str(device), gpu_name) mode, postPar, segPar = par.get('mode', 'others'), par.get('postPar'), par.get('segPar') names = par['labelnames'] postFile = par['postFile'] rainbows = postFile["rainbows"] new_device = torchHelper.select_device(par.get('device')) half = new_device.type != 'cpu' Detweights = par['Detweights'] if par['trtFlag_det']: with open(Detweights, "rb") as f, trt.Runtime(trt.Logger(trt.Logger.ERROR)) as runtime: model = runtime.deserialize_cuda_engine(f.read()) else: model = attempt_load(Detweights, map_location=new_device) # load FP32 model if half: model.half() par['segPar']['seg_nclass'] = par['seg_nclass'] Segweights = par['Segweights'] if Segweights: if modeType.value[3] == 'cityMangement3': segmodel = DMPRModel(weights=Segweights, par=par['segPar']) else: segmodel = stdcModel(weights=Segweights, par=par['segPar']) else: segmodel = None objectPar = { 'half': half, 'device': new_device, 'conf_thres': postFile["conf_thres"], 'ovlap_thres_crossCategory': postFile.get("ovlap_thres_crossCategory"), 'iou_thres': postFile["iou_thres"], # 对高速模型进行过滤 'segRegionCnt': par['segRegionCnt'], 'trtFlag_det': par['trtFlag_det'], 'trtFlag_seg': par['trtFlag_seg'], 'score_byClass':par['score_byClass'] if 'score_byClass' in par.keys() else None, 'fiterList': par['fiterList'] if 'fiterList' in par.keys() else [] } model_param = { "model": model, "segmodel": segmodel, "objectPar": objectPar, "segPar": segPar, "mode": mode, "postPar": postPar } self.model_conf = (modeType, model_param, allowedList, names, rainbows) except Exception: logger.error("模型加载异常:{}, requestId:{}", format_exc(), requestId) raise ServiceException(ExceptionType.MODEL_LOADING_EXCEPTION.value[0], ExceptionType.MODEL_LOADING_EXCEPTION.value[1]) logger.info("模型初始化时间:{}, requestId:{}", time.time() - start, requestId)