This commit is contained in:
nyh 2023-11-01 17:02:00 +08:00
parent 219a97f315
commit 4c861391aa
3 changed files with 11 additions and 8 deletions

View File

@ -43,7 +43,7 @@ def main():
if half:
model.half()
# STDC model
# load args
args = config.get_parser_for_inference().parse_args()
# STDC model
@ -55,7 +55,6 @@ def main():
STDC_model.eval()
# DMPR model
args = config.get_parser_for_inference().parse_args()
# DMPRmodel = DirectionalPointDetector(3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
# DMPRmodel.load_state_dict(torch.load(DMPRweights))
DMPRmodel = Model(args.cfg, ch=3).to(device)
@ -101,12 +100,9 @@ def main():
# 绘制角点
plot_points(img0, det1)
# save
# cv2.imwrite(file, img0)
# yolo joint DMPR
cls = 0 #需要过滤的box类别
joint_det, dilate_box = dmpr_yolo(det1, det0, img0.shape, cls, args.scale_ratio)
joint_det, dilate_box = dmpr_yolo(det1, det0, img0.shape, cls, args.scale_ratio, args.border)
t_joint = time.time()
print(f't_joint. ({t_joint - t_dmpr:.3f}s)')

View File

@ -4,7 +4,12 @@ import numpy as np
import torch
def dmpr_yolo(dmpr_det, yolo_det, img_shape, cls:int, scale_ratio):
def dmpr_yolo(dmpr_det, yolo_det, img_shape, cls:int, scale_ratio, border=80):
# 过滤在图像边界的box(防止出现一小半车辆的情况)
x_c = (yolo_det[:, 0] + yolo_det[:, 2]) / 2
y_c = (yolo_det[:, 1] + yolo_det[:, 3]) / 2
tmp = (x_c >= border) & (x_c <= (img_shape[1] - border)) & (y_c >= border) & (y_c <= (img_shape[0] - border))
yolo_det = yolo_det[tmp]
# 创建yolo_det_clone内容为x1, y1, x2, y2, conf, cls, unlabel (unlabel代表该类是否需要忽略0不忽略 其他:忽略)
yolo_det_clone = yolo_det.copy()

View File

@ -7,7 +7,7 @@ NUM_FEATURE_MAP_CHANNEL = 6
def add_common_arguments(parser):
"""Add common arguments for training and inference."""
parser.add_argument('--detector_weights', default=r'E:\pycharmProject\DMPR-PS\weights\dp_detector_499.pth',
parser.add_argument('--detector_weights', default=r'E:\pycharmProject\DMPR-PS\weights\dp_detector_299.pth',
help="The weights of pretrained detector.")
parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml',
help='model.yaml path')
@ -46,6 +46,8 @@ def get_parser_for_inference():
help='IOU threshold for NMS')
parser.add_argument('--scale-ratio', type=float, default=0.5,
help='detected box scale ratio')
parser.add_argument('--border', type=int, default=80,
help='The valid border to boundary')
parser.add_argument('--ovlap-thres', type=float, default=0.6, help='overlap threshold for OBS')
parser.add_argument('--agnostic-nms', action='store_true',
help='class-agnostic NMS')