Browse Source

2

add_stdc_seg
nyh 1 year ago
parent
commit
4c861391aa
3 changed files with 11 additions and 8 deletions
  1. +2
    -6
      AI_example.py
  2. +6
    -1
      DMPR_YOLO/jointUtil.py
  3. +3
    -1
      conf/config.py

+ 2
- 6
AI_example.py View File

@@ -43,7 +43,7 @@ def main():
if half:
model.half()

# STDC model
# load args
args = config.get_parser_for_inference().parse_args()

# STDC model
@@ -55,7 +55,6 @@ def main():
STDC_model.eval()

# DMPR model
args = config.get_parser_for_inference().parse_args()
# DMPRmodel = DirectionalPointDetector(3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
# DMPRmodel.load_state_dict(torch.load(DMPRweights))
DMPRmodel = Model(args.cfg, ch=3).to(device)
@@ -101,12 +100,9 @@ def main():
# 绘制角点
plot_points(img0, det1)

# save
# cv2.imwrite(file, img0)

# yolo joint DMPR
cls = 0 #需要过滤的box类别
joint_det, dilate_box = dmpr_yolo(det1, det0, img0.shape, cls, args.scale_ratio)
joint_det, dilate_box = dmpr_yolo(det1, det0, img0.shape, cls, args.scale_ratio, args.border)

t_joint = time.time()
print(f't_joint. ({t_joint - t_dmpr:.3f}s)')

+ 6
- 1
DMPR_YOLO/jointUtil.py View File

@@ -4,7 +4,12 @@ import numpy as np
import torch


def dmpr_yolo(dmpr_det, yolo_det, img_shape, cls:int, scale_ratio):
def dmpr_yolo(dmpr_det, yolo_det, img_shape, cls:int, scale_ratio, border=80):
# 过滤在图像边界的box(防止出现一小半车辆的情况)
x_c = (yolo_det[:, 0] + yolo_det[:, 2]) / 2
y_c = (yolo_det[:, 1] + yolo_det[:, 3]) / 2
tmp = (x_c >= border) & (x_c <= (img_shape[1] - border)) & (y_c >= border) & (y_c <= (img_shape[0] - border))
yolo_det = yolo_det[tmp]

# 创建yolo_det_clone内容为x1, y1, x2, y2, conf, cls, unlabel (unlabel代表该类是否需要忽略,0:不忽略 其他:忽略)
yolo_det_clone = yolo_det.copy()

+ 3
- 1
conf/config.py View File

@@ -7,7 +7,7 @@ NUM_FEATURE_MAP_CHANNEL = 6

def add_common_arguments(parser):
"""Add common arguments for training and inference."""
parser.add_argument('--detector_weights', default=r'E:\pycharmProject\DMPR-PS\weights\dp_detector_499.pth',
parser.add_argument('--detector_weights', default=r'E:\pycharmProject\DMPR-PS\weights\dp_detector_299.pth',
help="The weights of pretrained detector.")
parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml',
help='model.yaml path')
@@ -46,6 +46,8 @@ def get_parser_for_inference():
help='IOU threshold for NMS')
parser.add_argument('--scale-ratio', type=float, default=0.5,
help='detected box scale ratio')
parser.add_argument('--border', type=int, default=80,
help='The valid border to boundary')
parser.add_argument('--ovlap-thres', type=float, default=0.6, help='overlap threshold for OBS')
parser.add_argument('--agnostic-nms', action='store_true',
help='class-agnostic NMS')

Loading…
Cancel
Save