|
- #!/usr/bin/python
- # -*- encoding: utf-8 -*-
- import os
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
- from logger import setup_logger
- from models.model_stages import BiSeNet
- from predict_city.heliushuju2 import Heliushuju
- import cv2
- import torch
- import torch.nn as nn
- from torch.utils.data import DataLoader
- import torch.nn.functional as F
- import torch.distributed as dist
-
-
- import os.path as osp
- import logging
- import time
- import numpy as np
- from tqdm import tqdm
- import math
- import pandas as pd
- # from cv2 import getTickCount, getTickFrequency # 原始
- class MscEvalV0(object):
- def __init__(self, scale=0.5, ignore_label=255):
- self.ignore_label = ignore_label
- self.scale = scale
- def __call__(self, net, dl, n_classes):
- ## evaluate
- label_info = get_label_info('./class_dict.csv')
- hist = torch.zeros(n_classes, n_classes).cuda().detach()
- diter = enumerate(tqdm(dl))
- for i, (imgs, _, img_tt) in diter:
-
- loop_start = cv2.getTickCount()
-
- #N, _, H, W = label.shape
- #label = label.squeeze(1).cuda()
- #size = label.size()[-2:]
- imgs = imgs.cuda()
- N, C, H, W = imgs.size()
- new_hw = [int(H*self.scale), int(W*self.scale)]
- print(new_hw)
-
- imgs = F.interpolate(imgs, new_hw, mode='bilinear', align_corners=True)
-
- logits = net(imgs)[0]
-
- logits = F.interpolate(logits, size=imgs.size()[-2:],
- mode='bilinear', align_corners=True)
- probs = torch.softmax(logits, dim=1)
- preds = torch.argmax(probs, dim=1)
- preds_squeeze = preds.squeeze(0)
- preds_squeeze_predict = colour_code_segmentation(np.array(preds_squeeze.cpu()), label_info)
- print(preds_squeeze_predict.shape)
- preds_squeeze_predict = cv2.resize(np.uint8(preds_squeeze_predict), (W,H))
- save_path = './demo/' + img_tt[0] + '.png'
- # save_path = '/host/data/ChangeDetection/niushoushanhe_demo/19_/' + img_tt[0] + '.png'
- cv2.imwrite(save_path, cv2.cvtColor(np.uint8(preds_squeeze_predict), cv2.COLOR_RGB2BGR))
- loop_time = cv2.getTickCount() - loop_start
- tool_time = loop_time/(cv2.getTickFrequency())
- running_fps = int(1/tool_time)
- print('running_fps:', running_fps)
-
- def colour_code_segmentation(image, label_values):
- label_values = [label_values[key] for key in label_values]
- colour_codes = np.array(label_values)
- x = colour_codes[image.astype(int)]
- return x
- def get_label_info(csv_path):
- ann = pd.read_csv(csv_path)
- label = {}
- for iter, row in ann.iterrows():
- label_name = row['name']
- r = row['r']
- g = row['g']
- b = row['b']
- label[label_name] = [int(r), int(g), int(b)]
- return label
- def evaluatev0(respth='', dspth='', backbone='', scale=0.75, use_boundary_2=False,
- use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False):
- print('scale', scale)
- ## dataset
- batchsize = 1
- n_workers = 0
-
- # dsval = Heliushuju(dspth, mode='val') # 原始
- dsval = Heliushuju(dspth, mode='test') # 改动
-
- dl = DataLoader(dsval,
- batch_size = batchsize,
- shuffle = False,
- num_workers = n_workers,
- drop_last = False)
-
- n_classes = 2####################################################################
- print("backbone:", backbone)
- net = BiSeNet(backbone=backbone, n_classes=n_classes,
- use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4,
- use_boundary_8=use_boundary_8, use_boundary_16=use_boundary_16,
- use_conv_last=use_conv_last)
- net.load_state_dict(torch.load(respth))
- net.cuda()
- net.eval()
- with torch.no_grad():
- single_scale = MscEvalV0(scale=scale)
- single_scale(net, dl, 2)
-
-
- if __name__ == "__main__":
- #STDC2-Seg75 mIoU 0.7704
-
- # 原始
- # evaluatev0('/host/data/segmentation/Buildings_checkpoints/checkpoints3/wurenji_train_STDC1-Seg/pths/model_maxmIOU75.pth',
- # dspth='/host/data/segmentation/Buildings3/images_12/', backbone='STDCNet813', scale=0.75,
- # use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)
-
- # 改动
- evaluatev0('./checkpoints_1720/wurenji_train_STDC1-Seg/pths/model_maxmIOU75.pth',
- dspth='./data/segmentation/shuiyufenge_1720/', backbone='STDCNet813', scale=0.75,
- use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)
-
-
|