48 lines
1.6 KiB
Python
48 lines
1.6 KiB
Python
import os
|
|
import time,argparse
|
|
import cv2
|
|
import torch
|
|
import sys
|
|
sys.path.extend(['..' ])
|
|
from DMPRUtils.model.detector import DirectionalPointDetector
|
|
from pathlib import Path
|
|
from segutils.trtUtils import toONNX,ONNXtoTrt
|
|
from DMPRUtils.yolo_net import Model
|
|
|
|
def main(opt):
|
|
|
|
pars={'depth_factor':32,'NUM_FEATURE_MAP_CHANNEL':6,'dmpr_thresh':0.3, 'dmprimg_size':640,
|
|
'mWidth':640,'mHeight':640
|
|
}
|
|
|
|
##以下参数目前不可改
|
|
#DMPRweights = "weights/urbanManagement/DMPR/dp_detector_499.pth"
|
|
|
|
DMPRweights = opt.weights.strip()
|
|
DMPR_pthFile = Path(DMPRweights)
|
|
inputShape =(1, 3, pars['mHeight'],pars['mWidth'])#(bs,channels,height,width)
|
|
DMPR_onnxFile = str(DMPR_pthFile.with_suffix('.onnx'))
|
|
DMPR_trtFile = DMPR_onnxFile.replace('.onnx','.engine' )
|
|
|
|
|
|
##加载模型,准备好显示字符
|
|
device = 'cuda:0'
|
|
|
|
# DMPR model
|
|
#DMPRmodel = DirectionalPointDetector(3, pars['depth_factor'], pars['NUM_FEATURE_MAP_CHANNEL']).to(device)
|
|
confUrl = os.path.join( os.path.dirname(__file__),'config','yolov5s.yaml' )
|
|
DMPRmodel = Model(confUrl, ch=3).to(device)
|
|
|
|
|
|
DMPRmodel.load_state_dict(torch.load(DMPRweights))
|
|
|
|
toONNX(DMPRmodel,DMPR_onnxFile,inputShape=inputShape,device=device,dynamic=True)
|
|
ONNXtoTrt(DMPR_onnxFile,DMPR_trtFile)
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--weights', type=str, default='/mnt/thsw2/DSP2/weights/cityMangement2/weights/urbanManagement/DMPR/dp_detector_499.pth', help='model path(s)')
|
|
opt = parser.parse_args()
|
|
|
|
main(opt)
|