AIlib2/stdc.py

121 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from models.experimental import attempt_load
import tensorrt as trt
import torch
import sys
from segutils.trtUtils import segPreProcess_image,segTrtForward,segPreProcess_image_torch
from segutils.model_stages import BiSeNet_STDC
import time,cv2
import numpy as np
class stdcModel(object):
def __init__(self, weights=None,
par={'modelSize':(640,360),'dynamic':False,'nclass':2,'predResize':True,'mean':(0.485, 0.456, 0.406),'std' :(0.229, 0.224, 0.225),'numpy':False, 'RGB_convert_first':True}
):
self.par = par
self.device = 'cuda:0'
self.half =True
if 'dynamic' not in par.keys():
self.dynamic=False
else: self.dynamic=par['dynamic']
if weights.endswith('.engine'):
self. infer_type ='trt'
elif weights.endswith('.pth') or weights.endswith('.pt') :
self. infer_type ='pth'
else:
print('#########ERROR:',weights,': no registered inference type, exit')
sys.exit(0)
if self.infer_type=='trt':
if self.dynamic :
print('####################ERROR##########,STDC动态模型不能采用trt格式########')
logger = trt.Logger(trt.Logger.ERROR)
with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
self.model=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件返回ICudaEngine对象
elif self.infer_type=='pth':
if self.dynamic: modelSize=None
else: modelSize=( self.par['modelSize'][1], self.par['modelSize'][0] )
self.model = BiSeNet_STDC(backbone='STDCNet813', n_classes=par['seg_nclass'],
use_boundary_2=False, use_boundary_4=False,
use_boundary_8=True, use_boundary_16=False,
use_conv_last=False,
modelSize = modelSize
)
self.model.load_state_dict(torch.load(weights, map_location=torch.device(self.device) ))
self.model= self.model.to(self.device)
print('#########加载模型:',weights,' 类型:',self.infer_type)
def preprocess_image(self,image):
image = self.RB_convert(image)
if self.dynamic:
H,W=image.shape[0:2];
yscale = self.par['modelSize'][1]/H
xscale = self.par['modelSize'][0]/W
dscale = min(yscale,xscale)
re_size = ( int((dscale*W)//4*4), int( (dscale*H)//4*4 ) )
else: re_size = self.par['modelSize']
#print('####line 58:,', re_size,image.shape)
image = cv2.resize(image,re_size, interpolation=cv2.INTER_LINEAR)
image = image.astype(np.float32)
image /= 255.0
image[:, :, 0] -= self.par['mean'][0]
image[:, :, 1] -= self.par['mean'][1]
image[:, :, 2] -= self.par['mean'][2]
image[:, :, 0] /= self.par['std'][0]
image[:, :, 1] /= self.par['std'][1]
image[:, :, 2] /= self.par['std'][2]
image = np.transpose(image, (2, 0, 1))
image = torch.from_numpy(image).float()
image = image.unsqueeze(0)
if self.device != 'cpu':
image = image.to(self.device)
return image
def RB_convert(self,image):
image_c = image.copy()
image_c[:,:,0] = image[:,:,2]
image_c[:,:,2] = image[:,:,0]
return image_c
def get_ms(self,t1,t0):
return (t1-t0)*1000.0
def eval(self,image):
time0 = time.time()
imageH, imageW, _ = image.shape
img = self.preprocess_image(image)
time1 = time.time()
if self.infer_type=='trt':
pred=segTrtForward(self.model,[img])
elif self.infer_type=='pth':
self.model.eval()
with torch.no_grad():
pred = self.model(img)
time2 = time.time()
pred=torch.argmax(pred,dim=1).cpu().numpy()[0]
time3 = time.time()
pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
time4 = time.time()
outstr= 'pre-precess:%.1f ,infer:%.1f ,post-cpu-argmax:%.1f ,post-resize:%.1f, total:%.1f \n '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) )
return pred,outstr