Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

191 lines
8.5KB

  1. import argparse
  2. from PIL import Image
  3. from crowdUtils.engine import standard_transforms,preprocess,postprocess,DictToObject,AnchorPointsf
  4. from crowdUtils.models import build_model
  5. from segutils import trtUtils2
  6. import os,torch,cv2,time
  7. import numpy as np
  8. import warnings
  9. import tensorrt as trt
  10. from copy import deepcopy
  11. warnings.filterwarnings('ignore')
  12. class crowdModel(object):
  13. def __init__(self, weights=None,
  14. par={'mean':[0.485, 0.456, 0.406], 'std':[0.229, 0.224, 0.225],'threshold':0.5,
  15. 'modelPar':{'backbone':'vgg16_bn', 'gpu_id':0,'anchorFlag':False,'line':2,'width':None,'height':None , 'output_dir':'./output', 'row':2}
  16. }
  17. ):
  18. print('-'*20,par['modelPar'] )
  19. self.mean = par['mean']
  20. self.std =par['std']
  21. self.width = par['modelPar']['width']
  22. self.height = par['modelPar']['height']
  23. self.minShape = par['input_profile_shapes'][0]
  24. self.maxShape = par['input_profile_shapes'][2]
  25. self.IOShapes0,self.IOShapes1 = [ None,None,None ],[ None,None,None ]
  26. self.Oshapes0,self.Oshapes1 = [ None,None,None ], [ None,None,None ]
  27. self.modelPar = DictToObject(par['modelPar'])
  28. self.threshold = par['threshold']
  29. self.device = 'cuda:0'
  30. if weights.endswith('.engine') or weights.endswith('.trt'):
  31. self.infer_type ='trt'
  32. elif weights.endswith('.pth') or weights.endswith('.pt') :
  33. self.infer_type ='pth'
  34. else:
  35. print('#########ERROR:',weights,': no registered inference type, exit')
  36. sys.exit(0)
  37. if self.infer_type=='trt':
  38. logger = trt.Logger(trt.Logger.ERROR)
  39. with open(weights, "rb") as f, trt.Runtime(logger) as runtime:
  40. self.engine=runtime.deserialize_cuda_engine(f.read())# 输入trt本地文件,返回ICudaEngine对象
  41. #self.stream=cuda.Stream()
  42. self.bindingNames=[ self.engine.get_binding_name(ib) for ib in range(len(self.engine)) ]
  43. print('############load seg model trt success: ',weights,self.bindingNames)
  44. self.inputs,self.outputs,self.bindings,self.stream=None,None,None,None
  45. self.context = self.engine.create_execution_context()
  46. elif self.infer_type=='pth':
  47. #self.model = DirectionalPointDetector(3, self.par['depth_factor'], self.par['NUM_FEATURE_MAP_CHANNEL']).to(self.device)
  48. self.model = build_model(self.modelPar)
  49. checkpoint = torch.load(args.weight_path, map_location='cpu')
  50. self.model.load_state_dict(checkpoint['model'])
  51. self.model=self.model.to(self.device)
  52. if not self.modelPar.anchorFlag:
  53. if self.infer_type=='trt':
  54. self.anchors = AnchorPointsf(pyramid_levels=[3,], strides=None, row=self.modelPar.row, line=self.modelPar.line,device='cpu' )
  55. elif self.infer_type=='pth':
  56. self.anchors = AnchorPointsf(pyramid_levels=[3,], strides=None, row=self.modelPar.row, line=self.modelPar.line ,device='cuda:0')
  57. print('#########加载模型:',weights,' 类型:',self.infer_type)
  58. def preprocess(self,img):
  59. tmpImg = preprocess(img,mean=self.mean, std=self.std,minShape=self.minShape,maxShape=self.maxShape)
  60. if self.infer_type=='pth':
  61. tmpImg = torch.from_numpy(tmpImg)
  62. tmpImg = torch.Tensor(tmpImg).unsqueeze(0)
  63. elif self.infer_type=='trt':
  64. #if not self.height:
  65. chs, height, width= tmpImg.shape[0:3]
  66. self.width, self.height = width,height
  67. self.IOShapes1 = [ (1, chs, height, width ),(1, height//4*width//4,2),(1, height//4*width//4,2) ]
  68. self.Oshapes1 = [ (1, height//4*width//4,2),(1, height//4*width//4,2) ]
  69. tmpImg = tmpImg[np.newaxis,:,:,:]#CHW->NCHW
  70. return tmpImg
  71. def ms(self,t1,t0):
  72. return '%.1f'%( (t1-t0)*1000 )
  73. def eval(self,img):
  74. time0 = time.time()
  75. img_b = img.copy()
  76. #print('-----line54:',img.shape)
  77. samples = self.preprocess(img)
  78. time1 = time.time()
  79. if self.infer_type=='pth':
  80. samples = samples.to(self.device)
  81. elif self.infer_type=='trt' :
  82. #print('##### line83: 决定是否申请 内存 ',self.IOShapes1, self.IOShapes0,self.IOShapes1==self.IOShapes0)
  83. #if self.IOShapes1 != self.IOShapes0:
  84. self.inputs,self.outputs,self.bindings,self.stream = trtUtils2.allocate_buffers(self.engine,self.IOShapes1)
  85. #print('##### line96: 开辟新内存成功 ' ,self.height,self.width)
  86. self.IOShapes0=deepcopy(self.IOShapes1)
  87. time2 = time.time()
  88. if not self.modelPar.anchorFlag:
  89. self.anchor_points = self.anchors.eval(samples)
  90. if self.infer_type=='pth':
  91. # run inference
  92. self.model.eval()
  93. with torch.no_grad():
  94. outputs = self.model(samples)
  95. outputs['pred_points'] = outputs['pred_points'] + self.anchor_points
  96. #print('###line64:',outputs.keys(), outputs['pred_points'].shape, outputs['pred_logits'].shape)
  97. elif self.infer_type=='trt':
  98. outputs = trtUtils2.trt_inference( samples,self.height,self.width,self.context,self.inputs,self.outputs,self.bindings,self.stream,input_name = self.bindingNames[0])
  99. for i in range(len(self.Oshapes1)):
  100. outputs[i] = torch.from_numpy( np.reshape(outputs[i],self.Oshapes1[i]))
  101. outputs={'pred_points':outputs[0], 'pred_logits':outputs[1]}
  102. #print('###line117:',outputs.keys(), outputs['pred_points'].shape, outputs['pred_logits'].shape)
  103. outputs['pred_points'] = outputs['pred_points'] + self.anchor_points
  104. time3 = time.time()
  105. points,scores = self.postprocess(outputs)
  106. time4 = time.time()
  107. infos = 'precess:%s datacopy:%s infer:%s post:%s'%( self.ms(time1,time0) , self.ms(time2,time1), self.ms(time3,time2), self.ms(time4,time3) )
  108. p2 = self.toOBBformat(points,scores,cls=0 )
  109. presults=[ img_b, points,p2 ]
  110. return presults, infos
  111. def postprocess(self,outputs):
  112. return postprocess(outputs,threshold=self.threshold)
  113. def toOBBformat(self,points,scores,cls=0):
  114. outs = []
  115. for i in range(len(points)):
  116. pt,score = points[i],scores[i]
  117. pts4=[pt]*4
  118. ret = [ pts4,score,cls]
  119. outs.append(ret)
  120. return outs
  121. def main():
  122. par={'mean':[0.485, 0.456, 0.406], 'std':[0.229, 0.224, 0.225],'threshold':0.5, 'output_dir':'./output','input_profile_shapes':[(1,3,256,256),(1,3,1024,1024),(1,3,2048,2048)],'modelPar':{'backbone':'vgg16_bn', 'gpu_id':0,'anchorFlag':False, 'width':None,'height':None ,'line':2, 'row':2}
  123. }
  124. weights='weights/best_mae_dynamic.engine'
  125. #weights='weights/best_mae.pth'
  126. cmodel = crowdModel(weights,par)
  127. img_path = "./testImages"
  128. File = os.listdir(img_path)
  129. targetList = []
  130. for file in File[0:]:
  131. COORlist = []
  132. imgPath = img_path + os.sep + file
  133. img_raw = np.array(Image.open(imgPath).convert('RGB') )
  134. points, infos = cmodel.eval(img_raw)
  135. print(file,infos,img_raw.shape)
  136. img_to_draw = cv2.cvtColor(np.array(img_raw), cv2.COLOR_RGB2BGR)
  137. # 打印预测图像中人头的个数
  138. for p in points:
  139. img_to_draw = cv2.circle(img_to_draw, (int(p[0]), int(p[1])), 2, (0, 255, 0), -1)
  140. COORlist.append((int(p[0]), int(p[1])))
  141. # 将各测试图像中的人头坐标存储在targetList中, 格式:[[(x1, y1),(x2, y2),...], [(X1, Y1),(X2, Y2),..], ...]
  142. targetList.append(COORlist)
  143. time.sleep(2)
  144. # 保存预测图片
  145. cv2.imwrite(os.path.join(par['output_dir'], file), img_to_draw)
  146. #print(targetList )
  147. if __name__ == '__main__':
  148. par = {'backbone':'vgg16_bn', 'gpu_id':0, 'line':2, 'output_dir':'./output', 'row':2,'anchorFlag':False, 'weight_path':'./weights/best_mae.pth'}
  149. args = DictToObject(par)
  150. targetList = main()
  151. print("line81", targetList)