Refine thresholds
This commit is contained in:
parent
396495f801
commit
b79b06ee6c
|
|
@ -0,0 +1,84 @@
|
|||
"""Collect the value range of different propertity of ps dataset."""
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from data import MarkingPoint
|
||||
from data.struct import calc_point_squre_dist, direction_diff
|
||||
from prepare_dataset import generalize_marks
|
||||
|
||||
|
||||
def get_parser():
|
||||
"""Return argument parser for collecting thresholds."""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--label_directory', required=True,
|
||||
help="The location of label directory.")
|
||||
return parser
|
||||
|
||||
|
||||
def collect_thresholds(args):
|
||||
"""Collect range of value from ground truth to determine threshold."""
|
||||
distances = []
|
||||
separator_angles = []
|
||||
bridge_angles = []
|
||||
|
||||
for label_file in os.listdir(args.label_directory):
|
||||
print(label_file)
|
||||
with open(os.path.join(args.label_directory, label_file), 'r') as file:
|
||||
label = json.load(file)
|
||||
marks = np.array(label['marks'])
|
||||
slots = np.array(label['slots'])
|
||||
if slots.size == 0:
|
||||
continue
|
||||
if len(marks.shape) < 2:
|
||||
marks = np.expand_dims(marks, axis=0)
|
||||
if len(slots.shape) < 2:
|
||||
slots = np.expand_dims(slots, axis=0)
|
||||
marks[:, 0:4] -= 300.5
|
||||
marks = [MarkingPoint(*mark) for mark in generalize_marks(marks)]
|
||||
for slot in slots:
|
||||
mark_a = marks[slot[0] - 1]
|
||||
mark_b = marks[slot[1] - 1]
|
||||
distances.append(calc_point_squre_dist(mark_a, mark_b))
|
||||
|
||||
vector_ab = np.array([mark_b.x - mark_a.x, mark_b.y - mark_a.y])
|
||||
vector_ab = vector_ab / np.linalg.norm(vector_ab)
|
||||
ab_bridge_direction = math.atan2(vector_ab[1], vector_ab[0])
|
||||
ba_bridge_direction = math.atan2(-vector_ab[1], -vector_ab[0])
|
||||
separator_direction = math.atan2(-vector_ab[0], vector_ab[1])
|
||||
|
||||
sangle = direction_diff(separator_direction, mark_a.direction)
|
||||
if mark_a.shape > 0.5:
|
||||
separator_angles.append(sangle)
|
||||
else:
|
||||
bangle = direction_diff(ab_bridge_direction, mark_a.direction)
|
||||
if sangle < bangle:
|
||||
separator_angles.append(sangle)
|
||||
else:
|
||||
bridge_angles.append(bangle)
|
||||
|
||||
bangle = direction_diff(ba_bridge_direction, mark_b.direction)
|
||||
if mark_b.shape > 0.5:
|
||||
bridge_angles.append(bangle)
|
||||
else:
|
||||
sangle = direction_diff(separator_direction, mark_b.direction)
|
||||
if sangle < bangle:
|
||||
separator_angles.append(sangle)
|
||||
else:
|
||||
bridge_angles.append(bangle)
|
||||
|
||||
distances = sorted(distances)
|
||||
separator_angles = sorted(separator_angles)
|
||||
bridge_angles = sorted(bridge_angles)
|
||||
plt.figure()
|
||||
plt.hist(distances, len(distances) // 10)
|
||||
plt.figure()
|
||||
plt.hist(separator_angles, len(separator_angles) // 10)
|
||||
plt.figure()
|
||||
plt.hist(bridge_angles, len(bridge_angles) // 3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
collect_thresholds(get_parser().parse_args())
|
||||
21
config.py
21
config.py
|
|
@ -9,22 +9,25 @@ NUM_FEATURE_MAP_CHANNEL = 6
|
|||
# image_size / 2^5 = 512 / 32 = 16
|
||||
FEATURE_MAP_SIZE = 16
|
||||
# Threshold used to filter marking points too close to image boundary
|
||||
BOUNDARY_THRESH = 0.066666667
|
||||
BOUNDARY_THRESH = 0.05
|
||||
|
||||
# Thresholds to determine whether an detected point match ground truth.
|
||||
SQUARED_DISTANCE_THRESH = 0.000277778
|
||||
DIRECTION_ANGLE_THRESH = 0.5
|
||||
SQUARED_DISTANCE_THRESH = 0.000277778 # 10 pixel in 600*600 image
|
||||
DIRECTION_ANGLE_THRESH = 0.5235987755982988 # 30 degree in rad
|
||||
|
||||
VSLOT_MIN_DISTANCE = 0.044771278151623496
|
||||
VSLOT_MAX_DISTANCE = 0.1099427457599304
|
||||
HSLOT_MIN_DISTANCE = 0.15057789144568634
|
||||
HSLOT_MAX_DISTANCE = 0.44449496544202816
|
||||
VSLOT_MIN_DIST = 0.044771278151623496
|
||||
VSLOT_MAX_DIST = 0.1099427457599304
|
||||
HSLOT_MIN_DIST = 0.15057789144568634
|
||||
HSLOT_MAX_DIST = 0.44449496544202816
|
||||
|
||||
BRIDGE_ANGLE_DIFF = 0.25
|
||||
SEPARATOR_ANGLE_DIFF = 0.5
|
||||
# angle_prediction_error = 0.1384059287593468
|
||||
BRIDGE_ANGLE_DIFF = 0.09757113548987695 + 0.1384059287593468
|
||||
SEPARATOR_ANGLE_DIFF = 0.284967562063968 + 0.1384059287593468
|
||||
|
||||
SLOT_SUPPRESSION_DOT_PRODUCT_THRESH = 0.8
|
||||
|
||||
# precision = 0.995585, recall = 0.995805
|
||||
CONFID_THRESH_FOR_POINT = 0.11676871
|
||||
|
||||
def add_common_arguments(parser):
|
||||
"""Add common arguments for training and inference."""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Data related package."""
|
||||
from .data_process import get_predicted_points, pair_marking_points, filter_slots
|
||||
from .data_process import get_predicted_points, pair_marking_points, pass_through_third_point
|
||||
from .dataset import ParkingSlotDataset
|
||||
from .struct import MarkingPoint, Slot, match_marking_points, match_slots
|
||||
from .struct import MarkingPoint, Slot, match_marking_points, match_slots, calc_point_squre_dist, calc_point_direction_angle
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import math
|
|||
import numpy as np
|
||||
import torch
|
||||
import config
|
||||
from data.struct import MarkingPoint, calc_point_squre_dist, detemine_point_shape
|
||||
from data.struct import MarkingPoint, detemine_point_shape
|
||||
|
||||
|
||||
def non_maximum_suppression(pred_points):
|
||||
|
|
@ -50,11 +50,28 @@ def get_predicted_points(prediction, thresh):
|
|||
return non_maximum_suppression(predicted_points)
|
||||
|
||||
|
||||
def pass_through_third_point(marking_points, i, j):
|
||||
"""See whether the line between two points pass through a third point."""
|
||||
x_1 = marking_points[i].x
|
||||
y_1 = marking_points[i].y
|
||||
x_2 = marking_points[j].x
|
||||
y_2 = marking_points[j].y
|
||||
for point_idx, point in enumerate(marking_points):
|
||||
if point_idx == i or point_idx == j:
|
||||
continue
|
||||
x_0 = point.x
|
||||
y_0 = point.y
|
||||
vec1 = np.array([x_0 - x_1, y_0 - y_1])
|
||||
vec2 = np.array([x_2 - x_0, y_2 - y_0])
|
||||
vec1 = vec1 / np.linalg.norm(vec1)
|
||||
vec2 = vec2 / np.linalg.norm(vec2)
|
||||
if np.dot(vec1, vec2) > config.SLOT_SUPPRESSION_DOT_PRODUCT_THRESH:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def pair_marking_points(point_a, point_b):
|
||||
distance = calc_point_squre_dist(point_a, point_b)
|
||||
if not (config.VSLOT_MIN_DISTANCE <= distance <= config.VSLOT_MAX_DISTANCE
|
||||
or config.HSLOT_MIN_DISTANCE <= distance <= config.HSLOT_MAX_DISTANCE):
|
||||
return 0
|
||||
"""See whether two marking points form a slot."""
|
||||
vector_ab = np.array([point_b.x - point_a.x, point_b.y - point_a.y])
|
||||
vector_ab = vector_ab / np.linalg.norm(vector_ab)
|
||||
point_shape_a = detemine_point_shape(point_a, vector_ab)
|
||||
|
|
@ -77,30 +94,3 @@ def pair_marking_points(point_a, point_b):
|
|||
return 1
|
||||
if point_shape_b.value > 3:
|
||||
return -1
|
||||
|
||||
|
||||
def filter_slots(marking_points, slots):
|
||||
suppressed = [False] * len(slots)
|
||||
for i, slot in enumerate(slots):
|
||||
x1 = marking_points[slot[0]].x
|
||||
y1 = marking_points[slot[0]].y
|
||||
x2 = marking_points[slot[1]].x
|
||||
y2 = marking_points[slot[1]].y
|
||||
for point_idx, point in enumerate(marking_points):
|
||||
if point_idx == slot[0] or point_idx == slot[1]:
|
||||
continue
|
||||
x0 = point.x
|
||||
y0 = point.y
|
||||
vec1 = np.array([x0 - x1, y0 - y1])
|
||||
vec2 = np.array([x2 - x0, y2 - y0])
|
||||
vec1 = vec1 / np.linalg.norm(vec1)
|
||||
vec2 = vec2 / np.linalg.norm(vec2)
|
||||
if np.dot(vec1, vec2) > config.SLOT_SUPPRESSION_DOT_PRODUCT_THRESH:
|
||||
suppressed[i] = True
|
||||
if any(suppressed):
|
||||
unsupres_slots = []
|
||||
for i, supres in enumerate(suppressed):
|
||||
if not supres:
|
||||
unsupres_slots.append(slots[i])
|
||||
return unsupres_slots
|
||||
return slots
|
||||
|
|
|
|||
|
|
@ -20,11 +20,13 @@ class PointShape(Enum):
|
|||
|
||||
|
||||
def direction_diff(direction_a, direction_b):
|
||||
"""Calculate the angle between two direction."""
|
||||
diff = abs(direction_a - direction_b)
|
||||
return diff if diff < math.pi else 2*math.pi - diff
|
||||
|
||||
|
||||
def detemine_point_shape(point, vector):
|
||||
"""Determine which category the point is in."""
|
||||
vec_direct = math.atan2(vector[1], vector[0])
|
||||
vec_direct_up = math.atan2(-vector[0], vector[1])
|
||||
vec_direct_down = math.atan2(vector[0], -vector[1])
|
||||
|
|
@ -59,6 +61,10 @@ def match_marking_points(point_a, point_b):
|
|||
"""Determine whether a detected point match ground truth."""
|
||||
dist_square = calc_point_squre_dist(point_a, point_b)
|
||||
angle = calc_point_direction_angle(point_a, point_b)
|
||||
if point_a.shape > 0.5 and point_b.shape < 0.5:
|
||||
return False
|
||||
if point_a.shape < 0.5 and point_b.shape > 0.5:
|
||||
return False
|
||||
return (dist_square < config.SQUARED_DISTANCE_THRESH
|
||||
and angle < config.DIRECTION_ANGLE_THRESH)
|
||||
|
||||
|
|
|
|||
74
evaluate.py
74
evaluate.py
|
|
@ -1,14 +1,47 @@
|
|||
"""Evaluate directional marking point detector."""
|
||||
import cv2 as cv
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import config
|
||||
import util
|
||||
from data import get_predicted_points, match_marking_points
|
||||
from data import get_predicted_points, match_marking_points, calc_point_squre_dist, calc_point_direction_angle
|
||||
from data import ParkingSlotDataset
|
||||
from inference import plot_points
|
||||
from model import DirectionalPointDetector
|
||||
from train import generate_objective
|
||||
|
||||
|
||||
def is_gt_and_pred_matched(ground_truths, predictions, thresh):
|
||||
"""Check if there is any false positive or false negative."""
|
||||
predictions = [pred for pred in predictions if pred[0] >= thresh]
|
||||
prediction_matched = [False] * len(predictions)
|
||||
for ground_truth in ground_truths:
|
||||
idx = util.match_gt_with_preds(ground_truth, predictions,
|
||||
match_marking_points)
|
||||
if idx < 0:
|
||||
return False
|
||||
prediction_matched[idx] = True
|
||||
if not all(prediction_matched):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def collect_error(ground_truths, predictions, thresh):
|
||||
"""Collect errors for those correctly detected points."""
|
||||
dists = []
|
||||
angles = []
|
||||
predictions = [pred for pred in predictions if pred[0] >= thresh]
|
||||
for ground_truth in ground_truths:
|
||||
idx = util.match_gt_with_preds(ground_truth, predictions,
|
||||
match_marking_points)
|
||||
if idx >= 0:
|
||||
detected_point = predictions[idx][1]
|
||||
dists.append(calc_point_squre_dist(detected_point, ground_truth))
|
||||
angles.append(calc_point_direction_angle(detected_point, ground_truth))
|
||||
else:
|
||||
continue
|
||||
return dists, angles
|
||||
|
||||
|
||||
def evaluate_detector(args):
|
||||
"""Evaluate directional point detector."""
|
||||
args.cuda = not args.disable_cuda and torch.cuda.is_available()
|
||||
|
|
@ -21,30 +54,37 @@ def evaluate_detector(args):
|
|||
dp_detector.load_state_dict(torch.load(args.detector_weights))
|
||||
dp_detector.eval()
|
||||
|
||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||
data_loader = DataLoader(ParkingSlotDataset(args.dataset_directory),
|
||||
batch_size=args.batch_size, shuffle=True,
|
||||
num_workers=args.data_loading_workers,
|
||||
collate_fn=lambda x: list(zip(*x)))
|
||||
psdataset = ParkingSlotDataset(args.dataset_directory)
|
||||
logger = util.Logger(enable_visdom=args.enable_visdom)
|
||||
|
||||
total_loss = 0
|
||||
num_evaluation = 0
|
||||
position_errors = []
|
||||
direction_errors = []
|
||||
ground_truths_list = []
|
||||
predictions_list = []
|
||||
for iter_idx, (image, marking_points) in enumerate(data_loader):
|
||||
image = torch.stack(image)
|
||||
image = image.to(device)
|
||||
ground_truths_list += list(marking_points)
|
||||
for iter_idx, (image, marking_points) in enumerate(psdataset):
|
||||
ground_truths_list.append(marking_points)
|
||||
|
||||
image = torch.unsqueeze(image, 0).to(device)
|
||||
prediction = dp_detector(image)
|
||||
objective, gradient = generate_objective(marking_points, device)
|
||||
objective, gradient = generate_objective([marking_points], device)
|
||||
loss = (prediction - objective) ** 2
|
||||
total_loss += torch.sum(loss*gradient).item()
|
||||
num_evaluation += loss.size(0)
|
||||
|
||||
pred_points = [get_predicted_points(pred, 0.01) for pred in prediction]
|
||||
predictions_list += pred_points
|
||||
pred_points = get_predicted_points(prediction[0], 0.01)
|
||||
predictions_list.append(pred_points)
|
||||
|
||||
# if not is_gt_and_pred_matched(marking_points, pred_points,
|
||||
# config.CONFID_THRESH_FOR_POINT):
|
||||
# cvimage = util.tensor2array(image[0]).copy()
|
||||
# plot_points(cvimage, pred_points)
|
||||
# cv.imwrite('flaw/%d.jpg' % iter_idx, cvimage,
|
||||
# [int(cv.IMWRITE_JPEG_QUALITY), 100])
|
||||
dists, angles = collect_error(marking_points, pred_points,
|
||||
config.CONFID_THRESH_FOR_POINT)
|
||||
position_errors += dists
|
||||
direction_errors += angles
|
||||
|
||||
logger.log(iter=iter_idx, total_loss=total_loss)
|
||||
|
||||
precisions, recalls = util.calc_precision_recall(
|
||||
|
|
@ -52,7 +92,7 @@ def evaluate_detector(args):
|
|||
average_precision = util.calc_average_precision(precisions, recalls)
|
||||
if args.enable_visdom:
|
||||
logger.plot_curve(precisions, recalls)
|
||||
logger.log(average_loss=total_loss / num_evaluation,
|
||||
logger.log(average_loss=total_loss / len(psdataset),
|
||||
average_precision=average_precision)
|
||||
|
||||
|
||||
|
|
|
|||
33
inference.py
33
inference.py
|
|
@ -5,13 +5,13 @@ import numpy as np
|
|||
import torch
|
||||
from torchvision.transforms import ToTensor
|
||||
import config
|
||||
from data import get_predicted_points, pair_marking_points, filter_slots
|
||||
from data import get_predicted_points, pair_marking_points, calc_point_squre_dist, pass_through_third_point
|
||||
from model import DirectionalPointDetector
|
||||
from util import Timer
|
||||
|
||||
|
||||
def plot_points(image, pred_points):
|
||||
"""Plot marking points on the image and show."""
|
||||
"""Plot marking points on the image."""
|
||||
if not pred_points:
|
||||
return
|
||||
height = image.shape[0]
|
||||
|
|
@ -33,18 +33,19 @@ def plot_points(image, pred_points):
|
|||
p1_y = int(round(p1_y))
|
||||
p2_x = int(round(p2_x))
|
||||
p2_y = int(round(p2_y))
|
||||
cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (0, 0, 255))
|
||||
cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (0, 0, 255), 2)
|
||||
cv.putText(image, str(confidence), (p0_x, p0_y),
|
||||
cv.FONT_HERSHEY_PLAIN, 1, (0, 0, 0))
|
||||
if marking_point.shape > 0.5:
|
||||
cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (0, 0, 255))
|
||||
cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (0, 0, 255), 2)
|
||||
else:
|
||||
p3_x = int(round(p3_x))
|
||||
p3_y = int(round(p3_y))
|
||||
cv.line(image, (p2_x, p2_y), (p3_x, p3_y), (0, 0, 255))
|
||||
cv.line(image, (p2_x, p2_y), (p3_x, p3_y), (0, 0, 255), 2)
|
||||
|
||||
|
||||
def plot_slots(image, pred_points, slots):
|
||||
"""Plot parking slots on the image."""
|
||||
if not pred_points or not slots:
|
||||
return
|
||||
marking_points = list(list(zip(*pred_points))[1])
|
||||
|
|
@ -71,9 +72,9 @@ def plot_slots(image, pred_points, slots):
|
|||
p2_y = int(round(p2_y))
|
||||
p3_x = int(round(p3_x))
|
||||
p3_y = int(round(p3_y))
|
||||
cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (255, 0, 0))
|
||||
cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (255, 0, 0))
|
||||
cv.line(image, (p1_x, p1_y), (p3_x, p3_y), (255, 0, 0))
|
||||
cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (255, 0, 0), 2)
|
||||
cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (255, 0, 0), 2)
|
||||
cv.line(image, (p1_x, p1_y), (p3_x, p3_y), (255, 0, 0), 2)
|
||||
|
||||
|
||||
def preprocess_image(image):
|
||||
|
|
@ -95,12 +96,21 @@ def inference_slots(marking_points):
|
|||
slots = []
|
||||
for i in range(num_detected - 1):
|
||||
for j in range(i + 1, num_detected):
|
||||
result = pair_marking_points(marking_points[i], marking_points[j])
|
||||
point_i = marking_points[i]
|
||||
point_j = marking_points[j]
|
||||
# Step 1: length filtration.
|
||||
distance = calc_point_squre_dist(point_i, point_j)
|
||||
if not (config.VSLOT_MIN_DIST <= distance <= config.VSLOT_MAX_DIST
|
||||
or config.HSLOT_MIN_DIST <= distance <= config.HSLOT_MAX_DIST):
|
||||
continue
|
||||
# Step 2: pass through filtration.
|
||||
if pass_through_third_point(marking_points, i, j):
|
||||
continue
|
||||
result = pair_marking_points(point_i, point_j)
|
||||
if result == 1:
|
||||
slots.append((i, j))
|
||||
elif result == -1:
|
||||
slots.append((j, i))
|
||||
slots = filter_slots(marking_points, slots)
|
||||
return slots
|
||||
|
||||
|
||||
|
|
@ -114,7 +124,7 @@ def detect_video(detector, device, args):
|
|||
if args.save:
|
||||
output_video.open('record.avi', cv.VideoWriter_fourcc(*'MJPG'),
|
||||
input_video.get(cv.CAP_PROP_FPS),
|
||||
(frame_width, frame_height))
|
||||
(frame_width, frame_height), True)
|
||||
frame = np.empty([frame_height, frame_width, 3], dtype=np.uint8)
|
||||
while input_video.read(frame)[0]:
|
||||
timer.tic()
|
||||
|
|
@ -145,6 +155,7 @@ def detect_image(detector, device, args):
|
|||
timer.tic()
|
||||
pred_points = detect_marking_points(
|
||||
detector, image, args.thresh, device)
|
||||
slots = None
|
||||
if pred_points and args.inference_slot:
|
||||
marking_points = list(list(zip(*pred_points))[1])
|
||||
slots = inference_slots(marking_points)
|
||||
|
|
|
|||
|
|
@ -32,6 +32,20 @@ def define_halve_unit(basic_channel_size):
|
|||
return layers
|
||||
|
||||
|
||||
def define_depthwise_expand_unit(basic_channel_size):
|
||||
"""Define a 3x3 expand convolution with norm and activation."""
|
||||
conv1 = nn.Conv2d(basic_channel_size, 2 * basic_channel_size,
|
||||
kernel_size=1, stride=1, padding=0, bias=False)
|
||||
norm1 = nn.BatchNorm2d(2 * basic_channel_size)
|
||||
relu1 = nn.LeakyReLU(0.1)
|
||||
conv2 = nn.Conv2d(2 * basic_channel_size, 2 * basic_channel_size, kernel_size=3,
|
||||
stride=1, padding=1, bias=False, groups=2 * basic_channel_size)
|
||||
norm2 = nn.BatchNorm2d(2 * basic_channel_size)
|
||||
relu2 = nn.LeakyReLU(0.1)
|
||||
layers = [conv1, norm1, relu1, conv2, norm2, relu2]
|
||||
return layers
|
||||
|
||||
|
||||
def define_detector_block(basic_channel_size):
|
||||
"""Define a unit composite of a squeeze and expand unit."""
|
||||
layers = []
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ def generate_dataset(args):
|
|||
if len(centralied_marks.shape) < 2:
|
||||
centralied_marks = np.expand_dims(centralied_marks, axis=0)
|
||||
centralied_marks[:, 0:4] -= 300.5
|
||||
if boundary_check(centralied_marks):
|
||||
if boundary_check(centralied_marks) or args.dataset == 'test':
|
||||
output_name = os.path.join(args.output_directory, name)
|
||||
write_image_and_label(output_name, image,
|
||||
centralied_marks, name_list)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,79 @@
|
|||
"""Evaluate directional marking point detector."""
|
||||
import json
|
||||
import os
|
||||
import cv2 as cv
|
||||
import numpy as np
|
||||
import torch
|
||||
import config
|
||||
import util
|
||||
from data import match_slots, Slot
|
||||
from model import DirectionalPointDetector
|
||||
from inference import detect_marking_points, inference_slots
|
||||
|
||||
|
||||
def get_ground_truths(label):
|
||||
slots = np.array(label['slots'])
|
||||
if slots.size == 0:
|
||||
return []
|
||||
if len(slots.shape) < 2:
|
||||
slots = np.expand_dims(slots, axis=0)
|
||||
marks = np.array(label['marks'])
|
||||
if len(marks.shape) < 2:
|
||||
marks = np.expand_dims(marks, axis=0)
|
||||
ground_truths = []
|
||||
for slot in slots:
|
||||
mark_a = marks[slot[0] - 1]
|
||||
mark_b = marks[slot[1] - 1]
|
||||
coords = np.array([mark_a[0], mark_a[1], mark_b[0], mark_b[1]])
|
||||
coords = (coords - 0.5) / 600
|
||||
ground_truths.append(Slot(*coords))
|
||||
return ground_truths
|
||||
|
||||
|
||||
def psevaluate_detector(args):
|
||||
"""Evaluate directional point detector."""
|
||||
args.cuda = not args.disable_cuda and torch.cuda.is_available()
|
||||
device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
dp_detector = DirectionalPointDetector(
|
||||
3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
|
||||
if args.detector_weights:
|
||||
dp_detector.load_state_dict(torch.load(args.detector_weights))
|
||||
dp_detector.eval()
|
||||
|
||||
logger = util.Logger(enable_visdom=args.enable_visdom)
|
||||
|
||||
ground_truths_list = []
|
||||
predictions_list = []
|
||||
for idx, label_file in enumerate(os.listdir(args.label_directory)):
|
||||
name = os.path.splitext(label_file)[0]
|
||||
print(idx, name)
|
||||
image = cv.imread(os.path.join(args.image_directory, name + '.jpg'))
|
||||
pred_points = detect_marking_points(
|
||||
dp_detector, image, config.CONFID_THRESH_FOR_POINT, device)
|
||||
if pred_points:
|
||||
marking_points = list(list(zip(*pred_points))[1])
|
||||
slots = inference_slots(marking_points)
|
||||
pred_slots = []
|
||||
for slot in slots:
|
||||
point_a = marking_points[slot[0]]
|
||||
point_b = marking_points[slot[1]]
|
||||
prob = min((pred_points[slot[0]][0], pred_points[slot[1]][0]))
|
||||
pred_slots.append(
|
||||
(prob, Slot(point_a.x, point_a.y, point_b.x, point_b.y)))
|
||||
predictions_list.append(pred_slots)
|
||||
|
||||
with open(os.path.join(args.label_directory, label_file), 'r') as file:
|
||||
ground_truths_list.append(get_ground_truths(json.load(file)))
|
||||
|
||||
precisions, recalls = util.calc_precision_recall(
|
||||
ground_truths_list, predictions_list, match_slots)
|
||||
average_precision = util.calc_average_precision(precisions, recalls)
|
||||
if args.enable_visdom:
|
||||
logger.plot_curve(precisions, recalls)
|
||||
logger.log(average_precision=average_precision)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
psevaluate_detector(config.get_parser_for_ps_evaluation().parse_args())
|
||||
6
train.py
6
train.py
|
|
@ -69,8 +69,7 @@ def train_detector(args):
|
|||
print("Loading weights: %s" % args.optimizer_weights)
|
||||
optimizer.load_state_dict(torch.load(args.optimizer_weights))
|
||||
|
||||
logger = util.Logger(args.enable_visdom,
|
||||
['train_loss'] if args.enable_visdom else None)
|
||||
logger = util.Logger(args.enable_visdom, ['train_loss'])
|
||||
data_loader = DataLoader(data.ParkingSlotDataset(args.dataset_directory),
|
||||
batch_size=args.batch_size, shuffle=True,
|
||||
num_workers=args.data_loading_workers,
|
||||
|
|
@ -78,8 +77,7 @@ def train_detector(args):
|
|||
|
||||
for epoch_idx in range(args.num_epochs):
|
||||
for iter_idx, (image, marking_points) in enumerate(data_loader):
|
||||
image = torch.stack(image)
|
||||
image = image.to(device)
|
||||
image = torch.stack(image).to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
prediction = dp_detector(image)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Utility related package."""
|
||||
from .log import Logger
|
||||
from .precision_recall import calc_precision_recall, calc_average_precision
|
||||
from .precision_recall import calc_precision_recall, calc_average_precision, match_gt_with_preds
|
||||
from .utils import Timer, tensor2array, tensor2im
|
||||
|
|
|
|||
Loading…
Reference in New Issue