Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.

132 lines
4.8KB

  1. import torch
  2. from core.models.bisenet import BiSeNet
  3. from torchvision import transforms
  4. import cv2,os
  5. import numpy as np
  6. from core.models.dinknet import DinkNet34
  7. import matplotlib.pyplot as plt
  8. import matplotlib.pyplot as plt
  9. import time
  10. class SegModel(object):
  11. def __init__(self, nclass=2,weights=None,modelsize=512,device='cuda:3'):
  12. #self.args = args
  13. self.model = BiSeNet(nclass)
  14. #self.model = DinkNet34(nclass)
  15. checkpoint = torch.load(weights)
  16. self.modelsize = modelsize
  17. self.model.load_state_dict(checkpoint['model'])
  18. self.device = device
  19. self.model= self.model.to(self.device)
  20. '''self.composed_transforms = transforms.Compose([
  21. transforms.Normalize(mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)),
  22. transforms.ToTensor()]) '''
  23. self.mean = (0.335, 0.358, 0.332)
  24. self.std = (0.141, 0.138, 0.143)
  25. def eval(self,image,outsize=None):
  26. imageW,imageH,imageC = image.shape
  27. time0 = time.time()
  28. image = self.preprocess_image(image)
  29. time1 = time.time()
  30. self.model.eval()
  31. image = image.to(self.device)
  32. with torch.no_grad():
  33. output = self.model(image,outsize=outsize)
  34. time2 = time.time()
  35. pred = output.data.cpu().numpy()
  36. pred = np.argmax(pred, axis=1)[0]#得到每行
  37. time3 = time.time()
  38. pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
  39. time4 = time.time()
  40. print('pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ))
  41. return pred
  42. def get_ms(self,t1,t0):
  43. return (t1-t0)*1000.0
  44. def preprocess_image(self,image):
  45. time0 = time.time()
  46. image = cv2.resize(image,(self.modelsize,self.modelsize))
  47. time1 = time.time()
  48. image = image.astype(np.float32)
  49. image /= 255.0
  50. time2 = time.time()
  51. #image -= self.mean
  52. image[:,:,0] -=self.mean[0]
  53. image[:,:,1] -=self.mean[1]
  54. image[:,:,2] -=self.mean[2]
  55. time3 = time.time()
  56. #image /= self.std
  57. image[:,:,0] /= self.std[0]
  58. image[:,:,1] /= self.std[1]
  59. image[:,:,2] /= self.std[2]
  60. time4 = time.time()
  61. image = np.transpose(image, ( 2, 0, 1))
  62. time5 = time.time()
  63. image = torch.from_numpy(image).float()
  64. image = image.unsqueeze(0)
  65. print('resize:%.1f norm:%.1f mean:%.1f std:%.1f trans:%.f '%(self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ,self.get_ms(time5,time4) ) )
  66. return image
  67. def get_ms(t1,t0):
  68. return (t1-t0)*1000.0
  69. if __name__=='__main__':
  70. #os.environ["CUDA_VISIBLE_DEVICES"] = str('4')
  71. '''
  72. image_url = '../../data/landcover/corp512/test/images/N-33-139-C-d-2-4_169.jpg'
  73. nclass = 5
  74. weights = 'runs/landcover/DinkNet34_save/experiment_wj_loss-10-10-1/checkpoint.pth'
  75. '''
  76. image_url = 'temp_pics/DJI_0645.JPG'
  77. nclass = 2
  78. #weights = '../weights/segmentation/BiSeNet/checkpoint.pth'
  79. weights = 'runs/THriver/BiSeNet/train/experiment_0/checkpoint.pth'
  80. #weights = 'runs/segmentation/BiSeNet_test/experiment_10/checkpoint.pth'
  81. segmodel = SegModel(nclass=nclass,weights=weights,device='cuda:4')
  82. for i in range(10):
  83. image_array0 = cv2.imread(image_url)
  84. imageH,imageW,_ = image_array0.shape
  85. #print('###line84:',image_array0.shape)
  86. image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
  87. #image_in = segmodel.preprocess_image(image_array)
  88. pred = segmodel.eval(image_array,outsize=None)
  89. time0=time.time()
  90. binary = pred.copy()
  91. time1=time.time()
  92. contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  93. time2=time.time()
  94. print(pred.shape,' time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
  95. ##计算findconturs时间与大小的关系
  96. binary0 = binary.copy()
  97. for ii,ss in enumerate([22,256,512,1024,2048]):
  98. time0=time.time()
  99. image = cv2.resize(binary0,(ss,ss))
  100. time1=time.time()
  101. if ii ==0:
  102. contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  103. else:
  104. contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  105. time2=time.time()
  106. print('size:%d resize:%.1f ,findtime:%.1f '%(ss, get_ms(time1,time0),get_ms(time2,time1)))