from DMPRUtils.DMPR_process import DMPR_process import tensorrt as trt import sys,os #from DMPRUtils.model.detector import DirectionalPointDetector from DMPRUtils.yolo_net import Model import torch class DMPRModel(object): def __init__(self, weights=None, par={'depth_factor':32,'NUM_FEATURE_MAP_CHANNEL':6,'dmpr_thresh':0.3, 'dmprimg_size':640} ): self.par = par self.device = 'cuda:0' self.half =True if weights.endswith('.engine'): self.infer_type ='trt' elif weights.endswith('.pth') or weights.endswith('.pt') : self.infer_type ='pth' else: print('#########ERROR:',weights,': no registered inference type, exit') sys.exit(0) if self.infer_type=='trt': logger = trt.Logger(trt.Logger.ERROR) with open(weights, "rb") as f, trt.Runtime(logger) as runtime: self.model=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象 elif self.infer_type=='pth': #self.model = DirectionalPointDetector(3, self.par['depth_factor'], self.par['NUM_FEATURE_MAP_CHANNEL']).to(self.device) confUrl = os.path.join( os.path.dirname(__file__),'DMPRUtils','config','yolov5s.yaml' ) self.model = Model(confUrl, ch=3).to(self.device) self.model.load_state_dict(torch.load(weights)) print('#######load pt model:%s success '%(weights)) self.par['modelType']=self.infer_type print('#########加载模型:',weights,' 类型:',self.infer_type) def eval(self,image): det,timeInfos = DMPR_process(image, self.model, self.device, self.par) det = det.cpu().detach().numpy() return det,timeInfos def get_ms(self,t1,t0): return (t1-t0)*1000.0