地物分类项目代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1098 lines
55KB

  1. # 如果多类别存在嵌套关系,则删除被包含的类别,并将其颜色处理为被保留下来的类别所对应的颜色,同时过滤掉面积较小的区域。
  2. # 优化建筑物分割区域的边界,使其更加规整。
  3. from models.model_stages import BiSeNet
  4. from predict_city.heliushuju import Heliushuju
  5. from torch.utils.data import DataLoader
  6. import torch.nn.functional as F
  7. import pandas as pd
  8. import numpy as np
  9. from PIL import Image
  10. import time
  11. import os
  12. os.environ["OPENCV_IO_MAX_IMAGE_PIXELS"] = pow(2, 40).__str__()
  13. import cv2
  14. import argparse
  15. import torch
  16. import torchvision.transforms as transforms
  17. from osgeo import gdal, gdal_array, ogr, osr
  18. from rdp_alg import rdp
  19. from cal_dist_ang import cal_ang, cal_dist, azimuthAngle
  20. from rotate_ang import Nrotation_angle_get_coor_coordinates, Srotation_angle_get_coor_coordinates
  21. from line_intersection import line, intersection, par_line_dist, point_in_line
  22. import shutil
  23. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  24. os.environ['PROJ_LIB'] = r'/home/thsw/anaconda3/envs/zyy-torch1.10/lib/python3.8/site-packages/pyproj/proj_dir/share/proj'
  25. # 将某类别的分割结果中一contours和该contours中的坐标点放在一个列表中存储起来
  26. def classInfo(contours):
  27. content = []
  28. if len(contours) != 0:
  29. for i in range(len(contours)):
  30. COOR = []
  31. x = contours[i][:, :, 0]
  32. y = contours[i][:, :, 1]
  33. for j in range(len(x)):
  34. COOR.append((x[j][0], y[j][0]))
  35. content.append([COOR, contours[i]])
  36. return content
  37. # 就某一类别而言,返回其两种类型的contours
  38. def single(Content1, Content2, Content3, Content4, Content5, Content6, Content7):
  39. tempCnt2 = []
  40. tempCnt3 = []
  41. tempCnt4 = []
  42. tempCnt5 = []
  43. tempCnt6 = []
  44. tempCnt7 = []
  45. tempCOOR2 = []
  46. tempCOOR3 = []
  47. tempCOOR4 = []
  48. tempCOOR5 = []
  49. tempCOOR6 = []
  50. tempCOOR7 = []
  51. # 将分割结果中某一类别的contours存储在tempCnt*中,并将contours中的坐标存储在tempCOOR*中
  52. for i in range(len(Content2)):
  53. tempCnt2.append(Content2[i][1])
  54. tempCOOR2 = tempCOOR2 + Content2[i][0]
  55. for i in range(len(Content3)):
  56. tempCnt3.append(Content3[i][1])
  57. tempCOOR3 = tempCOOR3 + Content3[i][0]
  58. for i in range(len(Content4)):
  59. tempCnt4.append(Content4[i][1])
  60. tempCOOR4 = tempCOOR4 + Content4[i][0]
  61. for i in range(len(Content5)):
  62. tempCnt5.append(Content5[i][1])
  63. tempCOOR5 = tempCOOR5 + Content5[i][0]
  64. for i in range(len(Content6)):
  65. tempCnt6.append(Content6[i][1])
  66. tempCOOR6 = tempCOOR6 + Content6[i][0]
  67. for i in range(len(Content7)):
  68. tempCnt7.append(Content7[i][1])
  69. tempCOOR7 = tempCOOR7 + Content7[i][0]
  70. # 就tempCnt1而言,把其他类别对应的contours汇总存储到elseCnt中
  71. elseCnt = tempCnt2 + tempCnt3 + tempCnt4 + tempCnt5 + tempCnt6 + tempCnt7
  72. # 就tempCOOR1而言,把其他类别对应的tempCOOR*汇总存储到elseCOOR中
  73. elseCOOR = tempCOOR2 + tempCOOR3 + tempCOOR4 + tempCOOR5 + tempCOOR6 + tempCOOR7
  74. target = []
  75. questionCnt = []
  76. for i in range(len(Content1)):
  77. selfCOOR = Content1[i][0] # 自身坐标列表
  78. selfCnt = Content1[i][1] # 自身contours
  79. flag1 = False
  80. flag2 = False
  81. # 判断某类别的一个contours中的坐标点是否位于其他类别的contours中,只要有一个坐标点位于其他类别的contours中,则将该类别的这个contours和contours中的坐标点存储在questionCnt中
  82. for j in range(len(elseCnt)):
  83. for k in range(len(selfCOOR)):
  84. x = int(selfCOOR[k][0])
  85. y = int(selfCOOR[k][1])
  86. flag = cv2.pointPolygonTest(elseCnt[j], (x, y), False) # 自身的contours是否在其他contours中
  87. if flag >= 0:
  88. questionCnt.append([selfCOOR, selfCnt])
  89. flag1 = True
  90. break
  91. if flag1 == True:
  92. break
  93. # 判断其他类别的一个contours中的坐标点是否位于该类别的contours中,只要有一个坐标点位于该类别的contours中,则将该类别的这个contours和contours中的坐标点存储在questionCnt中
  94. if flag1 == False:
  95. for m in range(len(elseCOOR)):
  96. x = int(elseCOOR[m][0])
  97. y = int(elseCOOR[m][1])
  98. flag = cv2.pointPolygonTest(selfCnt, (x, y), False) # 其他contours是否在自身contours内部
  99. if flag >= 0:
  100. questionCnt.append([selfCOOR, selfCnt])
  101. flag2 = True
  102. break
  103. # 就某一类别而言,如果它的一个contours既没有被包含于其他类别的contours中,也没有包含其他类别的contours,则将该类别的这一contours存储在target列表中
  104. if flag2 == False:
  105. target.append(selfCnt)
  106. return target, questionCnt
  107. # 就某一类别而言,基于single()函数返回结果中的questionCnt,如果该类别的contours没有被包含于其他类别的contours中,则将其存储在externalCnt列表中,
  108. # 最后,若externalCnt列表的长度不为零,则基于mask和externalCnt,进行填充,以过滤掉questionCnt对应的contours中包含的其他类别的contours
  109. def findExternalCnt(questionCnt1, questionCnt2, questionCnt3, questionCnt4, questionCnt5, questionCnt6, questionCnt7,
  110. mask, value):
  111. externalCnt = []
  112. elseCnt = []
  113. if len(questionCnt1) != 0:
  114. elseQuestionCnt = questionCnt2 + questionCnt3 + questionCnt4 + questionCnt5 + questionCnt6 + questionCnt7
  115. for m in range(len(elseQuestionCnt)):
  116. elseCnt.append(elseQuestionCnt[m][1])
  117. for i in range(len(questionCnt1)):
  118. selfCOOR = questionCnt1[i][0] # 自身坐标列表
  119. selfCnt = questionCnt1[i][1] # 自身contours
  120. flag1 = False
  121. for j in range(len(elseCnt)):
  122. for k in range(len(selfCOOR)):
  123. x = int(selfCOOR[k][0])
  124. y = int(selfCOOR[k][1])
  125. flag = cv2.pointPolygonTest(elseCnt[j], (x, y), False) # 自身的contours是否在其他contours中
  126. if flag >= 0:
  127. flag1 = True
  128. break
  129. if flag1 == True:
  130. break
  131. if flag1 == False:
  132. externalCnt.append(selfCnt)
  133. if len(externalCnt) != 0:
  134. cv2.fillPoly(mask, externalCnt, color=value)
  135. return questionCnt1, questionCnt2, questionCnt3, questionCnt4, questionCnt5, questionCnt6, questionCnt7, mask
  136. # 就某一类别而言,基于single()函数返回结果中的target,过滤掉面积小于2000的contours,然后对剩余contours进行填充,并返回填充结果mask
  137. def filterArea(targetCnt, mask, value):
  138. # 过滤面积较小的区域
  139. targetNew = []
  140. for i in range(len(targetCnt)):
  141. cnt = targetCnt[i]
  142. cntArea = cv2.contourArea(cnt)
  143. if cntArea >= 2000:
  144. targetNew.append(cnt)
  145. cv2.fillPoly(mask, targetNew, color=value)
  146. return mask
  147. # 返回基于面积和包含关系过滤后的分割图像
  148. def judgeTermination(questionBuild, questionRoad, questionWater, questionFarmland, questionGrass, questionWoodland,
  149. questionBareSoil, mask):
  150. questionBuild, questionRoad, questionWater, questionFarmland, questionGrass, questionWoodland, questionBareSoil, mask = findExternalCnt(
  151. questionBuild, questionRoad, questionWater, questionFarmland, questionGrass, questionWoodland, questionBareSoil,
  152. mask, 1)
  153. questionRoad, questionBuild, questionWater, questionFarmland, questionGrass, questionWoodland, questionBareSoil, mask = findExternalCnt(
  154. questionRoad, questionBuild, questionWater, questionFarmland, questionGrass, questionWoodland, questionBareSoil,
  155. mask, 2)
  156. questionWater, questionBuild, questionRoad, questionFarmland, questionGrass, questionWoodland, questionBareSoil, mask = findExternalCnt(
  157. questionWater, questionBuild, questionRoad, questionFarmland, questionGrass, questionWoodland, questionBareSoil,
  158. mask, 3)
  159. questionFarmland, questionBuild, questionRoad, questionWater, questionGrass, questionWoodland, questionBareSoil, mask = findExternalCnt(
  160. questionFarmland, questionBuild, questionRoad, questionWater, questionGrass, questionWoodland, questionBareSoil,
  161. mask, 4)
  162. questionGrass, questionBuild, questionRoad, questionWater, questionFarmland, questionWoodland, questionBareSoil, mask = findExternalCnt(
  163. questionGrass, questionBuild, questionRoad, questionWater, questionFarmland, questionWoodland, questionBareSoil,
  164. mask, 5)
  165. questionWoodland, questionBuild, questionRoad, questionWater, questionFarmland, questionGrass, questionBareSoil, mask = findExternalCnt(
  166. questionWoodland, questionBuild, questionRoad, questionWater, questionFarmland, questionGrass, questionBareSoil,
  167. mask, 6)
  168. questionBareSoil, questionBuild, questionRoad, questionWater, questionFarmland, questionGrass, questionWoodland, mask = findExternalCnt(
  169. questionBareSoil, questionBuild, questionRoad, questionWater, questionFarmland, questionGrass, questionWoodland,
  170. mask, 7)
  171. return mask
  172. # 依据某一类别分割区域的面积大小及该类别的分割区域和其他类别的分割区域之间的包含关系,对原始分割结果进行优化,并返回优化后的分割结果
  173. def optimize(preds):
  174. h, w = preds.shape[0], preds.shape[1]
  175. mask = np.zeros((h, w), dtype="uint8")
  176. building = preds.copy()
  177. road = preds.copy()
  178. water = preds.copy()
  179. farmland = preds.copy()
  180. grass = preds.copy()
  181. woodland = preds.copy()
  182. bareSoil = preds.copy()
  183. building[building != 1] = 0 # 建筑物
  184. road[road != 2] = 0 # 道路
  185. water[water != 3] = 0 # 水体
  186. farmland[farmland != 4] = 0 # 耕地
  187. grass[grass != 5] = 0 # 草地
  188. woodland[woodland != 6] = 0 # 林地
  189. bareSoil[bareSoil != 7] = 0 # 裸土
  190. # 返回分割结果中单一类别的contours
  191. buildCnt, hierarchy = cv2.findContours(np.uint8(building), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  192. roadCnt, hierarchy = cv2.findContours(np.uint8(road), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  193. waterCnt, hierarchy = cv2.findContours(np.uint8(water), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  194. farmlandCnt, hierarchy = cv2.findContours(np.uint8(farmland), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  195. grassCnt, hierarchy = cv2.findContours(np.uint8(grass), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  196. woodlandCnt, hierarchy = cv2.findContours(np.uint8(woodland), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  197. bareSoilCnt, hierarchy = cv2.findContours(np.uint8(bareSoil), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  198. # 将某类别的分割结果中一contours和该contours中的坐标点放在一个列表中存储起来
  199. buildContent = classInfo(buildCnt) # [ [[(x1, y1), (x2, y2)], cnt], ... ]
  200. roadContent = classInfo(roadCnt)
  201. waterContent = classInfo(waterCnt)
  202. farmlandContent = classInfo(farmlandCnt)
  203. grassContent = classInfo(grassCnt)
  204. woodlandContent = classInfo(woodlandCnt)
  205. bareSoilContent = classInfo(bareSoilCnt)
  206. # 就某一类别而言,返回其两种类型的contours
  207. targetBuild, questionBuild = single(buildContent, roadContent, waterContent, farmlandContent, grassContent,
  208. woodlandContent, bareSoilContent)
  209. targetRoad, questionRoad = single(roadContent, buildContent, waterContent, farmlandContent, grassContent,
  210. woodlandContent, bareSoilContent)
  211. targetWater, questionWater = single(waterContent, buildContent, roadContent, farmlandContent, grassContent,
  212. woodlandContent, bareSoilContent)
  213. targetFarmland, questionFarmland = single(farmlandContent, buildContent, roadContent, waterContent, grassContent,
  214. woodlandContent, bareSoilContent)
  215. targetGrass, questionGrass = single(grassContent, buildContent, roadContent, waterContent, farmlandContent,
  216. woodlandContent, bareSoilContent)
  217. targetWoodland, questionWoodland = single(woodlandContent, buildContent, roadContent, waterContent, farmlandContent,
  218. grassContent, bareSoilContent)
  219. targetBareSoil, questionBareSoil = single(bareSoilContent, buildContent, roadContent, waterContent, farmlandContent,
  220. grassContent, woodlandContent)
  221. if len(targetBuild) != 0:
  222. # 就某一类别而言,基于single()函数返回结果中的target,过滤掉面积小于2000的contours,然后对剩余contours进行填充,并返回填充结果mask
  223. mask = filterArea(targetBuild, mask, 1)
  224. if len(targetRoad) != 0:
  225. mask = filterArea(targetRoad, mask, 2)
  226. if len(targetWater) != 0:
  227. mask = filterArea(targetWater, mask, 3)
  228. if len(targetFarmland) != 0:
  229. mask = filterArea(targetFarmland, mask, 4)
  230. if len(targetGrass) != 0:
  231. mask = filterArea(targetGrass, mask, 5)
  232. if len(targetWoodland) != 0:
  233. mask = filterArea(targetWoodland, mask, 6)
  234. if len(targetBareSoil) != 0:
  235. mask = filterArea(targetBareSoil, mask, 7)
  236. # 返回基于面积和包含关系过滤后的分割图像
  237. preds = judgeTermination(questionBuild, questionRoad, questionWater, questionFarmland, questionGrass,
  238. questionWoodland, questionBareSoil, mask)
  239. return preds
  240. # 对原始图像的宽进行讨论,判定是否需要填充像素,最终,返回x方向上可切分的图像个数及拼接后的图像
  241. def wideDirection(W, H, terrainClass, img):
  242. w_num = 0
  243. if W < terrainClass['w']:
  244. x_sup = terrainClass['w'] - W
  245. left = np.zeros((H, x_sup, 3), dtype='uint8')
  246. img = np.concatenate((img, left), axis=1)
  247. w_num = 1
  248. elif W == terrainClass['w']:
  249. w_num = 1
  250. else:
  251. for j in range(W):
  252. x_pixel = terrainClass['w'] * j - terrainClass['overlapX'] * (j - 1)
  253. if x_pixel - W == 0:
  254. w_num = j
  255. break
  256. elif x_pixel - W > 0:
  257. x_sup = x_pixel - W
  258. left = np.zeros((H, x_sup, 3), dtype='uint8')
  259. img = np.concatenate((img, left), axis=1)
  260. w_num = j
  261. break
  262. return w_num, img
  263. # 判定是否需要对原始图像进行填充,填充方向为高方向和宽方向,若需要填充,则对原始图像进行填充。最后返回填充后的图像及填充后的图像在高方向和宽方向上分别可被切分的个数。
  264. def img_sup(img, terrainClass):
  265. H, W = img.shape[0], img.shape[1]
  266. h_num = 0
  267. if H < terrainClass['h']:
  268. y_sup = terrainClass['h'] - H
  269. up = np.zeros((y_sup, W, 3), dtype='uint8')
  270. img = np.concatenate((img, up), axis=0)
  271. H = H + y_sup
  272. h_num = 1
  273. w_num, img = wideDirection(W, H, terrainClass, img)
  274. elif H == terrainClass['h']:
  275. h_num = 1
  276. w_num, img = wideDirection(W, H, terrainClass, img)
  277. else:
  278. for j in range(H):
  279. y_pixel = terrainClass['h'] * j - terrainClass['overlapY'] * (j - 1)
  280. if y_pixel - H == 0:
  281. h_num = j
  282. break
  283. elif y_pixel - H > 0:
  284. y_sup = y_pixel - H # 1264
  285. up = np.zeros((y_sup, W, 3), dtype='uint8')
  286. img = np.concatenate((img, up), axis=0)
  287. H = H + y_sup
  288. h_num = j
  289. break
  290. w_num, img = wideDirection(W, H, terrainClass, img)
  291. # print("line325", y_sup, x_sup, w_num, img.shape)
  292. return img, h_num, w_num
  293. # 利用STDC网络模型对切分的图像进行预测,返回预测结果
  294. def predict(img, self, size, net, label_info, terrainClass):
  295. img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
  296. img = self.to_tensor(img)
  297. img = img.cuda()
  298. img = torch.unsqueeze(img, dim=0)
  299. img = F.interpolate(img, size, mode='bilinear', align_corners=True)
  300. logits = net(img)[0]
  301. logits = F.interpolate(logits, size=size, mode='bilinear', align_corners=True)
  302. probs = torch.softmax(logits, dim=1)
  303. preds = torch.argmax(probs, dim=1)
  304. preds_squeeze = preds.squeeze(0)
  305. preds = preds_squeeze.cpu().numpy()
  306. # 调用优化函数optimize(),该优化函数的作用是:如果多类别间存在嵌套关系,则删除被包含的类别,并将其颜色处理为被保留下来的类别所对应的颜色,同时过滤掉面积较小的区域。
  307. preds = optimize(preds)
  308. preds = colour_code_segmentation(preds, label_info) # yuanshi
  309. preds = cv2.cvtColor(np.uint8(preds), cv2.COLOR_RGB2BGR)
  310. img = cv2.resize(preds.astype("uint8"), (terrainClass['h'], terrainClass['w']))
  311. return img
  312. # 裁剪形式1
  313. def Cropping1(img, terrainClass, i):
  314. img1 = img[0:terrainClass['h'], (i * terrainClass['w'] - i * terrainClass['overlapX']):(
  315. (i + 1) * terrainClass['w'] - i * terrainClass['overlapX'] - 1), :]
  316. return img1
  317. # 裁剪形式2
  318. def Cropping2(img, terrainClass, i):
  319. img1 = img[(i * terrainClass['h'] - i * terrainClass['overlapY']):(
  320. (i + 1) * terrainClass['h'] - i * terrainClass['overlapY'] - 1), 0:terrainClass['w'], :]
  321. return img1
  322. # 裁剪形式3
  323. def Cropping3(img, terrainClass, j):
  324. img1 = img[(j * terrainClass['h'] - j * terrainClass['overlapY']):(
  325. (j + 1) * terrainClass['h'] - j * terrainClass['overlapY'] - 1), 0:terrainClass['w'], :]
  326. return img1
  327. # 裁剪形式4
  328. def Cropping4(img, terrainClass, i, j):
  329. img1 = img[(j * terrainClass['h'] - j * terrainClass['overlapY']):(
  330. (j + 1) * terrainClass['h'] - j * terrainClass['overlapY'] - 1),
  331. (i * terrainClass['w'] - i * terrainClass['overlapX']):(
  332. (i + 1) * terrainClass['w'] - i * terrainClass['overlapX'] - 1), :]
  333. return img1
  334. # 切分图像的横向拼接
  335. def transverseConcatenate(savePredict):
  336. img1 = savePredict[0]
  337. for j in range(1, len(savePredict)):
  338. img1 = np.concatenate((img1, savePredict[j]), axis=1)
  339. return img1
  340. # 切分图像的纵向拼接
  341. def longitudinalConcatenate(concatenateList, img):
  342. if len(concatenateList) == 0:
  343. concatenateList.append(img)
  344. else:
  345. concatenateList = [np.concatenate((concatenateList[0], img), axis=0)]
  346. return concatenateList
  347. # 分四种情况讨论拼接过程
  348. def concatenateImage(h_num, w_num, self, size, img, net, label_info, terrainClass, h, w, save_path, tifFile, x1, x2, x3,
  349. x4, y1, y2, y3, y4):
  350. # 情形1,即:H <= terrain['h'], W <= terrain['w']
  351. if h_num == 1 and w_num == 1:
  352. img1 = img
  353. img1 = predict(img1, self, size, net, label_info, terrainClass)
  354. finalResult = img1[0:h, 0:w]
  355. finalResult = cv2.cvtColor(finalResult, cv2.COLOR_BGR2GRAY)
  356. # 情形2,即:H <= terrain['h'], W > terrain['w']
  357. elif h_num == 1 and w_num > 1:
  358. savePredict = []
  359. for i in range(w_num):
  360. if i == 0:
  361. img1 = img[0:terrainClass['h'], 0:terrainClass['w'], :]
  362. img1 = predict(img1, self, size, net, label_info, terrainClass)
  363. if terrainClass['overlapX'] % 2 == 1:
  364. img1 = img1[0:terrainClass['h'], 0:x3 + 1]
  365. else:
  366. img1 = img1[0:terrainClass['h'], 0:x4 + 1]
  367. savePredict.append(img1)
  368. elif i > 0 and i < w_num - 1:
  369. img1 = Cropping1(img, terrainClass, i)
  370. img1 = predict(img1, self, size, net, label_info, terrainClass)
  371. if terrainClass['overlapX'] % 2 == 1:
  372. img1 = img1[0:terrainClass['h'], x2:x3 + 1]
  373. else:
  374. img1 = img1[0:terrainClass['h'], x1:x4 + 1]
  375. savePredict.append(img1)
  376. else:
  377. img1 = Cropping1(img, terrainClass, i)
  378. img1 = predict(img1, self, size, net, label_info, terrainClass)
  379. if terrainClass['overlapX'] % 2 == 1:
  380. img1 = img1[0:terrainClass['h'], x2:]
  381. else:
  382. img1 = img1[0:terrainClass['h'], x1:]
  383. savePredict.append(img1)
  384. img1 = transverseConcatenate(savePredict)
  385. finalResult = img1[0:h, 0:w]
  386. finalResult = cv2.cvtColor(finalResult, cv2.COLOR_BGR2GRAY)
  387. # 情形3,即:H > terrain['h'], W <= terrain['w']
  388. elif h_num > 1 and w_num == 1:
  389. for i in range(h_num):
  390. if i == 0:
  391. img1 = img[0:terrainClass['h'], 0:terrainClass['w'], :]
  392. img1 = predict(img1, self, size, net, label_info, terrainClass)
  393. if terrainClass['overlapY'] % 2 == 1:
  394. img1 = img1[0:y3 + 1, 0:terrainClass['w']]
  395. else:
  396. img1 = img1[0:y4 + 1, 0:terrainClass['w']]
  397. concatenateList = [img1] # 存储拼接图象的中间结果
  398. elif i > 0 and i < h_num - 1:
  399. img1 = Cropping2(img, terrainClass, i)
  400. img1 = predict(img1, self, size, net, label_info, terrainClass)
  401. if terrainClass['overlapY'] % 2 == 1:
  402. img1 = img1[y2:y3 + 1, 0:terrainClass['w']]
  403. else:
  404. img1 = img1[y1:y4 + 1, 0:terrainClass['w']]
  405. concatenateList = [np.concatenate((concatenateList[0], img1), axis=0)]
  406. else:
  407. img1 = Cropping2(img, terrainClass, i)
  408. img1 = predict(img1, self, size, net, label_info, terrainClass)
  409. if terrainClass['overlapY'] % 2 == 1:
  410. img1 = img1[y2:, 0:terrainClass['w']]
  411. else:
  412. img1 = img1[y1:, 0:terrainClass['w']]
  413. concatenateList = [np.concatenate((concatenateList[0], img1), axis=0)]
  414. finalResult = concatenateList[0][0:h, 0:w]
  415. finalResult = cv2.cvtColor(finalResult, cv2.COLOR_BGR2GRAY)
  416. # 情形4,即:H > terrain['h'], W > terrain['w']
  417. elif h_num > 1 and w_num > 1:
  418. concatenateList = [] # 存储拼接图象的中间结果
  419. for j in range(h_num):
  420. if j == 0:
  421. savePredict = []
  422. for i in range(w_num):
  423. if i == 0:
  424. img1 = img[0:terrainClass['h'], 0:terrainClass['w'], :]
  425. img1 = predict(img1, self, size, net, label_info, terrainClass)
  426. if terrainClass['overlapX'] % 2 == 1 and terrainClass['overlapY'] % 2 == 1:
  427. savePredict.append(img1[0:y3 + 1, 0:x3 + 1])
  428. elif terrainClass['overlapX'] % 2 == 1 and terrainClass['overlapY'] % 2 == 0:
  429. savePredict.append(img1[0:y4 + 1, 0:x3 + 1])
  430. elif terrainClass['overlapX'] % 2 == 0 and terrainClass['overlapY'] % 2 == 0:
  431. savePredict.append(img1[0:y4 + 1, 0:x4 + 1])
  432. elif terrainClass['overlapX'] % 2 == 0 and terrainClass['overlapY'] % 2 == 1:
  433. savePredict.append(img1[0:y3 + 1, 0:x4 + 1])
  434. elif i > 0 and i < w_num - 1:
  435. img1 = Cropping1(img, terrainClass, i)
  436. img1 = predict(img1, self, size, net, label_info, terrainClass)
  437. if terrainClass['overlapX'] % 2 == 1 and terrainClass['overlapY'] % 2 == 1:
  438. savePredict.append(img1[0:y3 + 1, x2:x3 + 1])
  439. elif terrainClass['overlapX'] % 2 == 1 and terrainClass['overlapY'] % 2 == 0:
  440. savePredict.append(img1[0:y4 + 1, x2:x3 + 1])
  441. elif terrainClass['overlapX'] % 2 == 0 and terrainClass['overlapY'] % 2 == 0:
  442. savePredict.append(img1[0:y4 + 1, x1:x4 + 1])
  443. elif terrainClass['overlapX'] % 2 == 0 and terrainClass['overlapY'] % 2 == 1:
  444. savePredict.append(img1[0:y3 + 1, x1:x4 + 1])
  445. else:
  446. img1 = Cropping1(img, terrainClass, i)
  447. img1 = predict(img1, self, size, net, label_info, terrainClass)
  448. if terrainClass['overlapX'] % 2 == 1 and terrainClass['overlapY'] % 2 == 1:
  449. savePredict.append(img1[0:y3 + 1, x2:])
  450. elif terrainClass['overlapX'] % 2 == 1 and terrainClass['overlapY'] % 2 == 0:
  451. savePredict.append(img1[0:y4 + 1, x2:])
  452. elif terrainClass['overlapX'] % 2 == 0 and terrainClass['overlapY'] % 2 == 0:
  453. savePredict.append(img1[0:y4 + 1, x1:])
  454. elif terrainClass['overlapX'] % 2 == 0 and terrainClass['overlapY'] % 2 == 1:
  455. savePredict.append(img1[0:y3 + 1, x1:])
  456. img1 = transverseConcatenate(savePredict)
  457. concatenateList = longitudinalConcatenate(concatenateList, img1)
  458. elif j > 0 and j < h_num - 1:
  459. savePredict = []
  460. for i in range(w_num):
  461. if i == 0:
  462. img1 = Cropping3(img, terrainClass, j)
  463. img1 = predict(img1, self, size, net, label_info, terrainClass)
  464. if terrainClass['overlapX'] % 2 == 1 and terrainClass['overlapY'] % 2 == 1:
  465. savePredict.append(img1[y2:y3 + 1, 0:x3 + 1])
  466. elif terrainClass['overlapX'] % 2 == 1 and terrainClass['overlapY'] % 2 == 0:
  467. savePredict.append(img1[y1:y4 + 1, 0:x3 + 1])
  468. elif terrainClass['overlapX'] % 2 == 0 and terrainClass['overlapY'] % 2 == 0:
  469. savePredict.append(img1[y1:y4 + 1, 0:x4 + 1])
  470. elif terrainClass['overlapX'] % 2 == 0 and terrainClass['overlapY'] % 2 == 1:
  471. savePredict.append(img1[y2:y3 + 1, 0:x4 + 1])
  472. elif i > 0 and i < w_num - 1:
  473. img1 = Cropping4(img, terrainClass, i, j)
  474. img1 = predict(img1, self, size, net, label_info, terrainClass)
  475. if terrainClass['overlapX'] % 2 == 1 and terrainClass['overlapY'] % 2 == 1:
  476. savePredict.append(img1[y2:y3 + 1, x2:x3 + 1])
  477. elif terrainClass['overlapX'] % 2 == 1 and terrainClass['overlapY'] % 2 == 0:
  478. savePredict.append(img1[y1:y4 + 1, x2:x3 + 1])
  479. elif terrainClass['overlapX'] % 2 == 0 and terrainClass['overlapY'] % 2 == 0:
  480. savePredict.append(img1[y1:y4 + 1, x1:x4 + 1])
  481. elif terrainClass['overlapX'] % 2 == 0 and terrainClass['overlapY'] % 2 == 1:
  482. savePredict.append(img1[y2:y3 + 1, x1:x4 + 1])
  483. else:
  484. img1 = Cropping4(img, terrainClass, i, j)
  485. img1 = predict(img1, self, size, net, label_info, terrainClass)
  486. if terrainClass['overlapX'] % 2 == 1 and terrainClass['overlapY'] % 2 == 1:
  487. savePredict.append(img1[y2:y3 + 1, x2:])
  488. elif terrainClass['overlapX'] % 2 == 1 and terrainClass['overlapY'] % 2 == 0:
  489. savePredict.append(img1[y1:y4 + 1, x2:])
  490. elif terrainClass['overlapX'] % 2 == 0 and terrainClass['overlapY'] % 2 == 0:
  491. savePredict.append(img1[y1:y4 + 1, x1:])
  492. elif terrainClass['overlapX'] % 2 == 0 and terrainClass['overlapY'] % 2 == 1:
  493. savePredict.append(img1[y2:y3 + 1, x1:])
  494. img1 = transverseConcatenate(savePredict)
  495. concatenateList = longitudinalConcatenate(concatenateList, img1)
  496. else:
  497. savePredict = []
  498. for i in range(w_num):
  499. if i == 0:
  500. img1 = Cropping3(img, terrainClass, j)
  501. img1 = predict(img1, self, size, net, label_info, terrainClass)
  502. if terrainClass['overlapX'] % 2 == 1 and terrainClass['overlapY'] % 2 == 1:
  503. savePredict.append(img1[y2:, 0:x3 + 1])
  504. elif terrainClass['overlapX'] % 2 == 1 and terrainClass['overlapY'] % 2 == 0:
  505. savePredict.append(img1[y1:, 0:x3 + 1])
  506. elif terrainClass['overlapX'] % 2 == 0 and terrainClass['overlapY'] % 2 == 0:
  507. savePredict.append(img1[y1:, 0:x4 + 1])
  508. elif terrainClass['overlapX'] % 2 == 0 and terrainClass['overlapY'] % 2 == 1:
  509. savePredict.append(img1[y2:, 0:x4 + 1])
  510. elif i > 0 and i < w_num - 1:
  511. img1 = Cropping4(img, terrainClass, i, j)
  512. img1 = predict(img1, self, size, net, label_info, terrainClass)
  513. if terrainClass['overlapX'] % 2 == 1 and terrainClass['overlapY'] % 2 == 1:
  514. savePredict.append(img1[y2:, x2:x3 + 1])
  515. elif terrainClass['overlapX'] % 2 == 1 and terrainClass['overlapY'] % 2 == 0:
  516. savePredict.append(img1[y1:, x2:x3 + 1])
  517. elif terrainClass['overlapX'] % 2 == 0 and terrainClass['overlapY'] % 2 == 0:
  518. savePredict.append(img1[y1:, x1:x4 + 1])
  519. elif terrainClass['overlapX'] % 2 == 0 and terrainClass['overlapY'] % 2 == 1:
  520. savePredict.append(img1[y2:, x1:x4 + 1])
  521. else:
  522. img1 = Cropping4(img, terrainClass, i, j)
  523. img1 = predict(img1, self, size, net, label_info, terrainClass)
  524. if terrainClass['overlapX'] % 2 == 1 and terrainClass['overlapY'] % 2 == 1:
  525. savePredict.append(img1[y2:, x2:])
  526. elif terrainClass['overlapX'] % 2 == 1 and terrainClass['overlapY'] % 2 == 0:
  527. savePredict.append(img1[y1:, x2:])
  528. elif terrainClass['overlapX'] % 2 == 0 and terrainClass['overlapY'] % 2 == 0:
  529. savePredict.append(img1[y1:, x1:])
  530. elif terrainClass['overlapX'] % 2 == 0 and terrainClass['overlapY'] % 2 == 1:
  531. savePredict.append(img1[y2:, x1:])
  532. img1 = transverseConcatenate(savePredict)
  533. concatenateList = longitudinalConcatenate(concatenateList, img1)
  534. finalResult = concatenateList[0][0:h, 0:w]
  535. finalResult = cv2.cvtColor(finalResult, cv2.COLOR_BGR2GRAY)
  536. return finalResult
  537. # 读图像文件
  538. def read_img(filename):
  539. dataset = gdal.Open(filename) # 打开文件
  540. im_width = dataset.RasterXSize # 栅格矩阵的列数
  541. im_height = dataset.RasterYSize # 栅格矩阵的行数
  542. im_geotrans = dataset.GetGeoTransform() # 仿射矩阵
  543. im_proj = dataset.GetProjection() # 地图投影信息
  544. im_data = dataset.ReadAsArray(0, 0, im_width, im_height) # 将数据写成数组,对应栅格矩阵
  545. del dataset
  546. return im_proj, im_geotrans, im_data
  547. # 写文件,以写成tif为例
  548. def write_img(filename, im_proj, im_geotrans, im_data): # 文件名、地图投影信息、仿射矩阵,栅格矩阵
  549. # 判断栅格数据的数据类型
  550. if 'int8' in im_data.dtype.name:
  551. datatype = gdal.GDT_Byte
  552. elif 'int16' in im_data.dtype.name:
  553. datatype = gdal.GDT_UInt16
  554. else:
  555. datatype = gdal.GDT_Float32
  556. # 判读数组维数
  557. if len(im_data.shape) == 3:
  558. im_bands, im_height, im_width = im_data.shape
  559. else:
  560. im_bands, (im_height, im_width) = 1, im_data.shape
  561. # 创建文件
  562. driver = gdal.GetDriverByName("GTiff") # 数据类型必须有,因为要计算需要多大的内存空间
  563. dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
  564. dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
  565. dataset.SetProjection(im_proj) # 写入投影
  566. if im_bands == 1:
  567. dataset.GetRasterBand(1).WriteArray(im_data) # 写入数组数据
  568. else:
  569. for i in range(im_bands):
  570. dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
  571. del dataset
  572. # 优化建筑物轮廓
  573. def boundary_regularization(img, epsilon=6):
  574. h, w = img.shape[0:2]
  575. # 轮廓定位
  576. contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) # 检索所有轮廓
  577. contours = np.squeeze(contours[0]) # [[x1,y1], [x2, y2],...]
  578. # 轮廓精简(DP)
  579. contours = rdp(contours, epsilon=epsilon)
  580. contours[:, 1] = h - contours[:, 1]
  581. # 轮廓规则化
  582. dists = []
  583. azis = []
  584. azis_index = []
  585. # 获取每条边的长度和方位角
  586. for i in range(contours.shape[0]):
  587. cur_index = i
  588. next_index = i + 1 if i < contours.shape[0] - 1 else 0
  589. prev_index = i - 1
  590. cur_point = contours[cur_index]
  591. nest_point = contours[next_index]
  592. prev_point = contours[prev_index]
  593. dist = cal_dist(cur_point, nest_point) # 当前点到下一个点的距离
  594. azi = azimuthAngle(cur_point, nest_point) # 计算线条的方位角,线条的方位角是线条的逆时针方向与水平方向的夹角
  595. dists.append(dist)
  596. azis.append(azi)
  597. azis_index.append([cur_index, next_index])
  598. # 以最长的边的方向作为主方向
  599. longest_edge_idex = np.argmax(dists)
  600. main_direction = azis[longest_edge_idex] # 主方向与水平线在逆时针方向上的夹角
  601. # 方向纠正,绕中心点旋转到与主方向垂直或者平行
  602. correct_points = []
  603. para_vetr_idxs = [] # 0平行 1垂直
  604. for i, (azi, (point_0_index, point_1_index)) in enumerate(zip(azis, azis_index)):
  605. if i == longest_edge_idex:
  606. correct_points.append([contours[point_0_index], contours[point_1_index]])
  607. para_vetr_idxs.append(0)
  608. else:
  609. # 确定旋转角度
  610. rotate_ang = main_direction - azi
  611. if np.abs(rotate_ang) < 180 / 4:
  612. rotate_ang = rotate_ang
  613. para_vetr_idxs.append(0)
  614. elif np.abs(rotate_ang) >= 90 - 180 / 4:
  615. rotate_ang = rotate_ang + 90
  616. para_vetr_idxs.append(1)
  617. # 执行旋转任务
  618. point_0 = contours[point_0_index] # 当前点
  619. point_1 = contours[point_1_index] # 当前点的下一个点
  620. point_middle = (point_0 + point_1) / 2
  621. if rotate_ang > 0:
  622. rotate_point_0 = Srotation_angle_get_coor_coordinates(point_0, point_middle, np.abs(rotate_ang))
  623. rotate_point_1 = Srotation_angle_get_coor_coordinates(point_1, point_middle, np.abs(rotate_ang))
  624. elif rotate_ang < 0:
  625. rotate_point_0 = Nrotation_angle_get_coor_coordinates(point_0, point_middle, np.abs(rotate_ang))
  626. rotate_point_1 = Nrotation_angle_get_coor_coordinates(point_1, point_middle, np.abs(rotate_ang))
  627. else:
  628. rotate_point_0 = point_0
  629. rotate_point_1 = point_1
  630. correct_points.append([rotate_point_0, rotate_point_1])
  631. correct_points = np.array(correct_points)
  632. # 相邻边校正,垂直取交点,平行平移短边或者加线
  633. final_points = []
  634. final_points.append(correct_points[0][0])
  635. for i in range(correct_points.shape[0] - 1):
  636. cur_index = i
  637. next_index = i + 1 if i < correct_points.shape[0] - 1 else 0
  638. cur_edge_point_0 = correct_points[cur_index][0]
  639. cur_edge_point_1 = correct_points[cur_index][1]
  640. next_edge_point_0 = correct_points[next_index][0]
  641. next_edge_point_1 = correct_points[next_index][1]
  642. cur_para_vetr_idx = para_vetr_idxs[cur_index]
  643. next_para_vetr_idx = para_vetr_idxs[next_index]
  644. if cur_para_vetr_idx != next_para_vetr_idx:
  645. # 垂直取交点
  646. L1 = line(cur_edge_point_0, cur_edge_point_1)
  647. L2 = line(next_edge_point_0, next_edge_point_1)
  648. point_intersection = intersection(L1, L2) # 交点
  649. final_points.append(point_intersection)
  650. elif cur_para_vetr_idx == next_para_vetr_idx:
  651. # 平行分两种,一种加短线,一种平移,取决于距离阈值
  652. L1 = line(cur_edge_point_0, cur_edge_point_1)
  653. L2 = line(next_edge_point_0, next_edge_point_1)
  654. marg = par_line_dist(L1, L2) # 两个平行线之间的距离
  655. if marg < 3:
  656. # 平移
  657. point_move = point_in_line(next_edge_point_0[0], next_edge_point_0[1], cur_edge_point_0[0],
  658. cur_edge_point_0[1], cur_edge_point_1[0], cur_edge_point_1[1])
  659. final_points.append(point_move)
  660. # 更新平移之后的下一条边
  661. correct_points[next_index][0] = point_move
  662. correct_points[next_index][1] = point_in_line(next_edge_point_1[0], next_edge_point_1[1],
  663. cur_edge_point_0[0], cur_edge_point_0[1],
  664. cur_edge_point_1[0], cur_edge_point_1[1])
  665. else:
  666. # 加线
  667. add_mid_point = (cur_edge_point_1 + next_edge_point_0) / 2
  668. add_point_1 = point_in_line(add_mid_point[0], add_mid_point[1], cur_edge_point_0[0],
  669. cur_edge_point_0[1], cur_edge_point_1[0], cur_edge_point_1[1])
  670. add_point_2 = point_in_line(add_mid_point[0], add_mid_point[1], next_edge_point_0[0],
  671. next_edge_point_0[1], next_edge_point_1[0], next_edge_point_1[1])
  672. final_points.append(add_point_1)
  673. final_points.append(add_point_2)
  674. final_points.append(final_points[0])
  675. final_points = np.array(final_points)
  676. final_points[:, 1] = h - final_points[:, 1]
  677. return final_points
  678. # 将切分图像的预测结果进行拼接时,按照以下方式进行,以确保预测结果能无缝拼接,不论切分图像间的重叠像素个数为奇数还是偶数,都适用。
  679. def concatenateCOOR(terrainClass):
  680. y1 = int(terrainClass['overlapY'] / 2) # 3
  681. y2 = int((terrainClass['overlapY'] - 1) / 2 + 1) # 3
  682. y3 = int(terrainClass['h'] - 1 - (terrainClass['overlapY'] - 1) / 2) # 重叠为奇数,到中间那个数 253
  683. y4 = int(terrainClass['h'] - 1 - terrainClass['overlapY'] / 2) # 重叠为偶数,重叠为6,则到第三个
  684. x1 = int(terrainClass['overlapX'] / 2) # 3
  685. x2 = int((terrainClass['overlapX'] - 1) / 2 + 1) # 3
  686. x3 = int(terrainClass['w'] - 1 - (terrainClass['overlapX'] - 1) / 2) # 重叠为奇数,到中间那个数
  687. x4 = int(terrainClass['w'] - 1 - terrainClass['overlapX'] / 2) # 重叠为偶数,重叠为6,则到第三个
  688. return y1, y2, y3, y4, x1, x2, x3, x4
  689. # 基于原始分割结果,优化建筑物的边界,返回优化后的分割图像
  690. def optimizeBuilding(img, labelsDict):
  691. sourceSegImg = img.copy()
  692. sourceSegImg[sourceSegImg == labelsDict['building']] = 0 # 删除建筑物
  693. img[img != labelsDict['building']] = 0 # 只保留建筑物
  694. contours, hierarch = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
  695. # 由于存在一些边界处的噪声点,通过面积过滤的方式将这些噪声点过滤掉
  696. for i in range(len(contours)):
  697. buildingCnt = contours[i]
  698. buildingCntArea = cv2.contourArea(buildingCnt)
  699. if buildingCntArea < 2000:
  700. cv2.drawContours(img, [buildingCnt], 0, 0, -1) # 该轮廓区域填0
  701. ori_img1 = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  702. h, w = ori_img1.shape[0], ori_img1.shape[1]
  703. # 中值滤波,去噪
  704. ori_img = cv2.medianBlur(ori_img1, 5) # 滤波核大小为5
  705. ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2GRAY)
  706. ret, ori_img = cv2.threshold(ori_img, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
  707. # 连通域分析
  708. num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(ori_img,
  709. connectivity=8) # 参数8表示8连通。返回值:所有连通域的数目,图像上每一像素的标记,每一个标记的统计信息,连通域的中心点
  710. # 遍历连通域
  711. allCnt = []
  712. for i in range(1, num_labels):
  713. img = np.zeros_like(labels)
  714. index = np.where(labels == i)
  715. img[index] = 255
  716. img = np.array(img, dtype=np.uint8)
  717. regularization_contour = boundary_regularization(img).astype(np.int32)
  718. rows = regularization_contour.shape[0]
  719. regularization_contour = regularization_contour.reshape(rows, 1, 2)
  720. regularization_contour = regularization_contour.astype(int)
  721. allCnt.append(regularization_contour)
  722. buildingMask = np.zeros((h, w), dtype='uint8')
  723. cv2.fillPoly(buildingMask, allCnt, color=labelsDict['building'])
  724. buildingNew = buildingMask.copy()
  725. buildingMask[buildingMask == 0] = 255
  726. buildingMask[buildingMask == labelsDict['building']] = 0 # step2.png
  727. imgElse = cv2.bitwise_and(sourceSegImg, sourceSegImg, mask=buildingMask) # 在去掉建筑物区域的图像中,再去掉“优化后的建筑物”边界范围内的区域
  728. img = cv2.bitwise_or(buildingNew, imgElse)
  729. return img
  730. # 将原始图像处理为灰度图
  731. def sourceImg2gray(sourceImg, finalResult):
  732. imgGray = cv2.cvtColor(sourceImg, cv2.COLOR_BGR2GRAY)
  733. cv2.imwrite("/DATA/zyy/output/1-gray.tif", imgGray)
  734. maskNew = np.zeros((sourceImg.shape[0], sourceImg.shape[1]), dtype='uint8')
  735. contours, hierarch = cv2.findContours(imgGray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
  736. cntList = []
  737. for i in range(len(contours)):
  738. # 由于原始tif的边缘部分可能存在噪声像素点,这一步通过面积过滤掉噪声像素点的区域,只保留面积大于50000的contours
  739. if cv2.contourArea(contours[i]) >= 50000:
  740. cntList.append(contours[i])
  741. cv2.fillPoly(maskNew, cntList, color=255)
  742. cv2.imwrite("/DATA/zyy/output/1-binary.tif", maskNew)
  743. finalResult = cv2.bitwise_and(finalResult, finalResult, mask=maskNew)
  744. cv2.imwrite("/DATA/zyy/output/1-fill.tif", maskNew)
  745. return maskNew, finalResult
  746. # 生成单一类别的二值图像
  747. def generateBinaryImg(gdal_array, singleImgPath, outputImg):
  748. # 使用gdal库将图像载入numpy
  749. srcArr = gdal_array.LoadFile(singleImgPath)
  750. # 根据类别数将直方图分割成2个颜色区间,以便区分
  751. classes = gdal_array.numpy.histogram(srcArr, bins=2)[1]
  752. # 颜色查表的记录数必须为len(classes)+1
  753. lut = [[255, 0, 0], [0, 0, 0], [255, 255, 255]]
  754. # 分类的起始值
  755. start = 1
  756. # 建立输出图片
  757. rgb = gdal_array.numpy.zeros((3, srcArr.shape[0], srcArr.shape[1]), gdal_array.numpy.float32)
  758. # 处理所有类别并分配颜色
  759. for i in range(len(classes)):
  760. mask = gdal_array.numpy.logical_and(start <= srcArr, srcArr <= classes[i])
  761. for j in range(len(lut[i])):
  762. rgb[j] = gdal_array.numpy.choose(mask, (rgb[j], lut[i][j])) # 根据掩膜图层对图像进行裁剪
  763. start = classes[i] + 1
  764. # 保存图片
  765. output = gdal_array.SaveArray(rgb.astype(gdal_array.numpy.uint8), outputImg, format="GTIFF",
  766. prototype=singleImgPath)
  767. output = None
  768. # 生成shp文件的实现细节
  769. def generateShp(binaryImgPath, shp, shpLayer):
  770. # 打开输入的栅格文件
  771. # srcDS = gdal.Open(binaryImgPath) # 原始
  772. srcDS = gdal.Open(binaryImgPath, gdal.GA_ReadOnly)
  773. # 获取第一个波段
  774. band = srcDS.GetRasterBand(1)
  775. # 让gdal库使用该波段作为遮罩层
  776. mask = band
  777. # 创建输出的shapefile文件
  778. driver = ogr.GetDriverByName("ESRI Shapefile")
  779. shp = driver.CreateDataSource(shp)
  780. # 拷贝空间索引
  781. srs = osr.SpatialReference()
  782. srs.ImportFromWkt(srcDS.GetProjectionRef())
  783. layer = shp.CreateLayer(shpLayer, srs=srs)
  784. # 创建dbf文件
  785. fd = ogr.FieldDefn("DN", ogr.OFTInteger)
  786. layer.CreateField(fd)
  787. dst_field = 0
  788. # 从图片中自动提取特征
  789. extract = gdal.Polygonize(band, mask, layer, dst_field, [], None)
  790. extract = None
  791. shp = None # 一定要记得关闭,否则,shp打开后显示空白
  792. # 生成除other类别外的各类别对应的shp文件,并将其保存在shpfile目录下
  793. def generateShpFile(labelsDict, finalResult, tifFile, self, imgGray):
  794. # path2 = "./output/{}/addCOOR".format(tifFile[:-4])
  795. path2 = "/DATA/zyy/output/{}/addCOOR".format(tifFile[:-4])
  796. if not os.path.exists(path2):
  797. os.makedirs(path2)
  798. path3 = "/DATA/zyy/output/{}/binaryImage".format(tifFile[:-4])
  799. if not os.path.exists(path3):
  800. os.makedirs(path3)
  801. for k, v in labelsDict.items():
  802. singleClassImg = finalResult.copy()
  803. singleClassImg[singleClassImg != v] = 0
  804. path1 = "/DATA/zyy/output/{}/labelClass".format(tifFile[:-4])
  805. if not os.path.exists(path1):
  806. os.makedirs(path1)
  807. singleClassCnt, hierarch = cv2.findContours(singleClassImg, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
  808. # 由于存在一些边界处的噪声点,通过面积过滤的方式将这些噪声点过滤掉
  809. for i in range(len(singleClassCnt)):
  810. cnt1 = singleClassCnt[i]
  811. cntArea = cv2.contourArea(cnt1)
  812. if cntArea < 2000:
  813. cv2.drawContours(singleClassImg, [cnt1], 0, 0, -1) # 该轮廓区域填0
  814. # else:
  815. # otherList.append(cnt1)
  816. cv2.imwrite(path1 + os.sep + str(k) + '.tif', singleClassImg)
  817. # 将非other类别的mask添加到原始图像对应的mask上
  818. addImg = singleClassImg.copy()
  819. addImg[addImg == 0] = 255
  820. addImg[addImg != 255] = 0
  821. imgGray = cv2.bitwise_and(imgGray, imgGray, mask=addImg)
  822. # path4 = "./output/{}/shpfile".format(tifFile[:-4]) + os.sep + str(k) + '_shpfile'
  823. path4 = "/DATA/zyy/output/{}/shpfile".format(tifFile[:-4]) + os.sep + str(k) + '_shpfile'
  824. if not os.path.exists(path4):
  825. os.makedirs(path4)
  826. # 向灰度图像中写入原始tif图像的空间信息
  827. proj, geotrans, data = read_img(self.testImagePath + tifFile) # 读取原始图像数据,返回地图投影信息、仿射矩阵,栅格矩阵
  828. proj_single, geotrans_single, data_single = read_img(
  829. path1 + os.sep + str(k) + '.tif') # 读取只包含一个类别的灰度图像的地图投影信息和仿射矩阵
  830. write_img(path2 + os.sep + str(k) + '.tif', proj, geotrans, data_single) # 向只包含一个类别的灰度图像中写入原始图像的地图投影信息和仿射矩阵
  831. # 1、步骤1, 先使用分类的方法将图像分为两类
  832. # 分类后的原始图像
  833. singleImgPath = path2 + os.sep + str(k) + '.tif'
  834. # 输出文件名称
  835. outputImg = path3 + os.sep + str(k) + '_binary.tif'
  836. generateBinaryImg(gdal_array, singleImgPath, outputImg)
  837. # 2、步骤2
  838. # 阈值化后的输出栅格文件名称
  839. binaryImgPath = path3 + os.sep + str(k) + '_binary.tif'
  840. # 输出的shapefile文件名称
  841. shp = path4 + os.sep + str(k) + '.shp'
  842. # 图层名称
  843. shpLayer = str(k)
  844. generateShp(binaryImgPath, shp, shpLayer)
  845. return imgGray
  846. # 生成other类别的shp文件
  847. def generateOtherShpFile(tifFile, self):
  848. path5 = "/DATA/zyy/output/{}/shpfile/other_shpfile".format(tifFile[:-4])
  849. if not os.path.exists(path5):
  850. os.makedirs(path5)
  851. # 向灰度图像中写入原始tif图像的空间信息
  852. proj, geotrans, data = read_img(self.testImagePath + tifFile) # 读取原始图像数据,返回地图投影信息、仿射矩阵,栅格矩阵
  853. proj_single, geotrans_single, data_single = read_img(
  854. "/DATA/zyy/output/{}/labelClass/other.tif".format(tifFile[:-4])) # 读取只包含一个类别的灰度图像的地图投影信息和仿射矩阵
  855. write_img("/DATA/zyy/output/{}/addCOOR".format(tifFile[:-4]) + os.sep + 'other.tif', proj, geotrans,
  856. data_single) # 向只包含一个类别的灰度图像中写入原始图像的地图投影信息和仿射矩阵
  857. # 1、步骤1, 先使用分类的方法将图像分为两类
  858. # 分类后的原始图像
  859. singleImgPath = "/DATA/zyy/output/{}/addCOOR".format(tifFile[:-4]) + os.sep + 'other.tif'
  860. # 输出文件名称
  861. outputImg = "/DATA/zyy/output/{}/binaryImage".format(tifFile[:-4]) + os.sep + 'other_binary.tif'
  862. generateBinaryImg(gdal_array, singleImgPath, outputImg)
  863. # 2、步骤二
  864. # 阈值化后的输出栅格文件名称
  865. binaryImgPath = "/DATA/zyy/output/{}/binaryImage".format(tifFile[:-4]) + os.sep + 'other_binary.tif'
  866. # 输出的shapefile文件名称
  867. shp = "/DATA/zyy/output/{}/shpfile/other_shpfile".format(tifFile[:-4]) + os.sep + 'other.shp'
  868. # 图层名称
  869. shpLayer = "other"
  870. generateShp(binaryImgPath, shp, shpLayer)
  871. class MscEvalV0(object):
  872. def __init__(self, scaleH=1 / 3, scaleW=1 / 3, ignore_label=255, testImagePath=''):
  873. self.ignore_label = ignore_label
  874. self.scaleH = scaleH
  875. self.scaleW = scaleW
  876. self.testImagePath = testImagePath
  877. self.to_tensor = transforms.Compose([
  878. transforms.ToTensor(),
  879. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  880. ])
  881. def __call__(self, net, dl, n_classes):
  882. label_info = get_label_info('./class_dict.csv')
  883. tifList = os.listdir(self.testImagePath)
  884. # 预测出的分割图像中各个类别像素所对应的灰度值,将其填写在下面的字典中
  885. terrainClass = {'overlapX': 500, 'overlapY': 500, 'h': 1024, 'w': 1024}
  886. labelsDict = {'building': 206, 'road': 245, 'water': 193, 'farmland': 235, 'grass': 159, 'woodland': 130,
  887. 'bareSoil': 211}
  888. size = [640, 360] # 用360x640尺寸测试图片
  889. path0 = "/DATA/zyy/output"
  890. if not os.path.exists(path0):
  891. os.makedirs(path0)
  892. for tifFile in tifList:
  893. sourceImg = cv2.imread(self.testImagePath + tifFile)
  894. h, w = sourceImg.shape[0], sourceImg.shape[1]
  895. print("line991", sourceImg.shape)
  896. paddingImg, h_num, w_num = img_sup(sourceImg, terrainClass) # 返回填充后的图像及h和w方向上可被切分的个数
  897. save_path = './demo/'
  898. if not os.path.exists(save_path):
  899. os.makedirs(save_path)
  900. # 将切分图像的预测结果进行拼接时,按照以下方式进行,以确保预测结果能无缝拼接,不论切分图像间的重叠像素个数为奇数还是偶数,都适用。
  901. y1, y2, y3, y4, x1, x2, x3, x4 = concatenateCOOR(terrainClass)
  902. # 将预测结果拼接起来,返回拼接后的图像
  903. finalResult = concatenateImage(h_num, w_num, self, size, paddingImg, net, label_info, terrainClass, h, w,
  904. save_path, tifFile, x1,
  905. x2, x3, x4, y1, y2, y3, y4)
  906. cv2.imwrite(save_path + tifFile[:-4] + "_1024x1024-500-04-18.tif", finalResult)
  907. # 基于原始的分割结果,优化建筑物的边界,返回优化后的结果
  908. finalResult = optimizeBuilding(finalResult, labelsDict)
  909. cv2.imwrite("./demo/finalResult.tif", finalResult)
  910. # 将原始图像处理为灰度图
  911. imgGray, finalResult = sourceImg2gray(sourceImg, finalResult)
  912. # 生成除other类别外的各类别对应的shp文件,并将其保存在shpfile目录下
  913. imgGray = generateShpFile(labelsDict, finalResult, tifFile, self, imgGray)
  914. cv2.imwrite("/DATA/zyy/output/{}/labelClass/other.tif".format(tifFile[:-4]), imgGray)
  915. # 生成other类别的shp文件
  916. generateOtherShpFile(tifFile, self)
  917. # 删除过程中生成的文件
  918. shutil.rmtree("/DATA/zyy/output/{}/addCOOR".format(tifFile[:-4]))
  919. shutil.rmtree("/DATA/zyy/output/{}/binaryImage".format(tifFile[:-4]))
  920. shutil.rmtree("/DATA/zyy/output/{}/labelClass".format(tifFile[:-4]))
  921. def colour_code_segmentation(image, label_values): # x表示只包含speedRoad,y表示只包含vehicle
  922. label_values = [label_values[key] for key in label_values] # [[0,0,0],[128,0,0],[0,128,0]] list类型
  923. colour_codes = np.array(label_values) # [[0 0 0],[128 0 0],[0 128 0]] ndarray类型
  924. image = colour_codes[image]
  925. return image
  926. def get_label_info(csv_path):
  927. ann = pd.read_csv(csv_path)
  928. label = {}
  929. for iter, row in ann.iterrows():
  930. label_name = row['name']
  931. r = row['r']
  932. g = row['g']
  933. b = row['b']
  934. label[label_name] = [int(r), int(g), int(b)]
  935. return label
  936. def evaluatev0(respth='', dspth='', backbone='', testImagePath='', scaleH=1 / 3, scaleW=1 / 3, use_boundary_2=False,
  937. use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False):
  938. # dataset
  939. batchsize = 1
  940. n_workers = 0
  941. dsval = Heliushuju(dspth, mode='test')
  942. dl = DataLoader(dsval,
  943. batch_size=batchsize,
  944. shuffle=False,
  945. num_workers=n_workers,
  946. drop_last=False)
  947. n_classes = 8
  948. # print("backbone:", backbone)
  949. net = BiSeNet(backbone=backbone, n_classes=n_classes,
  950. use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4,
  951. use_boundary_8=use_boundary_8, use_boundary_16=use_boundary_16,
  952. use_conv_last=use_conv_last)
  953. net.load_state_dict(torch.load(respth))
  954. net.cuda()
  955. net.eval()
  956. with torch.no_grad():
  957. single_scale = MscEvalV0(scaleH=scaleH, scaleW=scaleW, testImagePath='./data/test/images/')
  958. single_scale(net, dl, 8)
  959. if __name__ == "__main__":
  960. parser = argparse.ArgumentParser()
  961. parser.add_argument('--weights', nargs='+', type=str, default='./model_save/pths/best.pt', help='model.pt path(s)')
  962. parser.add_argument('--source', type=str, default='./data/test/images', help='source') # file/folder, 0 for webcam
  963. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  964. parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
  965. parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
  966. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  967. parser.add_argument('--view-img', action='store_true', help='display results')
  968. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
  969. parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
  970. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
  971. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  972. parser.add_argument('--augment', action='store_true', help='augmented inference')
  973. parser.add_argument('--update', action='store_true', help='update all models')
  974. opt = parser.parse_args()
  975. t1 = time.time()
  976. evaluatev0(respth='./model_save/pths/model_final.pth',
  977. dspth='../trafficDetectionTestData/trafficAccidentTest/masks', backbone='STDCNet813', scaleH=1 / 3,
  978. testImagePath='./data/test/images/',
  979. scaleW=1 / 3, use_boundary_2=False, use_boundary_4=False, use_boundary_8=False,
  980. use_boundary_16=False, use_conv_last=False)
  981. t2 = time.time()
  982. print("line532", t2 - t1)