DMPR-PS/train.py

71 lines
2.9 KiB
Python
Raw Normal View History

2018-10-02 15:54:42 +08:00
"""Train directional marking point detector."""
2018-07-20 16:25:15 +08:00
import random
import torch
from torch.utils.data import DataLoader
import config
2018-10-02 15:54:42 +08:00
from data import get_predicted_points
from data import generate_objective
from dataset import ParkingSlotDataset
2018-07-20 16:25:15 +08:00
from detector import DirectionalPointDetector
from log import Logger
2018-10-02 15:54:42 +08:00
from utils import tensor2im
2018-07-20 16:25:15 +08:00
2018-10-02 15:54:42 +08:00
def plot_prediction(logger, image, marking_points, prediction):
2018-07-20 16:25:15 +08:00
"""Plot the ground truth and prediction of a random sample in a batch."""
rand_sample = random.randint(0, image.size(0)-1)
sampled_image = tensor2im(image[rand_sample])
logger.plot_marking_points(sampled_image, marking_points[rand_sample],
win_name='gt_marking_points')
sampled_image = tensor2im(image[rand_sample])
2018-10-02 15:54:42 +08:00
pred_points = get_predicted_points(prediction[rand_sample], 0.01)
if pred_points:
logger.plot_marking_points(sampled_image,
list(list(zip(*pred_points))[1]),
win_name='pred_marking_points')
2018-07-20 16:25:15 +08:00
def train_detector(args):
"""Train directional point detector."""
args.cuda = not args.disable_cuda and torch.cuda.is_available()
2018-10-02 15:54:42 +08:00
device = torch.device('cuda:'+str(args.gpu_id) if args.cuda else 'cpu')
2018-07-20 16:25:15 +08:00
2018-10-02 15:54:42 +08:00
dp_detector = DirectionalPointDetector(
3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
if args.detector_weights:
2018-07-20 16:25:15 +08:00
dp_detector.load_state_dict(torch.load(args.detector_weights))
optimizer = torch.optim.Adam(dp_detector.parameters(), lr=args.lr)
2018-10-02 15:54:42 +08:00
if args.optimizer_weights:
2018-07-20 16:25:15 +08:00
optimizer.load_state_dict(torch.load(args.optimizer_weights))
2018-10-02 15:54:42 +08:00
logger = Logger(['train_loss'] if args.enable_visdom else None)
2018-07-20 16:25:15 +08:00
data_loader = DataLoader(ParkingSlotDataset(args.dataset_directory),
batch_size=args.batch_size, shuffle=True,
2018-10-02 15:54:42 +08:00
num_workers=args.data_loading_workers,
2018-07-20 16:25:15 +08:00
collate_fn=lambda x: list(zip(*x)))
2018-10-02 15:54:42 +08:00
2018-07-20 16:25:15 +08:00
for epoch_idx in range(args.num_epochs):
for iter_idx, (image, marking_points) in enumerate(data_loader):
image = torch.stack(image)
image = image.to(device)
optimizer.zero_grad()
prediction = dp_detector(image)
2018-10-02 15:54:42 +08:00
objective, gradient = generate_objective(marking_points, device)
2018-07-20 16:25:15 +08:00
loss = (prediction - objective) ** 2
loss.backward(gradient)
optimizer.step()
2018-10-02 15:54:42 +08:00
train_loss = torch.sum(loss*gradient).item() / loss.size(0)
logger.log(epoch=epoch_idx, iter=iter_idx, train_loss=train_loss)
2018-07-20 16:25:15 +08:00
if args.enable_visdom:
2018-10-02 15:54:42 +08:00
plot_prediction(logger, image, marking_points, prediction)
2018-07-20 16:25:15 +08:00
torch.save(dp_detector.state_dict(),
'weights/dp_detector_%d.pth' % epoch_idx)
2018-10-02 15:54:42 +08:00
torch.save(optimizer.state_dict(), 'weights/optimizer.pth')
2018-07-20 16:25:15 +08:00
if __name__ == '__main__':
train_detector(config.get_parser_for_training().parse_args())