import torch import numpy as np import cv2 import time import os import sys sys.path.extend(['../AIlib2/obbUtils']) import matplotlib.pyplot as plt import func_utils import time import torchvision.transforms as transforms from obbmodels import ctrbox_net import decoder import tensorrt as trt import onnx import onnxruntime as ort def load_model_decoder_OBB(par={'down_ratio':4,'num_classes':15,'weights':'weights_dota/obb.pth'}): weights=par['weights'] heads = par['heads'] heads['hm']=par['num_classes'] par['heads']=heads if weights.endswith('.pth') or weights.endswith('.pt'): resume=par['weights'] down_ratio = par['down_ratio'] model = ctrbox_net.CTRBOX(heads=heads, pretrained=True, down_ratio=down_ratio, final_kernel=1, head_conv=256) checkpoint = torch.load(resume, map_location=lambda storage, loc: storage) print('loaded weights from {}, epoch {}'.format(resume, checkpoint['epoch'])) state_dict_ = checkpoint['model_state_dict'] model.load_state_dict(state_dict_, strict=True) model.eval() model = model.to(par['device']) model = model.half() if par['half'] else model par['saveType']='pth' elif weights.endswith('.onnx'): onnx_model = onnx.load(weights) onnx.checker.check_model(onnx_model) # 设置模型session以及输入信息 sess = ort.InferenceSession(str(weights),providers= ort.get_available_providers()) print('len():',len( sess.get_inputs() )) input_name = sess.get_inputs()[0].name model = {'sess':sess,'input_name':input_name} par['saveType']='onnx' elif weights.endswith('.engine'): logger = trt.Logger(trt.Logger.ERROR) with open(weights, "rb") as f, trt.Runtime(logger) as runtime: model = runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象 print('#####load TRT file:',weights,'success #####') par['saveType']='trt' decoder2 = decoder.DecDecoder(K=par['K'], conf_thresh=par['conf_thresh'], num_classes=par['num_classes']) return model, decoder2