from loguru import logger from models.experimental import attempt_load import tensorrt as trt import torch import sys from DrGraph.util.segutils.trtUtils import segPreProcess_image,segTrtForward,segPreProcess_image_torch from DrGraph.util.segutils.model_stages import BiSeNet_STDC import time,cv2 import numpy as np from DrGraph.util.drHelper import * class stdcModel(object): def __init__(self, weights=None, par={'modelSize':(640,360),'dynamic':False,'nclass':2,'predResize':True,'mean':(0.485, 0.456, 0.406),'std' :(0.229, 0.224, 0.225),'numpy':False, 'RGB_convert_first':True} ): self.par = par self.device = 'cuda:0' self.half =True if 'dynamic' not in par.keys(): self.dynamic=False else: self.dynamic=par['dynamic'] if weights.endswith('.engine'): self. infer_type ='trt' elif weights.endswith('.pth') or weights.endswith('.pt') : self. infer_type ='pth' else: logger.error(f'{weights}: no registered inference type, exit') sys.exit(0) if self.infer_type=='trt': if self.dynamic : logger.error('STDC动态模型不能采用trt格式') trt_logger = trt.Logger(trt.Logger.ERROR) with open(weights, "rb") as f, trt.Runtime(trt_logger) as runtime: self.model=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象 elif self.infer_type=='pth': if self.dynamic: modelSize=None else: modelSize=( self.par['modelSize'][1], self.par['modelSize'][0] ) self.model = BiSeNet_STDC(backbone='STDCNet813', n_classes=par['seg_nclass'], use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False, use_conv_last=False, modelSize = modelSize ) self.model.load_state_dict(torch.load(weights, map_location=torch.device(self.device) )) self.model= self.model.to(self.device) logger.info('加载 stdcModel 模型:',weights,' 类型:',self.infer_type) def preprocess_image(self,image): image = self.RB_convert(image) re_size = (640, 360) if 'modelSize' in self.par: if self.dynamic: H,W=image.shape[0:2]; yscale = self.par['modelSize'][1]/H xscale = self.par['modelSize'][0]/W dscale = min(yscale,xscale) re_size = ( int((dscale*W)//4*4), int( (dscale*H)//4*4 ) ) else: re_size = self.par['modelSize'] else: logger.warning('modelSize not in par, use default size(640, 360)') #print('####line 58:,', re_size,image.shape) image = cv2.resize(image, re_size, interpolation=cv2.INTER_LINEAR) image = image.astype(np.float32) image /= 255.0 if 'mean' not in self.par: self.par['mean'] = (0.485, 0.456, 0.406) logger.warning('mean not in par, use default mean(0.485, 0.456, 0.406)') if 'std' not in self.par: self.par['std'] = (0.229, 0.224, 0.225) logger.warning('std not in par, use default std(0.229, 0.224, 0.225)') image[:, :, 0] -= self.par['mean'][0] image[:, :, 1] -= self.par['mean'][1] image[:, :, 2] -= self.par['mean'][2] image[:, :, 0] /= self.par['std'][0] image[:, :, 1] /= self.par['std'][1] image[:, :, 2] /= self.par['std'][2] image = np.transpose(image, (2, 0, 1)) image = torch.from_numpy(image).float() image = image.unsqueeze(0) if self.device != 'cpu': image = image.to(self.device) return image def RB_convert(self,image): image_c = image.copy() image_c[:,:,0] = image[:,:,2] image_c[:,:,2] = image[:,:,0] return image_c def eval(self,image): time0 = time.time() imageH, imageW, _ = image.shape img = self.preprocess_image(image) time1 = time.time() if self.infer_type=='trt': pred=segTrtForward(self.model,[img]) elif self.infer_type=='pth': self.model.eval() with torch.no_grad(): pred = self.model(img) time2 = time.time() pred=torch.argmax(pred,dim=1).cpu().numpy()[0] time3 = time.time() pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH)) time4 = time.time() outstr= 'pre-precess:%.1f ,infer:%.1f ,post-cpu-argmax:%.1f ,post-resize:%.1f, total:%.1f \n '%( \ timeHelper.deltaTime_MS(time1,time0),\ timeHelper.deltaTime_MS(time2,time1),\ timeHelper.deltaTime_MS(time3,time2),\ timeHelper.deltaTime_MS(time4,time3),\ timeHelper.deltaTime_MS(time4,time0) ) return pred,outstr