選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

65 行
1.9KB

  1. import sys
  2. #sys.path.extend(['..','../AIlib2' ])
  3. from ocrTrt import toONNX,ONNXtoTrt
  4. from collections import OrderedDict
  5. import torch
  6. import argparse
  7. from load_obb_model import load_model_decoder_OBB
  8. def getModel(opt):
  9. ###倾斜框(OBB)的ship目标检测
  10. par={
  11. 'model_size':(608,608), #width,height
  12. 'K':100, #Maximum of objects'
  13. 'conf_thresh':0.18,##Confidence threshold, 0.1 for general evaluation
  14. 'device':"cuda:0",
  15. 'down_ratio':4,'num_classes':15,
  16. 'weights':opt.weights,
  17. 'dataset':'dota',
  18. 'test_dir': 'images/ship/',
  19. 'result_dir': 'images/results',
  20. 'half': False,
  21. 'mean':(0.5, 0.5, 0.5),
  22. 'std':(1, 1, 1),
  23. 'category':['0','1','2','3','4','5','6','7','8','9','10','11','12','13','boat'],
  24. 'model_size':(608,608),##width,height
  25. 'decoder':None,
  26. 'test_flag':True,
  27. 'heads': {'hm': None,'wh': 10,'reg': 2,'cls_theta': 1},
  28. }
  29. ####加载模型
  30. model,decoder2=load_model_decoder_OBB(par)
  31. par['decoder']=decoder2
  32. model = model.to(par['device'])
  33. return model
  34. if __name__=='__main__':
  35. parser = argparse.ArgumentParser()
  36. parser.add_argument('--weights', type=str, default='/mnt/thsw2/DSP2/weights/ship2/obb_608X608.pth', help='model path(s)')
  37. parser.add_argument('--mWidth', type=int, default=608, help='segmodel mWdith')
  38. parser.add_argument('--mHeight', type=int, default=608, help='segmodel mHeight')
  39. opt = parser.parse_args()
  40. pthmodel = getModel(opt)
  41. ###转换TRT模型
  42. onnxFile=opt.weights.replace('.pth','.onnx')
  43. trtFile=opt.weights.replace('.pth','.engine')
  44. print('#'*20, ' begin to toONNX')
  45. toONNX(pthmodel,onnxFile,inputShape=(1,3,opt.mHeight, opt.mWidth),device='cuda:0')
  46. print('#'*20, ' begin to TRT')
  47. ONNXtoTrt(onnxFile,trtFile,half=False)