浏览代码

Simplify autoshape() post-process (#1653)

* Simplify autoshape() post-process

* cleanup

* cleanup
5.0
Glenn Jocher GitHub 3 年前
父节点
当前提交
fa8f1fb0e9
找不到此签名对应的密钥 GPG 密钥 ID: 4AEE18F83AFDEB23
共有 4 个文件被更改,包括 9 次插入10 次删除
  1. +1
    -1
      hubconf.py
  2. +3
    -4
      models/common.py
  3. +4
    -4
      requirements.txt
  4. +1
    -1
      utils/general.py

+ 1
- 1
hubconf.py 查看文件

@@ -108,7 +108,7 @@ def yolov5x(pretrained=False, channels=3, classes=80):

if __name__ == '__main__':
model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # example
model = model.fuse().autoshape() # for PIL/cv2/np inputs and NMS
model = model.autoshape() # for PIL/cv2/np inputs and NMS

# Verify inference
from PIL import Image

+ 3
- 4
models/common.py 查看文件

@@ -167,8 +167,7 @@ class autoShape(nn.Module):

# Post-process
for i in batch:
if y[i] is not None:
y[i][:, :4] = scale_coords(shape1, y[i][:, :4], shape0[i])
scale_coords(shape1, y[i][:, :4], shape0[i])

return Detections(imgs, y, self.names)

@@ -177,13 +176,13 @@ class Detections:
# detections class for YOLOv5 inference results
def __init__(self, imgs, pred, names=None):
super(Detections, self).__init__()
d = pred[0].device # device
gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
self.imgs = imgs # list of images as numpy arrays
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
self.names = names # class names
self.xyxy = pred # xyxy pixels
self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
d = pred[0].device # device
gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
self.n = len(self.pred)

+ 4
- 4
requirements.txt 查看文件

@@ -9,8 +9,8 @@ Pillow
PyYAML>=5.3
scipy>=1.4.1
tensorboard>=2.2
torch>=1.6.0
torchvision>=0.7.0
torch>=1.7.0
torchvision>=0.8.1
tqdm>=4.41.0

# logging -------------------------------------
@@ -26,5 +26,5 @@ pandas
# scikit-learn==0.19.2 # for coreml quantization

# extras --------------------------------------
# thop # FLOPS computation
# pycocotools>=2.0 # COCO mAP
thop # FLOPS computation
pycocotools>=2.0 # COCO mAP

+ 1
- 1
utils/general.py 查看文件

@@ -258,7 +258,7 @@ def wh_iou(wh1, wh2):
return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)


def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, classes=None, agnostic=False, labels=()):
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
"""Performs Non-Maximum Suppression (NMS) on inference results

Returns:

正在加载...
取消
保存