|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- import torch
- import numpy as np
- import cv2
- import time
- import os
- import sys
- sys.path.extend(['../AIlib2/obbUtils'])
- import matplotlib.pyplot as plt
- import func_utils
- import time
- import torchvision.transforms as transforms
- from obbmodels import ctrbox_net
- import decoder
- import tensorrt as trt
- import onnx
- import onnxruntime as ort
-
- def load_model_decoder_OBB(par={'down_ratio':4,'num_classes':15,'weights':'weights_dota/obb.pth'}):
- weights=par['weights']
- heads = par['heads']
- heads['hm']=par['num_classes']
- par['heads']=heads
- if weights.endswith('.pth') or weights.endswith('.pt'):
- resume=par['weights']
- down_ratio = par['down_ratio']
- model = ctrbox_net.CTRBOX(heads=heads,
- pretrained=True,
- down_ratio=down_ratio,
- final_kernel=1,
- head_conv=256)
-
-
- checkpoint = torch.load(resume, map_location=lambda storage, loc: storage)
- print('loaded weights from {}, epoch {}'.format(resume, checkpoint['epoch']))
- state_dict_ = checkpoint['model_state_dict']
- model.load_state_dict(state_dict_, strict=True)
- model.eval()
- model = model.to(par['device'])
- model = model.half() if par['half'] else model
- par['saveType']='pth'
- elif weights.endswith('.onnx'):
- onnx_model = onnx.load(weights)
- onnx.checker.check_model(onnx_model)
- # 设置模型session以及输入信息
- sess = ort.InferenceSession(str(weights),providers= ort.get_available_providers())
- print('len():',len( sess.get_inputs() ))
- input_name = sess.get_inputs()[0].name
- model = {'sess':sess,'input_name':input_name}
- par['saveType']='onnx'
- elif weights.endswith('.engine'):
- logger = trt.Logger(trt.Logger.ERROR)
- with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
- model = runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
- print('#####load TRT file:',weights,'success #####')
- par['saveType']='trt'
-
-
- decoder2 = decoder.DecDecoder(K=par['K'],
- conf_thresh=par['conf_thresh'],
- num_classes=par['num_classes'])
-
-
-
- return model, decoder2
|