瀏覽代碼

Add parking slot inference

v1
Teoge 6 年之前
父節點
當前提交
396495f801
共有 7 個檔案被更改,包括 261 行新增53 行删除
  1. +29
    -4
      config.py
  2. +2
    -2
      data/__init__.py
  3. +63
    -25
      data/data_process.py
  4. +69
    -0
      data/struct.py
  5. +89
    -20
      inference.py
  6. +1
    -1
      prepare_dataset.py
  7. +8
    -1
      util/utils.py

+ 29
- 4
config.py 查看文件

@@ -8,10 +8,23 @@ INPUT_IMAGE_SIZE = 512
NUM_FEATURE_MAP_CHANNEL = 6
# image_size / 2^5 = 512 / 32 = 16
FEATURE_MAP_SIZE = 16
# Threshold used to filter marking points too close to image boundary
BOUNDARY_THRESH = 0.066666667

# Thresholds to determine whether an detected point match ground truth.
SQUARED_DISTANCE_THRESH = 0.000277778
DIRECTION_ANGLE_THRESH = 0.5

VSLOT_MIN_DISTANCE = 0.044771278151623496
VSLOT_MAX_DISTANCE = 0.1099427457599304
HSLOT_MIN_DISTANCE = 0.15057789144568634
HSLOT_MAX_DISTANCE = 0.44449496544202816

BRIDGE_ANGLE_DIFF = 0.25
SEPARATOR_ANGLE_DIFF = 0.5

SLOT_SUPPRESSION_DOT_PRODUCT_THRESH = 0.8


def add_common_arguments(parser):
"""Add common arguments for training and inference."""
@@ -34,9 +47,9 @@ def get_parser_for_training():
help="The weights of optimizer.")
parser.add_argument('--batch_size', type=int, default=24,
help="Batch size.")
parser.add_argument('--data_loading_workers', type=int, default=48,
parser.add_argument('--data_loading_workers', type=int, default=32,
help="Number of workers for data loading.")
parser.add_argument('--num_epochs', type=int, default=100,
parser.add_argument('--num_epochs', type=int, default=10,
help="Number of epochs to train for.")
parser.add_argument('--lr', type=float, default=1e-4,
help="The learning rate of back propagation.")
@@ -61,6 +74,18 @@ def get_parser_for_evaluation():
return parser


def get_parser_for_ps_evaluation():
"""Return argument parser for testing."""
parser = argparse.ArgumentParser()
parser.add_argument('--label_directory', required=True,
help="The location of dataset.")
parser.add_argument('--image_directory', required=True,
help="The location of dataset.")
parser.add_argument('--enable_visdom', action='store_true',
help="Enable Visdom to visualize training progress")
add_common_arguments(parser)
return parser

def get_parser_for_inference():
"""Return argument parser for inference."""
parser = argparse.ArgumentParser()
@@ -68,10 +93,10 @@ def get_parser_for_inference():
help="Inference image or video.")
parser.add_argument('--video',
help="Video path if you choose to inference video.")
parser.add_argument('--inference_slot', action='store_true',
help="Perform slot inference.")
parser.add_argument('--thresh', type=float, default=0.5,
help="Detection threshold.")
parser.add_argument('--timing', action='store_true',
help="Perform timing during reference.")
parser.add_argument('--save', action='store_true',
help="Save detection result to file.")
add_common_arguments(parser)

+ 2
- 2
data/__init__.py 查看文件

@@ -1,4 +1,4 @@
"""Data related package."""
from .data_process import get_predicted_points, match_marking_points
from .data_process import get_predicted_points, pair_marking_points, filter_slots
from .dataset import ParkingSlotDataset
from .struct import MarkingPoint, Slot
from .struct import MarkingPoint, Slot, match_marking_points, match_slots

+ 63
- 25
data/data_process.py 查看文件

@@ -1,8 +1,9 @@
"""Defines data structure and related function to process these data."""
import math
import numpy as np
import torch
import config
from data.struct import MarkingPoint
from data.struct import MarkingPoint, calc_point_squre_dist, detemine_point_shape


def non_maximum_suppression(pred_points):
@@ -10,11 +11,12 @@ def non_maximum_suppression(pred_points):
suppressed = [False] * len(pred_points)
for i in range(len(pred_points) - 1):
for j in range(i + 1, len(pred_points)):
dist_square = cal_squre_dist(pred_points[i][1], pred_points[j][1])
# TODO: recalculate following parameter
# minimum distance in training set: 40.309
# (40.309 / 600)^2 = 0.004513376
if dist_square < 0.0045:
i_x = pred_points[i][1].x
i_y = pred_points[i][1].y
j_x = pred_points[j][1].x
j_y = pred_points[j][1].y
# 0.0625 = 1 / 16
if abs(j_x - i_x) < 0.0625 and abs(j_y - i_y) < 0.0625:
idx = i if pred_points[i][0] < pred_points[j][0] else j
suppressed[idx] = True
if any(suppressed):
@@ -36,6 +38,9 @@ def get_predicted_points(prediction, thresh):
if prediction[0, i, j] >= thresh:
xval = (j + prediction[2, i, j]) / prediction.shape[2]
yval = (i + prediction[3, i, j]) / prediction.shape[1]
if not (config.BOUNDARY_THRESH <= xval <= 1-config.BOUNDARY_THRESH
and config.BOUNDARY_THRESH <= yval <= 1-config.BOUNDARY_THRESH):
continue
cos_value = prediction[4, i, j]
sin_value = prediction[5, i, j]
direction = math.atan2(sin_value, cos_value)
@@ -45,24 +50,57 @@ def get_predicted_points(prediction, thresh):
return non_maximum_suppression(predicted_points)


def cal_squre_dist(point_a, point_b):
"""Calculate distance between two marking points."""
distx = point_a.x - point_b.x
disty = point_a.y - point_b.y
return distx ** 2 + disty ** 2
def pair_marking_points(point_a, point_b):
distance = calc_point_squre_dist(point_a, point_b)
if not (config.VSLOT_MIN_DISTANCE <= distance <= config.VSLOT_MAX_DISTANCE
or config.HSLOT_MIN_DISTANCE <= distance <= config.HSLOT_MAX_DISTANCE):
return 0
vector_ab = np.array([point_b.x - point_a.x, point_b.y - point_a.y])
vector_ab = vector_ab / np.linalg.norm(vector_ab)
point_shape_a = detemine_point_shape(point_a, vector_ab)
point_shape_b = detemine_point_shape(point_b, -vector_ab)
if point_shape_a.value == 0 or point_shape_b.value == 0:
return 0
if point_shape_a.value == 3 and point_shape_b.value == 3:
return 0
if point_shape_a.value > 3 and point_shape_b.value > 3:
return 0
if point_shape_a.value < 3 and point_shape_b.value < 3:
return 0
if point_shape_a.value != 3:
if point_shape_a.value > 3:
return 1
if point_shape_a.value < 3:
return -1
if point_shape_a.value == 3:
if point_shape_b.value < 3:
return 1
if point_shape_b.value > 3:
return -1


def cal_direction_angle(point_a, point_b):
"""Calculate angle between direction in rad."""
angle = abs(point_a.direction - point_b.direction)
if angle > math.pi:
angle = 2*math.pi - angle
return angle


def match_marking_points(point_a, point_b):
"""Determine whether a detected point match ground truth."""
dist_square = cal_squre_dist(point_a, point_b)
angle = cal_direction_angle(point_a, point_b)
return (dist_square < config.SQUARED_DISTANCE_THRESH
and angle < config.DIRECTION_ANGLE_THRESH)
def filter_slots(marking_points, slots):
suppressed = [False] * len(slots)
for i, slot in enumerate(slots):
x1 = marking_points[slot[0]].x
y1 = marking_points[slot[0]].y
x2 = marking_points[slot[1]].x
y2 = marking_points[slot[1]].y
for point_idx, point in enumerate(marking_points):
if point_idx == slot[0] or point_idx == slot[1]:
continue
x0 = point.x
y0 = point.y
vec1 = np.array([x0 - x1, y0 - y1])
vec2 = np.array([x2 - x0, y2 - y0])
vec1 = vec1 / np.linalg.norm(vec1)
vec2 = vec2 / np.linalg.norm(vec2)
if np.dot(vec1, vec2) > config.SLOT_SUPPRESSION_DOT_PRODUCT_THRESH:
suppressed[i] = True
if any(suppressed):
unsupres_slots = []
for i, supres in enumerate(suppressed):
if not supres:
unsupres_slots.append(slots[i])
return unsupres_slots
return slots

+ 69
- 0
data/struct.py 查看文件

@@ -1,6 +1,75 @@
"""Defines data structure."""
import math
from collections import namedtuple
from enum import Enum
import config


MarkingPoint = namedtuple('MarkingPoint', ['x', 'y', 'direction', 'shape'])
Slot = namedtuple('Slot', ['x1', 'y1', 'x2', 'y2'])


class PointShape(Enum):
"""The point shape types used to pair two marking points into slot."""
none = 0
l_down = 1
t_down = 2
t_middle = 3
t_up = 4
l_up = 5


def direction_diff(direction_a, direction_b):
diff = abs(direction_a - direction_b)
return diff if diff < math.pi else 2*math.pi - diff


def detemine_point_shape(point, vector):
vec_direct = math.atan2(vector[1], vector[0])
vec_direct_up = math.atan2(-vector[0], vector[1])
vec_direct_down = math.atan2(vector[0], -vector[1])
if point.shape < 0.5:
if direction_diff(vec_direct, point.direction) < config.BRIDGE_ANGLE_DIFF:
return PointShape.t_middle
if direction_diff(vec_direct_up, point.direction) < config.SEPARATOR_ANGLE_DIFF:
return PointShape.t_up
if direction_diff(vec_direct_down, point.direction) < config.SEPARATOR_ANGLE_DIFF:
return PointShape.t_down
else:
if direction_diff(vec_direct, point.direction) < config.BRIDGE_ANGLE_DIFF:
return PointShape.l_down
if direction_diff(vec_direct_up, point.direction) < config.SEPARATOR_ANGLE_DIFF:
return PointShape.l_up
return PointShape.none


def calc_point_squre_dist(point_a, point_b):
"""Calculate distance between two marking points."""
distx = point_a.x - point_b.x
disty = point_a.y - point_b.y
return distx ** 2 + disty ** 2


def calc_point_direction_angle(point_a, point_b):
"""Calculate angle between direction in rad."""
return direction_diff(point_a.direction, point_b.direction)


def match_marking_points(point_a, point_b):
"""Determine whether a detected point match ground truth."""
dist_square = calc_point_squre_dist(point_a, point_b)
angle = calc_point_direction_angle(point_a, point_b)
return (dist_square < config.SQUARED_DISTANCE_THRESH
and angle < config.DIRECTION_ANGLE_THRESH)


def match_slots(slot_a, slot_b):
"""Determine whether a detected slot match ground truth."""
dist_x1 = slot_b.x1 - slot_a.x1
dist_y1 = slot_b.y1 - slot_a.y1
squared_dist1 = dist_x1**2 + dist_y1**2
dist_x2 = slot_b.x2 - slot_a.x2
dist_y2 = slot_b.y2 - slot_a.y2
squared_dist2 = dist_x2 ** 2 + dist_y2 ** 2
return (squared_dist1 < config.SQUARED_DISTANCE_THRESH
and squared_dist2 < config.SQUARED_DISTANCE_THRESH)

+ 89
- 20
inference.py 查看文件

@@ -5,16 +5,18 @@ import numpy as np
import torch
from torchvision.transforms import ToTensor
import config
from data import get_predicted_points
from data import get_predicted_points, pair_marking_points, filter_slots
from model import DirectionalPointDetector
from util import Timer


def plot_marking_points(image, marking_points):
def plot_points(image, pred_points):
"""Plot marking points on the image and show."""
if not pred_points:
return
height = image.shape[0]
width = image.shape[1]
for marking_point in marking_points:
for confidence, marking_point in pred_points:
p0_x = width * marking_point.x - 0.5
p0_y = height * marking_point.y - 0.5
cos_val = math.cos(marking_point.direction)
@@ -32,6 +34,8 @@ def plot_marking_points(image, marking_points):
p2_x = int(round(p2_x))
p2_y = int(round(p2_y))
cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (0, 0, 255))
cv.putText(image, str(confidence), (p0_x, p0_y),
cv.FONT_HERSHEY_PLAIN, 1, (0, 0, 0))
if marking_point.shape > 0.5:
cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (0, 0, 255))
else:
@@ -40,6 +44,38 @@ def plot_marking_points(image, marking_points):
cv.line(image, (p2_x, p2_y), (p3_x, p3_y), (0, 0, 255))


def plot_slots(image, pred_points, slots):
if not pred_points or not slots:
return
marking_points = list(list(zip(*pred_points))[1])
height = image.shape[0]
width = image.shape[1]
for slot in slots:
point_a = marking_points[slot[0]]
point_b = marking_points[slot[1]]
p0_x = width * point_a.x - 0.5
p0_y = height * point_a.y - 0.5
p1_x = width * point_b.x - 0.5
p1_y = height * point_b.y - 0.5
vec = np.array([p1_x - p0_x, p1_y - p0_y])
vec = vec / np.linalg.norm(vec)
p2_x = p0_x + 200*vec[1]
p2_y = p0_y - 200*vec[0]
p3_x = p1_x + 200*vec[1]
p3_y = p1_y - 200*vec[0]
p0_x = int(round(p0_x))
p0_y = int(round(p0_y))
p1_x = int(round(p1_x))
p1_y = int(round(p1_y))
p2_x = int(round(p2_x))
p2_y = int(round(p2_y))
p3_x = int(round(p3_x))
p3_y = int(round(p3_y))
cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (255, 0, 0))
cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (255, 0, 0))
cv.line(image, (p1_x, p1_y), (p3_x, p3_y), (255, 0, 0))


def preprocess_image(image):
"""Preprocess numpy image to torch tensor."""
if image.shape[0] != 512 or image.shape[1] != 512:
@@ -47,6 +83,27 @@ def preprocess_image(image):
return torch.unsqueeze(ToTensor()(image), 0)


def detect_marking_points(detector, image, thresh, device):
"""Given image read from opencv, return detected marking points."""
prediction = detector(preprocess_image(image).to(device))
return get_predicted_points(prediction[0], thresh)


def inference_slots(marking_points):
"""Inference slots based on marking points."""
num_detected = len(marking_points)
slots = []
for i in range(num_detected - 1):
for j in range(i + 1, num_detected):
result = pair_marking_points(marking_points[i], marking_points[j])
if result == 1:
slots.append((i, j))
elif result == -1:
slots.append((j, i))
slots = filter_slots(marking_points, slots)
return slots


def detect_video(detector, device, args):
"""Demo for detecting video."""
timer = Timer()
@@ -55,37 +112,49 @@ def detect_video(detector, device, args):
frame_height = int(input_video.get(cv.CAP_PROP_FRAME_HEIGHT))
output_video = cv.VideoWriter()
if args.save:
output_video.open('record.avi', cv.VideoWriter_fourcc(* 'MJPG'),
output_video.open('record.avi', cv.VideoWriter_fourcc(*'MJPG'),
input_video.get(cv.CAP_PROP_FPS),
(frame_width, frame_height))
frame = np.empty([frame_height, frame_width, 3], dtype=np.uint8)
while input_video.read(frame)[0]:
if args.timing:
timer.tic()
prediction = detector(preprocess_image(frame).to(device))
if args.timing:
timer.toc()
pred_points = get_predicted_points(prediction[0], args.thresh)
if pred_points:
plot_marking_points(frame, list(list(zip(*pred_points))[1]))
cv.imshow('demo', frame)
cv.waitKey(1)
timer.tic()
pred_points = detect_marking_points(
detector, frame, args.thresh, device)
slots = None
if pred_points and args.inference_slot:
marking_points = list(list(zip(*pred_points))[1])
slots = inference_slots(marking_points)
timer.toc()
plot_points(frame, pred_points)
plot_slots(frame, pred_points, slots)
cv.imshow('demo', frame)
cv.waitKey(1)
if args.save:
output_video.write(frame)
print("Average time: ", timer.calc_average_time(), "s.")
input_video.release()
output_video.release()


def detect_image(detector, device, args):
"""Demo for detecting images."""
image_file = input('Enter image file path: ')
image = cv.imread(image_file)
prediction = detector(preprocess_image(image).to(device))
pred_points = get_predicted_points(prediction[0], args.thresh)
if pred_points:
plot_marking_points(image, list(list(zip(*pred_points))[1]))
timer = Timer()
while True:
image_file = input('Enter image file path: ')
image = cv.imread(image_file)
timer.tic()
pred_points = detect_marking_points(
detector, image, args.thresh, device)
if pred_points and args.inference_slot:
marking_points = list(list(zip(*pred_points))[1])
slots = inference_slots(marking_points)
timer.toc()
plot_points(image, pred_points)
plot_slots(image, pred_points, slots)
cv.imshow('demo', image)
cv.waitKey(1)
if args.save:
cv.imwrite('save.jpg', image, [int(cv.IMWRITE_JPEG_QUALITY), 100])


def inference_detector(args):

+ 1
- 1
prepare_dataset.py 查看文件

@@ -109,7 +109,7 @@ def generate_dataset(args):
centralied_marks = np.array(label['marks'])
if len(centralied_marks.shape) < 2:
centralied_marks = np.expand_dims(centralied_marks, axis=0)
centralied_marks[:, 0: 4] -= 300.5
centralied_marks[:, 0:4] -= 300.5
if boundary_check(centralied_marks):
output_name = os.path.join(args.output_directory, name)
write_image_and_label(output_name, image,

+ 8
- 1
util/utils.py 查看文件

@@ -1,5 +1,4 @@
"""Utility classes and functions."""
import math
import time
import cv2 as cv
import torch
@@ -13,6 +12,8 @@ class Timer(object):
def __init__(self):
self.start_ticking = False
self.start = 0.
self.count = 0
self.total_time = 0.

def tic(self):
"""Start timer."""
@@ -24,6 +25,12 @@ class Timer(object):
duration = time.time() - self.start
self.start_ticking = False
print("Time elapsed:", duration, "s.")
self.count += 1
self.total_time += duration

def calc_average_time(self):
"""Calculate average elapsed time of timer."""
return self.total_time / self.count


def tensor2array(image_tensor, imtype=np.uint8):

Loading…
取消
儲存