You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

132 lines
4.8KB

  1. import torch
  2. from core.models.bisenet import BiSeNet
  3. from torchvision import transforms
  4. import cv2,os
  5. import numpy as np
  6. from core.models.dinknet import DinkNet34
  7. import matplotlib.pyplot as plt
  8. import matplotlib.pyplot as plt
  9. import time
  10. class SegModel(object):
  11. def __init__(self, nclass=2,weights=None,modelsize=512,device='cuda:3'):
  12. #self.args = args
  13. self.model = BiSeNet(nclass)
  14. #self.model = DinkNet34(nclass)
  15. checkpoint = torch.load(weights)
  16. self.modelsize = modelsize
  17. self.model.load_state_dict(checkpoint['model'])
  18. self.device = device
  19. self.model= self.model.to(self.device)
  20. '''self.composed_transforms = transforms.Compose([
  21. transforms.Normalize(mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)),
  22. transforms.ToTensor()]) '''
  23. self.mean = (0.335, 0.358, 0.332)
  24. self.std = (0.141, 0.138, 0.143)
  25. def eval(self,image,outsize=None):
  26. imageW,imageH,imageC = image.shape
  27. time0 = time.time()
  28. image = self.preprocess_image(image)
  29. time1 = time.time()
  30. self.model.eval()
  31. image = image.to(self.device)
  32. with torch.no_grad():
  33. output = self.model(image,outsize=outsize)
  34. time2 = time.time()
  35. pred = output.data.cpu().numpy()
  36. pred = np.argmax(pred, axis=1)[0]#得到每行
  37. time3 = time.time()
  38. pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
  39. time4 = time.time()
  40. print('pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ))
  41. return pred
  42. def get_ms(self,t1,t0):
  43. return (t1-t0)*1000.0
  44. def preprocess_image(self,image):
  45. time0 = time.time()
  46. image = cv2.resize(image,(self.modelsize,self.modelsize))
  47. time1 = time.time()
  48. image = image.astype(np.float32)
  49. image /= 255.0
  50. time2 = time.time()
  51. #image -= self.mean
  52. image[:,:,0] -=self.mean[0]
  53. image[:,:,1] -=self.mean[1]
  54. image[:,:,2] -=self.mean[2]
  55. time3 = time.time()
  56. #image /= self.std
  57. image[:,:,0] /= self.std[0]
  58. image[:,:,1] /= self.std[1]
  59. image[:,:,2] /= self.std[2]
  60. time4 = time.time()
  61. image = np.transpose(image, ( 2, 0, 1))
  62. time5 = time.time()
  63. image = torch.from_numpy(image).float()
  64. image = image.unsqueeze(0)
  65. print('resize:%.1f norm:%.1f mean:%.1f std:%.1f trans:%.f '%(self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ,self.get_ms(time5,time4) ) )
  66. return image
  67. def get_ms(t1,t0):
  68. return (t1-t0)*1000.0
  69. if __name__=='__main__':
  70. #os.environ["CUDA_VISIBLE_DEVICES"] = str('4')
  71. '''
  72. image_url = '../../data/landcover/corp512/test/images/N-33-139-C-d-2-4_169.jpg'
  73. nclass = 5
  74. weights = 'runs/landcover/DinkNet34_save/experiment_wj_loss-10-10-1/checkpoint.pth'
  75. '''
  76. image_url = 'temp_pics/DJI_0645.JPG'
  77. nclass = 2
  78. #weights = '../weights/segmentation/BiSeNet/checkpoint.pth'
  79. weights = 'runs/THriver/BiSeNet/train/experiment_0/checkpoint.pth'
  80. #weights = 'runs/segmentation/BiSeNet_test/experiment_10/checkpoint.pth'
  81. segmodel = SegModel(nclass=nclass,weights=weights,device='cuda:4')
  82. for i in range(10):
  83. image_array0 = cv2.imread(image_url)
  84. imageH,imageW,_ = image_array0.shape
  85. #print('###line84:',image_array0.shape)
  86. image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
  87. #image_in = segmodel.preprocess_image(image_array)
  88. pred = segmodel.eval(image_array,outsize=None)
  89. time0=time.time()
  90. binary = pred.copy()
  91. time1=time.time()
  92. contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  93. time2=time.time()
  94. print(pred.shape,' time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
  95. ##计算findconturs时间与大小的关系
  96. binary0 = binary.copy()
  97. for ii,ss in enumerate([22,256,512,1024,2048]):
  98. time0=time.time()
  99. image = cv2.resize(binary0,(ss,ss))
  100. time1=time.time()
  101. if ii ==0:
  102. contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  103. else:
  104. contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  105. time2=time.time()
  106. print('size:%d resize:%.1f ,findtime:%.1f '%(ss, get_ms(time1,time0),get_ms(time2,time1)))