This commit is contained in:
parent
219a97f315
commit
4c861391aa
|
|
@ -43,7 +43,7 @@ def main():
|
|||
if half:
|
||||
model.half()
|
||||
|
||||
# STDC model
|
||||
# load args
|
||||
args = config.get_parser_for_inference().parse_args()
|
||||
|
||||
# STDC model
|
||||
|
|
@ -55,7 +55,6 @@ def main():
|
|||
STDC_model.eval()
|
||||
|
||||
# DMPR model
|
||||
args = config.get_parser_for_inference().parse_args()
|
||||
# DMPRmodel = DirectionalPointDetector(3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
|
||||
# DMPRmodel.load_state_dict(torch.load(DMPRweights))
|
||||
DMPRmodel = Model(args.cfg, ch=3).to(device)
|
||||
|
|
@ -101,12 +100,9 @@ def main():
|
|||
# 绘制角点
|
||||
plot_points(img0, det1)
|
||||
|
||||
# save
|
||||
# cv2.imwrite(file, img0)
|
||||
|
||||
# yolo joint DMPR
|
||||
cls = 0 #需要过滤的box类别
|
||||
joint_det, dilate_box = dmpr_yolo(det1, det0, img0.shape, cls, args.scale_ratio)
|
||||
joint_det, dilate_box = dmpr_yolo(det1, det0, img0.shape, cls, args.scale_ratio, args.border)
|
||||
|
||||
t_joint = time.time()
|
||||
print(f't_joint. ({t_joint - t_dmpr:.3f}s)')
|
||||
|
|
|
|||
|
|
@ -4,7 +4,12 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
|
||||
def dmpr_yolo(dmpr_det, yolo_det, img_shape, cls:int, scale_ratio):
|
||||
def dmpr_yolo(dmpr_det, yolo_det, img_shape, cls:int, scale_ratio, border=80):
|
||||
# 过滤在图像边界的box(防止出现一小半车辆的情况)
|
||||
x_c = (yolo_det[:, 0] + yolo_det[:, 2]) / 2
|
||||
y_c = (yolo_det[:, 1] + yolo_det[:, 3]) / 2
|
||||
tmp = (x_c >= border) & (x_c <= (img_shape[1] - border)) & (y_c >= border) & (y_c <= (img_shape[0] - border))
|
||||
yolo_det = yolo_det[tmp]
|
||||
|
||||
# 创建yolo_det_clone内容为x1, y1, x2, y2, conf, cls, unlabel (unlabel代表该类是否需要忽略,0:不忽略 其他:忽略)
|
||||
yolo_det_clone = yolo_det.copy()
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ NUM_FEATURE_MAP_CHANNEL = 6
|
|||
|
||||
def add_common_arguments(parser):
|
||||
"""Add common arguments for training and inference."""
|
||||
parser.add_argument('--detector_weights', default=r'E:\pycharmProject\DMPR-PS\weights\dp_detector_499.pth',
|
||||
parser.add_argument('--detector_weights', default=r'E:\pycharmProject\DMPR-PS\weights\dp_detector_299.pth',
|
||||
help="The weights of pretrained detector.")
|
||||
parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml',
|
||||
help='model.yaml path')
|
||||
|
|
@ -46,6 +46,8 @@ def get_parser_for_inference():
|
|||
help='IOU threshold for NMS')
|
||||
parser.add_argument('--scale-ratio', type=float, default=0.5,
|
||||
help='detected box scale ratio')
|
||||
parser.add_argument('--border', type=int, default=80,
|
||||
help='The valid border to boundary')
|
||||
parser.add_argument('--ovlap-thres', type=float, default=0.6, help='overlap threshold for OBS')
|
||||
parser.add_argument('--agnostic-nms', action='store_true',
|
||||
help='class-agnostic NMS')
|
||||
|
|
|
|||
Loading…
Reference in New Issue