用kafka接收消息
Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

2 роки тому
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import torch
  2. from core.models.bisenet import BiSeNet
  3. from torchvision import transforms
  4. import cv2,os
  5. import numpy as np
  6. from core.models.dinknet import DinkNet34
  7. import matplotlib.pyplot as plt
  8. import matplotlib.pyplot as plt
  9. import time
  10. class SegModel(object):
  11. def __init__(self, nclass=2,weights=None,modelsize=512,device='cuda:3'):
  12. #self.args = args
  13. self.model = BiSeNet(nclass)
  14. #self.model = DinkNet34(nclass)
  15. checkpoint = torch.load(weights)
  16. self.modelsize = modelsize
  17. self.model.load_state_dict(checkpoint['model'])
  18. self.device = device
  19. self.model= self.model.to(self.device)
  20. '''self.composed_transforms = transforms.Compose([
  21. transforms.Normalize(mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)),
  22. transforms.ToTensor()]) '''
  23. self.mean = (0.335, 0.358, 0.332)
  24. self.std = (0.141, 0.138, 0.143)
  25. def eval(self,image,outsize=None):
  26. imageW,imageH,imageC = image.shape
  27. time0 = time.time()
  28. image = self.preprocess_image(image)
  29. time1 = time.time()
  30. self.model.eval()
  31. image = image.to(self.device)
  32. with torch.no_grad():
  33. output = self.model(image,outsize=outsize)
  34. time2 = time.time()
  35. pred = output.data.cpu().numpy()
  36. pred = np.argmax(pred, axis=1)[0]#得到每行
  37. time3 = time.time()
  38. pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
  39. time4 = time.time()
  40. print('pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ))
  41. return pred
  42. def get_ms(self,t1,t0):
  43. return (t1-t0)*1000.0
  44. def preprocess_image(self,image):
  45. time0 = time.time()
  46. image = cv2.resize(image,(self.modelsize,self.modelsize))
  47. time1 = time.time()
  48. image = image.astype(np.float32)
  49. image /= 255.0
  50. time2 = time.time()
  51. #image -= self.mean
  52. image[:,:,0] -=self.mean[0]
  53. image[:,:,1] -=self.mean[1]
  54. image[:,:,2] -=self.mean[2]
  55. time3 = time.time()
  56. #image /= self.std
  57. image[:,:,0] /= self.std[0]
  58. image[:,:,1] /= self.std[1]
  59. image[:,:,2] /= self.std[2]
  60. time4 = time.time()
  61. image = np.transpose(image, ( 2, 0, 1))
  62. time5 = time.time()
  63. image = torch.from_numpy(image).float()
  64. image = image.unsqueeze(0)
  65. print('resize:%.1f norm:%.1f mean:%.1f std:%.1f trans:%.f '%(self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ,self.get_ms(time5,time4) ) )
  66. return image
  67. def get_ms(t1,t0):
  68. return (t1-t0)*1000.0
  69. if __name__=='__main__':
  70. #os.environ["CUDA_VISIBLE_DEVICES"] = str('4')
  71. '''
  72. image_url = '../../data/landcover/corp512/test/images/N-33-139-C-d-2-4_169.jpg'
  73. nclass = 5
  74. weights = 'runs/landcover/DinkNet34_save/experiment_wj_loss-10-10-1/checkpoint.pth'
  75. '''
  76. image_url = 'temp_pics/DJI_0645.JPG'
  77. nclass = 2
  78. #weights = '../weights/segmentation/BiSeNet/checkpoint.pth'
  79. weights = 'runs/THriver/BiSeNet/train/experiment_0/checkpoint.pth'
  80. #weights = 'runs/segmentation/BiSeNet_test/experiment_10/checkpoint.pth'
  81. segmodel = SegModel(nclass=nclass,weights=weights,device='cuda:4')
  82. for i in range(10):
  83. image_array0 = cv2.imread(image_url)
  84. imageH,imageW,_ = image_array0.shape
  85. #print('###line84:',image_array0.shape)
  86. image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
  87. #image_in = segmodel.preprocess_image(image_array)
  88. pred = segmodel.eval(image_array,outsize=None)
  89. time0=time.time()
  90. binary = pred.copy()
  91. time1=time.time()
  92. contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  93. time2=time.time()
  94. print(pred.shape,' time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
  95. ##计算findconturs时间与大小的关系
  96. binary0 = binary.copy()
  97. for ii,ss in enumerate([22,256,512,1024,2048]):
  98. time0=time.time()
  99. image = cv2.resize(binary0,(ss,ss))
  100. time1=time.time()
  101. if ii ==0:
  102. contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  103. else:
  104. contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  105. time2=time.time()
  106. print('size:%d resize:%.1f ,findtime:%.1f '%(ss, get_ms(time1,time0),get_ms(time2,time1)))