AIlib2/crowdUtils/engine.py

300 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Train and eval functions used in main.py
Mostly copy-paste from DETR (https://github.com/facebookresearch/detr).
"""
import math
import os
import sys
from typing import Iterable
import torch
#print( os.path.abspath( os.path.dirname(__file__) ) )
sys.path.append( os.path.abspath( os.path.dirname(__file__) ) )
import util.misc as utils
from util.misc import NestedTensor
import numpy as np
import time
import torchvision.transforms as standard_transforms
import cv2
import PIL
class DictToObject:
def __init__(self, dictionary):
for key, value in dictionary.items():
if isinstance(value, dict):
setattr(self, key, DictToObject(value))
else:
setattr(self, key, value)
def letterImage(img,minShape,maxShape):
iH,iW = img.shape[0:2]
minH,minW = minShape[2:]
maxH,maxW = maxShape[2:]
flag=False
if iH<minH or iW<minW:
fy = iH/minH; fx = iW/minW; ff = min(fx,fy)
newH,newW = int(iH/ff), int(iW/ff);flag=True
if iH>maxH or iW>maxW:
fy = iH/maxH; fx = iW/maxW; ff = max(fx,fy)
newH,newW = int(iH/ff), int(iW/ff);flag=True
if flag:
assert minH<=newH and newH<= maxH , 'iH%d,iW:%d , newH:%d newW:%d, fx:%.1f fy:%.1f'%(iH,iW,newH,newW,fx,fy)
assert minW<=newW and newW<= maxW, 'iH%d,iW:%d , newH:%d newW:%d, fx:%.1f fy:%.1f'%(iH,iW,newH,newW,fx,fy)
return cv2.resize(img,(newW,newH))
else:
return img
def postprocess(outputs,threshold=0.5):
outputs_scores = torch.nn.functional.softmax(outputs['pred_logits'], -1)[:, :, 1][0]
outputs_points = outputs['pred_points'][0]
points = outputs_points[outputs_scores > threshold].detach().cpu().numpy().tolist()
scores = outputs_scores[outputs_scores > threshold].detach().cpu().numpy().tolist()
return points,scores
def toOBBformat(points,scores,cls=0):
outs = []
for i in range(len(points)):
pt,score = points[i],scores[i]
pts4=[pt]*4
ret = [ pts4,score,cls]
outs.append(ret)
return outs
#[ [ [ (x0,y0),(x1,y1),(x2,y2),(x3,y3) ],score cls ], [ [ (x0,y0),(x1,y1),(x2,y2),(x3,y3) ],score cls ],........ ]
def preprocess(img,mean,std,minShape,maxShape):
#img--numpy,(H,W,C)
#输入-RGB格式,(C,H,W)
if isinstance(img,PIL.Image.Image):
img = np.array(img)
img = letterImage(img,minShape,maxShape)
height,width = img.shape[0:2]
new_width = width // 128 * 128
new_height = height // 128 * 128
img = cv2.resize( img, (new_width, new_height) )
img = img/255.
tmpImg = np.zeros((new_height,new_width,3))
tmpImg[:,:,0]=(img[:,:,0]-mean[0])/std[0]
tmpImg[:,:,1]=(img[:,:,1]-mean[1])/std[1]
tmpImg[:,:,2]=(img[:,:,2]-mean[2])/std[2]
tmpImg = tmpImg.transpose((2,0,1)).astype(np.float32)# HWC->CHW
#tmpImg = tmpImg[np.newaxis,:,:,:]#CHW->NCHW
return tmpImg
class DeNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
for t, m, s in zip(tensor, self.mean, self.std):
t.mul_(s).add_(m)
return tensor
# generate the reference points in grid layout
def generate_anchor_points(stride=16, row=3, line=3):
row_step = stride / row
line_step = stride / line
shift_x = (np.arange(1, line + 1) - 0.5) * line_step - stride / 2
shift_y = (np.arange(1, row + 1) - 0.5) * row_step - stride / 2
shift_x, shift_y = np.meshgrid(shift_x, shift_y)
anchor_points = np.vstack((
shift_x.ravel(), shift_y.ravel()
)).transpose()
return anchor_points
def shift(shape, stride, anchor_points):
shift_x = (np.arange(0, shape[1]) + 0.5) * stride
shift_y = (np.arange(0, shape[0]) + 0.5) * stride
shift_x, shift_y = np.meshgrid(shift_x, shift_y)
shifts = np.vstack((
shift_x.ravel(), shift_y.ravel()
)).transpose()
A = anchor_points.shape[0]
K = shifts.shape[0]
all_anchor_points = (anchor_points.reshape((1, A, 2)) + shifts.reshape((1, K, 2)).transpose((1, 0, 2)))
all_anchor_points = all_anchor_points.reshape((K * A, 2))
return all_anchor_points
class AnchorPointsf(object):
def __init__(self, pyramid_levels=[3,], strides=None, row=3, line=3,device='cpu'):
if pyramid_levels is None:
self.pyramid_levels = [3, 4, 5, 6, 7]
else:
self.pyramid_levels = pyramid_levels
if strides is None:
self.strides = [2 ** x for x in self.pyramid_levels]
self.row = row
self.line = line
self.device = device
def eval(self, image):
image_shape = image.shape[2:]
image_shape = np.array(image_shape)
image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in self.pyramid_levels]
all_anchor_points = np.zeros((0, 2)).astype(np.float32)
# get reference points for each level
for idx, p in enumerate(self.pyramid_levels):
anchor_points = generate_anchor_points(2**p, row=self.row, line=self.line)
shifted_anchor_points = shift(image_shapes[idx], self.strides[idx], anchor_points)
all_anchor_points = np.append(all_anchor_points, shifted_anchor_points, axis=0)
all_anchor_points = np.expand_dims(all_anchor_points, axis=0)
# send reference points to device
if torch.cuda.is_available() and self.device!='cpu':
return torch.from_numpy(all_anchor_points.astype(np.float32)).cuda()
else:
return torch.from_numpy(all_anchor_points.astype(np.float32))
def vis(samples, targets, pred, vis_dir, des=None):
'''
samples -> tensor: [batch, 3, H, W]
targets -> list of dict: [{'points':[], 'image_id': str}]
pred -> list: [num_preds, 2]
'''
gts = [t['point'].tolist() for t in targets]
pil_to_tensor = standard_transforms.ToTensor()
restore_transform = standard_transforms.Compose([
DeNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
standard_transforms.ToPILImage()
])
# draw one by one
for idx in range(samples.shape[0]):
sample = restore_transform(samples[idx])
sample = pil_to_tensor(sample.convert('RGB')).numpy() * 255
sample_gt = sample.transpose([1, 2, 0])[:, :, ::-1].astype(np.uint8).copy()
sample_pred = sample.transpose([1, 2, 0])[:, :, ::-1].astype(np.uint8).copy()
max_len = np.max(sample_gt.shape)
size = 2
# draw gt
for t in gts[idx]:
sample_gt = cv2.circle(sample_gt, (int(t[0]), int(t[1])), size, (0, 255, 0), -1)
# draw predictions
for p in pred[idx]:
sample_pred = cv2.circle(sample_pred, (int(p[0]), int(p[1])), size, (0, 0, 255), -1)
name = targets[idx]['image_id']
# save the visualized images
if des is not None:
cv2.imwrite(os.path.join(vis_dir, '{}_{}_gt_{}_pred_{}_gt.jpg'.format(int(name),
des, len(gts[idx]), len(pred[idx]))), sample_gt)
cv2.imwrite(os.path.join(vis_dir, '{}_{}_gt_{}_pred_{}_pred.jpg'.format(int(name),
des, len(gts[idx]), len(pred[idx]))), sample_pred)
else:
cv2.imwrite(
os.path.join(vis_dir, '{}_gt_{}_pred_{}_gt.jpg'.format(int(name), len(gts[idx]), len(pred[idx]))),
sample_gt)
cv2.imwrite(
os.path.join(vis_dir, '{}_gt_{}_pred_{}_pred.jpg'.format(int(name), len(gts[idx]), len(pred[idx]))),
sample_pred)
# the training routine
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, max_norm: float = 0):
model.train()
criterion.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
# iterate all training samples
for samples, targets in data_loader:
samples = samples.to(device)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
# forward
outputs = model(samples)
# calc the losses
loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
# reduce all losses
loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_unscaled = {f'{k}_unscaled': v
for k, v in loss_dict_reduced.items()}
loss_dict_reduced_scaled = {k: v * weight_dict[k]
for k, v in loss_dict_reduced.items() if k in weight_dict}
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
loss_value = losses_reduced_scaled.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
print(loss_dict_reduced)
sys.exit(1)
# backward
optimizer.zero_grad()
losses.backward()
if max_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
optimizer.step()
# update logger
metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
# the inference routine
@torch.no_grad()
def evaluate_crowd_no_overlap(model, data_loader, device, vis_dir=None):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
# run inference on all images to calc MAE
maes = []
mses = []
for samples, targets in data_loader:
samples = samples.to(device)
outputs = model(samples)
outputs_scores = torch.nn.functional.softmax(outputs['pred_logits'], -1)[:, :, 1][0]
outputs_points = outputs['pred_points'][0]
gt_cnt = targets[0]['point'].shape[0]
# 0.5 is used by default
threshold = 0.5
points = outputs_points[outputs_scores > threshold].detach().cpu().numpy().tolist()
predict_cnt = int((outputs_scores > threshold).sum())
# if specified, save the visualized images
if vis_dir is not None:
vis(samples, targets, [points], vis_dir)
# accumulate MAE, MSE
mae = abs(predict_cnt - gt_cnt)
mse = (predict_cnt - gt_cnt) * (predict_cnt - gt_cnt)
maes.append(float(mae))
mses.append(float(mse))
# calc MAE, MSE
mae = np.mean(maes)
mse = np.sqrt(np.mean(mses))
return mae, mse