AIlib2/crowdUtils/toTrt.py

57 lines
1.8 KiB
Python
Executable File

import sys
from models import build_model
sys.path.extend(['..','.' ])
from segutils.trtUtils2 import pth2onnx,onnx2engine,onnx_inference
from engine import DictToObject
from pathlib import Path
import torch
import os
import tensorrt as trt
import numpy as np
import argparse
def main(opt):
#pth_model='../weights/best_mae.pth'
pth_model=opt.weights.strip()
onnx_name = pth_model.replace('.pth','_dynamic.onnx')
trt_name = onnx_name.replace('.onnx','.engine')
dynamic_hw ={'input':{0:'batch',2:'H',3:'W'},
'output0':{1:'C'},
'output1':{1:'C'},
}
par = {'backbone':'vgg16_bn', 'gpu_id':0, 'line':2, 'output_dir':'./output', 'row':2, 'anchorFlag':False,'weight_path':'./weights/best_mae.pth'}
args = DictToObject(par)
model = build_model(args)
pthFile = Path(pth_model)
checkpoint = torch.load(pthFile, map_location='cpu')
model.load_state_dict(checkpoint['model'])
model = model.to('cuda:0')
inputShape =(1, 3, 128*4,128*4)#(bs,channels,height,width)
input_profile_shapes = [(1,3,256,256),(1,3,1024,1024),(1,3,2048,2048)]
pth2onnx(model,onnx_name,input_shape=(1,3,512,512),input_names=['input'],output_names=[ 'output0' ,'output1'],dynamix_axis=dynamic_hw)
onnx2engine(onnx_name,trt_name,input_shape=[1,3,-1,-1],half=True,max_batch_size=1,input_profile_shapes=input_profile_shapes)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='/mnt/thsw2/DSP2/weights/cityMangement2/weights/urbanManagement/DMPR/dp_detector_499.pth', help='model path(s)')
opt = parser.parse_args()
main(opt)