2025-04-26 10:35:59 +08:00
|
|
|
|
from DMPRUtils.DMPR_process import DMPR_process
|
|
|
|
|
|
import tensorrt as trt
|
|
|
|
|
|
import sys,os
|
|
|
|
|
|
#from DMPRUtils.model.detector import DirectionalPointDetector
|
|
|
|
|
|
from DMPRUtils.yolo_net import Model
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
class DMPRModel(object):
|
|
|
|
|
|
def __init__(self, weights=None,
|
|
|
|
|
|
par={'depth_factor':32,'NUM_FEATURE_MAP_CHANNEL':6,'dmpr_thresh':0.3, 'dmprimg_size':640}
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
|
|
|
|
self.par = par
|
|
|
|
|
|
self.device = 'cuda:0'
|
|
|
|
|
|
self.half =True
|
|
|
|
|
|
|
|
|
|
|
|
if weights.endswith('.engine'):
|
|
|
|
|
|
self.infer_type ='trt'
|
|
|
|
|
|
elif weights.endswith('.pth') or weights.endswith('.pt') :
|
|
|
|
|
|
self.infer_type ='pth'
|
|
|
|
|
|
else:
|
|
|
|
|
|
print('#########ERROR:',weights,': no registered inference type, exit')
|
|
|
|
|
|
sys.exit(0)
|
|
|
|
|
|
|
|
|
|
|
|
if self.infer_type=='trt':
|
|
|
|
|
|
logger = trt.Logger(trt.Logger.ERROR)
|
|
|
|
|
|
with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
|
|
|
|
|
|
self.model=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
|
|
|
|
|
|
elif self.infer_type=='pth':
|
|
|
|
|
|
#self.model = DirectionalPointDetector(3, self.par['depth_factor'], self.par['NUM_FEATURE_MAP_CHANNEL']).to(self.device)
|
|
|
|
|
|
confUrl = os.path.join( os.path.dirname(__file__),'DMPRUtils','config','yolov5s.yaml' )
|
|
|
|
|
|
self.model = Model(confUrl, ch=3).to(self.device)
|
|
|
|
|
|
self.model.load_state_dict(torch.load(weights))
|
|
|
|
|
|
print('#######load pt model:%s success '%(weights))
|
|
|
|
|
|
self.par['modelType']=self.infer_type
|
|
|
|
|
|
print('#########加载模型:',weights,' 类型:',self.infer_type)
|
2025-04-26 14:13:15 +08:00
|
|
|
|
def eval(self,image):
|
2025-04-26 10:35:59 +08:00
|
|
|
|
det,timeInfos = DMPR_process(image, self.model, self.device, self.par)
|
|
|
|
|
|
det = det.cpu().detach().numpy()
|
|
|
|
|
|
return det,timeInfos
|
|
|
|
|
|
|
|
|
|
|
|
def get_ms(self,t1,t0):
|
|
|
|
|
|
return (t1-t0)*1000.0
|
|
|
|
|
|
|
2025-04-26 14:13:15 +08:00
|
|
|
|
|