|
- """Evaluate directional marking point detector."""
- import torch
- from torch.utils.data import DataLoader
- import config
- import util
- from data import get_predicted_points, match_marking_points
- from data import ParkingSlotDataset
- from model import DirectionalPointDetector
- from train import generate_objective
-
-
- def evaluate_detector(args):
- """Evaluate directional point detector."""
- args.cuda = not args.disable_cuda and torch.cuda.is_available()
- device = torch.device('cuda:' + str(args.gpu_id) if args.cuda else 'cpu')
- torch.set_grad_enabled(False)
-
- dp_detector = DirectionalPointDetector(
- 3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
- if args.detector_weights:
- dp_detector.load_state_dict(torch.load(args.detector_weights))
- dp_detector.eval()
-
- torch.multiprocessing.set_sharing_strategy('file_system')
- data_loader = DataLoader(ParkingSlotDataset(args.dataset_directory),
- batch_size=args.batch_size, shuffle=True,
- num_workers=args.data_loading_workers,
- collate_fn=lambda x: list(zip(*x)))
- logger = util.Logger(enable_visdom=args.enable_visdom)
-
- total_loss = 0
- num_evaluation = 0
- ground_truths_list = []
- predictions_list = []
- for iter_idx, (image, marking_points) in enumerate(data_loader):
- image = torch.stack(image)
- image = image.to(device)
- ground_truths_list += list(marking_points)
-
- prediction = dp_detector(image)
- objective, gradient = generate_objective(marking_points, device)
- loss = (prediction - objective) ** 2
- total_loss += torch.sum(loss*gradient).item()
- num_evaluation += loss.size(0)
-
- pred_points = [get_predicted_points(pred, 0.01) for pred in prediction]
- predictions_list += pred_points
- logger.log(iter=iter_idx, total_loss=total_loss)
-
- precisions, recalls = util.calc_precision_recall(
- ground_truths_list, predictions_list, match_marking_points)
- average_precision = util.calc_average_precision(precisions, recalls)
- if args.enable_visdom:
- logger.plot_curve(precisions, recalls)
- logger.log(average_loss=total_loss / num_evaluation,
- average_precision=average_precision)
-
-
- if __name__ == '__main__':
- evaluate_detector(config.get_parser_for_evaluation().parse_args())
|