AIlib2/obbUtils/load_obb_model.py

65 lines
2.5 KiB
Python
Raw Normal View History

2025-04-26 10:35:59 +08:00
import torch
import numpy as np
import cv2
import time
import os
import sys
sys.path.extend(['../AIlib2/obbUtils'])
import matplotlib.pyplot as plt
import func_utils
import time
import torchvision.transforms as transforms
from obbmodels import ctrbox_net
import decoder
import tensorrt as trt
import onnx
import onnxruntime as ort
def load_model_decoder_OBB(par={'down_ratio':4,'num_classes':15,'weights':'weights_dota/obb.pth'}):
weights=par['weights']
heads = par['heads']
heads['hm']=par['num_classes']
par['heads']=heads
if weights.endswith('.pth') or weights.endswith('.pt'):
resume=par['weights']
down_ratio = par['down_ratio']
model = ctrbox_net.CTRBOX(heads=heads,
pretrained=True,
down_ratio=down_ratio,
final_kernel=1,
head_conv=256)
checkpoint = torch.load(resume, map_location=lambda storage, loc: storage)
print('loaded weights from {}, epoch {}'.format(resume, checkpoint['epoch']))
state_dict_ = checkpoint['model_state_dict']
model.load_state_dict(state_dict_, strict=True)
model.eval()
model = model.to(par['device'])
model = model.half() if par['half'] else model
par['saveType']='pth'
elif weights.endswith('.onnx'):
onnx_model = onnx.load(weights)
onnx.checker.check_model(onnx_model)
# 设置模型session以及输入信息
sess = ort.InferenceSession(str(weights),providers= ort.get_available_providers())
print('len():',len( sess.get_inputs() ))
input_name = sess.get_inputs()[0].name
model = {'sess':sess,'input_name':input_name}
par['saveType']='onnx'
elif weights.endswith('.engine'):
logger = trt.Logger(trt.Logger.ERROR)
with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件返回ICudaEngine对象
print('#####load TRT file:',weights,'success #####')
par['saveType']='trt'
decoder2 = decoder.DecDecoder(K=par['K'],
conf_thresh=par['conf_thresh'],
num_classes=par['num_classes'])
return model, decoder2