AIlib2/DMPR.py

46 lines
1.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from DMPRUtils.DMPR_process import DMPR_process
import tensorrt as trt
import sys,os
#from DMPRUtils.model.detector import DirectionalPointDetector
from DMPRUtils.yolo_net import Model
import torch
class DMPRModel(object):
def __init__(self, weights=None,
par={'depth_factor':32,'NUM_FEATURE_MAP_CHANNEL':6,'dmpr_thresh':0.3, 'dmprimg_size':640}
):
self.par = par
self.device = 'cuda:0'
self.half =True
if weights.endswith('.engine'):
self.infer_type ='trt'
elif weights.endswith('.pth') or weights.endswith('.pt') :
self.infer_type ='pth'
else:
print('#########ERROR:',weights,': no registered inference type, exit')
sys.exit(0)
if self.infer_type=='trt':
logger = trt.Logger(trt.Logger.ERROR)
with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
self.model=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件返回ICudaEngine对象
elif self.infer_type=='pth':
#self.model = DirectionalPointDetector(3, self.par['depth_factor'], self.par['NUM_FEATURE_MAP_CHANNEL']).to(self.device)
confUrl = os.path.join( os.path.dirname(__file__),'DMPRUtils','config','yolov5s.yaml' )
self.model = Model(confUrl, ch=3).to(self.device)
self.model.load_state_dict(torch.load(weights))
print('#######load pt model:%s success '%(weights))
self.par['modelType']=self.infer_type
print('#########加载模型:',weights,' 类型:',self.infer_type)
def eval(self,image):
det,timeInfos = DMPR_process(image, self.model, self.device, self.par)
det = det.cpu().detach().numpy()
return det,timeInfos
def get_ms(self,t1,t0):
return (t1-t0)*1000.0