Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

segmodel_BiseNet.py 4.9KB

1 ano atrás
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import torch
  2. import sys,os
  3. sys.path.extend(['segutils'])
  4. from core.models.bisenet import BiSeNet
  5. from torchvision import transforms
  6. import cv2,glob
  7. import numpy as np
  8. from core.models.dinknet import DinkNet34
  9. import matplotlib.pyplot as plt
  10. import time
  11. class SegModel(object):
  12. def __init__(self, nclass=2,weights=None,modelsize=512,device='cuda:0'):
  13. #self.args = args
  14. self.model = BiSeNet(nclass)
  15. #self.model = DinkNet34(nclass)
  16. checkpoint = torch.load(weights)
  17. self.modelsize = modelsize
  18. self.model.load_state_dict(checkpoint['model'])
  19. self.device = device
  20. self.model= self.model.to(self.device)
  21. '''self.composed_transforms = transforms.Compose([
  22. transforms.Normalize(mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)),
  23. transforms.ToTensor()]) '''
  24. self.mean = (0.335, 0.358, 0.332)
  25. self.std = (0.141, 0.138, 0.143)
  26. def eval(self,image):
  27. time0 = time.time()
  28. imageH,imageW,imageC = image.shape
  29. image = self.preprocess_image(image)
  30. time1 = time.time()
  31. self.model.eval()
  32. image = image.to(self.device)
  33. with torch.no_grad():
  34. output = self.model(image)
  35. time2 = time.time()
  36. pred = output.data.cpu().numpy()
  37. pred = np.argmax(pred, axis=1)[0]#得到每行
  38. time3 = time.time()
  39. pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
  40. time4 = time.time()
  41. outstr= 'pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f, total:%.1f \n '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) )
  42. #print('pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f, total:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) ))
  43. return pred,outstr
  44. def get_ms(self,t1,t0):
  45. return (t1-t0)*1000.0
  46. def preprocess_image(self,image):
  47. time0 = time.time()
  48. image = cv2.resize(image,(self.modelsize,self.modelsize))
  49. time0 = time.time()
  50. image = image.astype(np.float32)
  51. image /= 255.0
  52. image[:,:,0] -=self.mean[0]
  53. image[:,:,1] -=self.mean[1]
  54. image[:,:,2] -=self.mean[2]
  55. image[:,:,0] /= self.std[0]
  56. image[:,:,1] /= self.std[1]
  57. image[:,:,2] /= self.std[2]
  58. image = cv2.cvtColor( image,cv2.COLOR_RGB2BGR)
  59. #image -= self.mean
  60. #image /= self.std
  61. image = np.transpose(image, ( 2, 0, 1))
  62. image = torch.from_numpy(image).float()
  63. image = image.unsqueeze(0)
  64. return image
  65. def get_ms(t1,t0):
  66. return (t1-t0)*1000.0
  67. def get_largest_contours(contours):
  68. areas = [cv2.contourArea(x) for x in contours]
  69. max_area = max(areas)
  70. max_id = areas.index(max_area)
  71. return max_id
  72. if __name__=='__main__':
  73. image_url = '/home/thsw2/WJ/data/THexit/val/images/DJI_0645.JPG'
  74. nclass = 2
  75. #weights = '../weights/segmentation/BiSeNet/checkpoint.pth'
  76. weights = '../weights/BiSeNet/checkpoint.pth'
  77. segmodel = SegModel(nclass=nclass,weights=weights)
  78. image_urls=glob.glob('../../river_demo/images/slope/*')
  79. out_dir ='../../river_demo/images/results/';
  80. os.makedirs(out_dir,exist_ok=True)
  81. for im,image_url in enumerate(image_urls[0:]):
  82. #image_url = '/home/thsw2/WJ/data/THexit/val/images/54(199).JPG'
  83. image_array0 = cv2.imread(image_url)
  84. H,W,C = image_array0.shape
  85. time_1=time.time()
  86. pred,outstr = segmodel.eval(image_array0 )
  87. #plt.figure(1);plt.imshow(pred);
  88. #plt.show()
  89. binary0 = pred.copy()
  90. time0 = time.time()
  91. contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  92. max_id = -1
  93. if len(contours)>0:
  94. max_id = get_largest_contours(contours)
  95. binary0[:,:] = 0
  96. cv2.fillPoly(binary0, [contours[max_id][:,0,:]], 1)
  97. time1 = time.time()
  98. time2 = time.time()
  99. cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
  100. time3 = time.time()
  101. out_url='%s/%s'%(out_dir,os.path.basename(image_url))
  102. ret = cv2.imwrite(out_url,image_array0)
  103. time4 = time.time()
  104. print('image:%d,%s ,%d*%d,eval:%.1f ms, %s,findcontours:%.1f ms,draw:%.1f total:%.1f'%(im,os.path.basename(image_url),H,W,get_ms(time0,time_1),outstr,get_ms(time1,time0), get_ms(time3,time2),get_ms(time3,time_1)) )
  105. #print(outstr)
  106. #plt.figure(0);plt.imshow(pred)
  107. #plt.figure(1);plt.imshow(image_array0)
  108. #plt.figure(2);plt.imshow(binary0)
  109. #plt.show()
  110. #print(out_url,ret)