You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

65 lines
2.5KB

  1. import torch
  2. import numpy as np
  3. import cv2
  4. import time
  5. import os
  6. import sys
  7. sys.path.extend(['../AIlib2/obbUtils'])
  8. import matplotlib.pyplot as plt
  9. import func_utils
  10. import time
  11. import torchvision.transforms as transforms
  12. from obbmodels import ctrbox_net
  13. import decoder
  14. import tensorrt as trt
  15. import onnx
  16. import onnxruntime as ort
  17. def load_model_decoder_OBB(par={'down_ratio':4,'num_classes':15,'weights':'weights_dota/obb.pth'}):
  18. weights=par['weights']
  19. heads = par['heads']
  20. heads['hm']=par['num_classes']
  21. par['heads']=heads
  22. if weights.endswith('.pth') or weights.endswith('.pt'):
  23. resume=par['weights']
  24. down_ratio = par['down_ratio']
  25. model = ctrbox_net.CTRBOX(heads=heads,
  26. pretrained=True,
  27. down_ratio=down_ratio,
  28. final_kernel=1,
  29. head_conv=256)
  30. checkpoint = torch.load(resume, map_location=lambda storage, loc: storage)
  31. print('loaded weights from {}, epoch {}'.format(resume, checkpoint['epoch']))
  32. state_dict_ = checkpoint['model_state_dict']
  33. model.load_state_dict(state_dict_, strict=True)
  34. model.eval()
  35. model = model.to(par['device'])
  36. model = model.half() if par['half'] else model
  37. par['saveType']='pth'
  38. elif weights.endswith('.onnx'):
  39. onnx_model = onnx.load(weights)
  40. onnx.checker.check_model(onnx_model)
  41. # 设置模型session以及输入信息
  42. sess = ort.InferenceSession(str(weights),providers= ort.get_available_providers())
  43. print('len():',len( sess.get_inputs() ))
  44. input_name = sess.get_inputs()[0].name
  45. model = {'sess':sess,'input_name':input_name}
  46. par['saveType']='onnx'
  47. elif weights.endswith('.engine'):
  48. logger = trt.Logger(trt.Logger.ERROR)
  49. with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
  50. model = runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
  51. print('#####load TRT file:',weights,'success #####')
  52. par['saveType']='trt'
  53. decoder2 = decoder.DecDecoder(K=par['K'],
  54. conf_thresh=par['conf_thresh'],
  55. num_classes=par['num_classes'])
  56. return model, decoder2