AIlib2/obbUtils/pth2onnx.py

65 lines
1.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
#sys.path.extend(['..','../AIlib2' ])
from ocrTrt import toONNX,ONNXtoTrt
from collections import OrderedDict
import torch
import argparse
from load_obb_model import load_model_decoder_OBB
def getModel(opt):
###倾斜框OBB的ship目标检测
par={
'model_size':(608,608), #width,height
'K':100, #Maximum of objects'
'conf_thresh':0.18,##Confidence threshold, 0.1 for general evaluation
'device':"cuda:0",
'down_ratio':4,'num_classes':15,
'weights':opt.weights,
'dataset':'dota',
'test_dir': 'images/ship/',
'result_dir': 'images/results',
'half': False,
'mean':(0.5, 0.5, 0.5),
'std':(1, 1, 1),
'category':['0','1','2','3','4','5','6','7','8','9','10','11','12','13','boat'],
'model_size':(608,608),##width,height
'decoder':None,
'test_flag':True,
'heads': {'hm': None,'wh': 10,'reg': 2,'cls_theta': 1},
}
####加载模型
model,decoder2=load_model_decoder_OBB(par)
par['decoder']=decoder2
model = model.to(par['device'])
return model
if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='/mnt/thsw2/DSP2/weights/ship2/obb_608X608.pth', help='model path(s)')
parser.add_argument('--mWidth', type=int, default=608, help='segmodel mWdith')
parser.add_argument('--mHeight', type=int, default=608, help='segmodel mHeight')
opt = parser.parse_args()
pthmodel = getModel(opt)
###转换TRT模型
onnxFile=opt.weights.replace('.pth','.onnx')
trtFile=opt.weights.replace('.pth','.engine')
print('#'*20, ' begin to toONNX')
toONNX(pthmodel,onnxFile,inputShape=(1,3,opt.mHeight, opt.mWidth),device='cuda:0')
print('#'*20, ' begin to TRT')
ONNXtoTrt(onnxFile,trtFile,half=False)