用kafka接收消息
Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

132 rindas
4.8KB

  1. import torch
  2. from core.models.bisenet import BiSeNet
  3. from torchvision import transforms
  4. import cv2,os
  5. import numpy as np
  6. from core.models.dinknet import DinkNet34
  7. import matplotlib.pyplot as plt
  8. import matplotlib.pyplot as plt
  9. import time
  10. class SegModel(object):
  11. def __init__(self, nclass=2,weights=None,modelsize=512,device='cuda:3'):
  12. #self.args = args
  13. self.model = BiSeNet(nclass)
  14. #self.model = DinkNet34(nclass)
  15. checkpoint = torch.load(weights)
  16. self.modelsize = modelsize
  17. self.model.load_state_dict(checkpoint['model'])
  18. self.device = device
  19. self.model= self.model.to(self.device)
  20. '''self.composed_transforms = transforms.Compose([
  21. transforms.Normalize(mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)),
  22. transforms.ToTensor()]) '''
  23. self.mean = (0.335, 0.358, 0.332)
  24. self.std = (0.141, 0.138, 0.143)
  25. def eval(self,image,outsize=None):
  26. imageW,imageH,imageC = image.shape
  27. time0 = time.time()
  28. image = self.preprocess_image(image)
  29. time1 = time.time()
  30. self.model.eval()
  31. image = image.to(self.device)
  32. with torch.no_grad():
  33. output = self.model(image,outsize=outsize)
  34. time2 = time.time()
  35. pred = output.data.cpu().numpy()
  36. pred = np.argmax(pred, axis=1)[0]#得到每行
  37. time3 = time.time()
  38. pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
  39. time4 = time.time()
  40. print('pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ))
  41. return pred
  42. def get_ms(self,t1,t0):
  43. return (t1-t0)*1000.0
  44. def preprocess_image(self,image):
  45. time0 = time.time()
  46. image = cv2.resize(image,(self.modelsize,self.modelsize))
  47. time1 = time.time()
  48. image = image.astype(np.float32)
  49. image /= 255.0
  50. time2 = time.time()
  51. #image -= self.mean
  52. image[:,:,0] -=self.mean[0]
  53. image[:,:,1] -=self.mean[1]
  54. image[:,:,2] -=self.mean[2]
  55. time3 = time.time()
  56. #image /= self.std
  57. image[:,:,0] /= self.std[0]
  58. image[:,:,1] /= self.std[1]
  59. image[:,:,2] /= self.std[2]
  60. time4 = time.time()
  61. image = np.transpose(image, ( 2, 0, 1))
  62. time5 = time.time()
  63. image = torch.from_numpy(image).float()
  64. image = image.unsqueeze(0)
  65. print('resize:%.1f norm:%.1f mean:%.1f std:%.1f trans:%.f '%(self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ,self.get_ms(time5,time4) ) )
  66. return image
  67. def get_ms(t1,t0):
  68. return (t1-t0)*1000.0
  69. if __name__=='__main__':
  70. #os.environ["CUDA_VISIBLE_DEVICES"] = str('4')
  71. '''
  72. image_url = '../../data/landcover/corp512/test/images/N-33-139-C-d-2-4_169.jpg'
  73. nclass = 5
  74. weights = 'runs/landcover/DinkNet34_save/experiment_wj_loss-10-10-1/checkpoint.pth'
  75. '''
  76. image_url = 'temp_pics/DJI_0645.JPG'
  77. nclass = 2
  78. #weights = '../weights/segmentation/BiSeNet/checkpoint.pth'
  79. weights = 'runs/THriver/BiSeNet/train/experiment_0/checkpoint.pth'
  80. #weights = 'runs/segmentation/BiSeNet_test/experiment_10/checkpoint.pth'
  81. segmodel = SegModel(nclass=nclass,weights=weights,device='cuda:4')
  82. for i in range(10):
  83. image_array0 = cv2.imread(image_url)
  84. imageH,imageW,_ = image_array0.shape
  85. #print('###line84:',image_array0.shape)
  86. image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
  87. #image_in = segmodel.preprocess_image(image_array)
  88. pred = segmodel.eval(image_array,outsize=None)
  89. time0=time.time()
  90. binary = pred.copy()
  91. time1=time.time()
  92. contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  93. time2=time.time()
  94. print(pred.shape,' time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
  95. ##计算findconturs时间与大小的关系
  96. binary0 = binary.copy()
  97. for ii,ss in enumerate([22,256,512,1024,2048]):
  98. time0=time.time()
  99. image = cv2.resize(binary0,(ss,ss))
  100. time1=time.time()
  101. if ii ==0:
  102. contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  103. else:
  104. contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  105. time2=time.time()
  106. print('size:%d resize:%.1f ,findtime:%.1f '%(ss, get_ms(time1,time0),get_ms(time2,time1)))