You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

165 lines
5.9KB

  1. import torch
  2. import sys,os
  3. sys.path.extend(['../AIlib2/segutils'])
  4. from model_stages import BiSeNet_STDC
  5. from torchvision import transforms
  6. import cv2,glob
  7. import numpy as np
  8. from core.models.dinknet import DinkNet34
  9. import matplotlib.pyplot as plt
  10. import time
  11. from PIL import Image
  12. import torch.nn.functional as F
  13. import torchvision.transforms as transforms
  14. class SegModel(object):
  15. def __init__(self, nclass=2,weights=None,modelsize=512,device='cuda:0'):
  16. #self.args = args
  17. self.model = BiSeNet_STDC(backbone='STDCNet813', n_classes=nclass,
  18. use_boundary_2=False, use_boundary_4=False,
  19. use_boundary_8=True, use_boundary_16=False,
  20. use_conv_last=False)
  21. self.device = device
  22. self.model.load_state_dict(torch.load(weights, map_location=torch.device(self.device) ))
  23. self.model= self.model.to(self.device)
  24. self.mean = (0.485, 0.456, 0.406)
  25. self.std = (0.229, 0.224, 0.225)
  26. def eval(self,image):
  27. time0 = time.time()
  28. imageH, imageW, _ = image.shape
  29. image = self.RB_convert(image)
  30. #print('###line28: image:',image[10:12,10:12,0])
  31. img = self.preprocess_image(image)
  32. if self.device != 'cpu':
  33. imgs = img.to(self.device)
  34. else:imgs=img
  35. time1 = time.time()
  36. self.model.eval()
  37. with torch.no_grad():
  38. #print('#### segmodel.py line35:',len(imgs),imgs[0].shape , imgs[0][0,10:12,10:12])
  39. output = self.model(imgs)
  40. time2 = time.time()
  41. pred = output.data.cpu().numpy()
  42. pred = np.argmax(pred, axis=1)[0]#得到每行
  43. time3 = time.time()
  44. pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
  45. time4 = time.time()
  46. outstr= 'pre-precess:%.1f ,infer:%.1f ,post-cpu-argmax:%.1f ,post-resize:%.1f, total:%.1f \n '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) )
  47. return pred,outstr
  48. def eval_zyy(self,image):###此函数采用的预处理方法,和zyy跑出来的结果一致
  49. self.to_tensor = transforms.Compose([
  50. transforms.ToTensor(),
  51. transforms.Normalize(self.mean, self.std),
  52. ])
  53. time0 = time.time()
  54. imageH, imageW, _ = image.shape
  55. imgs= Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
  56. imgs = self.to_tensor(imgs)
  57. if self.device != 'cpu':
  58. imgs = imgs.to(self.device)
  59. imgs = torch.unsqueeze(imgs, dim=0)
  60. imgs = F.interpolate(imgs, [ 360,640 ], mode='bilinear', align_corners=True)
  61. time1 = time.time()
  62. self.model.eval()
  63. with torch.no_grad():
  64. print('###line 64 img:',imgs[0].shape, imgs[0][0,10:12,10:12])
  65. output = self.model(imgs)
  66. print('###line69:',output.size())
  67. time2 = time.time()
  68. pred = output.data.cpu().numpy()
  69. pred = np.argmax(pred, axis=1)[0]#得到每行
  70. time3 = time.time()
  71. pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
  72. time4 = time.time()
  73. outstr= 'pre-precess:%.1f ,infer:%.1f ,post-cpu-argmax:%.1f ,post-resize:%.1f, total:%.1f \n '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) )
  74. print('####line78:',pred.shape,np.max(pred),np.min(pred))
  75. return pred,outstr
  76. def get_ms(self,t1,t0):
  77. return (t1-t0)*1000.0
  78. def preprocess_image(self,image):
  79. image = cv2.resize(image, (640,360), interpolation=cv2.INTER_LINEAR)
  80. image = image.astype(np.float32)
  81. image /= 255.0
  82. image[:, :, 0] -= self.mean[0]
  83. image[:, :, 1] -= self.mean[1]
  84. image[:, :, 2] -= self.mean[2]
  85. image[:, :, 0] /= self.std[0]
  86. image[:, :, 1] /= self.std[1]
  87. image[:, :, 2] /= self.std[2]
  88. image = np.transpose(image, (2, 0, 1))
  89. image = torch.from_numpy(image).float()
  90. image = image.unsqueeze(0)
  91. return image
  92. def RB_convert(self,image):
  93. image_c = image.copy()
  94. image_c[:,:,0] = image[:,:,2]
  95. image_c[:,:,2] = image[:,:,0]
  96. return image_c
  97. def get_ms(t1,t0):
  98. return (t1-t0)*1000.0
  99. def get_largest_contours(contours):
  100. areas = [cv2.contourArea(x) for x in contours]
  101. max_area = max(areas)
  102. max_id = areas.index(max_area)
  103. return max_id
  104. if __name__=='__main__':
  105. impth = '../../river_demo/images/slope/'
  106. outpth= 'results'
  107. folders = os.listdir(impth)
  108. weights = '../weights/STDC/model_maxmIOU75_1720_0.946_360640.pth'
  109. segmodel = SegModel(nclass=2,weights=weights)
  110. for i in range(len(folders)):
  111. imgpath = os.path.join(impth, folders[i])
  112. time0 = time.time()
  113. #img = Image.open(imgpath).convert('RGB')
  114. img = cv2.imread(imgpath)
  115. img = np.array(img)
  116. time1 = time.time()
  117. pred, outstr = segmodel.eval(image=img)#####
  118. time2 = time.time()
  119. binary0 = pred.copy()
  120. contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  121. time3 = time.time()
  122. max_id = -1
  123. if len(contours)>0:
  124. max_id = get_largest_contours(contours)
  125. binary0[:,:] = 0
  126. cv2.fillPoly(binary0, [contours[max_id][:,0,:]], 1)
  127. cv2.drawContours(img,contours,max_id,(0,255,255),3)
  128. time4 = time.time()
  129. #img_n = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
  130. cv2.imwrite( os.path.join( outpth,folders[i] ) ,img )
  131. time5 = time.time()
  132. print('image:%d ,infer:%.1f ms,findcontours:%.1f ms, draw:%.1f, total:%.1f'%(i,get_ms(time2,time1),get_ms(time3,time2),get_ms(time4,time3),get_ms(time4,time1)))