用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

389 lines
16KB

  1. import torch
  2. from core.models.bisenet import BiSeNet,BiSeNet_MultiOutput
  3. from torchvision import transforms
  4. import cv2,os,glob
  5. import numpy as np
  6. from core.models.dinknet import DinkNet34
  7. import matplotlib.pyplot as plt
  8. import time
  9. class SegModel(object):
  10. def __init__(self, nclass=2,model = None,weights=None,modelsize=512,device='cuda:3',multiOutput=False):
  11. #self.args = args
  12. self.model = model
  13. #self.model = DinkNet34(nclass)
  14. checkpoint = torch.load(weights)
  15. self.modelsize = modelsize
  16. self.model.load_state_dict(checkpoint['model'])
  17. self.device = device
  18. self.multiOutput = multiOutput
  19. self.model= self.model.to(self.device)
  20. '''self.composed_transforms = transforms.Compose([
  21. transforms.Normalize(mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)),
  22. transforms.ToTensor()]) '''
  23. self.mean = (0.335, 0.358, 0.332)
  24. self.std = (0.141, 0.138, 0.143)
  25. #mean=(0.335, 0.358, 0.332), std=(0.141, 0.138, 0.143)
  26. def eval(self,image,outsize=None,smooth_kernel=0):
  27. imageH,imageW,imageC = image.shape
  28. time0 = time.time()
  29. image = self.preprocess_image(image)
  30. time1 = time.time()
  31. self.model.eval()
  32. image = image.to(self.device)
  33. with torch.no_grad():
  34. output = self.model(image,test_flag=True,smooth_kernel = 0)
  35. time2 = time.time()
  36. if self.multiOutput:
  37. pred = [outputx.data.cpu().numpy()[0] for outputx in output]
  38. else:
  39. pred = output.data.cpu().numpy()
  40. pred = pred[0]
  41. time3 = time.time()
  42. if self.multiOutput:
  43. pred = [ cv2.blur(predx,(smooth_kernel,smooth_kernel) ) for predx in pred]
  44. pred = [cv2.resize(predx.astype(np.uint8),(imageW,imageH)) for predx in pred[0:2]]
  45. else:
  46. pred = cv2.blur(pred,(smooth_kernel,smooth_kernel) )
  47. pred = cv2.resize(pred.astype(np.uint8),(imageW,imageH),interpolation = cv2.INTER_NEAREST)
  48. time4 = time.time()
  49. outStr= '##line52:pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) )
  50. #print('##line52:pre-precess:%.1f ,infer:%.1f ,post-precess:%.1f ,post-resize:%.1f '%( self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ))
  51. return pred
  52. def get_ms(self,t1,t0):
  53. return (t1-t0)*1000.0
  54. def preprocess_image(self,image):
  55. time0 = time.time()
  56. image = cv2.resize(image,(self.modelsize,self.modelsize))
  57. time1 = time.time()
  58. image = image.astype(np.float32)
  59. image /= 255.0
  60. time2 = time.time()
  61. #image = image * 3.2 - 1.6
  62. image[:,:,0] -=self.mean[0]
  63. image[:,:,1] -=self.mean[1]
  64. image[:,:,2] -=self.mean[2]
  65. time3 = time.time()
  66. image[:,:,0] /= self.std[0]
  67. image[:,:,1] /= self.std[1]
  68. image[:,:,2] /= self.std[2]
  69. time4 = time.time()
  70. image = np.transpose(image, ( 2, 0, 1))
  71. time5 = time.time()
  72. image = torch.from_numpy(image).float()
  73. image = image.unsqueeze(0)
  74. outStr='###line84: in preprocess: resize:%.1f norm:%.1f mean:%.1f std:%.1f trans:%.f '%(self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ,self.get_ms(time5,time4) )
  75. #print('###line84: in preprocess: resize:%.1f norm:%.1f mean:%.1f std:%.1f trans:%.f '%(self.get_ms(time1,time0),self.get_ms(time2,time1),self.get_ms(time3,time2),self.get_ms(time4,time3) ,self.get_ms(time5,time4) ) )
  76. return image
  77. def get_ms(t1,t0):
  78. return (t1-t0)*1000.0
  79. def test():
  80. #os.environ["CUDA_VISIBLE_DEVICES"] = str('4')
  81. '''
  82. image_url = '../../data/landcover/corp512/test/images/N-33-139-C-d-2-4_169.jpg'
  83. nclass = 5
  84. weights = 'runs/landcover/DinkNet34_save/experiment_wj_loss-10-10-1/checkpoint.pth'
  85. '''
  86. image_url = 'temp_pics/DJI_0645.JPG'
  87. nclass = 2
  88. #weights = '../weights/segmentation/BiSeNet/checkpoint.pth'
  89. weights = 'runs/THriver/BiSeNet/train/experiment_0/checkpoint.pth'
  90. #weights = 'runs/segmentation/BiSeNet_test/experiment_10/checkpoint.pth'
  91. model = BiSeNet(nclass)
  92. segmodel = SegModel(model=model,nclass=nclass,weights=weights,device='cuda:4')
  93. for i in range(10):
  94. image_array0 = cv2.imread(image_url)
  95. imageH,imageW,_ = image_array0.shape
  96. #print('###line84:',image_array0.shape)
  97. image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
  98. #image_in = segmodel.preprocess_image(image_array)
  99. pred = segmodel.eval(image_array,outsize=None)
  100. time0=time.time()
  101. binary = pred.copy()
  102. time1=time.time()
  103. contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  104. time2=time.time()
  105. print(pred.shape,' time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
  106. label_dic={'landcover':[[0, 0, 0], [255, 0, 0], [0,255,0], [0,0,255], [255,255,0]],
  107. 'deepRoad':[[0,0,0],[255,0,0]],
  108. 'water':[[0,0,0],[255,255,255]],
  109. 'water_building':[[0,0,0],[0,0,255],[255,0,0]],
  110. 'floater':[[0,0,0], [0,255,0],[255,255,0],[255,0,255],[0,128, 255], [255,0,0], [0,255,255] ]
  111. }
  112. def index2color(label_mask,label_colours):
  113. r = label_mask.copy()
  114. g = label_mask.copy()
  115. b = label_mask.copy()
  116. label_cnt = len(label_colours)
  117. for ll in range(0, label_cnt):
  118. r[label_mask == ll] = label_colours[ll][0]
  119. g[label_mask == ll] = label_colours[ll][1]
  120. b[label_mask == ll] = label_colours[ll][2]
  121. rgb = np.stack((b, g,r), axis=-1)
  122. return rgb.astype(np.uint8)
  123. def get_largest_contours(contours):
  124. areas = [cv2.contourArea(x) for x in contours]
  125. max_area = max(areas)
  126. max_id = areas.index(max_area)
  127. return max_id
  128. def result_merge_sep(image,mask_colors):
  129. #mask_colors=[{ 'mask':mask_map,'index':[1],'color':[255,255,255] }]
  130. for mask_color in mask_colors:
  131. mask_map,indexes,colors = mask_color['mask'], mask_color['index'], mask_color['color']
  132. ishow = 2
  133. #plt.figure(1);plt.imshow(mask_map);
  134. for index,color in zip(indexes,colors):
  135. mask_binaray = (mask_map == index).astype(np.uint8)
  136. contours, hierarchy = cv2.findContours(mask_binaray,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  137. if len(contours)>0:
  138. d=hierarchy[0,:,3]<0 ;
  139. contours = np.array(contours,dtype=object)[d]
  140. cv2.drawContours(image,contours,-1,color[::-1],3)
  141. #plt.figure(ishow);plt.imshow(mask_binaray);ishow+=1
  142. #plt.show()
  143. return image
  144. def result_merge(image,mask_colors):
  145. #mask_colors=[{ 'mask':mask_map,'index':[1],'color':[255,255,255] }]
  146. for mask_color in mask_colors:
  147. mask_map,indexes,colors = mask_color['mask'], mask_color['index'], mask_color['color']
  148. mask_binary = (mask_map>0).astype(np.uint8)
  149. contours, hierarchy = cv2.findContours(mask_binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  150. if len(contours)>0:
  151. d=hierarchy[0,:,3]<0 ; contours = np.array(contours)[d]
  152. cv2.drawContours(image,contours,-1,colors[0][::-1],3)
  153. coors = np.array([(np.mean(contours_x ,axis=0)+0.5).astype(np.int32)[0] for contours_x in contours])
  154. #print(mask_map.shape,coors.shape)
  155. typess = mask_map[ coors[:,1],coors[:,0]]
  156. #for jj,iclass in enumerate(typess):
  157. #print(iclass,colors)
  158. # cv2.drawContours(image,contours,-1, colors[iclass][::-1],3)
  159. return image
  160. def test_floater():
  161. from core.models.dinknet import DinkNet34_MultiOutput
  162. #create_model('DinkNet34_MultiOutput',[2,5])
  163. image_url = 'temp_pics/DJI_0645.JPG'
  164. nclass = [2,7]
  165. outresult=True
  166. weights = 'runs/thFloater/BiSeNet_MultiOutput/train/experiment_4/checkpoint.pth'
  167. model = BiSeNet_MultiOutput(nclass)
  168. outdir='temp'
  169. image_dir = '/host/workspace/WJ/data/thFloater/val/images/'
  170. image_url_list=glob.glob('%s/*'%(image_dir))
  171. segmodel = SegModel(model=model,nclass=nclass,weights=weights,device='cuda:9',multiOutput=True)
  172. for i,image_url in enumerate(image_url_list[0:10]) :
  173. image_array0 = cv2.imread(image_url)
  174. image_array0 = cv2.cvtColor(image_array0, cv2.COLOR_BGR2RGB) # cv2默认为bgr顺序
  175. imageH,imageW,_ = image_array0.shape
  176. #image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
  177. pred = segmodel.eval(image_array,outsize=None)
  178. time0=time.time()
  179. if isinstance(pred,list):
  180. binary = [predx.copy() for predx in pred]
  181. time1=time.time()
  182. mask_colors=[ { 'mask':pred[0] ,'index':range(1,2),'color':label_dic['water'][0:] },
  183. { 'mask':pred[1] ,'index':[1,2,3,4,5,6],'color':label_dic['floater'][0:] } ]
  184. result_draw = result_merge(image_array0,mask_colors)
  185. time2=time.time()
  186. if outresult:
  187. basename=os.path.splitext( os.path.basename(image_url))[0]
  188. outname=os.path.join(outdir,basename+'_draw.png')
  189. cv2.imwrite(outname,result_draw[:,:,:])
  190. print('##line151: time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
  191. def test_water_buildings():
  192. from core.models.bisenet import BiSeNet
  193. #image_url = 'temp_pics/DJI_0645.JPG'
  194. nclass = 3
  195. outresult=True
  196. weights = 'runs/thWaterBuilding/BiSeNet/train/experiment_2/checkpoint.pth'
  197. model = BiSeNet(nclass)
  198. outdir='temp'
  199. image_dir = '/home/thsw/WJ/data/river_buildings/'
  200. #image_dir = '/home/thsw/WJ/data/THWaterBuilding/val/images'
  201. image_url_list=glob.glob('%s/*'%(image_dir))
  202. segmodel = SegModel(model=model,nclass=nclass,weights=weights,device='cuda:0',multiOutput=False)
  203. for i,image_url in enumerate(image_url_list[0:]) :
  204. #image_url = '/home/thsw/WJ/data/THWaterBuilding/val/images/0anWqgmO9rGe1n8P.png'
  205. image_array0 = cv2.imread(image_url)
  206. imageH,imageW,_ = image_array0.shape
  207. image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
  208. pred = segmodel.eval(image_array,outsize=None)
  209. time0=time.time()
  210. if isinstance(pred,list):
  211. binary = [predx.copy() for predx in pred]
  212. #print(binary[0].shape)
  213. time1=time.time()
  214. mask_colors=[ { 'mask':pred ,'index':range(1,3),'color':label_dic['water_building'][1:] },
  215. #{ 'mask':pred[1] ,'index':[1,2,3,4,5,6],'color':label_dic['floater'][0:] }
  216. ]
  217. result_draw = result_merge_sep(image_array0,mask_colors)
  218. time2=time.time()
  219. if outresult:
  220. basename=os.path.splitext( os.path.basename(image_url))[0]
  221. outname=os.path.join(outdir,basename+'_draw.png')
  222. cv2.imwrite(outname,result_draw[:,:,:])
  223. print('##line294: time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
  224. def get_illegal_index(contours,hierarchy,water_dilate,overlap_threshold):
  225. out_index=[]
  226. if len(contours)>0:
  227. d=hierarchy[0,:,3]<0 ;
  228. contours = np.array(contours,dtype=object)[d]
  229. imageH,imageW = water_dilate.shape
  230. for ii,cont in enumerate(contours):
  231. cont = cont.astype(np.int32)
  232. build_area=np.zeros((imageH,imageW ))
  233. try:
  234. cv2.fillPoly(build_area,[cont[:,0,:]],1)
  235. area1=np.sum(build_area);area2=np.sum(build_area*water_dilate)
  236. if (area2/area1) >overlap_threshold:
  237. out_index.append(ii)
  238. except Exception as e:
  239. print('###read error:%s '%(e))
  240. print(cont.shape,type(cont),cont.dtype)
  241. return out_index
  242. def illBuildings(pred,image_array0):
  243. ##画出水体区域
  244. contours, hierarchy = cv2.findContours(pred[0],cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  245. water = pred[0].copy(); water[:,:] = 0
  246. if len(contours)==0:
  247. return image_array0,water
  248. max_id = get_largest_contours(contours);
  249. cv2.fillPoly(water, [contours[max_id][:,0,:]], 1)
  250. cv2.drawContours(image_array0,contours,max_id,(0,255,255),3)
  251. ##画出水体膨胀后的蓝线区域。
  252. kernel = np.ones((100,100),np.uint8)
  253. water_dilate = cv2.dilate(water,kernel,iterations = 1)
  254. contours, hierarchy = cv2.findContours(water_dilate,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  255. #print('####line310:',contours)
  256. cv2.drawContours(image_array0,contours,-1,(255,0,0),3)
  257. ##确定违法建筑并绘图
  258. ###逐个建筑判断是否与蓝线内区域有交叉。如果交叉面积占本身面积超过0.1,则认为是违法建筑。
  259. contours, hierarchy = cv2.findContours(pred[1],cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
  260. outIndex=get_illegal_index(contours,hierarchy,water_dilate,0.1)
  261. for ii in outIndex:
  262. cv2.drawContours(image_array0,contours,ii,(0,0,255),3)
  263. return image_array0,water
  264. def test_water_building_seperately():
  265. #from core.models.dinknet import DinkNet34_MultiOutput
  266. #create_model('DinkNet34_MultiOutput',[2,5])
  267. image_url = 'temp_pics/DJI_0645.JPG'
  268. nclass = [2,2]
  269. outresult=True
  270. weights = '../weights/BiSeNet/checkpoint.pth'
  271. model = BiSeNet_MultiOutput(nclass)
  272. outdir='temp'
  273. image_dir = '/home/thsw/WJ/data/river_buildings/'
  274. #image_dir = '/home/thsw/WJ/data/THWaterBuilding/val/images'
  275. image_url_list=glob.glob('%s/*'%(image_dir))
  276. #segmodel = SegModel(model=model,nclass=nclass,weights=weights,device='cuda:1',multiOutput=True)
  277. segmodel = SegModel(nclass=nclass,weights=weights,device='cuda:1')
  278. print('###line307 image cnt:',len(image_url_list))
  279. for i,image_url in enumerate(image_url_list[0:1]) :
  280. image_url = '/home/thsw/WJ/data/river_buildings/DJI_20210904092044_0001_S_output896.jpg'
  281. image_array0 = cv2.imread(image_url)
  282. imageH,imageW,_ = image_array0.shape
  283. image_array = cv2.cvtColor( image_array0,cv2.COLOR_RGB2BGR)
  284. pred = segmodel.eval(image_array,outsize=None,smooth_kernel=20)
  285. image_array0,water = illBuildings(pred,image_array0)
  286. plt.imshow(image_array0);plt.show()
  287. ##
  288. time0=time.time()
  289. time1=time.time()
  290. mask_colors=[ { 'mask':pred[0],'index':[1],'color':label_dic['water_building'][1:2]},
  291. { 'mask':pred[1],'index':[1],'color':label_dic['water_building'][2:3]}
  292. ]
  293. result_draw = result_merge_sep(image_array0,mask_colors)
  294. time2=time.time()
  295. if outresult:
  296. basename=os.path.splitext( os.path.basename(image_url))[0]
  297. outname=os.path.join(outdir,basename+'_draw.png')
  298. cv2.imwrite(outname,result_draw[:,:,:])
  299. print('##line151: time copy:%.1f finccontour:%.1f '%(get_ms(time1,time0),get_ms(time2,time1) ))
  300. if __name__=='__main__':
  301. #test()
  302. #test_floater()
  303. #test_water_buildings()
  304. test_water_building_seperately()