Release
This commit is contained in:
parent
b79b06ee6c
commit
16a179a847
13
config.py
13
config.py
|
|
@ -1,6 +1,7 @@
|
||||||
"""Configurate arguments."""
|
"""Configurate arguments."""
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
# TODO: reproduce all number
|
||||||
|
|
||||||
INPUT_IMAGE_SIZE = 512
|
INPUT_IMAGE_SIZE = 512
|
||||||
# 0: confidence, 1: point_shape, 2: offset_x, 3: offset_y, 4: cos(direction),
|
# 0: confidence, 1: point_shape, 2: offset_x, 3: offset_y, 4: cos(direction),
|
||||||
|
|
@ -20,7 +21,10 @@ VSLOT_MAX_DIST = 0.1099427457599304
|
||||||
HSLOT_MIN_DIST = 0.15057789144568634
|
HSLOT_MIN_DIST = 0.15057789144568634
|
||||||
HSLOT_MAX_DIST = 0.44449496544202816
|
HSLOT_MAX_DIST = 0.44449496544202816
|
||||||
|
|
||||||
# angle_prediction_error = 0.1384059287593468
|
SHORT_SEPARATOR_LENGTH = 0.199519231
|
||||||
|
LONG_SEPARATOR_LENGTH = 0.46875
|
||||||
|
|
||||||
|
# angle_prediction_error = 0.1384059287593468 collected from evaluate.py
|
||||||
BRIDGE_ANGLE_DIFF = 0.09757113548987695 + 0.1384059287593468
|
BRIDGE_ANGLE_DIFF = 0.09757113548987695 + 0.1384059287593468
|
||||||
SEPARATOR_ANGLE_DIFF = 0.284967562063968 + 0.1384059287593468
|
SEPARATOR_ANGLE_DIFF = 0.284967562063968 + 0.1384059287593468
|
||||||
|
|
||||||
|
|
@ -52,7 +56,7 @@ def get_parser_for_training():
|
||||||
help="Batch size.")
|
help="Batch size.")
|
||||||
parser.add_argument('--data_loading_workers', type=int, default=32,
|
parser.add_argument('--data_loading_workers', type=int, default=32,
|
||||||
help="Number of workers for data loading.")
|
help="Number of workers for data loading.")
|
||||||
parser.add_argument('--num_epochs', type=int, default=10,
|
parser.add_argument('--num_epochs', type=int, default=20,
|
||||||
help="Number of epochs to train for.")
|
help="Number of epochs to train for.")
|
||||||
parser.add_argument('--lr', type=float, default=1e-4,
|
parser.add_argument('--lr', type=float, default=1e-4,
|
||||||
help="The learning rate of back propagation.")
|
help="The learning rate of back propagation.")
|
||||||
|
|
@ -67,10 +71,6 @@ def get_parser_for_evaluation():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--dataset_directory', required=True,
|
parser.add_argument('--dataset_directory', required=True,
|
||||||
help="The location of dataset.")
|
help="The location of dataset.")
|
||||||
parser.add_argument('--batch_size', type=int, default=32,
|
|
||||||
help="Batch size.")
|
|
||||||
parser.add_argument('--data_loading_workers', type=int, default=64,
|
|
||||||
help="Number of workers for data loading.")
|
|
||||||
parser.add_argument('--enable_visdom', action='store_true',
|
parser.add_argument('--enable_visdom', action='store_true',
|
||||||
help="Enable Visdom to visualize training progress")
|
help="Enable Visdom to visualize training progress")
|
||||||
add_common_arguments(parser)
|
add_common_arguments(parser)
|
||||||
|
|
@ -89,6 +89,7 @@ def get_parser_for_ps_evaluation():
|
||||||
add_common_arguments(parser)
|
add_common_arguments(parser)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_parser_for_inference():
|
def get_parser_for_inference():
|
||||||
"""Return argument parser for inference."""
|
"""Return argument parser for inference."""
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
"""Data related package."""
|
"""Data related package."""
|
||||||
from .data_process import get_predicted_points, pair_marking_points, pass_through_third_point
|
from .process import get_predicted_points, pair_marking_points, pass_through_third_point
|
||||||
from .dataset import ParkingSlotDataset
|
from .dataset import ParkingSlotDataset
|
||||||
from .struct import MarkingPoint, Slot, match_marking_points, match_slots, calc_point_squre_dist, calc_point_direction_angle
|
from .struct import MarkingPoint, Slot, match_marking_points, match_slots, calc_point_squre_dist, calc_point_direction_angle
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
"""Defines data structure and related function to process these data."""
|
"""Defines related function to process defined data structure."""
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -1,11 +1,9 @@
|
||||||
"""Evaluate directional marking point detector."""
|
"""Evaluate directional marking point detector."""
|
||||||
import cv2 as cv
|
|
||||||
import torch
|
import torch
|
||||||
import config
|
import config
|
||||||
import util
|
import util
|
||||||
from data import get_predicted_points, match_marking_points, calc_point_squre_dist, calc_point_direction_angle
|
from data import get_predicted_points, match_marking_points, calc_point_squre_dist, calc_point_direction_angle
|
||||||
from data import ParkingSlotDataset
|
from data import ParkingSlotDataset
|
||||||
from inference import plot_points
|
|
||||||
from model import DirectionalPointDetector
|
from model import DirectionalPointDetector
|
||||||
from train import generate_objective
|
from train import generate_objective
|
||||||
|
|
||||||
|
|
@ -74,12 +72,6 @@ def evaluate_detector(args):
|
||||||
pred_points = get_predicted_points(prediction[0], 0.01)
|
pred_points = get_predicted_points(prediction[0], 0.01)
|
||||||
predictions_list.append(pred_points)
|
predictions_list.append(pred_points)
|
||||||
|
|
||||||
# if not is_gt_and_pred_matched(marking_points, pred_points,
|
|
||||||
# config.CONFID_THRESH_FOR_POINT):
|
|
||||||
# cvimage = util.tensor2array(image[0]).copy()
|
|
||||||
# plot_points(cvimage, pred_points)
|
|
||||||
# cv.imwrite('flaw/%d.jpg' % iter_idx, cvimage,
|
|
||||||
# [int(cv.IMWRITE_JPEG_QUALITY), 100])
|
|
||||||
dists, angles = collect_error(marking_points, pred_points,
|
dists, angles = collect_error(marking_points, pred_points,
|
||||||
config.CONFID_THRESH_FOR_POINT)
|
config.CONFID_THRESH_FOR_POINT)
|
||||||
position_errors += dists
|
position_errors += dists
|
||||||
|
|
|
||||||
15
inference.py
15
inference.py
|
|
@ -60,10 +60,15 @@ def plot_slots(image, pred_points, slots):
|
||||||
p1_y = height * point_b.y - 0.5
|
p1_y = height * point_b.y - 0.5
|
||||||
vec = np.array([p1_x - p0_x, p1_y - p0_y])
|
vec = np.array([p1_x - p0_x, p1_y - p0_y])
|
||||||
vec = vec / np.linalg.norm(vec)
|
vec = vec / np.linalg.norm(vec)
|
||||||
p2_x = p0_x + 200*vec[1]
|
distance = calc_point_squre_dist(point_a, point_b)
|
||||||
p2_y = p0_y - 200*vec[0]
|
if config.VSLOT_MIN_DIST <= distance <= config.VSLOT_MAX_DIST:
|
||||||
p3_x = p1_x + 200*vec[1]
|
separating_length = config.LONG_SEPARATOR_LENGTH
|
||||||
p3_y = p1_y - 200*vec[0]
|
elif config.HSLOT_MIN_DIST <= distance <= config.HSLOT_MAX_DIST:
|
||||||
|
separating_length = config.SHORT_SEPARATOR_LENGTH
|
||||||
|
p2_x = p0_x + height * separating_length * vec[1]
|
||||||
|
p2_y = p0_y - width * separating_length * vec[0]
|
||||||
|
p3_x = p1_x + height * separating_length * vec[1]
|
||||||
|
p3_y = p1_y - width * separating_length * vec[0]
|
||||||
p0_x = int(round(p0_x))
|
p0_x = int(round(p0_x))
|
||||||
p0_y = int(round(p0_y))
|
p0_y = int(round(p0_y))
|
||||||
p1_x = int(round(p1_x))
|
p1_x = int(round(p1_x))
|
||||||
|
|
@ -122,7 +127,7 @@ def detect_video(detector, device, args):
|
||||||
frame_height = int(input_video.get(cv.CAP_PROP_FRAME_HEIGHT))
|
frame_height = int(input_video.get(cv.CAP_PROP_FRAME_HEIGHT))
|
||||||
output_video = cv.VideoWriter()
|
output_video = cv.VideoWriter()
|
||||||
if args.save:
|
if args.save:
|
||||||
output_video.open('record.avi', cv.VideoWriter_fourcc(*'MJPG'),
|
output_video.open('record.avi', cv.VideoWriter_fourcc(*'XVID'),
|
||||||
input_video.get(cv.CAP_PROP_FPS),
|
input_video.get(cv.CAP_PROP_FPS),
|
||||||
(frame_width, frame_height), True)
|
(frame_width, frame_height), True)
|
||||||
frame = np.empty([frame_height, frame_width, 3], dtype=np.uint8)
|
frame = np.empty([frame_height, frame_width, 3], dtype=np.uint8)
|
||||||
|
|
|
||||||
|
|
@ -55,8 +55,7 @@ class DirectionalPointDetector(nn.modules.Module):
|
||||||
self.predict = nn.Sequential(*layers)
|
self.predict = nn.Sequential(*layers)
|
||||||
|
|
||||||
def forward(self, *x):
|
def forward(self, *x):
|
||||||
feature = self.extract_feature(x[0])
|
prediction = self.predict(self.extract_feature(x[0]))
|
||||||
prediction = self.predict(feature)
|
|
||||||
# 4 represents that there are 4 value: confidence, shape, offset_x,
|
# 4 represents that there are 4 value: confidence, shape, offset_x,
|
||||||
# offset_y, whose range is between [0, 1].
|
# offset_y, whose range is between [0, 1].
|
||||||
point_pred, angle_pred = torch.split(prediction, 4, dim=1)
|
point_pred, angle_pred = torch.split(prediction, 4, dim=1)
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ from inference import detect_marking_points, inference_slots
|
||||||
|
|
||||||
|
|
||||||
def get_ground_truths(label):
|
def get_ground_truths(label):
|
||||||
|
"""Read label to get ground truth slot."""
|
||||||
slots = np.array(label['slots'])
|
slots = np.array(label['slots'])
|
||||||
if slots.size == 0:
|
if slots.size == 0:
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
12
train.py
12
train.py
|
|
@ -33,8 +33,8 @@ def generate_objective(marking_points_batch, device):
|
||||||
gradient[:, 0].fill_(1.)
|
gradient[:, 0].fill_(1.)
|
||||||
for batch_idx, marking_points in enumerate(marking_points_batch):
|
for batch_idx, marking_points in enumerate(marking_points_batch):
|
||||||
for marking_point in marking_points:
|
for marking_point in marking_points:
|
||||||
col = math.floor(marking_point.x * 16)
|
col = math.floor(marking_point.x * config.FEATURE_MAP_SIZE)
|
||||||
row = math.floor(marking_point.y * 16)
|
row = math.floor(marking_point.y * config.FEATURE_MAP_SIZE)
|
||||||
# Confidence Regression
|
# Confidence Regression
|
||||||
objective[batch_idx, 0, row, col] = 1.
|
objective[batch_idx, 0, row, col] = 1.
|
||||||
# Makring Point Shape Regression
|
# Makring Point Shape Regression
|
||||||
|
|
@ -76,11 +76,11 @@ def train_detector(args):
|
||||||
collate_fn=lambda x: list(zip(*x)))
|
collate_fn=lambda x: list(zip(*x)))
|
||||||
|
|
||||||
for epoch_idx in range(args.num_epochs):
|
for epoch_idx in range(args.num_epochs):
|
||||||
for iter_idx, (image, marking_points) in enumerate(data_loader):
|
for iter_idx, (images, marking_points) in enumerate(data_loader):
|
||||||
image = torch.stack(image).to(device)
|
images = torch.stack(images).to(device)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
prediction = dp_detector(image)
|
prediction = dp_detector(images)
|
||||||
objective, gradient = generate_objective(marking_points, device)
|
objective, gradient = generate_objective(marking_points, device)
|
||||||
loss = (prediction - objective) ** 2
|
loss = (prediction - objective) ** 2
|
||||||
loss.backward(gradient)
|
loss.backward(gradient)
|
||||||
|
|
@ -89,7 +89,7 @@ def train_detector(args):
|
||||||
train_loss = torch.sum(loss*gradient).item() / loss.size(0)
|
train_loss = torch.sum(loss*gradient).item() / loss.size(0)
|
||||||
logger.log(epoch=epoch_idx, iter=iter_idx, train_loss=train_loss)
|
logger.log(epoch=epoch_idx, iter=iter_idx, train_loss=train_loss)
|
||||||
if args.enable_visdom:
|
if args.enable_visdom:
|
||||||
plot_prediction(logger, image, marking_points, prediction)
|
plot_prediction(logger, images, marking_points, prediction)
|
||||||
torch.save(dp_detector.state_dict(),
|
torch.save(dp_detector.state_dict(),
|
||||||
'weights/dp_detector_%d.pth' % epoch_idx)
|
'weights/dp_detector_%d.pth' % epoch_idx)
|
||||||
torch.save(optimizer.state_dict(), 'weights/optimizer.pth')
|
torch.save(optimizer.state_dict(), 'weights/optimizer.pth')
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
# -*- coding: utf-8 -*-
|
"""Class for logging."""
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from visdom import Visdom
|
from visdom import Visdom
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,8 @@ class Timer(object):
|
||||||
|
|
||||||
def calc_average_time(self):
|
def calc_average_time(self):
|
||||||
"""Calculate average elapsed time of timer."""
|
"""Calculate average elapsed time of timer."""
|
||||||
|
if self.count == 0:
|
||||||
|
return 0.
|
||||||
return self.total_time / self.count
|
return self.total_time / self.count
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue