AIlib2/obbUtils/pth2onnx.py

65 lines
1.9 KiB
Python
Raw Normal View History

2025-04-26 10:35:59 +08:00
import sys
#sys.path.extend(['..','../AIlib2' ])
from ocrTrt import toONNX,ONNXtoTrt
from collections import OrderedDict
import torch
import argparse
from load_obb_model import load_model_decoder_OBB
def getModel(opt):
###倾斜框OBB的ship目标检测
par={
'model_size':(608,608), #width,height
'K':100, #Maximum of objects'
'conf_thresh':0.18,##Confidence threshold, 0.1 for general evaluation
'device':"cuda:0",
'down_ratio':4,'num_classes':15,
'weights':opt.weights,
'dataset':'dota',
'test_dir': 'images/ship/',
'result_dir': 'images/results',
'half': False,
'mean':(0.5, 0.5, 0.5),
'std':(1, 1, 1),
'category':['0','1','2','3','4','5','6','7','8','9','10','11','12','13','boat'],
'model_size':(608,608),##width,height
'decoder':None,
'test_flag':True,
'heads': {'hm': None,'wh': 10,'reg': 2,'cls_theta': 1},
}
####加载模型
model,decoder2=load_model_decoder_OBB(par)
par['decoder']=decoder2
model = model.to(par['device'])
return model
if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='/mnt/thsw2/DSP2/weights/ship2/obb_608X608.pth', help='model path(s)')
parser.add_argument('--mWidth', type=int, default=608, help='segmodel mWdith')
parser.add_argument('--mHeight', type=int, default=608, help='segmodel mHeight')
opt = parser.parse_args()
pthmodel = getModel(opt)
###转换TRT模型
onnxFile=opt.weights.replace('.pth','.onnx')
trtFile=opt.weights.replace('.pth','.engine')
print('#'*20, ' begin to toONNX')
toONNX(pthmodel,onnxFile,inputShape=(1,3,opt.mHeight, opt.mWidth),device='cuda:0')
print('#'*20, ' begin to TRT')
ONNXtoTrt(onnxFile,trtFile,half=False)