Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

65 Zeilen
2.5KB

  1. import torch
  2. import numpy as np
  3. import cv2
  4. import time
  5. import os
  6. import sys
  7. sys.path.extend(['../AIlib2/obbUtils'])
  8. import matplotlib.pyplot as plt
  9. import func_utils
  10. import time
  11. import torchvision.transforms as transforms
  12. from obbmodels import ctrbox_net
  13. import decoder
  14. import tensorrt as trt
  15. import onnx
  16. import onnxruntime as ort
  17. def load_model_decoder_OBB(par={'down_ratio':4,'num_classes':15,'weights':'weights_dota/obb.pth'}):
  18. weights=par['weights']
  19. heads = par['heads']
  20. heads['hm']=par['num_classes']
  21. par['heads']=heads
  22. if weights.endswith('.pth') or weights.endswith('.pt'):
  23. resume=par['weights']
  24. down_ratio = par['down_ratio']
  25. model = ctrbox_net.CTRBOX(heads=heads,
  26. pretrained=True,
  27. down_ratio=down_ratio,
  28. final_kernel=1,
  29. head_conv=256)
  30. checkpoint = torch.load(resume, map_location=lambda storage, loc: storage)
  31. print('loaded weights from {}, epoch {}'.format(resume, checkpoint['epoch']))
  32. state_dict_ = checkpoint['model_state_dict']
  33. model.load_state_dict(state_dict_, strict=True)
  34. model.eval()
  35. model = model.to(par['device'])
  36. model = model.half() if par['half'] else model
  37. par['saveType']='pth'
  38. elif weights.endswith('.onnx'):
  39. onnx_model = onnx.load(weights)
  40. onnx.checker.check_model(onnx_model)
  41. # 设置模型session以及输入信息
  42. sess = ort.InferenceSession(str(weights),providers= ort.get_available_providers())
  43. print('len():',len( sess.get_inputs() ))
  44. input_name = sess.get_inputs()[0].name
  45. model = {'sess':sess,'input_name':input_name}
  46. par['saveType']='onnx'
  47. elif weights.endswith('.engine'):
  48. logger = trt.Logger(trt.Logger.ERROR)
  49. with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
  50. model = runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
  51. print('#####load TRT file:',weights,'success #####')
  52. par['saveType']='trt'
  53. decoder2 = decoder.DecDecoder(K=par['K'],
  54. conf_thresh=par['conf_thresh'],
  55. num_classes=par['num_classes'])
  56. return model, decoder2