DMPR-PS/train.py

100 lines
4.3 KiB
Python

"""Train directional marking point detector."""
import math
import random
import torch
from torch.utils.data import DataLoader
import config
import data
import util
from model import DirectionalPointDetector
def plot_prediction(logger, image, marking_points, prediction):
"""Plot the ground truth and prediction of a random sample in a batch."""
rand_sample = random.randint(0, image.size(0)-1)
sampled_image = util.tensor2im(image[rand_sample])
logger.plot_marking_points(sampled_image, marking_points[rand_sample],
win_name='gt_marking_points')
sampled_image = util.tensor2im(image[rand_sample])
pred_points = data.get_predicted_points(prediction[rand_sample], 0.01)
if pred_points:
logger.plot_marking_points(sampled_image,
list(list(zip(*pred_points))[1]),
win_name='pred_marking_points')
def generate_objective(marking_points_batch, device):
"""Get regression objective and gradient for directional point detector."""
batch_size = len(marking_points_batch)
objective = torch.zeros(batch_size, config.NUM_FEATURE_MAP_CHANNEL,
config.FEATURE_MAP_SIZE, config.FEATURE_MAP_SIZE,
device=device)
gradient = torch.zeros_like(objective)
gradient[:, 0].fill_(1.)
for batch_idx, marking_points in enumerate(marking_points_batch):
for marking_point in marking_points:
col = math.floor(marking_point.x * 16)
row = math.floor(marking_point.y * 16)
# Confidence Regression
objective[batch_idx, 0, row, col] = 1.
# Makring Point Shape Regression
objective[batch_idx, 1, row, col] = marking_point.shape
# Offset Regression
objective[batch_idx, 2, row, col] = marking_point.x*16 - col
objective[batch_idx, 3, row, col] = marking_point.y*16 - row
# Direction Regression
direction = marking_point.direction
objective[batch_idx, 4, row, col] = math.cos(direction)
objective[batch_idx, 5, row, col] = math.sin(direction)
# Assign Gradient
gradient[batch_idx, 1:6, row, col].fill_(1.)
return objective, gradient
def train_detector(args):
"""Train directional point detector."""
args.cuda = not args.disable_cuda and torch.cuda.is_available()
device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
torch.set_grad_enabled(True)
dp_detector = DirectionalPointDetector(
3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
if args.detector_weights:
print("Loading weights: %s" % args.detector_weights)
dp_detector.load_state_dict(torch.load(args.detector_weights))
dp_detector.train()
optimizer = torch.optim.Adam(dp_detector.parameters(), lr=args.lr)
if args.optimizer_weights:
print("Loading weights: %s" % args.optimizer_weights)
optimizer.load_state_dict(torch.load(args.optimizer_weights))
logger = util.Logger(args.enable_visdom, ['train_loss'])
data_loader = DataLoader(data.ParkingSlotDataset(args.dataset_directory),
batch_size=args.batch_size, shuffle=True,
num_workers=args.data_loading_workers,
collate_fn=lambda x: list(zip(*x)))
for epoch_idx in range(args.num_epochs):
for iter_idx, (image, marking_points) in enumerate(data_loader):
image = torch.stack(image).to(device)
optimizer.zero_grad()
prediction = dp_detector(image)
objective, gradient = generate_objective(marking_points, device)
loss = (prediction - objective) ** 2
loss.backward(gradient)
optimizer.step()
train_loss = torch.sum(loss*gradient).item() / loss.size(0)
logger.log(epoch=epoch_idx, iter=iter_idx, train_loss=train_loss)
if args.enable_visdom:
plot_prediction(logger, image, marking_points, prediction)
torch.save(dp_detector.state_dict(),
'weights/dp_detector_%d.pth' % epoch_idx)
torch.save(optimizer.state_dict(), 'weights/optimizer.pth')
if __name__ == '__main__':
train_detector(config.get_parser_for_training().parse_args())