AIlib2/DrGraph/util/stdc.py

134 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from loguru import logger
from models.experimental import attempt_load
import tensorrt as trt
import torch
import sys
from DrGraph.util.segutils.trtUtils import segPreProcess_image,segTrtForward,segPreProcess_image_torch
from DrGraph.util.segutils.model_stages import BiSeNet_STDC
import time,cv2
import numpy as np
from DrGraph.util.drHelper import *
class stdcModel(object):
def __init__(self, weights=None,
par={'modelSize':(640,360),'dynamic':False,'nclass':2,'predResize':True,'mean':(0.485, 0.456, 0.406),'std' :(0.229, 0.224, 0.225),'numpy':False, 'RGB_convert_first':True}
):
self.par = par
self.device = 'cuda:0'
self.half =True
if 'dynamic' not in par.keys():
self.dynamic=False
else: self.dynamic=par['dynamic']
if weights.endswith('.engine'):
self. infer_type ='trt'
elif weights.endswith('.pth') or weights.endswith('.pt') :
self. infer_type ='pth'
else:
logger.error(f'{weights}: no registered inference type, exit')
sys.exit(0)
if self.infer_type=='trt':
if self.dynamic :
logger.error('STDC动态模型不能采用trt格式')
trt_logger = trt.Logger(trt.Logger.ERROR)
with open(weights, "rb") as f, trt.Runtime(trt_logger) as runtime:
self.model=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件返回ICudaEngine对象
elif self.infer_type=='pth':
if self.dynamic: modelSize=None
else: modelSize=( self.par['modelSize'][1], self.par['modelSize'][0] )
self.model = BiSeNet_STDC(backbone='STDCNet813', n_classes=par['seg_nclass'],
use_boundary_2=False, use_boundary_4=False,
use_boundary_8=True, use_boundary_16=False,
use_conv_last=False,
modelSize = modelSize
)
self.model.load_state_dict(torch.load(weights, map_location=torch.device(self.device) ))
self.model= self.model.to(self.device)
logger.info('加载 stdcModel 模型:',weights,' 类型:',self.infer_type)
def preprocess_image(self,image):
image = self.RB_convert(image)
re_size = (640, 360)
if 'modelSize' in self.par:
if self.dynamic:
H,W=image.shape[0:2];
yscale = self.par['modelSize'][1]/H
xscale = self.par['modelSize'][0]/W
dscale = min(yscale,xscale)
re_size = ( int((dscale*W)//4*4), int( (dscale*H)//4*4 ) )
else: re_size = self.par['modelSize']
else:
logger.warning('modelSize not in par, use default size(640, 360)')
#print('####line 58:,', re_size,image.shape)
image = cv2.resize(image, re_size, interpolation=cv2.INTER_LINEAR)
image = image.astype(np.float32)
image /= 255.0
if 'mean' not in self.par:
self.par['mean'] = (0.485, 0.456, 0.406)
logger.warning('mean not in par, use default mean(0.485, 0.456, 0.406)')
if 'std' not in self.par:
self.par['std'] = (0.229, 0.224, 0.225)
logger.warning('std not in par, use default std(0.229, 0.224, 0.225)')
image[:, :, 0] -= self.par['mean'][0]
image[:, :, 1] -= self.par['mean'][1]
image[:, :, 2] -= self.par['mean'][2]
image[:, :, 0] /= self.par['std'][0]
image[:, :, 1] /= self.par['std'][1]
image[:, :, 2] /= self.par['std'][2]
image = np.transpose(image, (2, 0, 1))
image = torch.from_numpy(image).float()
image = image.unsqueeze(0)
if self.device != 'cpu':
image = image.to(self.device)
return image
def RB_convert(self,image):
image_c = image.copy()
image_c[:,:,0] = image[:,:,2]
image_c[:,:,2] = image[:,:,0]
return image_c
def eval(self,image):
time0 = time.time()
imageH, imageW, _ = image.shape
img = self.preprocess_image(image)
time1 = time.time()
if self.infer_type=='trt':
pred=segTrtForward(self.model,[img])
elif self.infer_type=='pth':
self.model.eval()
with torch.no_grad():
pred = self.model(img)
time2 = time.time()
pred=torch.argmax(pred,dim=1).cpu().numpy()[0]
time3 = time.time()
pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
time4 = time.time()
outstr= 'pre-precess:%.1f ,infer:%.1f ,post-cpu-argmax:%.1f ,post-resize:%.1f, total:%.1f \n '%( \
timeHelper.deltaTime_MS(time1,time0),\
timeHelper.deltaTime_MS(time2,time1),\
timeHelper.deltaTime_MS(time3,time2),\
timeHelper.deltaTime_MS(time4,time3),\
timeHelper.deltaTime_MS(time4,time0) )
return pred,outstr