add ovlap_suppression branch
This commit is contained in:
parent
9c689502b6
commit
ec81e5439e
|
|
@ -10,7 +10,7 @@ from numpy import random
|
|||
from models.experimental import attempt_load
|
||||
from utils.datasets import LoadStreams, LoadImages
|
||||
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
|
||||
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
|
||||
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, overlap_box_suppression
|
||||
from utils.plots import plot_one_box
|
||||
from utils.torch_utils import select_device, load_classifier, time_synchronized
|
||||
|
||||
|
|
@ -73,6 +73,9 @@ def detect(save_img=False):
|
|||
|
||||
# Apply NMS
|
||||
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
|
||||
|
||||
pred = overlap_box_suppression(pred, opt.ovlap_thres)
|
||||
|
||||
t2 = time_synchronized()
|
||||
|
||||
# Apply Classifier
|
||||
|
|
@ -153,6 +156,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
|
||||
parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
|
||||
parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
|
||||
parser.add_argument('--ovlap-thres', type=float, default=0.6, help='overlap threshold for OBS')
|
||||
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
||||
parser.add_argument('--view-img', action='store_true', help='display results')
|
||||
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
|
||||
|
|
|
|||
|
|
@ -508,6 +508,49 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
|
|||
|
||||
return output
|
||||
|
||||
def overlap_box_suppression(prediction, ovlap_thres = 0.6):
|
||||
"""Runs overlap_box_suppression on inference results
|
||||
delete the box that overlap of boxes bigger than ovlap_thres
|
||||
Returns:
|
||||
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
||||
"""
|
||||
def box_iob(box1, box2):
|
||||
def box_area(box):
|
||||
return (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
|
||||
|
||||
area1 = box_area(box1) # (N,)
|
||||
area2 = box_area(box2) # (M,)
|
||||
|
||||
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
||||
lt = torch.max(box1[:, None, :2], box2[:, :2]) # [N,M,2] # N中一个和M个比较;
|
||||
rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # [N,M,2]
|
||||
wh = (rb - lt).clamp(min=0) #小于0的为0 clamp 钳;夹钳;
|
||||
inter = wh[:, :, 0] * wh[:, :, 1]
|
||||
|
||||
return torch.squeeze(inter / area1), torch.squeeze(inter / area2)
|
||||
|
||||
output = [torch.zeros((0, 6), device=prediction[0].device)] * len(prediction)
|
||||
for i, x in enumerate(prediction):
|
||||
keep = [] # 最终保留的结果, 在boxes中对应的索引;
|
||||
boxes = x[:, 0:4]
|
||||
scores = x[:, 4]
|
||||
cls = x[:, 5]
|
||||
idxs = scores.argsort()
|
||||
while idxs.numel() > 0:
|
||||
keep_idx = idxs[-1]
|
||||
keep_box = boxes[keep_idx][None, ] # [1, 4]
|
||||
keep.append(keep_idx)
|
||||
if idxs.size(0) == 1:
|
||||
break
|
||||
idxs = idxs[:-1] # 将得分最大框 从索引中删除; 剩余索引对应的框 和 得分最大框 计算iob;
|
||||
other_boxes = boxes[idxs]
|
||||
this_cls = cls[keep_idx]
|
||||
other_cls = cls[idxs]
|
||||
iobs1, iobs2 = box_iob(keep_box, other_boxes) # 一个框和其余框比较 1XM
|
||||
idxs = idxs[((iobs1 <= ovlap_thres) & (iobs2 <= ovlap_thres)) | (other_cls != this_cls)]
|
||||
keep = idxs.new(keep) # Tensor
|
||||
output[i] = x[keep]
|
||||
return output
|
||||
|
||||
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
|
||||
# Strip optimizer from 'f' to finalize training, optionally save as 's'
|
||||
|
|
|
|||
Loading…
Reference in New Issue