Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

65 rindas
2.5KB

  1. import torch
  2. import numpy as np
  3. import cv2
  4. import time
  5. import os
  6. import sys
  7. sys.path.extend(['../AIlib2/obbUtils'])
  8. import matplotlib.pyplot as plt
  9. import func_utils
  10. import time
  11. import torchvision.transforms as transforms
  12. from obbmodels import ctrbox_net
  13. import decoder
  14. import tensorrt as trt
  15. import onnx
  16. import onnxruntime as ort
  17. def load_model_decoder_OBB(par={'down_ratio':4,'num_classes':15,'weights':'weights_dota/obb.pth'}):
  18. weights=par['weights']
  19. heads = par['heads']
  20. heads['hm']=par['num_classes']
  21. par['heads']=heads
  22. if weights.endswith('.pth') or weights.endswith('.pt'):
  23. resume=par['weights']
  24. down_ratio = par['down_ratio']
  25. model = ctrbox_net.CTRBOX(heads=heads,
  26. pretrained=True,
  27. down_ratio=down_ratio,
  28. final_kernel=1,
  29. head_conv=256)
  30. checkpoint = torch.load(resume, map_location=lambda storage, loc: storage)
  31. print('loaded weights from {}, epoch {}'.format(resume, checkpoint['epoch']))
  32. state_dict_ = checkpoint['model_state_dict']
  33. model.load_state_dict(state_dict_, strict=True)
  34. model.eval()
  35. model = model.to(par['device'])
  36. model = model.half() if par['half'] else model
  37. par['saveType']='pth'
  38. elif weights.endswith('.onnx'):
  39. onnx_model = onnx.load(weights)
  40. onnx.checker.check_model(onnx_model)
  41. # 设置模型session以及输入信息
  42. sess = ort.InferenceSession(str(weights),providers= ort.get_available_providers())
  43. print('len():',len( sess.get_inputs() ))
  44. input_name = sess.get_inputs()[0].name
  45. model = {'sess':sess,'input_name':input_name}
  46. par['saveType']='onnx'
  47. elif weights.endswith('.engine'):
  48. logger = trt.Logger(trt.Logger.ERROR)
  49. with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
  50. model = runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
  51. print('#####load TRT file:',weights,'success #####')
  52. par['saveType']='trt'
  53. decoder2 = decoder.DecDecoder(K=par['K'],
  54. conf_thresh=par['conf_thresh'],
  55. num_classes=par['num_classes'])
  56. return model, decoder2