You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

131 lines
4.3KB

  1. import torch
  2. import sys,os
  3. sys.path.extend(['../AIlib2/segutils'])
  4. from model_stages import BiSeNet_STDC
  5. from torchvision import transforms
  6. import cv2,glob
  7. import numpy as np
  8. from core.models.dinknet import DinkNet34
  9. import matplotlib.pyplot as plt
  10. import time
  11. class SegModel(object):
  12. def __init__(self, nclass=2,weights=None,modelsize=512,device='cuda:0'):
  13. #self.args = args
  14. self.model = BiSeNet_STDC(backbone='STDCNet813', n_classes=nclass,
  15. use_boundary_2=False, use_boundary_4=False,
  16. use_boundary_8=True, use_boundary_16=False,
  17. use_conv_last=False)
  18. self.device = device
  19. self.model.load_state_dict(torch.load(weights, map_location=torch.device(self.device) ))
  20. self.model= self.model.to(self.device)
  21. self.mean = (0.485, 0.456, 0.406)
  22. self.std = (0.229, 0.224, 0.225)
  23. def eval(self,image):
  24. time0 = time.time()
  25. imageH, imageW, _ = image.shape
  26. image = self.RB_convert(image)
  27. img = self.preprocess_image(image)
  28. if self.device != 'cpu':
  29. imgs = img.to(self.device)
  30. else:imgs=img
  31. time1 = time.time()
  32. self.model.eval()
  33. with torch.no_grad():
  34. output = self.model(imgs)
  35. time2 = time.time()
  36. pred = output.data.cpu().numpy()
  37. pred = np.argmax(pred, axis=1)[0]#得到每行
  38. time3 = time.time()
  39. pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH))
  40. time4 = time.time()
  41. outstr= 'pre-precess:%.1f ,infer:%.1f ,post-cpu-argmax:%.1f ,post-resize:%.1f, total:%.1f \n '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3),self.get_ms(time4,time0) )
  42. return pred,outstr
  43. def get_ms(self,t1,t0):
  44. return (t1-t0)*1000.0
  45. def preprocess_image(self,image):
  46. image = cv2.resize(image, (640,360), interpolation=cv2.INTER_LINEAR)
  47. image = image.astype(np.float32)
  48. image /= 255.0
  49. image[:, :, 0] -= self.mean[0]
  50. image[:, :, 1] -= self.mean[1]
  51. image[:, :, 2] -= self.mean[2]
  52. image[:, :, 0] /= self.std[0]
  53. image[:, :, 1] /= self.std[1]
  54. image[:, :, 2] /= self.std[2]
  55. image = np.transpose(image, (2, 0, 1))
  56. image = torch.from_numpy(image).float()
  57. image = image.unsqueeze(0)
  58. return image
  59. def RB_convert(self,image):
  60. image_c = image.copy()
  61. image_c[:,:,0] = image[:,:,2]
  62. image_c[:,:,2] = image[:,:,0]
  63. return image_c
  64. def get_ms(t1,t0):
  65. return (t1-t0)*1000.0
  66. def get_largest_contours(contours):
  67. areas = [cv2.contourArea(x) for x in contours]
  68. max_area = max(areas)
  69. max_id = areas.index(max_area)
  70. return max_id
  71. if __name__=='__main__':
  72. impth = '../../river_demo/images/slope/'
  73. outpth= 'results'
  74. folders = os.listdir(impth)
  75. weights = '../weights/STDC/model_maxmIOU75_1720_0.946_360640.pth'
  76. segmodel = SegModel(nclass=2,weights=weights)
  77. for i in range(len(folders)):
  78. imgpath = os.path.join(impth, folders[i])
  79. time0 = time.time()
  80. #img = Image.open(imgpath).convert('RGB')
  81. img = cv2.imread(imgpath)
  82. img = np.array(img)
  83. time1 = time.time()
  84. pred, outstr = segmodel.eval(image=img)#####
  85. time2 = time.time()
  86. binary0 = pred.copy()
  87. contours, hierarchy = cv2.findContours(binary0,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  88. time3 = time.time()
  89. max_id = -1
  90. if len(contours)>0:
  91. max_id = get_largest_contours(contours)
  92. binary0[:,:] = 0
  93. cv2.fillPoly(binary0, [contours[max_id][:,0,:]], 1)
  94. cv2.drawContours(img,contours,max_id,(0,255,255),3)
  95. time4 = time.time()
  96. #img_n = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
  97. cv2.imwrite( os.path.join( outpth,folders[i] ) ,img )
  98. time5 = time.time()
  99. print('image:%d ,infer:%.1f ms,findcontours:%.1f ms, draw:%.1f, total:%.1f'%(i,get_ms(time2,time1),get_ms(time3,time2),get_ms(time4,time3),get_ms(time4,time1)))