Browse Source

remove fast, add merge

5.0
Glenn Jocher 4 years ago
parent
commit
1f1917ef56
3 changed files with 6 additions and 11 deletions
  1. +4
    -3
      test.py
  2. +1
    -2
      train.py
  3. +1
    -6
      utils/utils.py

+ 4
- 3
test.py View File

@@ -19,7 +19,7 @@ def test(data,
verbose=False,
model=None,
dataloader=None,
fast=False):
merge=False):
# Initialize/load model and set device
if model is None:
training = False
@@ -65,7 +65,7 @@ def test(data,
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once

fast |= conf_thres > 0.001 # enable fast mode
merge = opt.merge # use Merge NMS
path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images
dataset = LoadImagesAndLabels(path,
imgsz,
@@ -109,7 +109,7 @@ def test(data,

# Run NMS
t = torch_utils.time_synchronized()
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, fast=fast)
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, merge=merge)
t1 += torch_utils.time_synchronized() - t

# Statistics per image
@@ -254,6 +254,7 @@ if __name__ == '__main__':
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--merge', action='store_true', help='use Merge NMS')
parser.add_argument('--verbose', action='store_true', help='report mAP by class')
opt = parser.parse_args()
opt.img_size = check_img_size(opt.img_size)

+ 1
- 2
train.py View File

@@ -305,8 +305,7 @@ def train(hyp):
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
model=ema.ema,
single_cls=opt.single_cls,
dataloader=testloader,
fast=epoch < epochs / 2)
dataloader=testloader)

# Write
with open(results_file, 'a') as f:

+ 1
- 6
utils/utils.py View File

@@ -527,7 +527,7 @@ def build_targets(p, targets, model):
return tcls, tbox, indices, anch


def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, classes=None, agnostic=False):
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, classes=None, agnostic=False):
"""Performs Non-Maximum Suppression (NMS) on inference results

Returns:
@@ -544,12 +544,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c
max_det = 300 # maximum number of detections per image
time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections
fast |= conf_thres > 0.001 # fast mode
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
if fast:
merge = False
else:
merge = True # merge for best mAP (adds 0.5ms/img)

t = time.time()
output = [None] * prediction.shape[0]

Loading…
Cancel
Save