@@ -0,0 +1,84 @@ | |||
"""Collect the value range of different propertity of ps dataset.""" | |||
import argparse | |||
import json | |||
import math | |||
import os | |||
import matplotlib.pyplot as plt | |||
import numpy as np | |||
from data import MarkingPoint | |||
from data.struct import calc_point_squre_dist, direction_diff | |||
from prepare_dataset import generalize_marks | |||
def get_parser(): | |||
"""Return argument parser for collecting thresholds.""" | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--label_directory', required=True, | |||
help="The location of label directory.") | |||
return parser | |||
def collect_thresholds(args): | |||
"""Collect range of value from ground truth to determine threshold.""" | |||
distances = [] | |||
separator_angles = [] | |||
bridge_angles = [] | |||
for label_file in os.listdir(args.label_directory): | |||
print(label_file) | |||
with open(os.path.join(args.label_directory, label_file), 'r') as file: | |||
label = json.load(file) | |||
marks = np.array(label['marks']) | |||
slots = np.array(label['slots']) | |||
if slots.size == 0: | |||
continue | |||
if len(marks.shape) < 2: | |||
marks = np.expand_dims(marks, axis=0) | |||
if len(slots.shape) < 2: | |||
slots = np.expand_dims(slots, axis=0) | |||
marks[:, 0:4] -= 300.5 | |||
marks = [MarkingPoint(*mark) for mark in generalize_marks(marks)] | |||
for slot in slots: | |||
mark_a = marks[slot[0] - 1] | |||
mark_b = marks[slot[1] - 1] | |||
distances.append(calc_point_squre_dist(mark_a, mark_b)) | |||
vector_ab = np.array([mark_b.x - mark_a.x, mark_b.y - mark_a.y]) | |||
vector_ab = vector_ab / np.linalg.norm(vector_ab) | |||
ab_bridge_direction = math.atan2(vector_ab[1], vector_ab[0]) | |||
ba_bridge_direction = math.atan2(-vector_ab[1], -vector_ab[0]) | |||
separator_direction = math.atan2(-vector_ab[0], vector_ab[1]) | |||
sangle = direction_diff(separator_direction, mark_a.direction) | |||
if mark_a.shape > 0.5: | |||
separator_angles.append(sangle) | |||
else: | |||
bangle = direction_diff(ab_bridge_direction, mark_a.direction) | |||
if sangle < bangle: | |||
separator_angles.append(sangle) | |||
else: | |||
bridge_angles.append(bangle) | |||
bangle = direction_diff(ba_bridge_direction, mark_b.direction) | |||
if mark_b.shape > 0.5: | |||
bridge_angles.append(bangle) | |||
else: | |||
sangle = direction_diff(separator_direction, mark_b.direction) | |||
if sangle < bangle: | |||
separator_angles.append(sangle) | |||
else: | |||
bridge_angles.append(bangle) | |||
distances = sorted(distances) | |||
separator_angles = sorted(separator_angles) | |||
bridge_angles = sorted(bridge_angles) | |||
plt.figure() | |||
plt.hist(distances, len(distances) // 10) | |||
plt.figure() | |||
plt.hist(separator_angles, len(separator_angles) // 10) | |||
plt.figure() | |||
plt.hist(bridge_angles, len(bridge_angles) // 3) | |||
if __name__ == '__main__': | |||
collect_thresholds(get_parser().parse_args()) |
@@ -9,22 +9,25 @@ NUM_FEATURE_MAP_CHANNEL = 6 | |||
# image_size / 2^5 = 512 / 32 = 16 | |||
FEATURE_MAP_SIZE = 16 | |||
# Threshold used to filter marking points too close to image boundary | |||
BOUNDARY_THRESH = 0.066666667 | |||
BOUNDARY_THRESH = 0.05 | |||
# Thresholds to determine whether an detected point match ground truth. | |||
SQUARED_DISTANCE_THRESH = 0.000277778 | |||
DIRECTION_ANGLE_THRESH = 0.5 | |||
SQUARED_DISTANCE_THRESH = 0.000277778 # 10 pixel in 600*600 image | |||
DIRECTION_ANGLE_THRESH = 0.5235987755982988 # 30 degree in rad | |||
VSLOT_MIN_DISTANCE = 0.044771278151623496 | |||
VSLOT_MAX_DISTANCE = 0.1099427457599304 | |||
HSLOT_MIN_DISTANCE = 0.15057789144568634 | |||
HSLOT_MAX_DISTANCE = 0.44449496544202816 | |||
VSLOT_MIN_DIST = 0.044771278151623496 | |||
VSLOT_MAX_DIST = 0.1099427457599304 | |||
HSLOT_MIN_DIST = 0.15057789144568634 | |||
HSLOT_MAX_DIST = 0.44449496544202816 | |||
BRIDGE_ANGLE_DIFF = 0.25 | |||
SEPARATOR_ANGLE_DIFF = 0.5 | |||
# angle_prediction_error = 0.1384059287593468 | |||
BRIDGE_ANGLE_DIFF = 0.09757113548987695 + 0.1384059287593468 | |||
SEPARATOR_ANGLE_DIFF = 0.284967562063968 + 0.1384059287593468 | |||
SLOT_SUPPRESSION_DOT_PRODUCT_THRESH = 0.8 | |||
# precision = 0.995585, recall = 0.995805 | |||
CONFID_THRESH_FOR_POINT = 0.11676871 | |||
def add_common_arguments(parser): | |||
"""Add common arguments for training and inference.""" |
@@ -1,4 +1,4 @@ | |||
"""Data related package.""" | |||
from .data_process import get_predicted_points, pair_marking_points, filter_slots | |||
from .data_process import get_predicted_points, pair_marking_points, pass_through_third_point | |||
from .dataset import ParkingSlotDataset | |||
from .struct import MarkingPoint, Slot, match_marking_points, match_slots | |||
from .struct import MarkingPoint, Slot, match_marking_points, match_slots, calc_point_squre_dist, calc_point_direction_angle |
@@ -3,7 +3,7 @@ import math | |||
import numpy as np | |||
import torch | |||
import config | |||
from data.struct import MarkingPoint, calc_point_squre_dist, detemine_point_shape | |||
from data.struct import MarkingPoint, detemine_point_shape | |||
def non_maximum_suppression(pred_points): | |||
@@ -50,11 +50,28 @@ def get_predicted_points(prediction, thresh): | |||
return non_maximum_suppression(predicted_points) | |||
def pass_through_third_point(marking_points, i, j): | |||
"""See whether the line between two points pass through a third point.""" | |||
x_1 = marking_points[i].x | |||
y_1 = marking_points[i].y | |||
x_2 = marking_points[j].x | |||
y_2 = marking_points[j].y | |||
for point_idx, point in enumerate(marking_points): | |||
if point_idx == i or point_idx == j: | |||
continue | |||
x_0 = point.x | |||
y_0 = point.y | |||
vec1 = np.array([x_0 - x_1, y_0 - y_1]) | |||
vec2 = np.array([x_2 - x_0, y_2 - y_0]) | |||
vec1 = vec1 / np.linalg.norm(vec1) | |||
vec2 = vec2 / np.linalg.norm(vec2) | |||
if np.dot(vec1, vec2) > config.SLOT_SUPPRESSION_DOT_PRODUCT_THRESH: | |||
return True | |||
return False | |||
def pair_marking_points(point_a, point_b): | |||
distance = calc_point_squre_dist(point_a, point_b) | |||
if not (config.VSLOT_MIN_DISTANCE <= distance <= config.VSLOT_MAX_DISTANCE | |||
or config.HSLOT_MIN_DISTANCE <= distance <= config.HSLOT_MAX_DISTANCE): | |||
return 0 | |||
"""See whether two marking points form a slot.""" | |||
vector_ab = np.array([point_b.x - point_a.x, point_b.y - point_a.y]) | |||
vector_ab = vector_ab / np.linalg.norm(vector_ab) | |||
point_shape_a = detemine_point_shape(point_a, vector_ab) | |||
@@ -77,30 +94,3 @@ def pair_marking_points(point_a, point_b): | |||
return 1 | |||
if point_shape_b.value > 3: | |||
return -1 | |||
def filter_slots(marking_points, slots): | |||
suppressed = [False] * len(slots) | |||
for i, slot in enumerate(slots): | |||
x1 = marking_points[slot[0]].x | |||
y1 = marking_points[slot[0]].y | |||
x2 = marking_points[slot[1]].x | |||
y2 = marking_points[slot[1]].y | |||
for point_idx, point in enumerate(marking_points): | |||
if point_idx == slot[0] or point_idx == slot[1]: | |||
continue | |||
x0 = point.x | |||
y0 = point.y | |||
vec1 = np.array([x0 - x1, y0 - y1]) | |||
vec2 = np.array([x2 - x0, y2 - y0]) | |||
vec1 = vec1 / np.linalg.norm(vec1) | |||
vec2 = vec2 / np.linalg.norm(vec2) | |||
if np.dot(vec1, vec2) > config.SLOT_SUPPRESSION_DOT_PRODUCT_THRESH: | |||
suppressed[i] = True | |||
if any(suppressed): | |||
unsupres_slots = [] | |||
for i, supres in enumerate(suppressed): | |||
if not supres: | |||
unsupres_slots.append(slots[i]) | |||
return unsupres_slots | |||
return slots |
@@ -20,11 +20,13 @@ class PointShape(Enum): | |||
def direction_diff(direction_a, direction_b): | |||
"""Calculate the angle between two direction.""" | |||
diff = abs(direction_a - direction_b) | |||
return diff if diff < math.pi else 2*math.pi - diff | |||
def detemine_point_shape(point, vector): | |||
"""Determine which category the point is in.""" | |||
vec_direct = math.atan2(vector[1], vector[0]) | |||
vec_direct_up = math.atan2(-vector[0], vector[1]) | |||
vec_direct_down = math.atan2(vector[0], -vector[1]) | |||
@@ -59,6 +61,10 @@ def match_marking_points(point_a, point_b): | |||
"""Determine whether a detected point match ground truth.""" | |||
dist_square = calc_point_squre_dist(point_a, point_b) | |||
angle = calc_point_direction_angle(point_a, point_b) | |||
if point_a.shape > 0.5 and point_b.shape < 0.5: | |||
return False | |||
if point_a.shape < 0.5 and point_b.shape > 0.5: | |||
return False | |||
return (dist_square < config.SQUARED_DISTANCE_THRESH | |||
and angle < config.DIRECTION_ANGLE_THRESH) | |||
@@ -1,14 +1,47 @@ | |||
"""Evaluate directional marking point detector.""" | |||
import cv2 as cv | |||
import torch | |||
from torch.utils.data import DataLoader | |||
import config | |||
import util | |||
from data import get_predicted_points, match_marking_points | |||
from data import get_predicted_points, match_marking_points, calc_point_squre_dist, calc_point_direction_angle | |||
from data import ParkingSlotDataset | |||
from inference import plot_points | |||
from model import DirectionalPointDetector | |||
from train import generate_objective | |||
def is_gt_and_pred_matched(ground_truths, predictions, thresh): | |||
"""Check if there is any false positive or false negative.""" | |||
predictions = [pred for pred in predictions if pred[0] >= thresh] | |||
prediction_matched = [False] * len(predictions) | |||
for ground_truth in ground_truths: | |||
idx = util.match_gt_with_preds(ground_truth, predictions, | |||
match_marking_points) | |||
if idx < 0: | |||
return False | |||
prediction_matched[idx] = True | |||
if not all(prediction_matched): | |||
return False | |||
return True | |||
def collect_error(ground_truths, predictions, thresh): | |||
"""Collect errors for those correctly detected points.""" | |||
dists = [] | |||
angles = [] | |||
predictions = [pred for pred in predictions if pred[0] >= thresh] | |||
for ground_truth in ground_truths: | |||
idx = util.match_gt_with_preds(ground_truth, predictions, | |||
match_marking_points) | |||
if idx >= 0: | |||
detected_point = predictions[idx][1] | |||
dists.append(calc_point_squre_dist(detected_point, ground_truth)) | |||
angles.append(calc_point_direction_angle(detected_point, ground_truth)) | |||
else: | |||
continue | |||
return dists, angles | |||
def evaluate_detector(args): | |||
"""Evaluate directional point detector.""" | |||
args.cuda = not args.disable_cuda and torch.cuda.is_available() | |||
@@ -21,30 +54,37 @@ def evaluate_detector(args): | |||
dp_detector.load_state_dict(torch.load(args.detector_weights)) | |||
dp_detector.eval() | |||
torch.multiprocessing.set_sharing_strategy('file_system') | |||
data_loader = DataLoader(ParkingSlotDataset(args.dataset_directory), | |||
batch_size=args.batch_size, shuffle=True, | |||
num_workers=args.data_loading_workers, | |||
collate_fn=lambda x: list(zip(*x))) | |||
psdataset = ParkingSlotDataset(args.dataset_directory) | |||
logger = util.Logger(enable_visdom=args.enable_visdom) | |||
total_loss = 0 | |||
num_evaluation = 0 | |||
position_errors = [] | |||
direction_errors = [] | |||
ground_truths_list = [] | |||
predictions_list = [] | |||
for iter_idx, (image, marking_points) in enumerate(data_loader): | |||
image = torch.stack(image) | |||
image = image.to(device) | |||
ground_truths_list += list(marking_points) | |||
for iter_idx, (image, marking_points) in enumerate(psdataset): | |||
ground_truths_list.append(marking_points) | |||
image = torch.unsqueeze(image, 0).to(device) | |||
prediction = dp_detector(image) | |||
objective, gradient = generate_objective(marking_points, device) | |||
objective, gradient = generate_objective([marking_points], device) | |||
loss = (prediction - objective) ** 2 | |||
total_loss += torch.sum(loss*gradient).item() | |||
num_evaluation += loss.size(0) | |||
pred_points = [get_predicted_points(pred, 0.01) for pred in prediction] | |||
predictions_list += pred_points | |||
pred_points = get_predicted_points(prediction[0], 0.01) | |||
predictions_list.append(pred_points) | |||
# if not is_gt_and_pred_matched(marking_points, pred_points, | |||
# config.CONFID_THRESH_FOR_POINT): | |||
# cvimage = util.tensor2array(image[0]).copy() | |||
# plot_points(cvimage, pred_points) | |||
# cv.imwrite('flaw/%d.jpg' % iter_idx, cvimage, | |||
# [int(cv.IMWRITE_JPEG_QUALITY), 100]) | |||
dists, angles = collect_error(marking_points, pred_points, | |||
config.CONFID_THRESH_FOR_POINT) | |||
position_errors += dists | |||
direction_errors += angles | |||
logger.log(iter=iter_idx, total_loss=total_loss) | |||
precisions, recalls = util.calc_precision_recall( | |||
@@ -52,7 +92,7 @@ def evaluate_detector(args): | |||
average_precision = util.calc_average_precision(precisions, recalls) | |||
if args.enable_visdom: | |||
logger.plot_curve(precisions, recalls) | |||
logger.log(average_loss=total_loss / num_evaluation, | |||
logger.log(average_loss=total_loss / len(psdataset), | |||
average_precision=average_precision) | |||
@@ -5,13 +5,13 @@ import numpy as np | |||
import torch | |||
from torchvision.transforms import ToTensor | |||
import config | |||
from data import get_predicted_points, pair_marking_points, filter_slots | |||
from data import get_predicted_points, pair_marking_points, calc_point_squre_dist, pass_through_third_point | |||
from model import DirectionalPointDetector | |||
from util import Timer | |||
def plot_points(image, pred_points): | |||
"""Plot marking points on the image and show.""" | |||
"""Plot marking points on the image.""" | |||
if not pred_points: | |||
return | |||
height = image.shape[0] | |||
@@ -33,18 +33,19 @@ def plot_points(image, pred_points): | |||
p1_y = int(round(p1_y)) | |||
p2_x = int(round(p2_x)) | |||
p2_y = int(round(p2_y)) | |||
cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (0, 0, 255)) | |||
cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (0, 0, 255), 2) | |||
cv.putText(image, str(confidence), (p0_x, p0_y), | |||
cv.FONT_HERSHEY_PLAIN, 1, (0, 0, 0)) | |||
if marking_point.shape > 0.5: | |||
cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (0, 0, 255)) | |||
cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (0, 0, 255), 2) | |||
else: | |||
p3_x = int(round(p3_x)) | |||
p3_y = int(round(p3_y)) | |||
cv.line(image, (p2_x, p2_y), (p3_x, p3_y), (0, 0, 255)) | |||
cv.line(image, (p2_x, p2_y), (p3_x, p3_y), (0, 0, 255), 2) | |||
def plot_slots(image, pred_points, slots): | |||
"""Plot parking slots on the image.""" | |||
if not pred_points or not slots: | |||
return | |||
marking_points = list(list(zip(*pred_points))[1]) | |||
@@ -71,9 +72,9 @@ def plot_slots(image, pred_points, slots): | |||
p2_y = int(round(p2_y)) | |||
p3_x = int(round(p3_x)) | |||
p3_y = int(round(p3_y)) | |||
cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (255, 0, 0)) | |||
cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (255, 0, 0)) | |||
cv.line(image, (p1_x, p1_y), (p3_x, p3_y), (255, 0, 0)) | |||
cv.line(image, (p0_x, p0_y), (p1_x, p1_y), (255, 0, 0), 2) | |||
cv.line(image, (p0_x, p0_y), (p2_x, p2_y), (255, 0, 0), 2) | |||
cv.line(image, (p1_x, p1_y), (p3_x, p3_y), (255, 0, 0), 2) | |||
def preprocess_image(image): | |||
@@ -95,12 +96,21 @@ def inference_slots(marking_points): | |||
slots = [] | |||
for i in range(num_detected - 1): | |||
for j in range(i + 1, num_detected): | |||
result = pair_marking_points(marking_points[i], marking_points[j]) | |||
point_i = marking_points[i] | |||
point_j = marking_points[j] | |||
# Step 1: length filtration. | |||
distance = calc_point_squre_dist(point_i, point_j) | |||
if not (config.VSLOT_MIN_DIST <= distance <= config.VSLOT_MAX_DIST | |||
or config.HSLOT_MIN_DIST <= distance <= config.HSLOT_MAX_DIST): | |||
continue | |||
# Step 2: pass through filtration. | |||
if pass_through_third_point(marking_points, i, j): | |||
continue | |||
result = pair_marking_points(point_i, point_j) | |||
if result == 1: | |||
slots.append((i, j)) | |||
elif result == -1: | |||
slots.append((j, i)) | |||
slots = filter_slots(marking_points, slots) | |||
return slots | |||
@@ -114,7 +124,7 @@ def detect_video(detector, device, args): | |||
if args.save: | |||
output_video.open('record.avi', cv.VideoWriter_fourcc(*'MJPG'), | |||
input_video.get(cv.CAP_PROP_FPS), | |||
(frame_width, frame_height)) | |||
(frame_width, frame_height), True) | |||
frame = np.empty([frame_height, frame_width, 3], dtype=np.uint8) | |||
while input_video.read(frame)[0]: | |||
timer.tic() | |||
@@ -145,6 +155,7 @@ def detect_image(detector, device, args): | |||
timer.tic() | |||
pred_points = detect_marking_points( | |||
detector, image, args.thresh, device) | |||
slots = None | |||
if pred_points and args.inference_slot: | |||
marking_points = list(list(zip(*pred_points))[1]) | |||
slots = inference_slots(marking_points) |
@@ -32,6 +32,20 @@ def define_halve_unit(basic_channel_size): | |||
return layers | |||
def define_depthwise_expand_unit(basic_channel_size): | |||
"""Define a 3x3 expand convolution with norm and activation.""" | |||
conv1 = nn.Conv2d(basic_channel_size, 2 * basic_channel_size, | |||
kernel_size=1, stride=1, padding=0, bias=False) | |||
norm1 = nn.BatchNorm2d(2 * basic_channel_size) | |||
relu1 = nn.LeakyReLU(0.1) | |||
conv2 = nn.Conv2d(2 * basic_channel_size, 2 * basic_channel_size, kernel_size=3, | |||
stride=1, padding=1, bias=False, groups=2 * basic_channel_size) | |||
norm2 = nn.BatchNorm2d(2 * basic_channel_size) | |||
relu2 = nn.LeakyReLU(0.1) | |||
layers = [conv1, norm1, relu1, conv2, norm2, relu2] | |||
return layers | |||
def define_detector_block(basic_channel_size): | |||
"""Define a unit composite of a squeeze and expand unit.""" | |||
layers = [] |
@@ -110,7 +110,7 @@ def generate_dataset(args): | |||
if len(centralied_marks.shape) < 2: | |||
centralied_marks = np.expand_dims(centralied_marks, axis=0) | |||
centralied_marks[:, 0:4] -= 300.5 | |||
if boundary_check(centralied_marks): | |||
if boundary_check(centralied_marks) or args.dataset == 'test': | |||
output_name = os.path.join(args.output_directory, name) | |||
write_image_and_label(output_name, image, | |||
centralied_marks, name_list) |
@@ -0,0 +1,79 @@ | |||
"""Evaluate directional marking point detector.""" | |||
import json | |||
import os | |||
import cv2 as cv | |||
import numpy as np | |||
import torch | |||
import config | |||
import util | |||
from data import match_slots, Slot | |||
from model import DirectionalPointDetector | |||
from inference import detect_marking_points, inference_slots | |||
def get_ground_truths(label): | |||
slots = np.array(label['slots']) | |||
if slots.size == 0: | |||
return [] | |||
if len(slots.shape) < 2: | |||
slots = np.expand_dims(slots, axis=0) | |||
marks = np.array(label['marks']) | |||
if len(marks.shape) < 2: | |||
marks = np.expand_dims(marks, axis=0) | |||
ground_truths = [] | |||
for slot in slots: | |||
mark_a = marks[slot[0] - 1] | |||
mark_b = marks[slot[1] - 1] | |||
coords = np.array([mark_a[0], mark_a[1], mark_b[0], mark_b[1]]) | |||
coords = (coords - 0.5) / 600 | |||
ground_truths.append(Slot(*coords)) | |||
return ground_truths | |||
def psevaluate_detector(args): | |||
"""Evaluate directional point detector.""" | |||
args.cuda = not args.disable_cuda and torch.cuda.is_available() | |||
device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu') | |||
torch.set_grad_enabled(False) | |||
dp_detector = DirectionalPointDetector( | |||
3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device) | |||
if args.detector_weights: | |||
dp_detector.load_state_dict(torch.load(args.detector_weights)) | |||
dp_detector.eval() | |||
logger = util.Logger(enable_visdom=args.enable_visdom) | |||
ground_truths_list = [] | |||
predictions_list = [] | |||
for idx, label_file in enumerate(os.listdir(args.label_directory)): | |||
name = os.path.splitext(label_file)[0] | |||
print(idx, name) | |||
image = cv.imread(os.path.join(args.image_directory, name + '.jpg')) | |||
pred_points = detect_marking_points( | |||
dp_detector, image, config.CONFID_THRESH_FOR_POINT, device) | |||
if pred_points: | |||
marking_points = list(list(zip(*pred_points))[1]) | |||
slots = inference_slots(marking_points) | |||
pred_slots = [] | |||
for slot in slots: | |||
point_a = marking_points[slot[0]] | |||
point_b = marking_points[slot[1]] | |||
prob = min((pred_points[slot[0]][0], pred_points[slot[1]][0])) | |||
pred_slots.append( | |||
(prob, Slot(point_a.x, point_a.y, point_b.x, point_b.y))) | |||
predictions_list.append(pred_slots) | |||
with open(os.path.join(args.label_directory, label_file), 'r') as file: | |||
ground_truths_list.append(get_ground_truths(json.load(file))) | |||
precisions, recalls = util.calc_precision_recall( | |||
ground_truths_list, predictions_list, match_slots) | |||
average_precision = util.calc_average_precision(precisions, recalls) | |||
if args.enable_visdom: | |||
logger.plot_curve(precisions, recalls) | |||
logger.log(average_precision=average_precision) | |||
if __name__ == '__main__': | |||
psevaluate_detector(config.get_parser_for_ps_evaluation().parse_args()) |
@@ -69,8 +69,7 @@ def train_detector(args): | |||
print("Loading weights: %s" % args.optimizer_weights) | |||
optimizer.load_state_dict(torch.load(args.optimizer_weights)) | |||
logger = util.Logger(args.enable_visdom, | |||
['train_loss'] if args.enable_visdom else None) | |||
logger = util.Logger(args.enable_visdom, ['train_loss']) | |||
data_loader = DataLoader(data.ParkingSlotDataset(args.dataset_directory), | |||
batch_size=args.batch_size, shuffle=True, | |||
num_workers=args.data_loading_workers, | |||
@@ -78,8 +77,7 @@ def train_detector(args): | |||
for epoch_idx in range(args.num_epochs): | |||
for iter_idx, (image, marking_points) in enumerate(data_loader): | |||
image = torch.stack(image) | |||
image = image.to(device) | |||
image = torch.stack(image).to(device) | |||
optimizer.zero_grad() | |||
prediction = dp_detector(image) |
@@ -1,4 +1,4 @@ | |||
"""Utility related package.""" | |||
from .log import Logger | |||
from .precision_recall import calc_precision_recall, calc_average_precision | |||
from .precision_recall import calc_precision_recall, calc_average_precision, match_gt_with_preds | |||
from .utils import Timer, tensor2array, tensor2im |