Browse Source

Update test.py for IoU in native image-space (#1439)

* Update test.py for IoU in native image-space

* remove redundant

* gn to device

* remove output scale_coords

* --img-size correction

* update

* native-space labels

* pred to predn

* remove clip_coords()
5.0
Glenn Jocher GitHub 4 years ago
parent
commit
225845e781
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 13 deletions
  1. +14
    -0
      models/common.py
  2. +11
    -13
      test.py

+ 14
- 0
models/common.py View File

@@ -148,6 +148,8 @@ class autoShape(nn.Module):
batch = range(len(imgs)) # batch size
for i in batch:
imgs[i] = np.array(imgs[i]) # to numpy
if imgs[i].shape[0] < 5: # image in CHW
imgs[i] = imgs[i].transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
imgs[i] = imgs[i][:, :, :3] if imgs[i].ndim == 3 else np.tile(imgs[i][:, :, None], 3) # enforce 3ch input
s = imgs[i].shape[:2] # HWC
shape0.append(s) # image shape
@@ -184,6 +186,7 @@ class Detections:
gn = [torch.Tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.]) for im in imgs] # normalization gains
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
self.n = len(self.pred)

def display(self, pprint=False, show=False, save=False):
colors = color_list()
@@ -216,6 +219,17 @@ class Detections:
def save(self):
self.display(save=True) # save results

def __len__(self):
return self.n

def tolist(self):
# return a list of Detections objects, i.e. 'for result in results.tolist():'
x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)]
for d in x:
for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
setattr(d, k, getattr(d, k)[0]) # pop out of list
return x


class Flatten(nn.Module):
# Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions

+ 11
- 13
test.py View File

@@ -12,7 +12,7 @@ from tqdm import tqdm
from models.experimental import attempt_load
from utils.datasets import create_dataloader
from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, box_iou, \
non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, clip_coords, set_logging, increment_path
non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path
from utils.loss import compute_loss
from utils.metrics import ap_per_class
from utils.plots import plot_images, output_to_target
@@ -124,6 +124,7 @@ def test(data,
labels = targets[targets[:, 0] == si, 1:]
nl = len(labels)
tcls = labels[:, 0].tolist() if nl else [] # target class
path = Path(paths[si])
seen += 1

if len(pred) == 0:
@@ -131,13 +132,14 @@ def test(data,
stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
continue

# Predictions
predn = pred.clone()
scale_coords(img[si].shape[1:], predn[:, :4], shapes[si][0], shapes[si][1]) # native-space pred

# Append to text file
path = Path(paths[si])
if save_txt:
gn = torch.tensor(shapes[si][0])[[1, 0, 1, 0]] # normalization gain whwh
x = pred.clone()
x[:, :4] = scale_coords(img[si].shape[1:], x[:, :4], shapes[si][0], shapes[si][1]) # to original
for *xyxy, conf, cls in x:
for *xyxy, conf, cls in predn.tolist():
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
with open(save_dir / 'labels' / (path.stem + '.txt'), 'a') as f:
@@ -150,19 +152,14 @@ def test(data,
"box_caption": "%s %.3f" % (names[cls], conf),
"scores": {"class_score": conf},
"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
boxes = {"predictions": {"box_data": box_data, "class_labels": names}}
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
wandb_images.append(wandb.Image(img[si], boxes=boxes, caption=path.name))

# Clip boxes to image bounds
clip_coords(pred, (height, width))

# Append to pycocotools JSON dictionary
if save_json:
# [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ...
image_id = int(path.stem) if path.stem.isnumeric() else path.stem
box = pred[:, :4].clone() # xyxy
scale_coords(img[si].shape[1:], box, shapes[si][0], shapes[si][1]) # to original shape
box = xyxy2xywh(box) # xywh
box = xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
for p, b in zip(pred.tolist(), box.tolist()):
jdict.append({'image_id': image_id,
@@ -178,6 +175,7 @@ def test(data,

# target boxes
tbox = xywh2xyxy(labels[:, 1:5]) * whwh
scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1]) # native-space labels

# Per target class
for cls in torch.unique(tcls_tensor):
@@ -187,7 +185,7 @@ def test(data,
# Search for detections
if pi.shape[0]:
# Prediction to target ious
ious, i = box_iou(pred[pi, :4], tbox[ti]).max(1) # best ious, indices
ious, i = box_iou(predn[pi, :4], tbox[ti]).max(1) # best ious, indices

# Append detections
detected_set = set()

Loading…
Cancel
Save