Browse Source

Fix Detections class `tolist()` method (#5945)

* Fix tolist() to add the file for each Detection

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix PEP8 requirement for 2 spaces before an inline comment

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cleanup

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
modifyDataloader
Yono Mittlefehldt GitHub 2 years ago
parent
commit
8f354362cd
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 5 deletions
  1. +7
    -5
      models/common.py

+ 7
- 5
models/common.py View File



class Detections: class Detections:
# YOLOv5 detections class for inference results # YOLOv5 detections class for inference results
def __init__(self, imgs, pred, files, times=None, names=None, shape=None):
def __init__(self, imgs, pred, files, times=(0, 0, 0, 0), names=None, shape=None):
super().__init__() super().__init__()
d = pred[0].device # device d = pred[0].device # device
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in imgs] # normalizations gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in imgs] # normalizations
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls) self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
self.names = names # class names self.names = names # class names
self.files = files # image filenames self.files = files # image filenames
self.times = times # profiling times
self.xyxy = pred # xyxy pixels self.xyxy = pred # xyxy pixels
self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized


def tolist(self): def tolist(self):
# return a list of Detections objects, i.e. 'for result in results.tolist():' # return a list of Detections objects, i.e. 'for result in results.tolist():'
x = [Detections([self.imgs[i]], [self.pred[i]], names=self.names, shape=self.s) for i in range(self.n)]
for d in x:
for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
setattr(d, k, getattr(d, k)[0]) # pop out of list
r = range(self.n) # iterable
x = [Detections([self.imgs[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
# for d in x:
# for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
# setattr(d, k, getattr(d, k)[0]) # pop out of list
return x return x


def __len__(self): def __len__(self):

Loading…
Cancel
Save