STDC-th/evaluation_process.py

325 lines
12 KiB
Python

#!/usr/bin/python
# -*- encoding: utf-8 -*-
from logger import setup_logger
from models.model_stages import BiSeNet
from cityscapes import CityScapes
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.distributed as dist
import os
import os.path as osp
import logging
import time
import numpy as np
from tqdm import tqdm
import math
from PIL import Image
from heliushuju_process import Heliushuju
import json
from utils.metrics import Evaluator
class MscEvalV0(object):
def __init__(self, scale=0.5,ignore_label=255):
self.ignore_label = ignore_label
self.scale = scale
def __call__(self, net, dl, n_classes):
# evaluate
hist = torch.zeros(n_classes, n_classes).cuda().detach()
self.evaluator = Evaluator(n_classes)#创建实例化对象
self.evaluator.reset()
if dist.is_initialized() and dist.get_rank() != 0:
diter = enumerate(dl)
else:
diter = enumerate(tqdm(dl))
for i, (imgs, label) in diter:
N, _, H, W = label.shape # 原始
label = label.squeeze(1).cuda() # 原始
size = label.size()[-2:]
imgs = imgs.cuda()
N, C, H, W = imgs.size()
new_hw = [int(H*self.scale), int(W*self.scale)]
imgs = F.interpolate(imgs, new_hw, mode='bilinear', align_corners=True)
logits = net(imgs)[0]
logits = F.interpolate(logits, size=size, mode='bilinear', align_corners=True)
probs = torch.softmax(logits, dim=1)
preds = torch.argmax(probs, dim=1)
keep = label != self.ignore_label
#print( torch.max( label[keep]), torch.min( label[keep]), torch.max( preds[keep]), torch.min( preds[keep]), )
hist += torch.bincount(label[keep] * n_classes + preds[keep], minlength=n_classes ** 2).view(n_classes, n_classes).float() # 原始
self.evaluator.add_batch(label.cpu().numpy(), preds.cpu().numpy())#更新混淆矩阵
Acc = self.evaluator.Pixel_Accuracy()
Acc_class = self.evaluator.Pixel_Accuracy_Class()
class_IoU,mIoU= self.evaluator.Mean_Intersection_over_Union()
FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
recall,precision,f1=self.evaluator.Recall_Precision()
print("val Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
for i,iou in enumerate(class_IoU):
print(' class:%d ,Iou:%.4f '%(i,iou),end='')
print()
if dist.is_initialized():
dist.all_reduce(hist, dist.ReduceOp.SUM)
ious = hist.diag() / (hist.sum(dim=0) + hist.sum(dim=1) - hist.diag())
miou = ious.mean()
return miou.item()
def evaluatev0(respth='./pretrained', dspth='./data', backbone='CatNetSmall', scale=0.75, use_boundary_2=False, use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False,n_classes=4,modelSize=(640,360),mode='test',outpath='outputs/test/',labelJson='data/heliushuju_info.json'):
print('scale', scale)
print('use_boundary_2', use_boundary_2)
print('use_boundary_4', use_boundary_4)
print('use_boundary_8', use_boundary_8)
print('use_boundary_16', use_boundary_16)
## dataset
batchsize = 5
n_workers = 2
#dsval = CityScapes(dspth, mode='val')
dsval = Heliushuju(dspth, mode=mode,cropsize=modelSize,labelJson=labelJson)
with open(labelJson,'r') as fr:
labels_info = json.load(fr)
lb_map = {el['id']: el['color'] for el in labels_info}
#print('---line89 lb_map:',lb_map, ' labels_info:',labels_info)
lb_colors = np.array( [lb_map[k] for k in lb_map.keys()])
dl = DataLoader(dsval,
batch_size = batchsize,
shuffle = False,
num_workers = n_workers,
drop_last = False)
print("backbone:", backbone)
net = BiSeNet(backbone=backbone, n_classes=n_classes,
use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4,
use_boundary_8=use_boundary_8, use_boundary_16=use_boundary_16,
use_conv_last=use_conv_last)
net.load_state_dict(torch.load(respth))
net.cuda()
net.eval()
if mode=='val':
with torch.no_grad():
single_scale = MscEvalV0(scale=scale,ignore_label=255)
mIOU = single_scale(net, dl, n_classes)
logger = logging.getLogger()
logger.info('mIOU is: %s\n', mIOU)
else:
diter = enumerate(tqdm(dl))
with torch.no_grad():
for i, (imgs, filenames) in diter:
N, _, H, W = imgs.shape # 原始
imgs = imgs.cuda()
N, C, H, W = imgs.size()
new_hw = [int(H*scale), int(W*scale)]
imgs = F.interpolate(imgs, new_hw, mode='bilinear', align_corners=True)
logits = net(imgs)[0]
logits = F.interpolate(logits, size=(H,W), mode='bilinear', align_corners=True)
probs = torch.softmax(logits, dim=1)
preds = torch.argmax(probs, dim=1).cpu().numpy()
print(preds.shape,logits.shape)
for jj, ff in enumerate(filenames):
pred = preds[jj]
pred_color = lb_colors[ pred]
#print(jj,pred.shape,pred_color.shape ,type(pred_color ),lb_colors )
t1=Image.fromarray(np.uint8(pred_color))
t1.save(os.path.join(outpath,ff+'.png') )
#cv2.imwrite( os.path.join(outpath,ff+'.png'), imwrite.astype(np.uint8) )
class MscEval(object):
def __init__(self,
model,
dataloader,
scales = [0.5, 0.75, 1, 1.25, 1.5, 1.75],
n_classes = 19,
lb_ignore = 255,
cropsize = 1024,
flip = True,
*args, **kwargs):
self.scales = scales
self.n_classes = n_classes
self.lb_ignore = lb_ignore
self.flip = flip
self.cropsize = cropsize
## dataloader
self.dl = dataloader
self.net = model
def pad_tensor(self, inten, size):
N, C, H, W = inten.size()
outten = torch.zeros(N, C, size[0], size[1]).cuda()
outten.requires_grad = False
margin_h, margin_w = size[0]-H, size[1]-W
hst, hed = margin_h//2, margin_h//2+H
wst, wed = margin_w//2, margin_w//2+W
outten[:, :, hst:hed, wst:wed] = inten
return outten, [hst, hed, wst, wed]
def eval_chip(self, crop):
with torch.no_grad():
out = self.net(crop)[0]
prob = F.softmax(out, 1)
if self.flip:
crop = torch.flip(crop, dims=(3,))
out = self.net(crop)[0]
out = torch.flip(out, dims=(3,))
prob += F.softmax(out, 1)
prob = torch.exp(prob)
return prob
def crop_eval(self, im):
cropsize = self.cropsize
stride_rate = 5/6.
N, C, H, W = im.size()
long_size, short_size = (H,W) if H>W else (W,H)
if long_size < cropsize:
im, indices = self.pad_tensor(im, (cropsize, cropsize))
prob = self.eval_chip(im)
prob = prob[:, :, indices[0]:indices[1], indices[2]:indices[3]]
else:
stride = math.ceil(cropsize*stride_rate)
if short_size < cropsize:
if H < W:
im, indices = self.pad_tensor(im, (cropsize, W))
else:
im, indices = self.pad_tensor(im, (H, cropsize))
N, C, H, W = im.size()
n_x = math.ceil((W-cropsize)/stride)+1
n_y = math.ceil((H-cropsize)/stride)+1
prob = torch.zeros(N, self.n_classes, H, W).cuda()
prob.requires_grad = False
for iy in range(n_y):
for ix in range(n_x):
hed, wed = min(H, stride*iy+cropsize), min(W, stride*ix+cropsize)
hst, wst = hed-cropsize, wed-cropsize
chip = im[:, :, hst:hed, wst:wed]
prob_chip = self.eval_chip(chip)
prob[:, :, hst:hed, wst:wed] += prob_chip
if short_size < cropsize:
prob = prob[:, :, indices[0]:indices[1], indices[2]:indices[3]]
return prob
def scale_crop_eval(self, im, scale):
N, C, H, W = im.size()
new_hw = [int(H*scale), int(W*scale)]
im = F.interpolate(im, new_hw, mode='bilinear', align_corners=True)
prob = self.crop_eval(im)
prob = F.interpolate(prob, (H, W), mode='bilinear', align_corners=True)
return prob
def compute_hist(self, pred, lb):
n_classes = self.n_classes
ignore_idx = self.lb_ignore
keep = np.logical_not(lb==ignore_idx)
merge = pred[keep] * n_classes + lb[keep]
hist = np.bincount(merge, minlength=n_classes**2)
hist = hist.reshape((n_classes, n_classes))
return hist
def evaluate(self):
## evaluate
n_classes = self.n_classes
hist = np.zeros((n_classes, n_classes), dtype=np.float32)
dloader = tqdm(self.dl)
if dist.is_initialized() and not dist.get_rank()==0:
dloader = self.dl
for i, (imgs, label) in enumerate(dloader):
N, _, H, W = label.shape
probs = torch.zeros((N, self.n_classes, H, W))
probs.requires_grad = False
imgs = imgs.cuda()
for sc in self.scales:
# prob = self.scale_crop_eval(imgs, sc)
prob = self.eval_chip(imgs)
probs += prob.detach().cpu()
probs = probs.data.numpy()
preds = np.argmax(probs, axis=1)
hist_once = self.compute_hist(preds, label.data.numpy().squeeze(1))
hist = hist + hist_once
IOUs = np.diag(hist) / (np.sum(hist, axis=0)+np.sum(hist, axis=1)-np.diag(hist))
mIOU = np.mean(IOUs)
return mIOU
def evaluate(respth='./resv1_catnet/pths/', dspth='./data'):
## logger
logger = logging.getLogger()
## model
logger.info('\n')
logger.info('===='*20)
logger.info('evaluating the model ...\n')
logger.info('setup and restore model')
n_classes = 19
net = BiSeNet(n_classes=n_classes)
net.load_state_dict(torch.load(respth))
net.cuda()
net.eval()
## dataset
batchsize = 5
n_workers = 2
dsval = CityScapes(dspth, mode='val')
dl = DataLoader(dsval,
batch_size = batchsize,
shuffle = False,
num_workers = n_workers,
drop_last = False)
## evaluator
logger.info('compute the mIOU')
evaluator = MscEval(net, dl, scales=[1], flip = False)
## eval
mIOU = evaluator.evaluate()
logger.info('mIOU is: {:.6f}'.format(mIOU))
if __name__ == "__main__":
log_dir = 'evaluation_logs/'
if not os.path.exists(log_dir):
os.makedirs(log_dir)
setup_logger(log_dir)
#modelpath='./checkpooints/0430/pths/model_final.pth';n_classes=4:labelJson='data/heliushuju_info.json'i;dspth='../../data/carRoadLane/';mode='val'
modelpath='./checkpooints/0430pm/pths/model_final.pth';labelJson='data/RoadLane_info.json';n_classes=3;dspth='../../data/RoadLane/';mode='val'
evaluatev0(modelpath,
dspth=dspth, backbone='STDCNet813', scale=1.0,
use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False,n_classes=n_classes,modelSize=(1920,1080),mode=mode,outpath='outputs/test2/',labelJson=labelJson)