ソースを参照

NMS fast mode

5.0
Glenn Jocher 4年前
コミット
eb97b2e413
4個のファイルの変更26行の追加22行の削除
  1. +1
    -1
      detect.py
  2. +2
    -2
      test.py
  3. +11
    -11
      train.py
  4. +12
    -8
      utils/utils.py

+ 1
- 1
detect.py ファイルの表示

@@ -76,7 +76,7 @@ def detect(save_img=False):

# Apply NMS
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres,
multi_label=False, classes=opt.classes, agnostic=opt.agnostic_nms)
fast=True, classes=opt.classes, agnostic=opt.agnostic_nms)

# Apply Classifier
if classify:

+ 2
- 2
test.py ファイルの表示

@@ -19,7 +19,7 @@ def test(data,
augment=False,
model=None,
dataloader=None,
multi_label=True,
fast=False,
verbose=False): # 0 fast, 1 accurate
# Initialize/load model and set device
if model is None:
@@ -92,7 +92,7 @@ def test(data,

# Run NMS
t = torch_utils.time_synchronized()
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, multi_label=multi_label)
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, fast=fast)
t1 += torch_utils.time_synchronized() - t

# Statistics per image

+ 11
- 11
train.py ファイルの表示

@@ -293,13 +293,13 @@ def train(hyp):
final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP
results, maps, times = test.test(opt.data,
batch_size=batch_size,
imgsz=imgsz_test,
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
model=ema.ema,
single_cls=opt.single_cls,
dataloader=testloader,
multi_label=ni > n_burn)
batch_size=batch_size,
imgsz=imgsz_test,
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
model=ema.ema,
single_cls=opt.single_cls,
dataloader=testloader,
fast=ni > n_burn)

# Write
with open(results_file, 'a') as f:
@@ -325,10 +325,10 @@ def train(hyp):
if save:
with open(results_file, 'r') as f: # create checkpoint
ckpt = {'epoch': epoch,
'best_fitness': best_fitness,
'training_results': f.read(),
'model': ema.ema.module if hasattr(model, 'module') else ema.ema,
'optimizer': None if final_epoch else optimizer.state_dict()}
'best_fitness': best_fitness,
'training_results': f.read(),
'model': ema.ema.module if hasattr(model, 'module') else ema.ema,
'optimizer': None if final_epoch else optimizer.state_dict()}

# Save last, best and delete
torch.save(ckpt, last)

+ 12
- 8
utils/utils.py ファイルの表示

@@ -19,7 +19,7 @@ import torchvision
from scipy.signal import butter, filtfilt
from tqdm import tqdm

from . import torch_utils, google_utils # torch_utils, google_utils
from . import torch_utils, google_utils #  torch_utils, google_utils

# Set printoptions
torch.set_printoptions(linewidth=320, precision=5, profile='long')
@@ -460,29 +460,33 @@ def build_targets(p, targets, model):

return tcls, tbox, indices, anch


def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=True, classes=None, agnostic=False):
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, classes=None, agnostic=False):
"""
Performs Non-Maximum Suppression on inference results
Returns detections with shape:
nx6 (x1, y1, x2, y2, conf, cls)
"""
nc = prediction[0].shape[1] - 5 # number of classes

# Settings
merge = True # merge for best mAP
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
max_det = 300 # maximum number of detections per image
time_limit = 10.0 # seconds to quit after
redundant = conf_thres == 0.001 # require redundant detections
redundant = True # require redundant detections
fast |= conf_thres > 0.001 # fast mode
if fast:
merge = False
multi_label = False
else:
merge = True # merge for best mAP (adds 0.5ms/img)
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)

t = time.time()
nc = prediction[0].shape[1] - 5 # number of classes
multi_label &= nc > 1 # multiple labels per box
output = [None] * prediction.shape[0]
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[x[:, 4] > conf_thres] # confidence
# x = x[((x[:, 2:4] > min_wh) & (x[:, 2:4] < max_wh)).all(1)] # width-height

# If none remain process next image
if not x.shape[0]:

読み込み中…
キャンセル
保存