@@ -43,7 +43,7 @@ def main(): | |||
if half: | |||
model.half() | |||
# STDC model | |||
# load args | |||
args = config.get_parser_for_inference().parse_args() | |||
# STDC model | |||
@@ -55,7 +55,6 @@ def main(): | |||
STDC_model.eval() | |||
# DMPR model | |||
args = config.get_parser_for_inference().parse_args() | |||
# DMPRmodel = DirectionalPointDetector(3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device) | |||
# DMPRmodel.load_state_dict(torch.load(DMPRweights)) | |||
DMPRmodel = Model(args.cfg, ch=3).to(device) | |||
@@ -101,12 +100,9 @@ def main(): | |||
# 绘制角点 | |||
plot_points(img0, det1) | |||
# save | |||
# cv2.imwrite(file, img0) | |||
# yolo joint DMPR | |||
cls = 0 #需要过滤的box类别 | |||
joint_det, dilate_box = dmpr_yolo(det1, det0, img0.shape, cls, args.scale_ratio) | |||
joint_det, dilate_box = dmpr_yolo(det1, det0, img0.shape, cls, args.scale_ratio, args.border) | |||
t_joint = time.time() | |||
print(f't_joint. ({t_joint - t_dmpr:.3f}s)') |
@@ -4,7 +4,12 @@ import numpy as np | |||
import torch | |||
def dmpr_yolo(dmpr_det, yolo_det, img_shape, cls:int, scale_ratio): | |||
def dmpr_yolo(dmpr_det, yolo_det, img_shape, cls:int, scale_ratio, border=80): | |||
# 过滤在图像边界的box(防止出现一小半车辆的情况) | |||
x_c = (yolo_det[:, 0] + yolo_det[:, 2]) / 2 | |||
y_c = (yolo_det[:, 1] + yolo_det[:, 3]) / 2 | |||
tmp = (x_c >= border) & (x_c <= (img_shape[1] - border)) & (y_c >= border) & (y_c <= (img_shape[0] - border)) | |||
yolo_det = yolo_det[tmp] | |||
# 创建yolo_det_clone内容为x1, y1, x2, y2, conf, cls, unlabel (unlabel代表该类是否需要忽略,0:不忽略 其他:忽略) | |||
yolo_det_clone = yolo_det.copy() |
@@ -7,7 +7,7 @@ NUM_FEATURE_MAP_CHANNEL = 6 | |||
def add_common_arguments(parser): | |||
"""Add common arguments for training and inference.""" | |||
parser.add_argument('--detector_weights', default=r'E:\pycharmProject\DMPR-PS\weights\dp_detector_499.pth', | |||
parser.add_argument('--detector_weights', default=r'E:\pycharmProject\DMPR-PS\weights\dp_detector_299.pth', | |||
help="The weights of pretrained detector.") | |||
parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml', | |||
help='model.yaml path') | |||
@@ -46,6 +46,8 @@ def get_parser_for_inference(): | |||
help='IOU threshold for NMS') | |||
parser.add_argument('--scale-ratio', type=float, default=0.5, | |||
help='detected box scale ratio') | |||
parser.add_argument('--border', type=int, default=80, | |||
help='The valid border to boundary') | |||
parser.add_argument('--ovlap-thres', type=float, default=0.6, help='overlap threshold for OBS') | |||
parser.add_argument('--agnostic-nms', action='store_true', | |||
help='class-agnostic NMS') |