用kafka接收消息
Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

132 linhas
4.8KB

  1. import torch
  2. from core.models.bisenet import BiSeNet
  3. from torchvision import transforms
  4. import cv2,os
  5. import numpy as np
  6. from core.models.dinknet import DinkNet34
  7. import matplotlib.pyplot as plt
  8. import matplotlib.pyplot as plt
  9. import time
  10. class SegModel(object):
  11. def __init__(self, nclass=2,weights=None,modelsize=512,device='cuda:3'):
  12. #self.args = args
  13. self.model = BiSeNet(nclass)
  14. #self.model = DinkNet34(nclass)
  15. checkpoint = torch.load(weights)
  16. self.modelsize = modelsize
  17. self.model.load_state_dict(checkpoint['model'])
  18. self.device = device
  19. self.model= self.model.to(self.device)
  20. '''self.composed_transforms = transforms.Compose([
  21. transforms.Normalize(mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)),
  22. transforms.ToTensor()]) '''
  23. self.mean = (0.335, 0.358, 0.332)
  24. self.std = (0.141, 0.138, 0.143)
  25. def eval(self,image,outsize=None):
  26. imageW,imageH,imageC = image.shape
  27. time0 = time.time()
  28. image = self.preprocess_image(image)
  29. time1 = time.time()
  30. self.model.eval()
  31. image = image.to(self.device)
  32. with torch.no_grad():
  33. output = self.model(image,outsize=outsize)
  34. time2 = time.time()
  35. pred = output.data.cpu().numpy()
  36. pred = np.argmax(pred, axis=1)[0]#得到每行
  37. time3 = time.time()
  38. pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
  39. time4 = time.time()
  40. print('pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ))
  41. return pred
  42. def get_ms(self,t1,t0):
  43. return (t1-t0)*1000.0
  44. def preprocess_image(self,image):
  45. time0 = time.time()
  46. image = cv2.resize(image,(self.modelsize,self.modelsize))
  47. time1 = time.time()
  48. image = image.astype(np.float32)
  49. image /= 255.0
  50. time2 = time.time()
  51. #image -= self.mean
  52. image[:,:,0] -=self.mean[0]
  53. image[:,:,1] -=self.mean[1]
  54. image[:,:,2] -=self.mean[2]
  55. time3 = time.time()
  56. #image /= self.std
  57. image[:,:,0] /= self.std[0]
  58. image[:,:,1] /= self.std[1]
  59. image[:,:,2] /= self.std[2]
  60. time4 = time.time()
  61. image = np.transpose(image, ( 2, 0, 1))
  62. time5 = time.time()
  63. image = torch.from_numpy(image).float()
  64. image = image.unsqueeze(0)
  65. print('resize:%.1f norm:%.1f mean:%.1f std:%.1f trans:%.f '%(self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ,self.get_ms(time5,time4) ) )
  66. return image
  67. def get_ms(t1,t0):
  68. return (t1-t0)*1000.0
  69. if __name__=='__main__':
  70. #os.environ["CUDA_VISIBLE_DEVICES"] = str('4')
  71. '''
  72. image_url = '../../data/landcover/corp512/test/images/N-33-139-C-d-2-4_169.jpg'
  73. nclass = 5
  74. weights = 'runs/landcover/DinkNet34_save/experiment_wj_loss-10-10-1/checkpoint.pth'
  75. '''
  76. image_url = 'temp_pics/DJI_0645.JPG'
  77. nclass = 2
  78. #weights = '../weights/segmentation/BiSeNet/checkpoint.pth'
  79. weights = 'runs/THriver/BiSeNet/train/experiment_0/checkpoint.pth'
  80. #weights = 'runs/segmentation/BiSeNet_test/experiment_10/checkpoint.pth'
  81. segmodel = SegModel(nclass=nclass,weights=weights,device='cuda:4')
  82. for i in range(10):
  83. image_array0 = cv2.imread(image_url)
  84. imageH,imageW,_ = image_array0.shape
  85. #print('###line84:',image_array0.shape)
  86. image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
  87. #image_in = segmodel.preprocess_image(image_array)
  88. pred = segmodel.eval(image_array,outsize=None)
  89. time0=time.time()
  90. binary = pred.copy()
  91. time1=time.time()
  92. contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  93. time2=time.time()
  94. print(pred.shape,' time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
  95. ##计算findconturs时间与大小的关系
  96. binary0 = binary.copy()
  97. for ii,ss in enumerate([22,256,512,1024,2048]):
  98. time0=time.time()
  99. image = cv2.resize(binary0,(ss,ss))
  100. time1=time.time()
  101. if ii ==0:
  102. contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  103. else:
  104. contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  105. time2=time.time()
  106. print('size:%d resize:%.1f ,findtime:%.1f '%(ss, get_ms(time1,time0),get_ms(time2,time1)))