300 lines
11 KiB
Python
300 lines
11 KiB
Python
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|||
|
|
"""
|
|||
|
|
Train and eval functions used in main.py
|
|||
|
|
Mostly copy-paste from DETR (https://github.com/facebookresearch/detr).
|
|||
|
|
"""
|
|||
|
|
import math
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
from typing import Iterable
|
|||
|
|
|
|||
|
|
import torch
|
|||
|
|
#print( os.path.abspath( os.path.dirname(__file__) ) )
|
|||
|
|
sys.path.append( os.path.abspath( os.path.dirname(__file__) ) )
|
|||
|
|
import util.misc as utils
|
|||
|
|
from util.misc import NestedTensor
|
|||
|
|
import numpy as np
|
|||
|
|
import time
|
|||
|
|
import torchvision.transforms as standard_transforms
|
|||
|
|
import cv2
|
|||
|
|
import PIL
|
|||
|
|
|
|||
|
|
class DictToObject:
|
|||
|
|
def __init__(self, dictionary):
|
|||
|
|
for key, value in dictionary.items():
|
|||
|
|
if isinstance(value, dict):
|
|||
|
|
setattr(self, key, DictToObject(value))
|
|||
|
|
else:
|
|||
|
|
setattr(self, key, value)
|
|||
|
|
|
|||
|
|
def letterImage(img,minShape,maxShape):
|
|||
|
|
iH,iW = img.shape[0:2]
|
|||
|
|
minH,minW = minShape[2:]
|
|||
|
|
maxH,maxW = maxShape[2:]
|
|||
|
|
flag=False
|
|||
|
|
if iH<minH or iW<minW:
|
|||
|
|
fy = iH/minH; fx = iW/minW; ff = min(fx,fy)
|
|||
|
|
newH,newW = int(iH/ff), int(iW/ff);flag=True
|
|||
|
|
if iH>maxH or iW>maxW:
|
|||
|
|
fy = iH/maxH; fx = iW/maxW; ff = max(fx,fy)
|
|||
|
|
newH,newW = int(iH/ff), int(iW/ff);flag=True
|
|||
|
|
if flag:
|
|||
|
|
assert minH<=newH and newH<= maxH , 'iH%d,iW:%d , newH:%d newW:%d, fx:%.1f fy:%.1f'%(iH,iW,newH,newW,fx,fy)
|
|||
|
|
assert minW<=newW and newW<= maxW, 'iH%d,iW:%d , newH:%d newW:%d, fx:%.1f fy:%.1f'%(iH,iW,newH,newW,fx,fy)
|
|||
|
|
return cv2.resize(img,(newW,newH))
|
|||
|
|
else:
|
|||
|
|
return img
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
def postprocess(outputs,threshold=0.5):
|
|||
|
|
|
|||
|
|
outputs_scores = torch.nn.functional.softmax(outputs['pred_logits'], -1)[:, :, 1][0]
|
|||
|
|
outputs_points = outputs['pred_points'][0]
|
|||
|
|
points = outputs_points[outputs_scores > threshold].detach().cpu().numpy().tolist()
|
|||
|
|
scores = outputs_scores[outputs_scores > threshold].detach().cpu().numpy().tolist()
|
|||
|
|
|
|||
|
|
return points,scores
|
|||
|
|
|
|||
|
|
def toOBBformat(points,scores,cls=0):
|
|||
|
|
outs = []
|
|||
|
|
for i in range(len(points)):
|
|||
|
|
pt,score = points[i],scores[i]
|
|||
|
|
pts4=[pt]*4
|
|||
|
|
ret = [ pts4,score,cls]
|
|||
|
|
outs.append(ret)
|
|||
|
|
return outs
|
|||
|
|
|
|||
|
|
#[ [ [ (x0,y0),(x1,y1),(x2,y2),(x3,y3) ],score, cls ], [ [ (x0,y0),(x1,y1),(x2,y2),(x3,y3) ],score ,cls ],........ ]
|
|||
|
|
|
|||
|
|
def preprocess(img,mean,std,minShape,maxShape):
|
|||
|
|
#img--numpy,(H,W,C)
|
|||
|
|
#输入-RGB格式,(C,H,W)
|
|||
|
|
if isinstance(img,PIL.Image.Image):
|
|||
|
|
img = np.array(img)
|
|||
|
|
|
|||
|
|
img = letterImage(img,minShape,maxShape)
|
|||
|
|
height,width = img.shape[0:2]
|
|||
|
|
|
|||
|
|
new_width = width // 128 * 128
|
|||
|
|
new_height = height // 128 * 128
|
|||
|
|
img = cv2.resize( img, (new_width, new_height) )
|
|||
|
|
|
|||
|
|
img = img/255.
|
|||
|
|
tmpImg = np.zeros((new_height,new_width,3))
|
|||
|
|
|
|||
|
|
|
|||
|
|
tmpImg[:,:,0]=(img[:,:,0]-mean[0])/std[0]
|
|||
|
|
tmpImg[:,:,1]=(img[:,:,1]-mean[1])/std[1]
|
|||
|
|
tmpImg[:,:,2]=(img[:,:,2]-mean[2])/std[2]
|
|||
|
|
tmpImg = tmpImg.transpose((2,0,1)).astype(np.float32)# HWC->CHW
|
|||
|
|
#tmpImg = tmpImg[np.newaxis,:,:,:]#CHW->NCHW
|
|||
|
|
return tmpImg
|
|||
|
|
|
|||
|
|
class DeNormalize(object):
|
|||
|
|
def __init__(self, mean, std):
|
|||
|
|
self.mean = mean
|
|||
|
|
self.std = std
|
|||
|
|
|
|||
|
|
def __call__(self, tensor):
|
|||
|
|
for t, m, s in zip(tensor, self.mean, self.std):
|
|||
|
|
t.mul_(s).add_(m)
|
|||
|
|
return tensor
|
|||
|
|
|
|||
|
|
# generate the reference points in grid layout
|
|||
|
|
def generate_anchor_points(stride=16, row=3, line=3):
|
|||
|
|
row_step = stride / row
|
|||
|
|
line_step = stride / line
|
|||
|
|
|
|||
|
|
shift_x = (np.arange(1, line + 1) - 0.5) * line_step - stride / 2
|
|||
|
|
shift_y = (np.arange(1, row + 1) - 0.5) * row_step - stride / 2
|
|||
|
|
|
|||
|
|
shift_x, shift_y = np.meshgrid(shift_x, shift_y)
|
|||
|
|
|
|||
|
|
anchor_points = np.vstack((
|
|||
|
|
shift_x.ravel(), shift_y.ravel()
|
|||
|
|
)).transpose()
|
|||
|
|
|
|||
|
|
return anchor_points
|
|||
|
|
def shift(shape, stride, anchor_points):
|
|||
|
|
shift_x = (np.arange(0, shape[1]) + 0.5) * stride
|
|||
|
|
shift_y = (np.arange(0, shape[0]) + 0.5) * stride
|
|||
|
|
|
|||
|
|
shift_x, shift_y = np.meshgrid(shift_x, shift_y)
|
|||
|
|
|
|||
|
|
shifts = np.vstack((
|
|||
|
|
shift_x.ravel(), shift_y.ravel()
|
|||
|
|
)).transpose()
|
|||
|
|
|
|||
|
|
A = anchor_points.shape[0]
|
|||
|
|
K = shifts.shape[0]
|
|||
|
|
all_anchor_points = (anchor_points.reshape((1, A, 2)) + shifts.reshape((1, K, 2)).transpose((1, 0, 2)))
|
|||
|
|
all_anchor_points = all_anchor_points.reshape((K * A, 2))
|
|||
|
|
|
|||
|
|
return all_anchor_points
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AnchorPointsf(object):
|
|||
|
|
def __init__(self, pyramid_levels=[3,], strides=None, row=3, line=3,device='cpu'):
|
|||
|
|
|
|||
|
|
if pyramid_levels is None:
|
|||
|
|
self.pyramid_levels = [3, 4, 5, 6, 7]
|
|||
|
|
else:
|
|||
|
|
self.pyramid_levels = pyramid_levels
|
|||
|
|
|
|||
|
|
if strides is None:
|
|||
|
|
self.strides = [2 ** x for x in self.pyramid_levels]
|
|||
|
|
|
|||
|
|
self.row = row
|
|||
|
|
self.line = line
|
|||
|
|
self.device = device
|
|||
|
|
def eval(self, image):
|
|||
|
|
image_shape = image.shape[2:]
|
|||
|
|
image_shape = np.array(image_shape)
|
|||
|
|
image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in self.pyramid_levels]
|
|||
|
|
|
|||
|
|
all_anchor_points = np.zeros((0, 2)).astype(np.float32)
|
|||
|
|
# get reference points for each level
|
|||
|
|
for idx, p in enumerate(self.pyramid_levels):
|
|||
|
|
anchor_points = generate_anchor_points(2**p, row=self.row, line=self.line)
|
|||
|
|
shifted_anchor_points = shift(image_shapes[idx], self.strides[idx], anchor_points)
|
|||
|
|
all_anchor_points = np.append(all_anchor_points, shifted_anchor_points, axis=0)
|
|||
|
|
|
|||
|
|
all_anchor_points = np.expand_dims(all_anchor_points, axis=0)
|
|||
|
|
# send reference points to device
|
|||
|
|
if torch.cuda.is_available() and self.device!='cpu':
|
|||
|
|
return torch.from_numpy(all_anchor_points.astype(np.float32)).cuda()
|
|||
|
|
else:
|
|||
|
|
return torch.from_numpy(all_anchor_points.astype(np.float32))
|
|||
|
|
def vis(samples, targets, pred, vis_dir, des=None):
|
|||
|
|
'''
|
|||
|
|
samples -> tensor: [batch, 3, H, W]
|
|||
|
|
targets -> list of dict: [{'points':[], 'image_id': str}]
|
|||
|
|
pred -> list: [num_preds, 2]
|
|||
|
|
'''
|
|||
|
|
gts = [t['point'].tolist() for t in targets]
|
|||
|
|
|
|||
|
|
pil_to_tensor = standard_transforms.ToTensor()
|
|||
|
|
|
|||
|
|
restore_transform = standard_transforms.Compose([
|
|||
|
|
DeNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|||
|
|
standard_transforms.ToPILImage()
|
|||
|
|
])
|
|||
|
|
# draw one by one
|
|||
|
|
for idx in range(samples.shape[0]):
|
|||
|
|
sample = restore_transform(samples[idx])
|
|||
|
|
sample = pil_to_tensor(sample.convert('RGB')).numpy() * 255
|
|||
|
|
sample_gt = sample.transpose([1, 2, 0])[:, :, ::-1].astype(np.uint8).copy()
|
|||
|
|
sample_pred = sample.transpose([1, 2, 0])[:, :, ::-1].astype(np.uint8).copy()
|
|||
|
|
|
|||
|
|
max_len = np.max(sample_gt.shape)
|
|||
|
|
|
|||
|
|
size = 2
|
|||
|
|
# draw gt
|
|||
|
|
for t in gts[idx]:
|
|||
|
|
sample_gt = cv2.circle(sample_gt, (int(t[0]), int(t[1])), size, (0, 255, 0), -1)
|
|||
|
|
# draw predictions
|
|||
|
|
for p in pred[idx]:
|
|||
|
|
sample_pred = cv2.circle(sample_pred, (int(p[0]), int(p[1])), size, (0, 0, 255), -1)
|
|||
|
|
|
|||
|
|
name = targets[idx]['image_id']
|
|||
|
|
# save the visualized images
|
|||
|
|
if des is not None:
|
|||
|
|
cv2.imwrite(os.path.join(vis_dir, '{}_{}_gt_{}_pred_{}_gt.jpg'.format(int(name),
|
|||
|
|
des, len(gts[idx]), len(pred[idx]))), sample_gt)
|
|||
|
|
cv2.imwrite(os.path.join(vis_dir, '{}_{}_gt_{}_pred_{}_pred.jpg'.format(int(name),
|
|||
|
|
des, len(gts[idx]), len(pred[idx]))), sample_pred)
|
|||
|
|
else:
|
|||
|
|
cv2.imwrite(
|
|||
|
|
os.path.join(vis_dir, '{}_gt_{}_pred_{}_gt.jpg'.format(int(name), len(gts[idx]), len(pred[idx]))),
|
|||
|
|
sample_gt)
|
|||
|
|
cv2.imwrite(
|
|||
|
|
os.path.join(vis_dir, '{}_gt_{}_pred_{}_pred.jpg'.format(int(name), len(gts[idx]), len(pred[idx]))),
|
|||
|
|
sample_pred)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# the training routine
|
|||
|
|
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
|
|||
|
|
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
|||
|
|
device: torch.device, epoch: int, max_norm: float = 0):
|
|||
|
|
model.train()
|
|||
|
|
criterion.train()
|
|||
|
|
metric_logger = utils.MetricLogger(delimiter=" ")
|
|||
|
|
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
|||
|
|
# iterate all training samples
|
|||
|
|
for samples, targets in data_loader:
|
|||
|
|
samples = samples.to(device)
|
|||
|
|
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
|||
|
|
# forward
|
|||
|
|
outputs = model(samples)
|
|||
|
|
# calc the losses
|
|||
|
|
loss_dict = criterion(outputs, targets)
|
|||
|
|
weight_dict = criterion.weight_dict
|
|||
|
|
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
|||
|
|
|
|||
|
|
# reduce all losses
|
|||
|
|
loss_dict_reduced = utils.reduce_dict(loss_dict)
|
|||
|
|
loss_dict_reduced_unscaled = {f'{k}_unscaled': v
|
|||
|
|
for k, v in loss_dict_reduced.items()}
|
|||
|
|
loss_dict_reduced_scaled = {k: v * weight_dict[k]
|
|||
|
|
for k, v in loss_dict_reduced.items() if k in weight_dict}
|
|||
|
|
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
|
|||
|
|
|
|||
|
|
loss_value = losses_reduced_scaled.item()
|
|||
|
|
|
|||
|
|
if not math.isfinite(loss_value):
|
|||
|
|
print("Loss is {}, stopping training".format(loss_value))
|
|||
|
|
print(loss_dict_reduced)
|
|||
|
|
sys.exit(1)
|
|||
|
|
# backward
|
|||
|
|
optimizer.zero_grad()
|
|||
|
|
losses.backward()
|
|||
|
|
if max_norm > 0:
|
|||
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
|||
|
|
optimizer.step()
|
|||
|
|
# update logger
|
|||
|
|
metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
|
|||
|
|
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
|||
|
|
# gather the stats from all processes
|
|||
|
|
metric_logger.synchronize_between_processes()
|
|||
|
|
print("Averaged stats:", metric_logger)
|
|||
|
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
|||
|
|
|
|||
|
|
# the inference routine
|
|||
|
|
@torch.no_grad()
|
|||
|
|
def evaluate_crowd_no_overlap(model, data_loader, device, vis_dir=None):
|
|||
|
|
model.eval()
|
|||
|
|
|
|||
|
|
metric_logger = utils.MetricLogger(delimiter=" ")
|
|||
|
|
metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
|
|||
|
|
# run inference on all images to calc MAE
|
|||
|
|
maes = []
|
|||
|
|
mses = []
|
|||
|
|
for samples, targets in data_loader:
|
|||
|
|
samples = samples.to(device)
|
|||
|
|
|
|||
|
|
outputs = model(samples)
|
|||
|
|
outputs_scores = torch.nn.functional.softmax(outputs['pred_logits'], -1)[:, :, 1][0]
|
|||
|
|
|
|||
|
|
outputs_points = outputs['pred_points'][0]
|
|||
|
|
|
|||
|
|
gt_cnt = targets[0]['point'].shape[0]
|
|||
|
|
# 0.5 is used by default
|
|||
|
|
threshold = 0.5
|
|||
|
|
|
|||
|
|
points = outputs_points[outputs_scores > threshold].detach().cpu().numpy().tolist()
|
|||
|
|
predict_cnt = int((outputs_scores > threshold).sum())
|
|||
|
|
# if specified, save the visualized images
|
|||
|
|
if vis_dir is not None:
|
|||
|
|
vis(samples, targets, [points], vis_dir)
|
|||
|
|
# accumulate MAE, MSE
|
|||
|
|
mae = abs(predict_cnt - gt_cnt)
|
|||
|
|
mse = (predict_cnt - gt_cnt) * (predict_cnt - gt_cnt)
|
|||
|
|
maes.append(float(mae))
|
|||
|
|
mses.append(float(mse))
|
|||
|
|
# calc MAE, MSE
|
|||
|
|
mae = np.mean(maes)
|
|||
|
|
mse = np.sqrt(np.mean(mses))
|
|||
|
|
|
|||
|
|
return mae, mse
|