Browse Source

add ovlap_suppression branch

ovlap_suppression
Administrator 1 year ago
parent
commit
ec81e5439e
2 changed files with 48 additions and 1 deletions
  1. +5
    -1
      detect.py
  2. +43
    -0
      utils/general.py

+ 5
- 1
detect.py View File

from models.experimental import attempt_load from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \ from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, overlap_box_suppression
from utils.plots import plot_one_box from utils.plots import plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized from utils.torch_utils import select_device, load_classifier, time_synchronized




# Apply NMS # Apply NMS
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms) pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)

pred = overlap_box_suppression(pred, opt.ovlap_thres)

t2 = time_synchronized() t2 = time_synchronized()


# Apply Classifier # Apply Classifier
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold') parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS') parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
parser.add_argument('--ovlap-thres', type=float, default=0.6, help='overlap threshold for OBS')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--view-img', action='store_true', help='display results')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')

+ 43
- 0
utils/general.py View File



return output return output


def overlap_box_suppression(prediction, ovlap_thres = 0.6):
"""Runs overlap_box_suppression on inference results
delete the box that overlap of boxes bigger than ovlap_thres
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
def box_iob(box1, box2):
def box_area(box):
return (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])

area1 = box_area(box1) # (N,)
area2 = box_area(box2) # (M,)

# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
lt = torch.max(box1[:, None, :2], box2[:, :2]) # [N,M,2] # N中一个和M个比较;
rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) #小于0的为0 clamp 钳;夹钳;
inter = wh[:, :, 0] * wh[:, :, 1]

return torch.squeeze(inter / area1), torch.squeeze(inter / area2)

output = [torch.zeros((0, 6), device=prediction[0].device)] * len(prediction)
for i, x in enumerate(prediction):
keep = [] # 最终保留的结果, 在boxes中对应的索引;
boxes = x[:, 0:4]
scores = x[:, 4]
cls = x[:, 5]
idxs = scores.argsort()
while idxs.numel() > 0:
keep_idx = idxs[-1]
keep_box = boxes[keep_idx][None, ] # [1, 4]
keep.append(keep_idx)
if idxs.size(0) == 1:
break
idxs = idxs[:-1] # 将得分最大框 从索引中删除; 剩余索引对应的框 和 得分最大框 计算iob;
other_boxes = boxes[idxs]
this_cls = cls[keep_idx]
other_cls = cls[idxs]
iobs1, iobs2 = box_iob(keep_box, other_boxes) # 一个框和其余框比较 1XM
idxs = idxs[((iobs1 <= ovlap_thres) & (iobs2 <= ovlap_thres)) | (other_cls != this_cls)]
keep = idxs.new(keep) # Tensor
output[i] = x[keep]
return output


def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer() def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
# Strip optimizer from 'f' to finalize training, optionally save as 's' # Strip optimizer from 'f' to finalize training, optionally save as 's'

Loading…
Cancel
Save