DMPR-PS/train.py

100 lines
4.2 KiB
Python
Raw Normal View History

2018-07-20 16:25:15 +08:00
"""Train directional point detector."""
import math
import random
import torch
from torch.utils.data import DataLoader
import config
from data import ParkingSlotDataset
from detector import DirectionalPointDetector
from log import Logger
from utils import tensor2im, get_marking_points
def get_objective_from_labels(marking_points_batch, device):
"""Get regression objective and gradient for directional point detector."""
batch_size = len(marking_points_batch)
objective = torch.zeros(batch_size, config.NUM_FEATURE_MAP_CHANNEL,
config.FEATURE_MAP_SIZE, config.FEATURE_MAP_SIZE,
device=device)
gradient = torch.zeros_like(objective)
gradient[:, 0].fill_(1.)
for batch_idx, marking_points in enumerate(marking_points_batch):
for marking_point in marking_points:
col = math.floor(marking_point[0] * 16)
row = math.floor(marking_point[1] * 16)
# Confidence Regression
objective[batch_idx, 0, row, col] = 1.
# Offset Regression
offset_x = marking_point[0]*16 - col
offset_y = marking_point[1]*16 - row
objective[batch_idx, 1, row, col] = offset_x
objective[batch_idx, 2, row, col] = offset_y
# Direction Regression
direction = marking_point[2]
objective[batch_idx, 3, row, col] = math.cos(direction)
objective[batch_idx, 4, row, col] = math.sin(direction)
# Assign Gradient
gradient[batch_idx, 1:5, row, col].fill_(1.)
return objective, gradient
def plot_random_prediction(logger, image, marking_points, prediction):
"""Plot the ground truth and prediction of a random sample in a batch."""
rand_sample = random.randint(0, image.size(0)-1)
sampled_image = tensor2im(image[rand_sample])
logger.plot_marking_points(sampled_image, marking_points[rand_sample],
win_name='gt_marking_points')
sampled_image = tensor2im(image[rand_sample])
pred_points = get_marking_points(prediction[rand_sample], 0.01)
logger.plot_marking_points(sampled_image, pred_points,
win_name='pred_marking_points')
def train_detector(args):
"""Train directional point detector."""
args.cuda = not args.disable_cuda and torch.cuda.is_available()
device = torch.device("cuda:"+str(args.gpu_id) if args.cuda else "cpu")
dp_detector = DirectionalPointDetector(3, args.depth_factor, 5).to(device)
if args.detector_weights is not None:
dp_detector.load_state_dict(torch.load(args.detector_weights))
optimizer = torch.optim.Adam(dp_detector.parameters(), lr=args.lr)
if args.optimizer_weights is not None:
optimizer.load_state_dict(torch.load(args.optimizer_weights))
if args.enable_visdom:
logger = Logger(['loss'])
else:
logger = Logger()
data_loader = DataLoader(ParkingSlotDataset(args.dataset_directory),
batch_size=args.batch_size, shuffle=True,
collate_fn=lambda x: list(zip(*x)))
for epoch_idx in range(args.num_epochs):
for iter_idx, (image, marking_points) in enumerate(data_loader):
image = torch.stack(image)
image = image.to(device)
optimizer.zero_grad()
prediction = dp_detector(image)
objective, gradient = get_objective_from_labels(marking_points,
device)
loss = (prediction - objective) ** 2
loss.backward(gradient)
optimizer.step()
logger.log(epoch=epoch_idx, iter=iter_idx,
loss=torch.sum(loss * gradient).item())
if args.enable_visdom:
logger.plot_curve()
plot_random_prediction(logger, image, marking_points,
prediction)
torch.save(dp_detector.state_dict(),
'weights/dp_detector_%d.pth' % epoch_idx)
torch.save(optimizer.state_dict(), 'weights/optimizer.pth')
if __name__ == '__main__':
train_detector(config.get_parser_for_training().parse_args())