autoShape() default for PyTorch Hub models (#1692)

* Add autoshape parameter

* Remove autoshape call in ReadMe

* Update hubconf.py

* file/URI inputs and autoshape check passthrough

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
NanoCode012 2020-12-27 10:58:26 +07:00 committed by GitHub
parent c0ffcdf998
commit 14b0abe2d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 33 deletions

View File

@ -106,7 +106,7 @@ import torch
from PIL import Image from PIL import Image
# Model # Model
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True).autoshape() # for PIL/cv2/np inputs and NMS model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) # for PIL/cv2/np inputs and NMS
# Images # Images
img1 = Image.open('zidane.jpg') img1 = Image.open('zidane.jpg')

View File

@ -17,7 +17,7 @@ dependencies = ['torch', 'yaml']
set_logging() set_logging()
def create(name, pretrained, channels, classes): def create(name, pretrained, channels, classes, autoshape):
"""Creates a specified YOLOv5 model """Creates a specified YOLOv5 model
Arguments: Arguments:
@ -41,7 +41,8 @@ def create(name, pretrained, channels, classes):
model.load_state_dict(state_dict, strict=False) # load model.load_state_dict(state_dict, strict=False) # load
if len(ckpt['model'].names) == classes: if len(ckpt['model'].names) == classes:
model.names = ckpt['model'].names # set class names attribute model.names = ckpt['model'].names # set class names attribute
# model = model.autoshape() # for PIL/cv2/np inputs and NMS if autoshape:
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
return model return model
except Exception as e: except Exception as e:
@ -50,7 +51,7 @@ def create(name, pretrained, channels, classes):
raise Exception(s) from e raise Exception(s) from e
def yolov5s(pretrained=False, channels=3, classes=80): def yolov5s(pretrained=False, channels=3, classes=80, autoshape=True):
"""YOLOv5-small model from https://github.com/ultralytics/yolov5 """YOLOv5-small model from https://github.com/ultralytics/yolov5
Arguments: Arguments:
@ -61,10 +62,10 @@ def yolov5s(pretrained=False, channels=3, classes=80):
Returns: Returns:
pytorch model pytorch model
""" """
return create('yolov5s', pretrained, channels, classes) return create('yolov5s', pretrained, channels, classes, autoshape)
def yolov5m(pretrained=False, channels=3, classes=80): def yolov5m(pretrained=False, channels=3, classes=80, autoshape=True):
"""YOLOv5-medium model from https://github.com/ultralytics/yolov5 """YOLOv5-medium model from https://github.com/ultralytics/yolov5
Arguments: Arguments:
@ -75,10 +76,10 @@ def yolov5m(pretrained=False, channels=3, classes=80):
Returns: Returns:
pytorch model pytorch model
""" """
return create('yolov5m', pretrained, channels, classes) return create('yolov5m', pretrained, channels, classes, autoshape)
def yolov5l(pretrained=False, channels=3, classes=80): def yolov5l(pretrained=False, channels=3, classes=80, autoshape=True):
"""YOLOv5-large model from https://github.com/ultralytics/yolov5 """YOLOv5-large model from https://github.com/ultralytics/yolov5
Arguments: Arguments:
@ -89,10 +90,10 @@ def yolov5l(pretrained=False, channels=3, classes=80):
Returns: Returns:
pytorch model pytorch model
""" """
return create('yolov5l', pretrained, channels, classes) return create('yolov5l', pretrained, channels, classes, autoshape)
def yolov5x(pretrained=False, channels=3, classes=80): def yolov5x(pretrained=False, channels=3, classes=80, autoshape=True):
"""YOLOv5-xlarge model from https://github.com/ultralytics/yolov5 """YOLOv5-xlarge model from https://github.com/ultralytics/yolov5
Arguments: Arguments:
@ -103,10 +104,10 @@ def yolov5x(pretrained=False, channels=3, classes=80):
Returns: Returns:
pytorch model pytorch model
""" """
return create('yolov5x', pretrained, channels, classes) return create('yolov5x', pretrained, channels, classes, autoshape)
def custom(path_or_model='path/to/model.pt'): def custom(path_or_model='path/to/model.pt', autoshape=True):
"""YOLOv5-custom model from https://github.com/ultralytics/yolov5 """YOLOv5-custom model from https://github.com/ultralytics/yolov5
Arguments (3 options): Arguments (3 options):
@ -124,13 +125,12 @@ def custom(path_or_model='path/to/model.pt'):
hub_model = Model(model.yaml).to(next(model.parameters()).device) # create hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
hub_model.load_state_dict(model.float().state_dict()) # load state_dict hub_model.load_state_dict(model.float().state_dict()) # load state_dict
hub_model.names = model.names # class names hub_model.names = model.names # class names
return hub_model return hub_model.autoshape() if autoshape else hub_model
if __name__ == '__main__': if __name__ == '__main__':
model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # pretrained example model = create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True) # pretrained example
# model = custom(path_or_model='path/to/model.pt') # custom example # model = custom(path_or_model='path/to/model.pt') # custom example
model = model.autoshape() # for PIL/cv2/np inputs and NMS
# Verify inference # Verify inference
from PIL import Image from PIL import Image

View File

@ -2,6 +2,7 @@
import math import math
import numpy as np import numpy as np
import requests
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
@ -143,35 +144,42 @@ class autoShape(nn.Module):
super(autoShape, self).__init__() super(autoShape, self).__init__()
self.model = model.eval() self.model = model.eval()
def autoshape(self):
print('autoShape already enabled, skipping... ') # model already converted to model.autoshape()
return self
def forward(self, imgs, size=640, augment=False, profile=False): def forward(self, imgs, size=640, augment=False, profile=False):
# supports inference from various sources. For height=720, width=1280, RGB images example inputs are: # Inference from various sources. For height=720, width=1280, RGB images example inputs are:
# opencv: imgs = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3) # filename: imgs = 'data/samples/zidane.jpg'
# PIL: imgs = Image.open('image.jpg') # HWC x(720,1280,3) # URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
# numpy: imgs = np.zeros((720,1280,3)) # HWC # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
# torch: imgs = torch.zeros(16,3,720,1280) # BCHW # PIL: = Image.open('image.jpg') # HWC x(720,1280,3)
# multiple: imgs = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images # numpy: = np.zeros((720,1280,3)) # HWC
# torch: = torch.zeros(16,3,720,1280) # BCHW
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
p = next(self.model.parameters()) # for device and type p = next(self.model.parameters()) # for device and type
if isinstance(imgs, torch.Tensor): # torch if isinstance(imgs, torch.Tensor): # torch
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
# Pre-process # Pre-process
if not isinstance(imgs, list): n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
imgs = [imgs]
shape0, shape1 = [], [] # image and inference shapes shape0, shape1 = [], [] # image and inference shapes
batch = range(len(imgs)) # batch size for i, im in enumerate(imgs):
for i in batch: if isinstance(im, str): # filename or uri
imgs[i] = np.array(imgs[i]) # to numpy im = Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im) # open
if imgs[i].shape[0] < 5: # image in CHW im = np.array(im) # to numpy
imgs[i] = imgs[i].transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1) if im.shape[0] < 5: # image in CHW
imgs[i] = imgs[i][:, :, :3] if imgs[i].ndim == 3 else np.tile(imgs[i][:, :, None], 3) # enforce 3ch input im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
s = imgs[i].shape[:2] # HWC im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input
s = im.shape[:2] # HWC
shape0.append(s) # image shape shape0.append(s) # image shape
g = (size / max(s)) # gain g = (size / max(s)) # gain
shape1.append([y * g for y in s]) shape1.append([y * g for y in s])
imgs[i] = im # update
shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
x = [letterbox(imgs[i], new_shape=shape1, auto=False)[0] for i in batch] # pad x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
x = np.stack(x, 0) if batch[-1] else x[0][None] # stack x = np.stack(x, 0) if n > 1 else x[0][None] # stack
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32 x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
@ -181,7 +189,7 @@ class autoShape(nn.Module):
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
# Post-process # Post-process
for i in batch: for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i]) scale_coords(shape1, y[i][:, :4], shape0[i])
return Detections(imgs, y, self.names) return Detections(imgs, y, self.names)