AIlib2/DMPRUtils/toTrt.py

48 lines
1.6 KiB
Python

import os
import time,argparse
import cv2
import torch
import sys
sys.path.extend(['..' ])
from DMPRUtils.model.detector import DirectionalPointDetector
from pathlib import Path
from segutils.trtUtils import toONNX,ONNXtoTrt
from DMPRUtils.yolo_net import Model
def main(opt):
pars={'depth_factor':32,'NUM_FEATURE_MAP_CHANNEL':6,'dmpr_thresh':0.3, 'dmprimg_size':640,
'mWidth':640,'mHeight':640
}
##以下参数目前不可改
#DMPRweights = "weights/urbanManagement/DMPR/dp_detector_499.pth"
DMPRweights = opt.weights.strip()
DMPR_pthFile = Path(DMPRweights)
inputShape =(1, 3, pars['mHeight'],pars['mWidth'])#(bs,channels,height,width)
DMPR_onnxFile = str(DMPR_pthFile.with_suffix('.onnx'))
DMPR_trtFile = DMPR_onnxFile.replace('.onnx','.engine' )
##加载模型,准备好显示字符
device = 'cuda:0'
# DMPR model
#DMPRmodel = DirectionalPointDetector(3, pars['depth_factor'], pars['NUM_FEATURE_MAP_CHANNEL']).to(device)
confUrl = os.path.join( os.path.dirname(__file__),'config','yolov5s.yaml' )
DMPRmodel = Model(confUrl, ch=3).to(device)
DMPRmodel.load_state_dict(torch.load(DMPRweights))
toONNX(DMPRmodel,DMPR_onnxFile,inputShape=inputShape,device=device,dynamic=True)
ONNXtoTrt(DMPR_onnxFile,DMPR_trtFile)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='/mnt/thsw2/DSP2/weights/cityMangement2/weights/urbanManagement/DMPR/dp_detector_499.pth', help='model path(s)')
opt = parser.parse_args()
main(opt)