소스 검색

Update inference default to multi_label=False (#2252)

* Update inference default to multi_label=False

* bug fix

* Update plots.py

* Update plots.py
5.0
Glenn Jocher GitHub 3 년 전
부모
커밋
c09964c27c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4개의 변경된 파일11개의 추가작업 그리고 10개의 파일을 삭제
  1. +1
    -1
      models/common.py
  2. +4
    -4
      test.py
  3. +5
    -4
      utils/general.py
  4. +1
    -1
      utils/plots.py

+ 1
- 1
models/common.py 파일 보기

@@ -7,7 +7,7 @@ import numpy as np
import requests
import torch
import torch.nn as nn
from PIL import Image, ImageDraw
from PIL import Image

from utils.datasets import letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh

+ 4
- 4
test.py 파일 보기

@@ -106,7 +106,7 @@ def test(data,
with torch.no_grad():
# Run model
t = time_synchronized()
inf_out, train_out = model(img, augment=augment) # inference and training outputs
out, train_out = model(img, augment=augment) # inference and training outputs
t0 += time_synchronized() - t

# Compute loss
@@ -117,11 +117,11 @@ def test(data,
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
t = time_synchronized()
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb)
out = non_max_suppression(out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb, multi_label=True)
t1 += time_synchronized() - t

# Statistics per image
for si, pred in enumerate(output):
for si, pred in enumerate(out):
labels = targets[targets[:, 0] == si, 1:]
nl = len(labels)
tcls = labels[:, 0].tolist() if nl else [] # target class
@@ -209,7 +209,7 @@ def test(data,
f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels
Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start()
f = save_dir / f'test_batch{batch_i}_pred.jpg' # predictions
Thread(target=plot_images, args=(img, output_to_target(output), paths, f, names), daemon=True).start()
Thread(target=plot_images, args=(img, output_to_target(out), paths, f, names), daemon=True).start()

# Compute statistics
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy

+ 5
- 4
utils/general.py 파일 보기

@@ -390,11 +390,12 @@ def wh_iou(wh1, wh2):
return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)


def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
"""Performs Non-Maximum Suppression (NMS) on inference results
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
labels=()):
"""Runs Non-Maximum Suppression (NMS) on inference results

Returns:
detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""

nc = prediction.shape[2] - 5 # number of classes
@@ -406,7 +407,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS

t = time.time()

+ 1
- 1
utils/plots.py 파일 보기

@@ -54,7 +54,7 @@ def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
return filtfilt(b, a, data) # forward-backward filter


def plot_one_box(x, img, color=None, label=None, line_thickness=None):
def plot_one_box(x, img, color=None, label=None, line_thickness=3):
# Plots one bounding box on image img
tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
color = color or [random.randint(0, 255) for _ in range(3)]

Loading…
취소
저장