You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

45 lines
1.9KB

  1. from DMPRUtils.DMPR_process import DMPR_process
  2. import tensorrt as trt
  3. import sys,os
  4. #from DMPRUtils.model.detector import DirectionalPointDetector
  5. from DMPRUtils.yolo_net import Model
  6. import torch
  7. class DMPRModel(object):
  8. def __init__(self, weights=None,
  9. par={'depth_factor':32,'NUM_FEATURE_MAP_CHANNEL':6,'dmpr_thresh':0.3, 'dmprimg_size':640}
  10. ):
  11. self.par = par
  12. self.device = 'cuda:0'
  13. self.half =True
  14. if weights.endswith('.engine'):
  15. self.infer_type ='trt'
  16. elif weights.endswith('.pth') or weights.endswith('.pt') :
  17. self.infer_type ='pth'
  18. else:
  19. print('#########ERROR:',weights,': no registered inference type, exit')
  20. sys.exit(0)
  21. if self.infer_type=='trt':
  22. logger = trt.Logger(trt.Logger.ERROR)
  23. with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
  24. self.model=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
  25. elif self.infer_type=='pth':
  26. #self.model = DirectionalPointDetector(3, self.par['depth_factor'], self.par['NUM_FEATURE_MAP_CHANNEL']).to(self.device)
  27. confUrl = os.path.join( os.path.dirname(__file__),'DMPRUtils','config','yolov5s.yaml' )
  28. self.model = Model(confUrl, ch=3).to(self.device)
  29. self.model.load_state_dict(torch.load(weights))
  30. print('#######load pt model:%s success '%(weights))
  31. self.par['modelType']=self.infer_type
  32. print('#########加载模型:',weights,' 类型:',self.infer_type)
  33. def eval(self,image):
  34. det,timeInfos = DMPR_process(image, self.model, self.device, self.par)
  35. det = det.cpu().detach().numpy()
  36. return det,timeInfos
  37. def get_ms(self,t1,t0):
  38. return (t1-t0)*1000.0