65 lines
2.5 KiB
Python
65 lines
2.5 KiB
Python
import torch
|
||
import numpy as np
|
||
import cv2
|
||
import time
|
||
import os
|
||
import sys
|
||
sys.path.extend(['../AIlib2/obbUtils'])
|
||
import matplotlib.pyplot as plt
|
||
import func_utils
|
||
import time
|
||
import torchvision.transforms as transforms
|
||
from obbmodels import ctrbox_net
|
||
import decoder
|
||
import tensorrt as trt
|
||
import onnx
|
||
import onnxruntime as ort
|
||
|
||
def load_model_decoder_OBB(par={'down_ratio':4,'num_classes':15,'weights':'weights_dota/obb.pth'}):
|
||
weights=par['weights']
|
||
heads = par['heads']
|
||
heads['hm']=par['num_classes']
|
||
par['heads']=heads
|
||
if weights.endswith('.pth') or weights.endswith('.pt'):
|
||
resume=par['weights']
|
||
down_ratio = par['down_ratio']
|
||
model = ctrbox_net.CTRBOX(heads=heads,
|
||
pretrained=True,
|
||
down_ratio=down_ratio,
|
||
final_kernel=1,
|
||
head_conv=256)
|
||
|
||
|
||
checkpoint = torch.load(resume, map_location=lambda storage, loc: storage)
|
||
print('loaded weights from {}, epoch {}'.format(resume, checkpoint['epoch']))
|
||
state_dict_ = checkpoint['model_state_dict']
|
||
model.load_state_dict(state_dict_, strict=True)
|
||
model.eval()
|
||
model = model.to(par['device'])
|
||
model = model.half() if par['half'] else model
|
||
par['saveType']='pth'
|
||
elif weights.endswith('.onnx'):
|
||
onnx_model = onnx.load(weights)
|
||
onnx.checker.check_model(onnx_model)
|
||
# 设置模型session以及输入信息
|
||
sess = ort.InferenceSession(str(weights),providers= ort.get_available_providers())
|
||
print('len():',len( sess.get_inputs() ))
|
||
input_name = sess.get_inputs()[0].name
|
||
model = {'sess':sess,'input_name':input_name}
|
||
par['saveType']='onnx'
|
||
elif weights.endswith('.engine'):
|
||
logger = trt.Logger(trt.Logger.ERROR)
|
||
with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
|
||
model = runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
|
||
print('#####load TRT file:',weights,'success #####')
|
||
par['saveType']='trt'
|
||
|
||
|
||
decoder2 = decoder.DecDecoder(K=par['K'],
|
||
conf_thresh=par['conf_thresh'],
|
||
num_classes=par['num_classes'])
|
||
|
||
|
||
|
||
return model, decoder2
|