You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

segMultiOutModel.py 15KB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. import torch
  2. from core.models.bisenet import BiSeNet,BiSeNet_MultiOutput
  3. from torchvision import transforms
  4. import cv2,os,glob
  5. import numpy as np
  6. from core.models.dinknet import DinkNet34
  7. import matplotlib.pyplot as plt
  8. import matplotlib.pyplot as plt
  9. import time
  10. class SegModel(object):
  11. def __init__(self, nclass=2,model = None,weights=None,modelsize=512,device='cuda:3',multiOutput=False):
  12. #self.args = args
  13. self.model = model
  14. #self.model = DinkNet34(nclass)
  15. checkpoint = torch.load(weights)
  16. self.modelsize = modelsize
  17. self.model.load_state_dict(checkpoint['model'])
  18. self.device = device
  19. self.multiOutput = multiOutput
  20. self.model= self.model.to(self.device)
  21. '''self.composed_transforms = transforms.Compose([
  22. transforms.Normalize(mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)),
  23. transforms.ToTensor()]) '''
  24. self.mean = (0.335, 0.358, 0.332)
  25. self.std = (0.141, 0.138, 0.143)
  26. #mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)
  27. def eval(self,image,outsize=None,smooth_kernel=0):
  28. imageH,imageW,imageC = image.shape
  29. time0 = time.time()
  30. image = self.preprocess_image(image)
  31. time1 = time.time()
  32. self.model.eval()
  33. image = image.to(self.device)
  34. with torch.no_grad():
  35. output = self.model(image,test_flag=True,smooth_kernel = 0)
  36. time2 = time.time()
  37. if self.multiOutput:
  38. pred = [outputx.data.cpu().numpy()[0] for outputx in output]
  39. else:
  40. pred = output.data.cpu().numpy()
  41. pred = pred[0]
  42. time3 = time.time()
  43. if self.multiOutput:
  44. pred = [ cv2.blur(predx,(smooth_kernel,smooth_kernel) ) for predx in pred]
  45. pred = [cv2.resize(predx.astype(np.uint8),(imageW,imageH)) for predx in pred[0:2]]
  46. else:
  47. pred = cv2.blur(pred,(smooth_kernel,smooth_kernel) )
  48. pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH),interpolation = cv2.INTER_NEAREST)
  49. time4 = time.time()
  50. print('##line52:pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ))
  51. return pred
  52. def get_ms(self,t1,t0):
  53. return (t1-t0)*1000.0
  54. def preprocess_image(self,image):
  55. time0 = time.time()
  56. image = cv2.resize(image,(self.modelsize,self.modelsize))
  57. time1 = time.time()
  58. image = image.astype(np.float32)
  59. image /= 255.0
  60. time2 = time.time()
  61. #image = image * 3.2 - 1.6
  62. image[:,:,0] -=self.mean[0]
  63. image[:,:,1] -=self.mean[1]
  64. image[:,:,2] -=self.mean[2]
  65. time3 = time.time()
  66. image[:,:,0] /= self.std[0]
  67. image[:,:,1] /= self.std[1]
  68. image[:,:,2] /= self.std[2]
  69. time4 = time.time()
  70. image = np.transpose(image, ( 2, 0, 1))
  71. time5 = time.time()
  72. image = torch.from_numpy(image).float()
  73. image = image.unsqueeze(0)
  74. print('###line84: in preprocess: resize:%.1f norm:%.1f mean:%.1f std:%.1f trans:%.f '%(self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ,self.get_ms(time5,time4) ) )
  75. return image
  76. def get_ms(t1,t0):
  77. return (t1-t0)*1000.0
  78. def test():
  79. #os.environ["CUDA_VISIBLE_DEVICES"] = str('4')
  80. '''
  81. image_url = '../../data/landcover/corp512/test/images/N-33-139-C-d-2-4_169.jpg'
  82. nclass = 5
  83. weights = 'runs/landcover/DinkNet34_save/experiment_wj_loss-10-10-1/checkpoint.pth'
  84. '''
  85. image_url = 'temp_pics/DJI_0645.JPG'
  86. nclass = 2
  87. #weights = '../weights/segmentation/BiSeNet/checkpoint.pth'
  88. weights = 'runs/THriver/BiSeNet/train/experiment_0/checkpoint.pth'
  89. #weights = 'runs/segmentation/BiSeNet_test/experiment_10/checkpoint.pth'
  90. model = BiSeNet(nclass)
  91. segmodel = SegModel(model=model,nclass=nclass,weights=weights,device='cuda:4')
  92. for i in range(10):
  93. image_array0 = cv2.imread(image_url)
  94. imageH,imageW,_ = image_array0.shape
  95. #print('###line84:',image_array0.shape)
  96. image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
  97. #image_in = segmodel.preprocess_image(image_array)
  98. pred = segmodel.eval(image_array,outsize=None)
  99. time0=time.time()
  100. binary = pred.copy()
  101. time1=time.time()
  102. contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  103. time2=time.time()
  104. print(pred.shape,' time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
  105. label_dic={'landcover':[[0, 0, 0], [255, 0, 0], [0,255,0], [0,0,255], [255,255,0]],
  106. 'deepRoad':[[0,0,0],[255,0,0]],
  107. 'water':[[0,0,0],[255,255,255]],
  108. 'water_building':[[0,0,0],[0,0,255],[255,0,0]],
  109. 'floater':[[0,0,0], [0,255,0],[255,255,0],[255,0,255],[0,128, 255], [255,0,0], [0,255,255] ]
  110. }
  111. def index2color(label_mask,label_colours):
  112. r = label_mask.copy()
  113. g = label_mask.copy()
  114. b = label_mask.copy()
  115. label_cnt = len(label_colours)
  116. for ll in range(0, label_cnt):
  117. r[label_mask == ll] = label_colours[ll][0]
  118. g[label_mask == ll] = label_colours[ll][1]
  119. b[label_mask == ll] = label_colours[ll][2]
  120. rgb = np.stack((b, g,r), axis=-1)
  121. return rgb.astype(np.uint8)
  122. def get_largest_contours(contours):
  123. areas = [cv2.contourArea(x) for x in contours]
  124. max_area = max(areas)
  125. max_id = areas.index(max_area)
  126. return max_id
  127. def result_merge_sep(image,mask_colors):
  128. #mask_colors=[{ 'mask':mask_map,'index':[1],'color':[255,255,255] }]
  129. for mask_color in mask_colors:
  130. mask_map,indexes,colors = mask_color['mask'], mask_color['index'], mask_color['color']
  131. ishow = 2
  132. #plt.figure(1);plt.imshow(mask_map);
  133. for index,color in zip(indexes,colors):
  134. mask_binaray = (mask_map == index).astype(np.uint8)
  135. contours, hierarchy = cv2.findContours(mask_binaray,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  136. if len(contours)>0:
  137. d=hierarchy[0,:,3]<0 ;
  138. contours = np.array(contours,dtype=object)[d]
  139. cv2.drawContours(image,contours,-1,color[::-1],3)
  140. #plt.figure(ishow);plt.imshow(mask_binaray);ishow+=1
  141. #plt.show()
  142. return image
  143. def result_merge(image,mask_colors):
  144. #mask_colors=[{ 'mask':mask_map,'index':[1],'color':[255,255,255] }]
  145. for mask_color in mask_colors:
  146. mask_map,indexes,colors = mask_color['mask'], mask_color['index'], mask_color['color']
  147. mask_binary = (mask_map>0).astype(np.uint8)
  148. contours, hierarchy = cv2.findContours(mask_binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  149. if len(contours)>0:
  150. d=hierarchy[0,:,3]<0 ; contours = np.array(contours)[d]
  151. cv2.drawContours(image,contours,-1,colors[0][::-1],3)
  152. coors = np.array([(np.mean(contours_x ,axis=0)+0.5).astype(np.int32)[0] for contours_x in contours])
  153. #print(mask_map.shape,coors.shape)
  154. typess = mask_map[ coors[:,1],coors[:,0]]
  155. #for jj,iclass in enumerate(typess):
  156. #print(iclass,colors)
  157. # cv2.drawContours(image,contours,-1, colors[iclass][::-1],3)
  158. return image
  159. def test_floater():
  160. from core.models.dinknet import DinkNet34_MultiOutput
  161. #create_model('DinkNet34_MultiOutput',[2,5])
  162. image_url = 'temp_pics/DJI_0645.JPG'
  163. nclass = [2,7]
  164. outresult=True
  165. weights = 'runs/thFloater/BiSeNet_MultiOutput/train/experiment_4/checkpoint.pth'
  166. model = BiSeNet_MultiOutput(nclass)
  167. outdir='temp'
  168. image_dir = '/host/workspace/WJ/data/thFloater/val/images/'
  169. image_url_list=glob.glob('%s/*'%(image_dir))
  170. segmodel = SegModel(model=model,nclass=nclass,weights=weights,device='cuda:9',multiOutput=True)
  171. for i,image_url in enumerate(image_url_list[0:10]) :
  172. image_array0 = cv2.imread(image_url)
  173. image_array0 = cv2.cvtColor(image_array0, cv2.COLOR_BGR2RGB) # cv2默认为bgr顺序
  174. imageH,imageW,_ = image_array0.shape
  175. #image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
  176. pred = segmodel.eval(image_array,outsize=None)
  177. time0=time.time()
  178. if isinstance(pred,list):
  179. binary = [predx.copy() for predx in pred]
  180. time1=time.time()
  181. mask_colors=[ { 'mask':pred[0] ,'index':range(1,2),'color':label_dic['water'][0:] },
  182. { 'mask':pred[1] ,'index':[1,2,3,4,5,6],'color':label_dic['floater'][0:] } ]
  183. result_draw = result_merge(image_array0,mask_colors)
  184. time2=time.time()
  185. if outresult:
  186. basename=os.path.splitext( os.path.basename(image_url))[0]
  187. outname=os.path.join(outdir,basename+'_draw.png')
  188. cv2.imwrite(outname,result_draw[:,:,:])
  189. print('##line151: time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
  190. def test_water_buildings():
  191. from core.models.bisenet import BiSeNet
  192. #image_url = 'temp_pics/DJI_0645.JPG'
  193. nclass = 3
  194. outresult=True
  195. weights = 'runs/thWaterBuilding/BiSeNet/train/experiment_2/checkpoint.pth'
  196. model = BiSeNet(nclass)
  197. outdir='temp'
  198. image_dir = '/home/thsw/WJ/data/river_buildings/'
  199. #image_dir = '/home/thsw/WJ/data/THWaterBuilding/val/images'
  200. image_url_list=glob.glob('%s/*'%(image_dir))
  201. segmodel = SegModel(model=model,nclass=nclass,weights=weights,device='cuda:0',multiOutput=False)
  202. for i,image_url in enumerate(image_url_list[0:]) :
  203. #image_url = '/home/thsw/WJ/data/THWaterBuilding/val/images/0anWqgmO9rGe1n8P.png'
  204. image_array0 = cv2.imread(image_url)
  205. imageH,imageW,_ = image_array0.shape
  206. image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
  207. pred = segmodel.eval(image_array,outsize=None)
  208. time0=time.time()
  209. if isinstance(pred,list):
  210. binary = [predx.copy() for predx in pred]
  211. #print(binary[0].shape)
  212. time1=time.time()
  213. mask_colors=[ { 'mask':pred ,'index':range(1,3),'color':label_dic['water_building'][1:] },
  214. #{ 'mask':pred[1] ,'index':[1,2,3,4,5,6],'color':label_dic['floater'][0:] }
  215. ]
  216. result_draw = result_merge_sep(image_array0,mask_colors)
  217. time2=time.time()
  218. if outresult:
  219. basename=os.path.splitext( os.path.basename(image_url))[0]
  220. outname=os.path.join(outdir,basename+'_draw.png')
  221. cv2.imwrite(outname,result_draw[:,:,:])
  222. print('##line294: time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
  223. def get_illegal_index(contours,hierarchy,water_dilate,overlap_threshold):
  224. out_index=[]
  225. if len(contours)>0:
  226. d=hierarchy[0,:,3]<0 ;
  227. contours = np.array(contours,dtype=object)[d]
  228. imageH,imageW = water_dilate.shape
  229. for ii,cont in enumerate(contours):
  230. build_area=np.zeros((imageH,imageW ))
  231. cv2.fillPoly(build_area,[cont[:,0,:]],1)
  232. area1=np.sum(build_area);area2=np.sum(build_area*water_dilate)
  233. if (area2/area1) >overlap_threshold:
  234. out_index.append(ii)
  235. return out_index
  236. def test_water_building_seperately():
  237. from core.models.dinknet import DinkNet34_MultiOutput
  238. #create_model('DinkNet34_MultiOutput',[2,5])
  239. image_url = 'temp_pics/DJI_0645.JPG'
  240. nclass = [2,2]
  241. outresult=True
  242. weights = 'runs/thWaterBuilding_seperate/BiSeNet_MultiOutput/train/experiment_0/checkpoint.pth'
  243. model = BiSeNet_MultiOutput(nclass)
  244. outdir='temp'
  245. image_dir = '/home/thsw/WJ/data/river_buildings/'
  246. #image_dir = '/home/thsw/WJ/data/THWaterBuilding/val/images'
  247. image_url_list=glob.glob('%s/*'%(image_dir))
  248. segmodel = SegModel(model=model,nclass=nclass,weights=weights,device='cuda:1',multiOutput=True)
  249. print('###line307 image cnt:',len(image_url_list))
  250. for i,image_url in enumerate(image_url_list[0:1]) :
  251. image_url = '/home/thsw/WJ/data/river_buildings/DJI_20210904092044_0001_S_output896.jpg'
  252. image_array0 = cv2.imread(image_url)
  253. imageH,imageW,_ = image_array0.shape
  254. image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
  255. pred = segmodel.eval(image_array,outsize=None,smooth_kernel=20)
  256. ##画出水体区域
  257. contours, hierarchy = cv2.findContours(pred[0],cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  258. max_id = get_largest_contours(contours);
  259. water = pred[0].copy(); water[:,:] = 0
  260. cv2.fillPoly(water, [contours[max_id][:,0,:]], 1)
  261. cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
  262. ##画出水体膨胀后的蓝线区域。
  263. kernel = np.ones((100,100),np.uint8)
  264. water_dilate = cv2.dilate(water,kernel,iterations = 1)
  265. contours, hierarchy = cv2.findContours(water_dilate,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  266. #print('####line310:',contours)
  267. cv2.drawContours(image_array0,contours,-1,(255,0,0),3)
  268. ###逐个建筑判断是否与蓝线内区域有交叉。如果交叉面积占本身面积超过0.1,则认为是违法建筑。
  269. contours, hierarchy = cv2.findContours(pred[1],cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  270. outIndex=get_illegal_index(contours,hierarchy,water_dilate,0.1)
  271. for ii in outIndex:
  272. cv2.drawContours(image_array0,contours,ii,(0,0,255),3)
  273. plt.imshow(image_array0);plt.show()
  274. ##
  275. time0=time.time()
  276. time1=time.time()
  277. mask_colors=[ { 'mask':pred[0],'index':[1],'color':label_dic['water_building'][1:2]},
  278. { 'mask':pred[1],'index':[1],'color':label_dic['water_building'][2:3]}
  279. ]
  280. result_draw = result_merge_sep(image_array0,mask_colors)
  281. time2=time.time()
  282. if outresult:
  283. basename=os.path.splitext( os.path.basename(image_url))[0]
  284. outname=os.path.join(outdir,basename+'_draw.png')
  285. cv2.imwrite(outname,result_draw[:,:,:])
  286. print('##line151: time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
  287. if __name__=='__main__':
  288. #test()
  289. #test_floater()
  290. #test_water_buildings()
  291. test_water_building_seperately()