Browse Source

Add Hub results.pandas() method (#2725)

* Add Hub results.pandas() method

New method converts results from torch tensors to pandas DataFrames with column names.

This PR may partially resolve issue https://github.com/ultralytics/yolov5/issues/2703

```python
results = model(imgs)

print(results.pandas().xyxy[0])
         xmin        ymin        xmax        ymax  confidence  class    name
0   57.068970  391.770599  241.383545  905.797852    0.868964      0  person
1  667.661255  399.303589  810.000000  881.396667    0.851888      0  person
2  222.878387  414.774231  343.804474  857.825073    0.838376      0  person
3    4.205386  234.447678  803.739136  750.023376    0.658006      5     bus
4    0.000000  550.596008   76.681190  878.669922    0.450596      0  person
```

* Update comments 

torch example input now shown resized to size=640 and also now a multiple of P6 stride 64 (see https://github.com/ultralytics/yolov5/issues/2722#issuecomment-814785930)

* apply decorators

* PEP8

* Update common.py

* pd.options.display.max_columns = 10

* Update common.py
5.0
Glenn Jocher GitHub 3 years ago
parent
commit
c03d590320
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 18 deletions
  1. +1
    -1
      hubconf.py
  2. +29
    -17
      models/common.py
  3. +2
    -0
      utils/general.py

+ 1
- 1
hubconf.py View File

fname = f'{name}.pt' # checkpoint filename fname = f'{name}.pt' # checkpoint filename
attempt_download(fname) # download if not found locally attempt_download(fname) # download if not found locally
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
msd = model.state_dict() # model state_dict
msd = model.state_dict() # model state_dict
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
model.load_state_dict(csd, strict=False) # load model.load_state_dict(csd, strict=False) # load

+ 29
- 17
models/common.py View File

# YOLOv5 common modules # YOLOv5 common modules


import math import math
from copy import copy
from pathlib import Path from pathlib import Path


import numpy as np import numpy as np
import pandas as pd
import requests import requests
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
from torch.cuda import amp


from utils.datasets import letterbox from utils.datasets import letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh
print('autoShape already enabled, skipping... ') # model already converted to model.autoshape() print('autoShape already enabled, skipping... ') # model already converted to model.autoshape()
return self return self


@torch.no_grad()
@torch.cuda.amp.autocast()
def forward(self, imgs, size=640, augment=False, profile=False): def forward(self, imgs, size=640, augment=False, profile=False):
# Inference from various sources. For height=720, width=1280, RGB images example inputs are:
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
# filename: imgs = 'data/samples/zidane.jpg' # filename: imgs = 'data/samples/zidane.jpg'
# URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg' # URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
# PIL: = Image.open('image.jpg') # HWC x(720,1280,3)
# numpy: = np.zeros((720,1280,3)) # HWC
# torch: = torch.zeros(16,3,720,1280) # BCHW
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
# PIL: = Image.open('image.jpg') # HWC x(640,1280,3)
# numpy: = np.zeros((640,1280,3)) # HWC
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images


t = [time_synchronized()] t = [time_synchronized()]
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32 x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
t.append(time_synchronized()) t.append(time_synchronized())


with torch.no_grad(), amp.autocast(enabled=p.device.type != 'cpu'):
# Inference
y = self.model(x, augment, profile)[0] # forward
t.append(time_synchronized())
# Inference
y = self.model(x, augment, profile)[0] # forward
t.append(time_synchronized())


# Post-process
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i])
# Post-process
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i])


t.append(time_synchronized()) t.append(time_synchronized())
return Detections(imgs, y, files, t, self.names, x.shape) return Detections(imgs, y, files, t, self.names, x.shape)
self.display(render=True) # render results self.display(render=True) # render results
return self.imgs return self.imgs


def __len__(self):
return self.n
def pandas(self):
# return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
new = copy(self) # return copy
ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
return new


def tolist(self): def tolist(self):
# return a list of Detections objects, i.e. 'for result in results.tolist():' # return a list of Detections objects, i.e. 'for result in results.tolist():'
x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)]
x = [Detections([self.imgs[i]], [self.pred[i]], self.names, self.s) for i in range(self.n)]
for d in x: for d in x:
for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']: for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
setattr(d, k, getattr(d, k)[0]) # pop out of list setattr(d, k, getattr(d, k)[0]) # pop out of list
return x return x


def __len__(self):
return self.n



class Classify(nn.Module): class Classify(nn.Module):
# Classification head, i.e. x(b,c1,20,20) to x(b,c2) # Classification head, i.e. x(b,c1,20,20) to x(b,c2)

+ 2
- 0
utils/general.py View File



import cv2 import cv2
import numpy as np import numpy as np
import pandas as pd
import torch import torch
import torchvision import torchvision
import yaml import yaml
# Settings # Settings
torch.set_printoptions(linewidth=320, precision=5, profile='long') torch.set_printoptions(linewidth=320, precision=5, profile='long')
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5 np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
pd.options.display.max_columns = 10
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader) cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads



Loading…
Cancel
Save