地物分类项目代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

116 lines
4.1KB

  1. # "./data/test"目录下不需要有labels_2文件夹
  2. import os
  3. os.environ['CUDA_VISIBLE_DEVICES'] = '1'
  4. from models.model_stages import BiSeNet
  5. from predict_city.heliushuju import Heliushuju
  6. import cv2
  7. import torch
  8. from torch.utils.data import DataLoader
  9. import torch.nn.functional as F
  10. import os.path as osp
  11. import numpy as np
  12. from tqdm import tqdm
  13. import pandas as pd
  14. import matplotlib.pyplot as plt
  15. class MscEvalV0(object):
  16. def __init__(self, scale=0.75, ignore_label=255):
  17. self.ignore_label = ignore_label
  18. self.scale = scale
  19. def __call__(self, net, dl, n_classes):
  20. # evaluate
  21. label_info = get_label_info('./class_dict.csv')
  22. hist = torch.zeros(n_classes, n_classes).cuda().detach()
  23. diter = enumerate(tqdm(dl))
  24. # for i, (imgs, label, img_tt) in diter: # 测试时,"./data/test"目录下需要有labels_2文件夹(labels_2文件夹里存放标签文件,标签的个数和文件名与测试图像对应)时,需要把这一行加上
  25. for i, (imgs, img_tt) in diter:
  26. loop_start = cv2.getTickCount()
  27. # N, _, H, W = label.shape
  28. # label = label.squeeze(1).cuda()
  29. # size = label.size()[-2:]
  30. # size = [360, 640]
  31. size = [810, 1440]
  32. imgs = imgs.cuda()
  33. N, C, H, W = imgs.size()
  34. new_hw = [int(H * self.scale), int(W * self.scale)]
  35. print(new_hw)
  36. print("line43", imgs.size())
  37. imgs = F.interpolate(imgs, new_hw, mode='bilinear', align_corners=True)
  38. logits = net(imgs)[0]
  39. loop_time = cv2.getTickCount() - loop_start
  40. tool_time = loop_time / (cv2.getTickFrequency())
  41. running_fps = int(1 / tool_time)
  42. print('running_fps:', running_fps)
  43. logits = F.interpolate(logits, size=size, mode='bilinear', align_corners=True)
  44. probs = torch.softmax(logits, dim=1)
  45. preds = torch.argmax(probs, dim=1)
  46. preds_squeeze = preds.squeeze(0)
  47. preds_squeeze_predict = colour_code_segmentation(np.array(preds_squeeze.cpu()), label_info)
  48. print(preds_squeeze_predict.shape)
  49. # preds_squeeze_predict = cv2.resize(np.uint(preds_squeeze_predict), (W, H))
  50. save_path = './demo/' + img_tt[0] + '.png'
  51. cv2.imwrite(save_path, cv2.cvtColor(np.uint8(preds_squeeze_predict), cv2.COLOR_RGB2BGR))
  52. def colour_code_segmentation(image, label_values):
  53. label_values = [label_values[key] for key in label_values]
  54. colour_codes = np.array(label_values)
  55. x = colour_codes[image.astype(int)]
  56. return x
  57. def get_label_info(csv_path):
  58. ann = pd.read_csv(csv_path)
  59. label = {}
  60. for iter, row in ann.iterrows():
  61. label_name = row['name']
  62. r = row['r']
  63. g = row['g']
  64. b = row['b']
  65. label[label_name] = [int(r), int(g), int(b)]
  66. return label
  67. def evaluatev0(respth='', dspth='', backbone='', scale=0.75, use_boundary_2=False,
  68. use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False):
  69. print('scale', scale)
  70. ## dataset
  71. batchsize = 1
  72. n_workers = 0
  73. dsval = Heliushuju(dspth, mode='test')
  74. dl = DataLoader(dsval,
  75. batch_size=batchsize,
  76. shuffle=False,
  77. num_workers=n_workers,
  78. drop_last=False)
  79. n_classes = 3
  80. print("backbone:", backbone)
  81. net = BiSeNet(backbone=backbone, n_classes=n_classes,
  82. use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4,
  83. use_boundary_8=use_boundary_8, use_boundary_16=use_boundary_16,
  84. use_conv_last=use_conv_last)
  85. net.load_state_dict(torch.load(respth))
  86. net.cuda()
  87. net.eval()
  88. with torch.no_grad():
  89. single_scale = MscEvalV0(scale=scale)
  90. single_scale(net, dl, 2)
  91. if __name__ == "__main__":
  92. evaluatev0('./model_save/pths/model_final.pth',
  93. dspth='./data/', backbone='STDCNet813', scale=0.75,
  94. use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)