您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

2 年前
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. import torch
  2. from core.models.bisenet import BiSeNet,BiSeNet_MultiOutput
  3. from torchvision import transforms
  4. import cv2,os,glob
  5. import numpy as np
  6. from core.models.dinknet import DinkNet34
  7. import matplotlib.pyplot as plt
  8. import matplotlib.pyplot as plt
  9. import time
  10. class SegModel(object):
  11. def __init__(self, nclass=2,model = None,weights=None,modelsize=512,device='cuda:3',multiOutput=False):
  12. #self.args = args
  13. self.model = model
  14. #self.model = DinkNet34(nclass)
  15. checkpoint = torch.load(weights)
  16. self.modelsize = modelsize
  17. self.model.load_state_dict(checkpoint['model'])
  18. self.device = device
  19. self.multiOutput = multiOutput
  20. self.model= self.model.to(self.device)
  21. '''self.composed_transforms = transforms.Compose([
  22. transforms.Normalize(mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)),
  23. transforms.ToTensor()]) '''
  24. self.mean = (0.335, 0.358, 0.332)
  25. self.std = (0.141, 0.138, 0.143)
  26. #mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)
  27. def eval(self,image,outsize=None,smooth_kernel=0):
  28. imageH,imageW,imageC = image.shape
  29. time0 = time.time()
  30. image = self.preprocess_image(image)
  31. time1 = time.time()
  32. self.model.eval()
  33. image = image.to(self.device)
  34. with torch.no_grad():
  35. output = self.model(image,test_flag=True,smooth_kernel = 0)
  36. time2 = time.time()
  37. if self.multiOutput:
  38. pred = [outputx.data.cpu().numpy()[0] for outputx in output]
  39. else:
  40. pred = output.data.cpu().numpy()
  41. pred = pred[0]
  42. time3 = time.time()
  43. if self.multiOutput:
  44. pred = [ cv2.blur(predx,(smooth_kernel,smooth_kernel) ) for predx in pred]
  45. pred = [cv2.resize(predx.astype(np.uint8),(imageW,imageH)) for predx in pred[0:2]]
  46. else:
  47. pred = cv2.blur(pred,(smooth_kernel,smooth_kernel) )
  48. pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH),interpolation = cv2.INTER_NEAREST)
  49. time4 = time.time()
  50. print('##line52:pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ))
  51. return pred
  52. def get_ms(self,t1,t0):
  53. return (t1-t0)*1000.0
  54. def preprocess_image(self,image):
  55. time0 = time.time()
  56. image = cv2.resize(image,(self.modelsize,self.modelsize))
  57. time1 = time.time()
  58. image = image.astype(np.float32)
  59. image /= 255.0
  60. time2 = time.time()
  61. #image = image * 3.2 - 1.6
  62. image[:,:,0] -=self.mean[0]
  63. image[:,:,1] -=self.mean[1]
  64. image[:,:,2] -=self.mean[2]
  65. time3 = time.time()
  66. image[:,:,0] /= self.std[0]
  67. image[:,:,1] /= self.std[1]
  68. image[:,:,2] /= self.std[2]
  69. time4 = time.time()
  70. image = np.transpose(image, ( 2, 0, 1))
  71. time5 = time.time()
  72. image = torch.from_numpy(image).float()
  73. image = image.unsqueeze(0)
  74. print('###line84: in preprocess: resize:%.1f norm:%.1f mean:%.1f std:%.1f trans:%.f '%(self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ,self.get_ms(time5,time4) ) )
  75. return image
  76. def get_ms(t1,t0):
  77. return (t1-t0)*1000.0
  78. def test():
  79. #os.environ["CUDA_VISIBLE_DEVICES"] = str('4')
  80. '''
  81. image_url = '../../data/landcover/corp512/test/images/N-33-139-C-d-2-4_169.jpg'
  82. nclass = 5
  83. weights = 'runs/landcover/DinkNet34_save/experiment_wj_loss-10-10-1/checkpoint.pth'
  84. '''
  85. image_url = 'temp_pics/DJI_0645.JPG'
  86. nclass = 2
  87. #weights = '../weights/segmentation/BiSeNet/checkpoint.pth'
  88. weights = 'runs/THriver/BiSeNet/train/experiment_0/checkpoint.pth'
  89. #weights = 'runs/segmentation/BiSeNet_test/experiment_10/checkpoint.pth'
  90. model = BiSeNet(nclass)
  91. segmodel = SegModel(model=model,nclass=nclass,weights=weights,device='cuda:4')
  92. for i in range(10):
  93. image_array0 = cv2.imread(image_url)
  94. imageH,imageW,_ = image_array0.shape
  95. #print('###line84:',image_array0.shape)
  96. image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
  97. #image_in = segmodel.preprocess_image(image_array)
  98. pred = segmodel.eval(image_array,outsize=None)
  99. time0=time.time()
  100. binary = pred.copy()
  101. time1=time.time()
  102. contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  103. time2=time.time()
  104. print(pred.shape,' time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
  105. label_dic={'landcover':[[0, 0, 0], [255, 0, 0], [0,255,0], [0,0,255], [255,255,0]],
  106. 'deepRoad':[[0,0,0],[255,0,0]],
  107. 'water':[[0,0,0],[255,255,255]],
  108. 'water_building':[[0,0,0],[0,0,255],[255,0,0]],
  109. 'floater':[[0,0,0], [0,255,0],[255,255,0],[255,0,255],[0,128, 255], [255,0,0], [0,255,255] ]
  110. }
  111. def index2color(label_mask,label_colours):
  112. r = label_mask.copy()
  113. g = label_mask.copy()
  114. b = label_mask.copy()
  115. label_cnt = len(label_colours)
  116. for ll in range(0, label_cnt):
  117. r[label_mask == ll] = label_colours[ll][0]
  118. g[label_mask == ll] = label_colours[ll][1]
  119. b[label_mask == ll] = label_colours[ll][2]
  120. rgb = np.stack((b, g,r), axis=-1)
  121. return rgb.astype(np.uint8)
  122. def get_largest_contours(contours):
  123. areas = [cv2.contourArea(x) for x in contours]
  124. max_area = max(areas)
  125. max_id = areas.index(max_area)
  126. return max_id
  127. def result_merge_sep(image,mask_colors):
  128. #mask_colors=[{ 'mask':mask_map,'index':[1],'color':[255,255,255] }]
  129. for mask_color in mask_colors:
  130. mask_map,indexes,colors = mask_color['mask'], mask_color['index'], mask_color['color']
  131. ishow = 2
  132. #plt.figure(1);plt.imshow(mask_map);
  133. for index,color in zip(indexes,colors):
  134. mask_binaray = (mask_map == index).astype(np.uint8)
  135. contours, hierarchy = cv2.findContours(mask_binaray,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  136. if len(contours)>0:
  137. d=hierarchy[0,:,3]<0 ;
  138. contours = np.array(contours,dtype=object)[d]
  139. cv2.drawContours(image,contours,-1,color[::-1],3)
  140. #plt.figure(ishow);plt.imshow(mask_binaray);ishow+=1
  141. #plt.show()
  142. return image
  143. def result_merge(image,mask_colors):
  144. #mask_colors=[{ 'mask':mask_map,'index':[1],'color':[255,255,255] }]
  145. for mask_color in mask_colors:
  146. mask_map,indexes,colors = mask_color['mask'], mask_color['index'], mask_color['color']
  147. mask_binary = (mask_map>0).astype(np.uint8)
  148. contours, hierarchy = cv2.findContours(mask_binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  149. if len(contours)>0:
  150. d=hierarchy[0,:,3]<0 ; contours = np.array(contours)[d]
  151. cv2.drawContours(image,contours,-1,colors[0][::-1],3)
  152. coors = np.array([(np.mean(contours_x ,axis=0)+0.5).astype(np.int32)[0] for contours_x in contours])
  153. #print(mask_map.shape,coors.shape)
  154. typess = mask_map[ coors[:,1],coors[:,0]]
  155. #for jj,iclass in enumerate(typess):
  156. #print(iclass,colors)
  157. # cv2.drawContours(image,contours,-1, colors[iclass][::-1],3)
  158. return image
  159. def test_floater():
  160. from core.models.dinknet import DinkNet34_MultiOutput
  161. #create_model('DinkNet34_MultiOutput',[2,5])
  162. image_url = 'temp_pics/DJI_0645.JPG'
  163. nclass = [2,7]
  164. outresult=True
  165. weights = 'runs/thFloater/BiSeNet_MultiOutput/train/experiment_4/checkpoint.pth'
  166. model = BiSeNet_MultiOutput(nclass)
  167. outdir='temp'
  168. image_dir = '/host/workspace/WJ/data/thFloater/val/images/'
  169. image_url_list=glob.glob('%s/*'%(image_dir))
  170. segmodel = SegModel(model=model,nclass=nclass,weights=weights,device='cuda:9',multiOutput=True)
  171. for i,image_url in enumerate(image_url_list[0:10]) :
  172. image_array0 = cv2.imread(image_url)
  173. image_array0 = cv2.cvtColor(image_array0, cv2.COLOR_BGR2RGB) # cv2默认为bgr顺序
  174. imageH,imageW,_ = image_array0.shape
  175. #image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
  176. pred = segmodel.eval(image_array,outsize=None)
  177. time0=time.time()
  178. if isinstance(pred,list):
  179. binary = [predx.copy() for predx in pred]
  180. time1=time.time()
  181. mask_colors=[ { 'mask':pred[0] ,'index':range(1,2),'color':label_dic['water'][0:] },
  182. { 'mask':pred[1] ,'index':[1,2,3,4,5,6],'color':label_dic['floater'][0:] } ]
  183. result_draw = result_merge(image_array0,mask_colors)
  184. time2=time.time()
  185. if outresult:
  186. basename=os.path.splitext( os.path.basename(image_url))[0]
  187. outname=os.path.join(outdir,basename+'_draw.png')
  188. cv2.imwrite(outname,result_draw[:,:,:])
  189. print('##line151: time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
  190. def test_water_buildings():
  191. from core.models.bisenet import BiSeNet
  192. #image_url = 'temp_pics/DJI_0645.JPG'
  193. nclass = 3
  194. outresult=True
  195. weights = 'runs/thWaterBuilding/BiSeNet/train/experiment_2/checkpoint.pth'
  196. model = BiSeNet(nclass)
  197. outdir='temp'
  198. image_dir = '/home/thsw/WJ/data/river_buildings/'
  199. #image_dir = '/home/thsw/WJ/data/THWaterBuilding/val/images'
  200. image_url_list=glob.glob('%s/*'%(image_dir))
  201. segmodel = SegModel(model=model,nclass=nclass,weights=weights,device='cuda:0',multiOutput=False)
  202. for i,image_url in enumerate(image_url_list[0:]) :
  203. #image_url = '/home/thsw/WJ/data/THWaterBuilding/val/images/0anWqgmO9rGe1n8P.png'
  204. image_array0 = cv2.imread(image_url)
  205. imageH,imageW,_ = image_array0.shape
  206. image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
  207. pred = segmodel.eval(image_array,outsize=None)
  208. time0=time.time()
  209. if isinstance(pred,list):
  210. binary = [predx.copy() for predx in pred]
  211. #print(binary[0].shape)
  212. time1=time.time()
  213. mask_colors=[ { 'mask':pred ,'index':range(1,3),'color':label_dic['water_building'][1:] },
  214. #{ 'mask':pred[1] ,'index':[1,2,3,4,5,6],'color':label_dic['floater'][0:] }
  215. ]
  216. result_draw = result_merge_sep(image_array0,mask_colors)
  217. time2=time.time()
  218. if outresult:
  219. basename=os.path.splitext( os.path.basename(image_url))[0]
  220. outname=os.path.join(outdir,basename+'_draw.png')
  221. cv2.imwrite(outname,result_draw[:,:,:])
  222. print('##line294: time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
  223. def get_illegal_index(contours,hierarchy,water_dilate,overlap_threshold):
  224. out_index=[]
  225. if len(contours)>0:
  226. d=hierarchy[0,:,3]<0 ;
  227. contours = np.array(contours,dtype=object)[d]
  228. imageH,imageW = water_dilate.shape
  229. for ii,cont in enumerate(contours):
  230. build_area=np.zeros((imageH,imageW ))
  231. cv2.fillPoly(build_area,[cont[:,0,:]],1)
  232. area1=np.sum(build_area);area2=np.sum(build_area*water_dilate)
  233. if (area2/area1) >overlap_threshold:
  234. out_index.append(ii)
  235. return out_index
  236. def test_water_building_seperately():
  237. from core.models.dinknet import DinkNet34_MultiOutput
  238. #create_model('DinkNet34_MultiOutput',[2,5])
  239. image_url = 'temp_pics/DJI_0645.JPG'
  240. nclass = [2,2]
  241. outresult=True
  242. weights = 'runs/thWaterBuilding_seperate/BiSeNet_MultiOutput/train/experiment_0/checkpoint.pth'
  243. model = BiSeNet_MultiOutput(nclass)
  244. outdir='temp'
  245. image_dir = '/home/thsw/WJ/data/river_buildings/'
  246. #image_dir = '/home/thsw/WJ/data/THWaterBuilding/val/images'
  247. image_url_list=glob.glob('%s/*'%(image_dir))
  248. segmodel = SegModel(model=model,nclass=nclass,weights=weights,device='cuda:1',multiOutput=True)
  249. print('###line307 image cnt:',len(image_url_list))
  250. for i,image_url in enumerate(image_url_list[0:1]) :
  251. image_url = '/home/thsw/WJ/data/river_buildings/DJI_20210904092044_0001_S_output896.jpg'
  252. image_array0 = cv2.imread(image_url)
  253. imageH,imageW,_ = image_array0.shape
  254. image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
  255. pred = segmodel.eval(image_array,outsize=None,smooth_kernel=20)
  256. ##画出水体区域
  257. contours, hierarchy = cv2.findContours(pred[0],cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  258. max_id = get_largest_contours(contours);
  259. water = pred[0].copy(); water[:,:] = 0
  260. cv2.fillPoly(water, [contours[max_id][:,0,:]], 1)
  261. cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
  262. ##画出水体膨胀后的蓝线区域。
  263. kernel = np.ones((100,100),np.uint8)
  264. water_dilate = cv2.dilate(water,kernel,iterations = 1)
  265. contours, hierarchy = cv2.findContours(water_dilate,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  266. #print('####line310:',contours)
  267. cv2.drawContours(image_array0,contours,-1,(255,0,0),3)
  268. ###逐个建筑判断是否与蓝线内区域有交叉。如果交叉面积占本身面积超过0.1,则认为是违法建筑。
  269. contours, hierarchy = cv2.findContours(pred[1],cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  270. outIndex=get_illegal_index(contours,hierarchy,water_dilate,0.1)
  271. for ii in outIndex:
  272. cv2.drawContours(image_array0,contours,ii,(0,0,255),3)
  273. plt.imshow(image_array0);plt.show()
  274. ##
  275. time0=time.time()
  276. time1=time.time()
  277. mask_colors=[ { 'mask':pred[0],'index':[1],'color':label_dic['water_building'][1:2]},
  278. { 'mask':pred[1],'index':[1],'color':label_dic['water_building'][2:3]}
  279. ]
  280. result_draw = result_merge_sep(image_array0,mask_colors)
  281. time2=time.time()
  282. if outresult:
  283. basename=os.path.splitext( os.path.basename(image_url))[0]
  284. outname=os.path.join(outdir,basename+'_draw.png')
  285. cv2.imwrite(outname,result_draw[:,:,:])
  286. print('##line151: time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
  287. if __name__=='__main__':
  288. #test()
  289. #test_floater()
  290. #test_water_buildings()
  291. test_water_building_seperately()