Refine thresholds

This commit is contained in:
Teoge 2018-11-21 15:45:22 +08:00
parent 396495f801
commit b79b06ee6c
12 changed files with 302 additions and 77 deletions

84
collect_thresholds.py Normal file
View File

@ -0,0 +1,84 @@
"""Collect the value range of different propertity of ps dataset."""
import argparse
import json
import math
import os
import matplotlib.pyplot as plt
import numpy as np
from data import MarkingPoint
from data.struct import calc_point_squre_dist, direction_diff
from prepare_dataset import generalize_marks
def get_parser():
"""Return argument parser for collecting thresholds."""
parser = argparse.ArgumentParser()
parser.add_argument('--label_directory', required=True,
help="The location of label directory.")
return parser
def collect_thresholds(args):
"""Collect range of value from ground truth to determine threshold."""
distances = []
separator_angles = []
bridge_angles = []
for label_file in os.listdir(args.label_directory):
print(label_file)
with open(os.path.join(args.label_directory, label_file), 'r') as file:
label = json.load(file)
marks = np.array(label['marks'])
slots = np.array(label['slots'])
if slots.size == 0:
continue
if len(marks.shape) < 2:
marks = np.expand_dims(marks, axis=0)
if len(slots.shape) < 2:
slots = np.expand_dims(slots, axis=0)
marks[:, 0:4] -= 300.5
marks = [MarkingPoint(*mark) for mark in generalize_marks(marks)]
for slot in slots:
mark_a = marks[slot[0] - 1]
mark_b = marks[slot[1] - 1]
distances.append(calc_point_squre_dist(mark_a, mark_b))
vector_ab = np.array([mark_b.x - mark_a.x, mark_b.y - mark_a.y])
vector_ab = vector_ab / np.linalg.norm(vector_ab)
ab_bridge_direction = math.atan2(vector_ab[1], vector_ab[0])
ba_bridge_direction = math.atan2(-vector_ab[1], -vector_ab[0])
separator_direction = math.atan2(-vector_ab[0], vector_ab[1])
sangle = direction_diff(separator_direction, mark_a.direction)
if mark_a.shape > 0.5:
separator_angles.append(sangle)
else:
bangle = direction_diff(ab_bridge_direction, mark_a.direction)
if sangle < bangle:
separator_angles.append(sangle)
else:
bridge_angles.append(bangle)
bangle = direction_diff(ba_bridge_direction, mark_b.direction)
if mark_b.shape > 0.5:
bridge_angles.append(bangle)
else:
sangle = direction_diff(separator_direction, mark_b.direction)
if sangle < bangle:
separator_angles.append(sangle)
else:
bridge_angles.append(bangle)
distances = sorted(distances)
separator_angles = sorted(separator_angles)
bridge_angles = sorted(bridge_angles)
plt.figure()
plt.hist(distances, len(distances) // 10)
plt.figure()
plt.hist(separator_angles, len(separator_angles) // 10)
plt.figure()
plt.hist(bridge_angles, len(bridge_angles) // 3)
if __name__ == '__main__':
collect_thresholds(get_parser().parse_args())

View File

@ -9,22 +9,25 @@ NUM_FEATURE_MAP_CHANNEL = 6
# image_size / 2^5 = 512 / 32 = 16 # image_size / 2^5 = 512 / 32 = 16
FEATURE_MAP_SIZE = 16 FEATURE_MAP_SIZE = 16
# Threshold used to filter marking points too close to image boundary # Threshold used to filter marking points too close to image boundary
BOUNDARY_THRESH = 0.066666667 BOUNDARY_THRESH = 0.05
# Thresholds to determine whether an detected point match ground truth. # Thresholds to determine whether an detected point match ground truth.
SQUARED_DISTANCE_THRESH = 0.000277778 SQUARED_DISTANCE_THRESH = 0.000277778 # 10 pixel in 600*600 image
DIRECTION_ANGLE_THRESH = 0.5 DIRECTION_ANGLE_THRESH = 0.5235987755982988 # 30 degree in rad
VSLOT_MIN_DISTANCE = 0.044771278151623496 VSLOT_MIN_DIST = 0.044771278151623496
VSLOT_MAX_DISTANCE = 0.1099427457599304 VSLOT_MAX_DIST = 0.1099427457599304
HSLOT_MIN_DISTANCE = 0.15057789144568634 HSLOT_MIN_DIST = 0.15057789144568634
HSLOT_MAX_DISTANCE = 0.44449496544202816 HSLOT_MAX_DIST = 0.44449496544202816
BRIDGE_ANGLE_DIFF = 0.25 # angle_prediction_error = 0.1384059287593468
SEPARATOR_ANGLE_DIFF = 0.5 BRIDGE_ANGLE_DIFF = 0.09757113548987695 + 0.1384059287593468
SEPARATOR_ANGLE_DIFF = 0.284967562063968 + 0.1384059287593468
SLOT_SUPPRESSION_DOT_PRODUCT_THRESH = 0.8 SLOT_SUPPRESSION_DOT_PRODUCT_THRESH = 0.8
# precision = 0.995585, recall = 0.995805
CONFID_THRESH_FOR_POINT = 0.11676871
def add_common_arguments(parser): def add_common_arguments(parser):
"""Add common arguments for training and inference.""" """Add common arguments for training and inference."""

View File

@ -1,4 +1,4 @@
"""Data related package.""" """Data related package."""
from .data_process import get_predicted_points, pair_marking_points, filter_slots from .data_process import get_predicted_points, pair_marking_points, pass_through_third_point
from .dataset import ParkingSlotDataset from .dataset import ParkingSlotDataset
from .struct import MarkingPoint, Slot, match_marking_points, match_slots from .struct import MarkingPoint, Slot, match_marking_points, match_slots, calc_point_squre_dist, calc_point_direction_angle

View File

@ -3,7 +3,7 @@ import math
import numpy as np import numpy as np
import torch import torch
import config import config
from data.struct import MarkingPoint, calc_point_squre_dist, detemine_point_shape from data.struct import MarkingPoint, detemine_point_shape
def non_maximum_suppression(pred_points): def non_maximum_suppression(pred_points):
@ -50,11 +50,28 @@ def get_predicted_points(prediction, thresh):
return non_maximum_suppression(predicted_points) return non_maximum_suppression(predicted_points)
def pass_through_third_point(marking_points, i, j):
"""See whether the line between two points pass through a third point."""
x_1 = marking_points[i].x
y_1 = marking_points[i].y
x_2 = marking_points[j].x
y_2 = marking_points[j].y
for point_idx, point in enumerate(marking_points):
if point_idx == i or point_idx == j:
continue
x_0 = point.x
y_0 = point.y
vec1 = np.array([x_0 - x_1, y_0 - y_1])
vec2 = np.array([x_2 - x_0, y_2 - y_0])
vec1 = vec1 / np.linalg.norm(vec1)
vec2 = vec2 / np.linalg.norm(vec2)
if np.dot(vec1, vec2) > config.SLOT_SUPPRESSION_DOT_PRODUCT_THRESH:
return True
return False
def pair_marking_points(point_a, point_b): def pair_marking_points(point_a, point_b):
distance = calc_point_squre_dist(point_a, point_b) """See whether two marking points form a slot."""
if not (config.VSLOT_MIN_DISTANCE <= distance <= config.VSLOT_MAX_DISTANCE
or config.HSLOT_MIN_DISTANCE <= distance <= config.HSLOT_MAX_DISTANCE):
return 0
vector_ab = np.array([point_b.x - point_a.x, point_b.y - point_a.y]) vector_ab = np.array([point_b.x - point_a.x, point_b.y - point_a.y])
vector_ab = vector_ab / np.linalg.norm(vector_ab) vector_ab = vector_ab / np.linalg.norm(vector_ab)
point_shape_a = detemine_point_shape(point_a, vector_ab) point_shape_a = detemine_point_shape(point_a, vector_ab)
@ -77,30 +94,3 @@ def pair_marking_points(point_a, point_b):
return 1 return 1
if point_shape_b.value > 3: if point_shape_b.value > 3:
return -1 return -1
def filter_slots(marking_points, slots):
suppressed = [False] * len(slots)
for i, slot in enumerate(slots):
x1 = marking_points[slot[0]].x
y1 = marking_points[slot[0]].y
x2 = marking_points[slot[1]].x
y2 = marking_points[slot[1]].y
for point_idx, point in enumerate(marking_points):
if point_idx == slot[0] or point_idx == slot[1]:
continue
x0 = point.x
y0 = point.y
vec1 = np.array([x0 - x1, y0 - y1])
vec2 = np.array([x2 - x0, y2 - y0])
vec1 = vec1 / np.linalg.norm(vec1)
vec2 = vec2 / np.linalg.norm(vec2)
if np.dot(vec1, vec2) > config.SLOT_SUPPRESSION_DOT_PRODUCT_THRESH:
suppressed[i] = True
if any(suppressed):
unsupres_slots = []
for i, supres in enumerate(suppressed):
if not supres:
unsupres_slots.append(slots[i])
return unsupres_slots
return slots

View File

@ -20,11 +20,13 @@ class PointShape(Enum):
def direction_diff(direction_a, direction_b): def direction_diff(direction_a, direction_b):
"""Calculate the angle between two direction."""
diff = abs(direction_a - direction_b) diff = abs(direction_a - direction_b)
return diff if diff < math.pi else 2*math.pi - diff return diff if diff < math.pi else 2*math.pi - diff
def detemine_point_shape(point, vector): def detemine_point_shape(point, vector):
"""Determine which category the point is in."""
vec_direct = math.atan2(vector[1], vector[0]) vec_direct = math.atan2(vector[1], vector[0])
vec_direct_up = math.atan2(-vector[0], vector[1]) vec_direct_up = math.atan2(-vector[0], vector[1])
vec_direct_down = math.atan2(vector[0], -vector[1]) vec_direct_down = math.atan2(vector[0], -vector[1])
@ -59,6 +61,10 @@ def match_marking_points(point_a, point_b):
"""Determine whether a detected point match ground truth.""" """Determine whether a detected point match ground truth."""
dist_square = calc_point_squre_dist(point_a, point_b) dist_square = calc_point_squre_dist(point_a, point_b)
angle = calc_point_direction_angle(point_a, point_b) angle = calc_point_direction_angle(point_a, point_b)
if point_a.shape > 0.5 and point_b.shape < 0.5:
return False
if point_a.shape < 0.5 and point_b.shape > 0.5:
return False
return (dist_square < config.SQUARED_DISTANCE_THRESH return (dist_square < config.SQUARED_DISTANCE_THRESH
and angle < config.DIRECTION_ANGLE_THRESH) and angle < config.DIRECTION_ANGLE_THRESH)

View File

@ -1,14 +1,47 @@
"""Evaluate directional marking point detector.""" """Evaluate directional marking point detector."""
import cv2 as cv
import torch import torch
from torch.utils.data import DataLoader
import config import config
import util import util
from data import get_predicted_points, match_marking_points from data import get_predicted_points, match_marking_points, calc_point_squre_dist, calc_point_direction_angle
from data import ParkingSlotDataset from data import ParkingSlotDataset
from inference import plot_points
from model import DirectionalPointDetector from model import DirectionalPointDetector
from train import generate_objective from train import generate_objective
def is_gt_and_pred_matched(ground_truths, predictions, thresh):
"""Check if there is any false positive or false negative."""
predictions = [pred for pred in predictions if pred[0] >= thresh]
prediction_matched = [False] * len(predictions)
for ground_truth in ground_truths:
idx = util.match_gt_with_preds(ground_truth, predictions,
match_marking_points)
if idx < 0:
return False
prediction_matched[idx] = True
if not all(prediction_matched):
return False
return True
def collect_error(ground_truths, predictions, thresh):
"""Collect errors for those correctly detected points."""
dists = []
angles = []
predictions = [pred for pred in predictions if pred[0] >= thresh]
for ground_truth in ground_truths:
idx = util.match_gt_with_preds(ground_truth, predictions,
match_marking_points)
if idx >= 0:
detected_point = predictions[idx][1]
dists.append(calc_point_squre_dist(detected_point, ground_truth))
angles.append(calc_point_direction_angle(detected_point, ground_truth))
else:
continue
return dists, angles
def evaluate_detector(args): def evaluate_detector(args):
"""Evaluate directional point detector.""" """Evaluate directional point detector."""
args.cuda = not args.disable_cuda and torch.cuda.is_available() args.cuda = not args.disable_cuda and torch.cuda.is_available()
@ -21,30 +54,37 @@ def evaluate_detector(args):
dp_detector.load_state_dict(torch.load(args.detector_weights)) dp_detector.load_state_dict(torch.load(args.detector_weights))
dp_detector.eval() dp_detector.eval()
torch.multiprocessing.set_sharing_strategy('file_system') psdataset = ParkingSlotDataset(args.dataset_directory)
data_loader = DataLoader(ParkingSlotDataset(args.dataset_directory),
batch_size=args.batch_size, shuffle=True,
num_workers=args.data_loading_workers,
collate_fn=lambda x: list(zip(*x)))
logger = util.Logger(enable_visdom=args.enable_visdom) logger = util.Logger(enable_visdom=args.enable_visdom)
total_loss = 0 total_loss = 0
num_evaluation = 0 position_errors = []
direction_errors = []
ground_truths_list = [] ground_truths_list = []
predictions_list = [] predictions_list = []
for iter_idx, (image, marking_points) in enumerate(data_loader): for iter_idx, (image, marking_points) in enumerate(psdataset):
image = torch.stack(image) ground_truths_list.append(marking_points)
image = image.to(device)
ground_truths_list += list(marking_points)
image = torch.unsqueeze(image, 0).to(device)
prediction = dp_detector(image) prediction = dp_detector(image)
objective, gradient = generate_objective(marking_points, device) objective, gradient = generate_objective([marking_points], device)
loss = (prediction - objective) ** 2 loss = (prediction - objective) ** 2
total_loss += torch.sum(loss*gradient).item() total_loss += torch.sum(loss*gradient).item()
num_evaluation += loss.size(0)
pred_points = [get_predicted_points(pred, 0.01) for pred in prediction] pred_points = get_predicted_points(prediction[0], 0.01)
predictions_list += pred_points predictions_list.append(pred_points)
# if not is_gt_and_pred_matched(marking_points, pred_points,
# config.CONFID_THRESH_FOR_POINT):
# cvimage = util.tensor2array(image[0]).copy()
# plot_points(cvimage, pred_points)
# cv.imwrite('flaw/%d.jpg' % iter_idx, cvimage,
# [int(cv.IMWRITE_JPEG_QUALITY), 100])
dists, angles = collect_error(marking_points, pred_points,
config.CONFID_THRESH_FOR_POINT)
position_errors += dists
direction_errors += angles
logger.log(iter=iter_idx, total_loss=total_loss) logger.log(iter=iter_idx, total_loss=total_loss)
precisions, recalls = util.calc_precision_recall( precisions, recalls = util.calc_precision_recall(
@ -52,7 +92,7 @@ def evaluate_detector(args):
average_precision = util.calc_average_precision(precisions, recalls) average_precision = util.calc_average_precision(precisions, recalls)
if args.enable_visdom: if args.enable_visdom:
logger.plot_curve(precisions, recalls) logger.plot_curve(precisions, recalls)
logger.log(average_loss=total_loss / num_evaluation, logger.log(average_loss=total_loss / len(psdataset),
average_precision=average_precision) average_precision=average_precision)

View File

@ -5,13 +5,13 @@ import numpy as np
import torch import torch
from torchvision.transforms import ToTensor from torchvision.transforms import ToTensor
import config import config
from data import get_predicted_points, pair_marking_points, filter_slots from data import get_predicted_points, pair_marking_points, calc_point_squre_dist, pass_through_third_point
from model import DirectionalPointDetector from model import DirectionalPointDetector
from util import Timer from util import Timer
def plot_points(image, pred_points): def plot_points(image, pred_points):
"""Plot marking points on the image and show.""" """Plot marking points on the image."""
if not pred_points: if not pred_points:
return return
height = image.shape[0] height = image.shape[0]
@ -33,18 +33,19 @@ def plot_points(image, pred_points):
p1_y = int(round(p1_y)) p1_y = int(round(p1_y))
p2_x = int(round(p2_x)) p2_x = int(round(p2_x))
p2_y = int(round(p2_y)) p2_y = int(round(p2_y))
cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (0, 0, 255)) cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (0, 0, 255), 2)
cv.putText(image, str(confidence), (p0_x, p0_y), cv.putText(image, str(confidence), (p0_x, p0_y),
cv.FONT_HERSHEY_PLAIN, 1, (0, 0, 0)) cv.FONT_HERSHEY_PLAIN, 1, (0, 0, 0))
if marking_point.shape > 0.5: if marking_point.shape > 0.5:
cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (0, 0, 255)) cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (0, 0, 255), 2)
else: else:
p3_x = int(round(p3_x)) p3_x = int(round(p3_x))
p3_y = int(round(p3_y)) p3_y = int(round(p3_y))
cv.line(image, (p2_x, p2_y), (p3_x, p3_y), (0, 0, 255)) cv.line(image, (p2_x, p2_y), (p3_x, p3_y), (0, 0, 255), 2)
def plot_slots(image, pred_points, slots): def plot_slots(image, pred_points, slots):
"""Plot parking slots on the image."""
if not pred_points or not slots: if not pred_points or not slots:
return return
marking_points = list(list(zip(*pred_points))[1]) marking_points = list(list(zip(*pred_points))[1])
@ -71,9 +72,9 @@ def plot_slots(image, pred_points, slots):
p2_y = int(round(p2_y)) p2_y = int(round(p2_y))
p3_x = int(round(p3_x)) p3_x = int(round(p3_x))
p3_y = int(round(p3_y)) p3_y = int(round(p3_y))
cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (255, 0, 0)) cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (255, 0, 0), 2)
cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (255, 0, 0)) cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (255, 0, 0), 2)
cv.line(image, (p1_x, p1_y), (p3_x, p3_y), (255, 0, 0)) cv.line(image, (p1_x, p1_y), (p3_x, p3_y), (255, 0, 0), 2)
def preprocess_image(image): def preprocess_image(image):
@ -95,12 +96,21 @@ def inference_slots(marking_points):
slots = [] slots = []
for i in range(num_detected - 1): for i in range(num_detected - 1):
for j in range(i + 1, num_detected): for j in range(i + 1, num_detected):
result = pair_marking_points(marking_points[i], marking_points[j]) point_i = marking_points[i]
point_j = marking_points[j]
# Step 1: length filtration.
distance = calc_point_squre_dist(point_i, point_j)
if not (config.VSLOT_MIN_DIST <= distance <= config.VSLOT_MAX_DIST
or config.HSLOT_MIN_DIST <= distance <= config.HSLOT_MAX_DIST):
continue
# Step 2: pass through filtration.
if pass_through_third_point(marking_points, i, j):
continue
result = pair_marking_points(point_i, point_j)
if result == 1: if result == 1:
slots.append((i, j)) slots.append((i, j))
elif result == -1: elif result == -1:
slots.append((j, i)) slots.append((j, i))
slots = filter_slots(marking_points, slots)
return slots return slots
@ -114,7 +124,7 @@ def detect_video(detector, device, args):
if args.save: if args.save:
output_video.open('record.avi', cv.VideoWriter_fourcc(*'MJPG'), output_video.open('record.avi', cv.VideoWriter_fourcc(*'MJPG'),
input_video.get(cv.CAP_PROP_FPS), input_video.get(cv.CAP_PROP_FPS),
(frame_width, frame_height)) (frame_width, frame_height), True)
frame = np.empty([frame_height, frame_width, 3], dtype=np.uint8) frame = np.empty([frame_height, frame_width, 3], dtype=np.uint8)
while input_video.read(frame)[0]: while input_video.read(frame)[0]:
timer.tic() timer.tic()
@ -145,6 +155,7 @@ def detect_image(detector, device, args):
timer.tic() timer.tic()
pred_points = detect_marking_points( pred_points = detect_marking_points(
detector, image, args.thresh, device) detector, image, args.thresh, device)
slots = None
if pred_points and args.inference_slot: if pred_points and args.inference_slot:
marking_points = list(list(zip(*pred_points))[1]) marking_points = list(list(zip(*pred_points))[1])
slots = inference_slots(marking_points) slots = inference_slots(marking_points)

View File

@ -32,6 +32,20 @@ def define_halve_unit(basic_channel_size):
return layers return layers
def define_depthwise_expand_unit(basic_channel_size):
"""Define a 3x3 expand convolution with norm and activation."""
conv1 = nn.Conv2d(basic_channel_size, 2 * basic_channel_size,
kernel_size=1, stride=1, padding=0, bias=False)
norm1 = nn.BatchNorm2d(2 * basic_channel_size)
relu1 = nn.LeakyReLU(0.1)
conv2 = nn.Conv2d(2 * basic_channel_size, 2 * basic_channel_size, kernel_size=3,
stride=1, padding=1, bias=False, groups=2 * basic_channel_size)
norm2 = nn.BatchNorm2d(2 * basic_channel_size)
relu2 = nn.LeakyReLU(0.1)
layers = [conv1, norm1, relu1, conv2, norm2, relu2]
return layers
def define_detector_block(basic_channel_size): def define_detector_block(basic_channel_size):
"""Define a unit composite of a squeeze and expand unit.""" """Define a unit composite of a squeeze and expand unit."""
layers = [] layers = []

View File

@ -110,7 +110,7 @@ def generate_dataset(args):
if len(centralied_marks.shape) < 2: if len(centralied_marks.shape) < 2:
centralied_marks = np.expand_dims(centralied_marks, axis=0) centralied_marks = np.expand_dims(centralied_marks, axis=0)
centralied_marks[:, 0:4] -= 300.5 centralied_marks[:, 0:4] -= 300.5
if boundary_check(centralied_marks): if boundary_check(centralied_marks) or args.dataset == 'test':
output_name = os.path.join(args.output_directory, name) output_name = os.path.join(args.output_directory, name)
write_image_and_label(output_name, image, write_image_and_label(output_name, image,
centralied_marks, name_list) centralied_marks, name_list)

79
ps_evaluate.py Normal file
View File

@ -0,0 +1,79 @@
"""Evaluate directional marking point detector."""
import json
import os
import cv2 as cv
import numpy as np
import torch
import config
import util
from data import match_slots, Slot
from model import DirectionalPointDetector
from inference import detect_marking_points, inference_slots
def get_ground_truths(label):
slots = np.array(label['slots'])
if slots.size == 0:
return []
if len(slots.shape) < 2:
slots = np.expand_dims(slots, axis=0)
marks = np.array(label['marks'])
if len(marks.shape) < 2:
marks = np.expand_dims(marks, axis=0)
ground_truths = []
for slot in slots:
mark_a = marks[slot[0] - 1]
mark_b = marks[slot[1] - 1]
coords = np.array([mark_a[0], mark_a[1], mark_b[0], mark_b[1]])
coords = (coords - 0.5) / 600
ground_truths.append(Slot(*coords))
return ground_truths
def psevaluate_detector(args):
"""Evaluate directional point detector."""
args.cuda = not args.disable_cuda and torch.cuda.is_available()
device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
torch.set_grad_enabled(False)
dp_detector = DirectionalPointDetector(
3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
if args.detector_weights:
dp_detector.load_state_dict(torch.load(args.detector_weights))
dp_detector.eval()
logger = util.Logger(enable_visdom=args.enable_visdom)
ground_truths_list = []
predictions_list = []
for idx, label_file in enumerate(os.listdir(args.label_directory)):
name = os.path.splitext(label_file)[0]
print(idx, name)
image = cv.imread(os.path.join(args.image_directory, name + '.jpg'))
pred_points = detect_marking_points(
dp_detector, image, config.CONFID_THRESH_FOR_POINT, device)
if pred_points:
marking_points = list(list(zip(*pred_points))[1])
slots = inference_slots(marking_points)
pred_slots = []
for slot in slots:
point_a = marking_points[slot[0]]
point_b = marking_points[slot[1]]
prob = min((pred_points[slot[0]][0], pred_points[slot[1]][0]))
pred_slots.append(
(prob, Slot(point_a.x, point_a.y, point_b.x, point_b.y)))
predictions_list.append(pred_slots)
with open(os.path.join(args.label_directory, label_file), 'r') as file:
ground_truths_list.append(get_ground_truths(json.load(file)))
precisions, recalls = util.calc_precision_recall(
ground_truths_list, predictions_list, match_slots)
average_precision = util.calc_average_precision(precisions, recalls)
if args.enable_visdom:
logger.plot_curve(precisions, recalls)
logger.log(average_precision=average_precision)
if __name__ == '__main__':
psevaluate_detector(config.get_parser_for_ps_evaluation().parse_args())

View File

@ -69,8 +69,7 @@ def train_detector(args):
print("Loading weights: %s" % args.optimizer_weights) print("Loading weights: %s" % args.optimizer_weights)
optimizer.load_state_dict(torch.load(args.optimizer_weights)) optimizer.load_state_dict(torch.load(args.optimizer_weights))
logger = util.Logger(args.enable_visdom, logger = util.Logger(args.enable_visdom, ['train_loss'])
['train_loss'] if args.enable_visdom else None)
data_loader = DataLoader(data.ParkingSlotDataset(args.dataset_directory), data_loader = DataLoader(data.ParkingSlotDataset(args.dataset_directory),
batch_size=args.batch_size, shuffle=True, batch_size=args.batch_size, shuffle=True,
num_workers=args.data_loading_workers, num_workers=args.data_loading_workers,
@ -78,8 +77,7 @@ def train_detector(args):
for epoch_idx in range(args.num_epochs): for epoch_idx in range(args.num_epochs):
for iter_idx, (image, marking_points) in enumerate(data_loader): for iter_idx, (image, marking_points) in enumerate(data_loader):
image = torch.stack(image) image = torch.stack(image).to(device)
image = image.to(device)
optimizer.zero_grad() optimizer.zero_grad()
prediction = dp_detector(image) prediction = dp_detector(image)

View File

@ -1,4 +1,4 @@
"""Utility related package.""" """Utility related package."""
from .log import Logger from .log import Logger
from .precision_recall import calc_precision_recall, calc_average_precision from .precision_recall import calc_precision_recall, calc_average_precision, match_gt_with_preds
from .utils import Timer, tensor2array, tensor2im from .utils import Timer, tensor2array, tensor2im