|
- # "./data/test"目录下不需要有labels_2文件夹
-
- import os
- os.environ['CUDA_VISIBLE_DEVICES'] = '1'
- from models.model_stages import BiSeNet
- from predict_city.heliushuju import Heliushuju
- import cv2
- import torch
- from torch.utils.data import DataLoader
- import torch.nn.functional as F
- import os.path as osp
- import numpy as np
- from tqdm import tqdm
- import pandas as pd
- import matplotlib.pyplot as plt
-
-
- class MscEvalV0(object):
- def __init__(self, scale=0.75, ignore_label=255):
- self.ignore_label = ignore_label
- self.scale = scale
-
- def __call__(self, net, dl, n_classes):
- # evaluate
- label_info = get_label_info('./class_dict.csv')
- hist = torch.zeros(n_classes, n_classes).cuda().detach()
- diter = enumerate(tqdm(dl))
-
- # for i, (imgs, label, img_tt) in diter: # 测试时,"./data/test"目录下需要有labels_2文件夹(labels_2文件夹里存放标签文件,标签的个数和文件名与测试图像对应)时,需要把这一行加上
- for i, (imgs, img_tt) in diter:
- loop_start = cv2.getTickCount()
-
- # N, _, H, W = label.shape
- # label = label.squeeze(1).cuda()
- # size = label.size()[-2:]
- # size = [360, 640]
- size = [810, 1440]
-
- imgs = imgs.cuda()
- N, C, H, W = imgs.size()
- new_hw = [int(H * self.scale), int(W * self.scale)]
- print(new_hw)
- print("line43", imgs.size())
- imgs = F.interpolate(imgs, new_hw, mode='bilinear', align_corners=True)
- logits = net(imgs)[0]
-
- loop_time = cv2.getTickCount() - loop_start
- tool_time = loop_time / (cv2.getTickFrequency())
- running_fps = int(1 / tool_time)
- print('running_fps:', running_fps)
-
- logits = F.interpolate(logits, size=size, mode='bilinear', align_corners=True)
- probs = torch.softmax(logits, dim=1)
- preds = torch.argmax(probs, dim=1)
- preds_squeeze = preds.squeeze(0)
- preds_squeeze_predict = colour_code_segmentation(np.array(preds_squeeze.cpu()), label_info)
- print(preds_squeeze_predict.shape)
- # preds_squeeze_predict = cv2.resize(np.uint(preds_squeeze_predict), (W, H))
- save_path = './demo/' + img_tt[0] + '.png'
- cv2.imwrite(save_path, cv2.cvtColor(np.uint8(preds_squeeze_predict), cv2.COLOR_RGB2BGR))
-
-
- def colour_code_segmentation(image, label_values):
- label_values = [label_values[key] for key in label_values]
- colour_codes = np.array(label_values)
- x = colour_codes[image.astype(int)]
- return x
-
-
- def get_label_info(csv_path):
- ann = pd.read_csv(csv_path)
- label = {}
- for iter, row in ann.iterrows():
- label_name = row['name']
- r = row['r']
- g = row['g']
- b = row['b']
- label[label_name] = [int(r), int(g), int(b)]
- return label
-
-
- def evaluatev0(respth='', dspth='', backbone='', scale=0.75, use_boundary_2=False,
- use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False):
- print('scale', scale)
- ## dataset
- batchsize = 1
- n_workers = 0
-
- dsval = Heliushuju(dspth, mode='test')
-
- dl = DataLoader(dsval,
- batch_size=batchsize,
- shuffle=False,
- num_workers=n_workers,
- drop_last=False)
-
- n_classes = 3
- print("backbone:", backbone)
- net = BiSeNet(backbone=backbone, n_classes=n_classes,
- use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4,
- use_boundary_8=use_boundary_8, use_boundary_16=use_boundary_16,
- use_conv_last=use_conv_last)
- net.load_state_dict(torch.load(respth))
- net.cuda()
- net.eval()
- with torch.no_grad():
- single_scale = MscEvalV0(scale=scale)
- single_scale(net, dl, 2)
-
-
- if __name__ == "__main__":
- evaluatev0('./model_save/pths/model_final.pth',
- dspth='./data/', backbone='STDCNet813', scale=0.75,
- use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)
|