65 lines
2.5 KiB
Python
65 lines
2.5 KiB
Python
|
|
import torch
|
|||
|
|
import numpy as np
|
|||
|
|
import cv2
|
|||
|
|
import time
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
sys.path.extend(['../AIlib2/obbUtils'])
|
|||
|
|
import matplotlib.pyplot as plt
|
|||
|
|
import func_utils
|
|||
|
|
import time
|
|||
|
|
import torchvision.transforms as transforms
|
|||
|
|
from obbmodels import ctrbox_net
|
|||
|
|
import decoder
|
|||
|
|
import tensorrt as trt
|
|||
|
|
import onnx
|
|||
|
|
import onnxruntime as ort
|
|||
|
|
|
|||
|
|
def load_model_decoder_OBB(par={'down_ratio':4,'num_classes':15,'weights':'weights_dota/obb.pth'}):
|
|||
|
|
weights=par['weights']
|
|||
|
|
heads = par['heads']
|
|||
|
|
heads['hm']=par['num_classes']
|
|||
|
|
par['heads']=heads
|
|||
|
|
if weights.endswith('.pth') or weights.endswith('.pt'):
|
|||
|
|
resume=par['weights']
|
|||
|
|
down_ratio = par['down_ratio']
|
|||
|
|
model = ctrbox_net.CTRBOX(heads=heads,
|
|||
|
|
pretrained=True,
|
|||
|
|
down_ratio=down_ratio,
|
|||
|
|
final_kernel=1,
|
|||
|
|
head_conv=256)
|
|||
|
|
|
|||
|
|
|
|||
|
|
checkpoint = torch.load(resume, map_location=lambda storage, loc: storage)
|
|||
|
|
print('loaded weights from {}, epoch {}'.format(resume, checkpoint['epoch']))
|
|||
|
|
state_dict_ = checkpoint['model_state_dict']
|
|||
|
|
model.load_state_dict(state_dict_, strict=True)
|
|||
|
|
model.eval()
|
|||
|
|
model = model.to(par['device'])
|
|||
|
|
model = model.half() if par['half'] else model
|
|||
|
|
par['saveType']='pth'
|
|||
|
|
elif weights.endswith('.onnx'):
|
|||
|
|
onnx_model = onnx.load(weights)
|
|||
|
|
onnx.checker.check_model(onnx_model)
|
|||
|
|
# 设置模型session以及输入信息
|
|||
|
|
sess = ort.InferenceSession(str(weights),providers= ort.get_available_providers())
|
|||
|
|
print('len():',len( sess.get_inputs() ))
|
|||
|
|
input_name = sess.get_inputs()[0].name
|
|||
|
|
model = {'sess':sess,'input_name':input_name}
|
|||
|
|
par['saveType']='onnx'
|
|||
|
|
elif weights.endswith('.engine'):
|
|||
|
|
logger = trt.Logger(trt.Logger.ERROR)
|
|||
|
|
with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
|
|||
|
|
model = runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
|
|||
|
|
print('#####load TRT file:',weights,'success #####')
|
|||
|
|
par['saveType']='trt'
|
|||
|
|
|
|||
|
|
|
|||
|
|
decoder2 = decoder.DecDecoder(K=par['K'],
|
|||
|
|
conf_thresh=par['conf_thresh'],
|
|||
|
|
num_classes=par['num_classes'])
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
return model, decoder2
|