DMPR-PS/train.py

156 lines
6.7 KiB
Python
Raw Normal View History

2018-10-02 15:54:42 +08:00
"""Train directional marking point detector."""
2018-10-04 09:30:25 +08:00
import math
2018-07-20 16:25:15 +08:00
import random
2023-12-26 16:50:16 +08:00
import numpy as np
2018-07-20 16:25:15 +08:00
import torch
2023-12-26 16:50:16 +08:00
import yaml
from torch import nn
2018-07-20 16:25:15 +08:00
from torch.utils.data import DataLoader
import config
2018-10-04 09:30:25 +08:00
import data
import util
from model import DirectionalPointDetector
2023-12-26 16:50:16 +08:00
from models.yolo import Model
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
2018-07-20 16:25:15 +08:00
2018-10-02 15:54:42 +08:00
def plot_prediction(logger, image, marking_points, prediction):
2018-07-20 16:25:15 +08:00
"""Plot the ground truth and prediction of a random sample in a batch."""
rand_sample = random.randint(0, image.size(0)-1)
2018-10-04 09:30:25 +08:00
sampled_image = util.tensor2im(image[rand_sample])
2018-07-20 16:25:15 +08:00
logger.plot_marking_points(sampled_image, marking_points[rand_sample],
win_name='gt_marking_points')
2018-10-04 09:30:25 +08:00
sampled_image = util.tensor2im(image[rand_sample])
pred_points = data.get_predicted_points(prediction[rand_sample], 0.01)
2018-10-02 15:54:42 +08:00
if pred_points:
logger.plot_marking_points(sampled_image,
list(list(zip(*pred_points))[1]),
win_name='pred_marking_points')
2018-07-20 16:25:15 +08:00
2018-10-04 09:30:25 +08:00
def generate_objective(marking_points_batch, device):
"""Get regression objective and gradient for directional point detector."""
batch_size = len(marking_points_batch)
objective = torch.zeros(batch_size, config.NUM_FEATURE_MAP_CHANNEL,
config.FEATURE_MAP_SIZE, config.FEATURE_MAP_SIZE,
device=device)
gradient = torch.zeros_like(objective)
gradient[:, 0].fill_(1.)
for batch_idx, marking_points in enumerate(marking_points_batch):
for marking_point in marking_points:
2019-04-07 14:11:38 +08:00
col = math.floor(marking_point.x * config.FEATURE_MAP_SIZE)
row = math.floor(marking_point.y * config.FEATURE_MAP_SIZE)
2018-10-04 09:30:25 +08:00
# Confidence Regression
objective[batch_idx, 0, row, col] = 1.
# Makring Point Shape Regression
objective[batch_idx, 1, row, col] = marking_point.shape
# Offset Regression
2023-08-08 10:15:16 +08:00
objective[batch_idx, 2, row, col] = marking_point.x*config.FEATURE_MAP_SIZE - col
objective[batch_idx, 3, row, col] = marking_point.y*config.FEATURE_MAP_SIZE - row
2018-10-04 09:30:25 +08:00
# Direction Regression
direction = marking_point.direction
objective[batch_idx, 4, row, col] = math.cos(direction)
objective[batch_idx, 5, row, col] = math.sin(direction)
# Assign Gradient
gradient[batch_idx, 1:6, row, col].fill_(1.)
return objective, gradient
2023-12-26 16:50:16 +08:00
# class FocalLoss(nn.Module):
# # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
# def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
# super(FocalLoss, self).__init__()
# self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
# self.gamma = gamma
# self.alpha = alpha
# self.reduction = loss_fcn.reduction
# self.loss_fcn.reduction = 'none' # required to apply FL to each element
#
# def forward(self, pred, true):
# loss = self.loss_fcn(pred, true)
# # p_t = torch.exp(-loss)
# # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
#
# # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
# pred_prob = torch.sigmoid(pred) # prob from logits
# p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
# alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
# modulating_factor = (1.0 - p_t) ** self.gamma
# loss *= alpha_factor * modulating_factor
#
# if self.reduction == 'mean':
# return loss.mean()
# elif self.reduction == 'sum':
# return loss.sum()
# else: # 'none'
# return loss
2018-07-20 16:25:15 +08:00
def train_detector(args):
"""Train directional point detector."""
args.cuda = not args.disable_cuda and torch.cuda.is_available()
2018-10-04 09:30:25 +08:00
device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
torch.set_grad_enabled(True)
2018-07-20 16:25:15 +08:00
2023-12-26 16:50:16 +08:00
# dp_detector = DirectionalPointDetector(
# 3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
# if args.detector_weights:
# print("Loading weights: %s" % args.detector_weights)
# dp_detector.load_state_dict(torch.load(args.detector_weights))
# dp_detector.train()
with open(args.hyp) as f:
hyp = yaml.load(f, Loader=yaml.SafeLoader)
dp_detector = Model(args.cfg, ch=3, anchors=hyp.get('anchors')).to(device)
2018-10-02 15:54:42 +08:00
if args.detector_weights:
2018-10-04 09:30:25 +08:00
print("Loading weights: %s" % args.detector_weights)
2018-07-20 16:25:15 +08:00
dp_detector.load_state_dict(torch.load(args.detector_weights))
2018-10-04 09:30:25 +08:00
dp_detector.train()
2018-07-20 16:25:15 +08:00
optimizer = torch.optim.Adam(dp_detector.parameters(), lr=args.lr)
2018-10-02 15:54:42 +08:00
if args.optimizer_weights:
2018-10-04 09:30:25 +08:00
print("Loading weights: %s" % args.optimizer_weights)
2018-07-20 16:25:15 +08:00
optimizer.load_state_dict(torch.load(args.optimizer_weights))
2018-11-21 15:45:22 +08:00
logger = util.Logger(args.enable_visdom, ['train_loss'])
2018-10-04 09:30:25 +08:00
data_loader = DataLoader(data.ParkingSlotDataset(args.dataset_directory),
2018-07-20 16:25:15 +08:00
batch_size=args.batch_size, shuffle=True,
2018-10-02 15:54:42 +08:00
num_workers=args.data_loading_workers,
2023-12-26 16:50:16 +08:00
pin_memory=True,
2018-07-20 16:25:15 +08:00
collate_fn=lambda x: list(zip(*x)))
2018-10-02 15:54:42 +08:00
2023-12-26 16:50:16 +08:00
# BCEobj = nn.BCEWithLogitsLoss(reduction='none', pos_weight=torch.tensor([hyp['obj_pw']], device=device))
# # Focal loss
# g = hyp['fl_gamma'] # focal loss gamma
# if g > 0:
# BCEobj = FocalLoss(BCEobj, g)
2018-07-20 16:25:15 +08:00
for epoch_idx in range(args.num_epochs):
2019-04-07 14:11:38 +08:00
for iter_idx, (images, marking_points) in enumerate(data_loader):
images = torch.stack(images).to(device)
2023-12-26 16:50:16 +08:00
# images = torch.from_numpy(np.stack(images, axis=0)).to(device).permute(0, 3, 1, 2)
2018-07-20 16:25:15 +08:00
optimizer.zero_grad()
2019-04-07 14:11:38 +08:00
prediction = dp_detector(images)
2018-10-02 15:54:42 +08:00
objective, gradient = generate_objective(marking_points, device)
2023-12-26 16:50:16 +08:00
# lobj = BCEobj(prediction[:, 0, ...], objective[:, 0, ...])
2018-07-20 16:25:15 +08:00
loss = (prediction - objective) ** 2
2023-12-26 16:50:16 +08:00
# lobj = torch.unsqueeze(lobj, 1)
# loss = torch.cat((lobj, l_sxycs), 1)
2018-07-20 16:25:15 +08:00
loss.backward(gradient)
optimizer.step()
2018-10-02 15:54:42 +08:00
train_loss = torch.sum(loss*gradient).item() / loss.size(0)
logger.log(epoch=epoch_idx, iter=iter_idx, train_loss=train_loss)
2018-07-20 16:25:15 +08:00
if args.enable_visdom:
2019-04-07 14:11:38 +08:00
plot_prediction(logger, images, marking_points, prediction)
2018-07-20 16:25:15 +08:00
torch.save(dp_detector.state_dict(),
'weights/dp_detector_%d.pth' % epoch_idx)
2018-10-02 15:54:42 +08:00
torch.save(optimizer.state_dict(), 'weights/optimizer.pth')
2018-07-20 16:25:15 +08:00
if __name__ == '__main__':
train_detector(config.get_parser_for_training().parse_args())