#!/usr/bin/python # -*- encoding: utf-8 -*- import os os.environ['CUDA_VISIBLE_DEVICES'] = '1' from logger import setup_logger from models.model_stages import BiSeNet from predict_city.heliushuju import Heliushuju import cv2 import sys import torch import torch.nn as nn from torch.utils.data import DataLoader import torch.nn.functional as F import torch.distributed as dist import os.path as osp import logging import time import numpy as np from tqdm import tqdm import math import pandas as pd import matplotlib.pyplot as plt # from cv2 import getTickCount, getTickFrequency class MscEvalV0(object): def __init__(self, scale=0.5, ignore_label=255): self.ignore_label = ignore_label self.scale = scale def __call__(self, net, dl, n_classes): # evaluate label_info = get_label_info('./class_dict.csv') hist = torch.zeros(n_classes, n_classes).cuda().detach() diter = enumerate(tqdm(dl)) for i, (imgs, label, img_tt) in diter: loop_start = cv2.getTickCount() N, _, H, W = label.shape label = label.squeeze(1).cuda() size = label.size()[-2:] imgs = imgs.cuda() N, C, H, W = imgs.size() new_hw = [int(H*self.scale), int(W*self.scale)] print(new_hw) imgs = F.interpolate(imgs, new_hw, mode='bilinear', align_corners=True) logits = net(imgs)[0] loop_time = cv2.getTickCount() - loop_start tool_time = loop_time/(cv2.getTickFrequency()) running_fps = int(1/tool_time) print('running_fps:', running_fps) logits = F.interpolate(logits, size=size, mode='bilinear', align_corners=True) probs = torch.softmax(logits, dim=1) preds = torch.argmax(probs, dim=1) preds_squeeze = preds.squeeze(0) preds_squeeze_predict = colour_code_segmentation(np.array(preds_squeeze.cpu()), label_info) print(preds_squeeze_predict.shape) preds_squeeze_predict = cv2.resize(np.uint(preds_squeeze_predict), (W, H)) save_path = './demo/' + img_tt[0] + '.png' cv2.imwrite(save_path, cv2.cvtColor(np.uint8(preds_squeeze_predict), cv2.COLOR_RGB2BGR)) # preds_squeeze_predict = preds_squeeze.cpu().numpy().copy() # plt.imshow(preds_squeeze_predict) ;plt.show() # preds_3chs = np.zeros( (*preds_squeeze_predict.shape,3 )) # preds_3chs[...,0]=preds_squeeze_predict.copy() # preds_3chs[...,1]=preds_squeeze_predict.copy() # preds_3chs[...,2]=preds_squeeze_predict.copy() # preds_3chs = (preds_3chs*255).astype(np.uint8) # # # print('####line66',preds_squeeze_predict.shape) # preds_squeeze_predict = cv2.resize(np.uint(preds_squeeze_predict), (W,H)) # save_path = './demo/' + img_tt[0] + '.png' # cv2.imwrite(save_path, cv2.cvtColor(np.uint8(preds_squeeze_predict), cv2.COLOR_RGB2BGR)) # print('#####DEBUG#####') # sys.exit(0) def colour_code_segmentation(image, label_values): label_values = [label_values[key] for key in label_values] colour_codes = np.array(label_values) x = colour_codes[image.astype(int)] return x def get_label_info(csv_path): ann = pd.read_csv(csv_path) label = {} for iter, row in ann.iterrows(): label_name = row['name'] r = row['r'] g = row['g'] b = row['b'] label[label_name] = [int(r), int(g), int(b)] return label def evaluatev0(respth='', dspth='', backbone='', scale=0.75, use_boundary_2=False, use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False): print('scale', scale) ## dataset batchsize = 1 n_workers = 0 # dsval = Heliushuju(dspth, mode='val') # 原始 dsval = Heliushuju(dspth, mode='test') # 改动 dl = DataLoader(dsval, batch_size = batchsize, shuffle = False, num_workers = n_workers, drop_last = False) n_classes = 3 print("backbone:", backbone) net = BiSeNet(backbone=backbone, n_classes=n_classes, use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4, use_boundary_8=use_boundary_8, use_boundary_16=use_boundary_16, use_conv_last=use_conv_last) net.load_state_dict(torch.load(respth)) net.cuda() net.eval() with torch.no_grad(): single_scale = MscEvalV0(scale=scale) single_scale(net, dl, 2) if __name__ == "__main__": # STDC2-Seg75 mIoU 0.7704 # 原始 # evaluatev0('/host/data/segmentation/Buildings_checkpoints/checkpoints3/wurenji_train_STDC1-Seg/pths/model_maxmIOU75.pth', # dspth='/host/data/segmentation/Buildings2/images_12/', backbone='STDCNet813', scale=0.75, # use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False) # 改动 evaluatev0('./checkpoints_1720/wurenji_train_STDC1-Seg/pths/model_final.pth', dspth='./data/segmentation/shuiyufenge_1720/', backbone='STDCNet813', scale=0.75, use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)