add ovlap_suppression branch

This commit is contained in:
Administrator 2023-04-21 14:21:56 +08:00
parent 9c689502b6
commit ec81e5439e
2 changed files with 48 additions and 1 deletions

View File

@ -10,7 +10,7 @@ from numpy import random
from models.experimental import attempt_load from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \ from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, overlap_box_suppression
from utils.plots import plot_one_box from utils.plots import plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized from utils.torch_utils import select_device, load_classifier, time_synchronized
@ -73,6 +73,9 @@ def detect(save_img=False):
# Apply NMS # Apply NMS
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms) pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
pred = overlap_box_suppression(pred, opt.ovlap_thres)
t2 = time_synchronized() t2 = time_synchronized()
# Apply Classifier # Apply Classifier
@ -153,6 +156,7 @@ if __name__ == '__main__':
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold') parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS') parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
parser.add_argument('--ovlap-thres', type=float, default=0.6, help='overlap threshold for OBS')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--view-img', action='store_true', help='display results')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')

View File

@ -508,6 +508,49 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
return output return output
def overlap_box_suppression(prediction, ovlap_thres = 0.6):
"""Runs overlap_box_suppression on inference results
delete the box that overlap of boxes bigger than ovlap_thres
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
def box_iob(box1, box2):
def box_area(box):
return (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
area1 = box_area(box1) # (N,)
area2 = box_area(box2) # (M,)
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
lt = torch.max(box1[:, None, :2], box2[:, :2]) # [N,M,2] # N中一个和M个比较
rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) #小于0的为0 clamp 钳;夹钳;
inter = wh[:, :, 0] * wh[:, :, 1]
return torch.squeeze(inter / area1), torch.squeeze(inter / area2)
output = [torch.zeros((0, 6), device=prediction[0].device)] * len(prediction)
for i, x in enumerate(prediction):
keep = [] # 最终保留的结果, 在boxes中对应的索引
boxes = x[:, 0:4]
scores = x[:, 4]
cls = x[:, 5]
idxs = scores.argsort()
while idxs.numel() > 0:
keep_idx = idxs[-1]
keep_box = boxes[keep_idx][None, ] # [1, 4]
keep.append(keep_idx)
if idxs.size(0) == 1:
break
idxs = idxs[:-1] # 将得分最大框 从索引中删除; 剩余索引对应的框 和 得分最大框 计算iob
other_boxes = boxes[idxs]
this_cls = cls[keep_idx]
other_cls = cls[idxs]
iobs1, iobs2 = box_iob(keep_box, other_boxes) # 一个框和其余框比较 1XM
idxs = idxs[((iobs1 <= ovlap_thres) & (iobs2 <= ovlap_thres)) | (other_cls != this_cls)]
keep = idxs.new(keep) # Tensor
output[i] = x[keep]
return output
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer() def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
# Strip optimizer from 'f' to finalize training, optionally save as 's' # Strip optimizer from 'f' to finalize training, optionally save as 's'