2018-10-02 15:54:42 +08:00
|
|
|
"""Evaluate directional marking point detector."""
|
|
|
|
|
import torch
|
|
|
|
|
import config
|
2018-10-04 09:30:25 +08:00
|
|
|
import util
|
2019-12-27 10:57:50 +08:00
|
|
|
from thop import profile
|
2018-11-21 15:45:22 +08:00
|
|
|
from data import get_predicted_points, match_marking_points, calc_point_squre_dist, calc_point_direction_angle
|
2018-10-04 09:30:25 +08:00
|
|
|
from data import ParkingSlotDataset
|
|
|
|
|
from model import DirectionalPointDetector
|
|
|
|
|
from train import generate_objective
|
2018-10-02 15:54:42 +08:00
|
|
|
|
|
|
|
|
|
2018-11-21 15:45:22 +08:00
|
|
|
def is_gt_and_pred_matched(ground_truths, predictions, thresh):
|
|
|
|
|
"""Check if there is any false positive or false negative."""
|
|
|
|
|
predictions = [pred for pred in predictions if pred[0] >= thresh]
|
|
|
|
|
prediction_matched = [False] * len(predictions)
|
|
|
|
|
for ground_truth in ground_truths:
|
|
|
|
|
idx = util.match_gt_with_preds(ground_truth, predictions,
|
|
|
|
|
match_marking_points)
|
|
|
|
|
if idx < 0:
|
|
|
|
|
return False
|
|
|
|
|
prediction_matched[idx] = True
|
|
|
|
|
if not all(prediction_matched):
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def collect_error(ground_truths, predictions, thresh):
|
|
|
|
|
"""Collect errors for those correctly detected points."""
|
|
|
|
|
dists = []
|
|
|
|
|
angles = []
|
|
|
|
|
predictions = [pred for pred in predictions if pred[0] >= thresh]
|
|
|
|
|
for ground_truth in ground_truths:
|
|
|
|
|
idx = util.match_gt_with_preds(ground_truth, predictions,
|
|
|
|
|
match_marking_points)
|
|
|
|
|
if idx >= 0:
|
|
|
|
|
detected_point = predictions[idx][1]
|
|
|
|
|
dists.append(calc_point_squre_dist(detected_point, ground_truth))
|
2019-07-05 12:16:44 +08:00
|
|
|
angles.append(calc_point_direction_angle(
|
|
|
|
|
detected_point, ground_truth))
|
2018-11-21 15:45:22 +08:00
|
|
|
else:
|
|
|
|
|
continue
|
|
|
|
|
return dists, angles
|
|
|
|
|
|
|
|
|
|
|
2018-10-02 15:54:42 +08:00
|
|
|
def evaluate_detector(args):
|
|
|
|
|
"""Evaluate directional point detector."""
|
|
|
|
|
args.cuda = not args.disable_cuda and torch.cuda.is_available()
|
2018-10-04 09:30:25 +08:00
|
|
|
device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
|
|
|
|
|
torch.set_grad_enabled(False)
|
2018-10-02 15:54:42 +08:00
|
|
|
|
|
|
|
|
dp_detector = DirectionalPointDetector(
|
|
|
|
|
3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
|
|
|
|
|
if args.detector_weights:
|
|
|
|
|
dp_detector.load_state_dict(torch.load(args.detector_weights))
|
2018-10-04 09:30:25 +08:00
|
|
|
dp_detector.eval()
|
2018-10-02 15:54:42 +08:00
|
|
|
|
2018-11-21 15:45:22 +08:00
|
|
|
psdataset = ParkingSlotDataset(args.dataset_directory)
|
2018-10-04 09:30:25 +08:00
|
|
|
logger = util.Logger(enable_visdom=args.enable_visdom)
|
2018-10-02 15:54:42 +08:00
|
|
|
|
|
|
|
|
total_loss = 0
|
2018-11-21 15:45:22 +08:00
|
|
|
position_errors = []
|
|
|
|
|
direction_errors = []
|
2018-10-02 15:54:42 +08:00
|
|
|
ground_truths_list = []
|
|
|
|
|
predictions_list = []
|
2018-11-21 15:45:22 +08:00
|
|
|
for iter_idx, (image, marking_points) in enumerate(psdataset):
|
|
|
|
|
ground_truths_list.append(marking_points)
|
2018-10-02 15:54:42 +08:00
|
|
|
|
2018-11-21 15:45:22 +08:00
|
|
|
image = torch.unsqueeze(image, 0).to(device)
|
2018-10-02 15:54:42 +08:00
|
|
|
prediction = dp_detector(image)
|
2018-11-21 15:45:22 +08:00
|
|
|
objective, gradient = generate_objective([marking_points], device)
|
2018-10-02 15:54:42 +08:00
|
|
|
loss = (prediction - objective) ** 2
|
|
|
|
|
total_loss += torch.sum(loss*gradient).item()
|
|
|
|
|
|
2018-11-21 15:45:22 +08:00
|
|
|
pred_points = get_predicted_points(prediction[0], 0.01)
|
|
|
|
|
predictions_list.append(pred_points)
|
|
|
|
|
|
|
|
|
|
dists, angles = collect_error(marking_points, pred_points,
|
|
|
|
|
config.CONFID_THRESH_FOR_POINT)
|
|
|
|
|
position_errors += dists
|
|
|
|
|
direction_errors += angles
|
|
|
|
|
|
2018-10-04 09:30:25 +08:00
|
|
|
logger.log(iter=iter_idx, total_loss=total_loss)
|
2018-10-02 15:54:42 +08:00
|
|
|
|
2018-10-04 09:30:25 +08:00
|
|
|
precisions, recalls = util.calc_precision_recall(
|
2018-10-02 15:54:42 +08:00
|
|
|
ground_truths_list, predictions_list, match_marking_points)
|
2018-10-04 09:30:25 +08:00
|
|
|
average_precision = util.calc_average_precision(precisions, recalls)
|
2018-10-02 15:54:42 +08:00
|
|
|
if args.enable_visdom:
|
|
|
|
|
logger.plot_curve(precisions, recalls)
|
2019-12-27 10:57:50 +08:00
|
|
|
|
|
|
|
|
sample = torch.randn(1, 3, config.INPUT_IMAGE_SIZE,
|
|
|
|
|
config.INPUT_IMAGE_SIZE)
|
|
|
|
|
flops, params = profile(dp_detector, inputs=(sample.to(device), ))
|
2018-11-21 15:45:22 +08:00
|
|
|
logger.log(average_loss=total_loss / len(psdataset),
|
2019-12-27 10:57:50 +08:00
|
|
|
average_precision=average_precision,
|
|
|
|
|
flops=flops,
|
|
|
|
|
params=params)
|
2018-10-02 15:54:42 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
evaluate_detector(config.get_parser_for_evaluation().parse_args())
|