Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

65 linhas
2.5KB

  1. import torch
  2. import numpy as np
  3. import cv2
  4. import time
  5. import os
  6. import sys
  7. sys.path.extend(['../AIlib2/obbUtils'])
  8. import matplotlib.pyplot as plt
  9. import func_utils
  10. import time
  11. import torchvision.transforms as transforms
  12. from obbmodels import ctrbox_net
  13. import decoder
  14. import tensorrt as trt
  15. import onnx
  16. import onnxruntime as ort
  17. def load_model_decoder_OBB(par={'down_ratio':4,'num_classes':15,'weights':'weights_dota/obb.pth'}):
  18. weights=par['weights']
  19. heads = par['heads']
  20. heads['hm']=par['num_classes']
  21. par['heads']=heads
  22. if weights.endswith('.pth') or weights.endswith('.pt'):
  23. resume=par['weights']
  24. down_ratio = par['down_ratio']
  25. model = ctrbox_net.CTRBOX(heads=heads,
  26. pretrained=True,
  27. down_ratio=down_ratio,
  28. final_kernel=1,
  29. head_conv=256)
  30. checkpoint = torch.load(resume, map_location=lambda storage, loc: storage)
  31. print('loaded weights from {}, epoch {}'.format(resume, checkpoint['epoch']))
  32. state_dict_ = checkpoint['model_state_dict']
  33. model.load_state_dict(state_dict_, strict=True)
  34. model.eval()
  35. model = model.to(par['device'])
  36. model = model.half() if par['half'] else model
  37. par['saveType']='pth'
  38. elif weights.endswith('.onnx'):
  39. onnx_model = onnx.load(weights)
  40. onnx.checker.check_model(onnx_model)
  41. # 设置模型session以及输入信息
  42. sess = ort.InferenceSession(str(weights),providers= ort.get_available_providers())
  43. print('len():',len( sess.get_inputs() ))
  44. input_name = sess.get_inputs()[0].name
  45. model = {'sess':sess,'input_name':input_name}
  46. par['saveType']='onnx'
  47. elif weights.endswith('.engine'):
  48. logger = trt.Logger(trt.Logger.ERROR)
  49. with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
  50. model = runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
  51. print('#####load TRT file:',weights,'success #####')
  52. par['saveType']='trt'
  53. decoder2 = decoder.DecDecoder(K=par['K'],
  54. conf_thresh=par['conf_thresh'],
  55. num_classes=par['num_classes'])
  56. return model, decoder2