Release
This commit is contained in:
parent
b79b06ee6c
commit
16a179a847
13
config.py
13
config.py
|
|
@ -1,6 +1,7 @@
|
|||
"""Configurate arguments."""
|
||||
import argparse
|
||||
|
||||
# TODO: reproduce all number
|
||||
|
||||
INPUT_IMAGE_SIZE = 512
|
||||
# 0: confidence, 1: point_shape, 2: offset_x, 3: offset_y, 4: cos(direction),
|
||||
|
|
@ -20,7 +21,10 @@ VSLOT_MAX_DIST = 0.1099427457599304
|
|||
HSLOT_MIN_DIST = 0.15057789144568634
|
||||
HSLOT_MAX_DIST = 0.44449496544202816
|
||||
|
||||
# angle_prediction_error = 0.1384059287593468
|
||||
SHORT_SEPARATOR_LENGTH = 0.199519231
|
||||
LONG_SEPARATOR_LENGTH = 0.46875
|
||||
|
||||
# angle_prediction_error = 0.1384059287593468 collected from evaluate.py
|
||||
BRIDGE_ANGLE_DIFF = 0.09757113548987695 + 0.1384059287593468
|
||||
SEPARATOR_ANGLE_DIFF = 0.284967562063968 + 0.1384059287593468
|
||||
|
||||
|
|
@ -52,7 +56,7 @@ def get_parser_for_training():
|
|||
help="Batch size.")
|
||||
parser.add_argument('--data_loading_workers', type=int, default=32,
|
||||
help="Number of workers for data loading.")
|
||||
parser.add_argument('--num_epochs', type=int, default=10,
|
||||
parser.add_argument('--num_epochs', type=int, default=20,
|
||||
help="Number of epochs to train for.")
|
||||
parser.add_argument('--lr', type=float, default=1e-4,
|
||||
help="The learning rate of back propagation.")
|
||||
|
|
@ -67,10 +71,6 @@ def get_parser_for_evaluation():
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset_directory', required=True,
|
||||
help="The location of dataset.")
|
||||
parser.add_argument('--batch_size', type=int, default=32,
|
||||
help="Batch size.")
|
||||
parser.add_argument('--data_loading_workers', type=int, default=64,
|
||||
help="Number of workers for data loading.")
|
||||
parser.add_argument('--enable_visdom', action='store_true',
|
||||
help="Enable Visdom to visualize training progress")
|
||||
add_common_arguments(parser)
|
||||
|
|
@ -89,6 +89,7 @@ def get_parser_for_ps_evaluation():
|
|||
add_common_arguments(parser)
|
||||
return parser
|
||||
|
||||
|
||||
def get_parser_for_inference():
|
||||
"""Return argument parser for inference."""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Data related package."""
|
||||
from .data_process import get_predicted_points, pair_marking_points, pass_through_third_point
|
||||
from .process import get_predicted_points, pair_marking_points, pass_through_third_point
|
||||
from .dataset import ParkingSlotDataset
|
||||
from .struct import MarkingPoint, Slot, match_marking_points, match_slots, calc_point_squre_dist, calc_point_direction_angle
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Defines data structure and related function to process these data."""
|
||||
"""Defines related function to process defined data structure."""
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
@ -1,11 +1,9 @@
|
|||
"""Evaluate directional marking point detector."""
|
||||
import cv2 as cv
|
||||
import torch
|
||||
import config
|
||||
import util
|
||||
from data import get_predicted_points, match_marking_points, calc_point_squre_dist, calc_point_direction_angle
|
||||
from data import ParkingSlotDataset
|
||||
from inference import plot_points
|
||||
from model import DirectionalPointDetector
|
||||
from train import generate_objective
|
||||
|
||||
|
|
@ -74,12 +72,6 @@ def evaluate_detector(args):
|
|||
pred_points = get_predicted_points(prediction[0], 0.01)
|
||||
predictions_list.append(pred_points)
|
||||
|
||||
# if not is_gt_and_pred_matched(marking_points, pred_points,
|
||||
# config.CONFID_THRESH_FOR_POINT):
|
||||
# cvimage = util.tensor2array(image[0]).copy()
|
||||
# plot_points(cvimage, pred_points)
|
||||
# cv.imwrite('flaw/%d.jpg' % iter_idx, cvimage,
|
||||
# [int(cv.IMWRITE_JPEG_QUALITY), 100])
|
||||
dists, angles = collect_error(marking_points, pred_points,
|
||||
config.CONFID_THRESH_FOR_POINT)
|
||||
position_errors += dists
|
||||
|
|
|
|||
15
inference.py
15
inference.py
|
|
@ -60,10 +60,15 @@ def plot_slots(image, pred_points, slots):
|
|||
p1_y = height * point_b.y - 0.5
|
||||
vec = np.array([p1_x - p0_x, p1_y - p0_y])
|
||||
vec = vec / np.linalg.norm(vec)
|
||||
p2_x = p0_x + 200*vec[1]
|
||||
p2_y = p0_y - 200*vec[0]
|
||||
p3_x = p1_x + 200*vec[1]
|
||||
p3_y = p1_y - 200*vec[0]
|
||||
distance = calc_point_squre_dist(point_a, point_b)
|
||||
if config.VSLOT_MIN_DIST <= distance <= config.VSLOT_MAX_DIST:
|
||||
separating_length = config.LONG_SEPARATOR_LENGTH
|
||||
elif config.HSLOT_MIN_DIST <= distance <= config.HSLOT_MAX_DIST:
|
||||
separating_length = config.SHORT_SEPARATOR_LENGTH
|
||||
p2_x = p0_x + height * separating_length * vec[1]
|
||||
p2_y = p0_y - width * separating_length * vec[0]
|
||||
p3_x = p1_x + height * separating_length * vec[1]
|
||||
p3_y = p1_y - width * separating_length * vec[0]
|
||||
p0_x = int(round(p0_x))
|
||||
p0_y = int(round(p0_y))
|
||||
p1_x = int(round(p1_x))
|
||||
|
|
@ -122,7 +127,7 @@ def detect_video(detector, device, args):
|
|||
frame_height = int(input_video.get(cv.CAP_PROP_FRAME_HEIGHT))
|
||||
output_video = cv.VideoWriter()
|
||||
if args.save:
|
||||
output_video.open('record.avi', cv.VideoWriter_fourcc(*'MJPG'),
|
||||
output_video.open('record.avi', cv.VideoWriter_fourcc(*'XVID'),
|
||||
input_video.get(cv.CAP_PROP_FPS),
|
||||
(frame_width, frame_height), True)
|
||||
frame = np.empty([frame_height, frame_width, 3], dtype=np.uint8)
|
||||
|
|
|
|||
|
|
@ -55,8 +55,7 @@ class DirectionalPointDetector(nn.modules.Module):
|
|||
self.predict = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, *x):
|
||||
feature = self.extract_feature(x[0])
|
||||
prediction = self.predict(feature)
|
||||
prediction = self.predict(self.extract_feature(x[0]))
|
||||
# 4 represents that there are 4 value: confidence, shape, offset_x,
|
||||
# offset_y, whose range is between [0, 1].
|
||||
point_pred, angle_pred = torch.split(prediction, 4, dim=1)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from inference import detect_marking_points, inference_slots
|
|||
|
||||
|
||||
def get_ground_truths(label):
|
||||
"""Read label to get ground truth slot."""
|
||||
slots = np.array(label['slots'])
|
||||
if slots.size == 0:
|
||||
return []
|
||||
|
|
|
|||
12
train.py
12
train.py
|
|
@ -33,8 +33,8 @@ def generate_objective(marking_points_batch, device):
|
|||
gradient[:, 0].fill_(1.)
|
||||
for batch_idx, marking_points in enumerate(marking_points_batch):
|
||||
for marking_point in marking_points:
|
||||
col = math.floor(marking_point.x * 16)
|
||||
row = math.floor(marking_point.y * 16)
|
||||
col = math.floor(marking_point.x * config.FEATURE_MAP_SIZE)
|
||||
row = math.floor(marking_point.y * config.FEATURE_MAP_SIZE)
|
||||
# Confidence Regression
|
||||
objective[batch_idx, 0, row, col] = 1.
|
||||
# Makring Point Shape Regression
|
||||
|
|
@ -76,11 +76,11 @@ def train_detector(args):
|
|||
collate_fn=lambda x: list(zip(*x)))
|
||||
|
||||
for epoch_idx in range(args.num_epochs):
|
||||
for iter_idx, (image, marking_points) in enumerate(data_loader):
|
||||
image = torch.stack(image).to(device)
|
||||
for iter_idx, (images, marking_points) in enumerate(data_loader):
|
||||
images = torch.stack(images).to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
prediction = dp_detector(image)
|
||||
prediction = dp_detector(images)
|
||||
objective, gradient = generate_objective(marking_points, device)
|
||||
loss = (prediction - objective) ** 2
|
||||
loss.backward(gradient)
|
||||
|
|
@ -89,7 +89,7 @@ def train_detector(args):
|
|||
train_loss = torch.sum(loss*gradient).item() / loss.size(0)
|
||||
logger.log(epoch=epoch_idx, iter=iter_idx, train_loss=train_loss)
|
||||
if args.enable_visdom:
|
||||
plot_prediction(logger, image, marking_points, prediction)
|
||||
plot_prediction(logger, images, marking_points, prediction)
|
||||
torch.save(dp_detector.state_dict(),
|
||||
'weights/dp_detector_%d.pth' % epoch_idx)
|
||||
torch.save(optimizer.state_dict(), 'weights/optimizer.pth')
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""Class for logging."""
|
||||
import math
|
||||
import numpy as np
|
||||
from visdom import Visdom
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ class Timer(object):
|
|||
|
||||
def calc_average_time(self):
|
||||
"""Calculate average elapsed time of timer."""
|
||||
if self.count == 0:
|
||||
return 0.
|
||||
return self.total_time / self.count
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue