You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

121 lines
4.6KB

  1. from models.experimental import attempt_load
  2. import tensorrt as trt
  3. import torch
  4. import sys
  5. from segutils.trtUtils import segPreProcess_image,segTrtForward,segPreProcess_image_torch
  6. from segutils.model_stages import BiSeNet_STDC
  7. import time,cv2
  8. import numpy as np
  9. class stdcModel(object):
  10. def __init__(self, weights=None,
  11. par={'modelSize':(640,360),'dynamic':False,'nclass':2,'predResize':True,'mean':(0.485, 0.456, 0.406),'std' :(0.229, 0.224, 0.225),'numpy':False, 'RGB_convert_first':True}
  12. ):
  13. self.par = par
  14. self.device = 'cuda:0'
  15. self.half =True
  16. if 'dynamic' not in par.keys():
  17. self.dynamic=False
  18. else: self.dynamic=par['dynamic']
  19. if weights.endswith('.engine'):
  20. self. infer_type ='trt'
  21. elif weights.endswith('.pth') or weights.endswith('.pt') :
  22. self. infer_type ='pth'
  23. else:
  24. print('#########ERROR:',weights,': no registered inference type, exit')
  25. sys.exit(0)
  26. if self.infer_type=='trt':
  27. if self.dynamic :
  28. print('####################ERROR##########,STDC动态模型不能采用trt格式########')
  29. logger = trt.Logger(trt.Logger.ERROR)
  30. with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
  31. self.model=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
  32. elif self.infer_type=='pth':
  33. if self.dynamic: modelSize=None
  34. else: modelSize=( self.par['modelSize'][1], self.par['modelSize'][0] )
  35. self.model = BiSeNet_STDC(backbone='STDCNet813', n_classes=par['seg_nclass'],
  36. use_boundary_2=False, use_boundary_4=False,
  37. use_boundary_8=True, use_boundary_16=False,
  38. use_conv_last=False,
  39. modelSize = modelSize
  40. )
  41. self.model.load_state_dict(torch.load(weights, map_location=torch.device(self.device) ))
  42. self.model= self.model.to(self.device)
  43. print('#########加载模型:',weights,' 类型:',self.infer_type)
  44. def preprocess_image(self,image):
  45. image = self.RB_convert(image)
  46. if self.dynamic:
  47. H,W=image.shape[0:2];
  48. yscale = self.par['modelSize'][1]/H
  49. xscale = self.par['modelSize'][0]/W
  50. dscale = min(yscale,xscale)
  51. re_size = ( int((dscale*W)//4*4), int( (dscale*H)//4*4 ) )
  52. else: re_size = self.par['modelSize']
  53. #print('####line 58:,', re_size,image.shape)
  54. image = cv2.resize(image,re_size, interpolation=cv2.INTER_LINEAR)
  55. image = image.astype(np.float32)
  56. image /= 255.0
  57. image[:, :, 0] -= self.par['mean'][0]
  58. image[:, :, 1] -= self.par['mean'][1]
  59. image[:, :, 2] -= self.par['mean'][2]
  60. image[:, :, 0] /= self.par['std'][0]
  61. image[:, :, 1] /= self.par['std'][1]
  62. image[:, :, 2] /= self.par['std'][2]
  63. image = np.transpose(image, (2, 0, 1))
  64. image = torch.from_numpy(image).float()
  65. image = image.unsqueeze(0)
  66. if self.device != 'cpu':
  67. image = image.to(self.device)
  68. return image
  69. def RB_convert(self,image):
  70. image_c = image.copy()
  71. image_c[:,:,0] = image[:,:,2]
  72. image_c[:,:,2] = image[:,:,0]
  73. return image_c
  74. def get_ms(self,t1,t0):
  75. return (t1-t0)*1000.0
  76. def eval(self,image):
  77. time0 = time.time()
  78. imageH, imageW, _ = image.shape
  79. img = self.preprocess_image(image)
  80. time1 = time.time()
  81. if self.infer_type=='trt':
  82. pred=segTrtForward(self.model,[img])
  83. elif self.infer_type=='pth':
  84. self.model.eval()
  85. with torch.no_grad():
  86. pred = self.model(img)
  87. time2 = time.time()
  88. pred=torch.argmax(pred,dim=1).cpu().numpy()[0]
  89. time3 = time.time()
  90. pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
  91. time4 = time.time()
  92. outstr= 'pre-precess:%.1f ,infer:%.1f ,post-cpu-argmax:%.1f ,post-resize:%.1f, total:%.1f \n '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) )
  93. return pred,outstr