* Add autoshape parameter * Remove autoshape call in ReadMe * Update hubconf.py * file/URI inputs and autoshape check passthrough Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>5.0
@@ -106,7 +106,7 @@ import torch | |||
from PIL import Image | |||
# Model | |||
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True).autoshape() # for PIL/cv2/np inputs and NMS | |||
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) # for PIL/cv2/np inputs and NMS | |||
# Images | |||
img1 = Image.open('zidane.jpg') |
@@ -17,7 +17,7 @@ dependencies = ['torch', 'yaml'] | |||
set_logging() | |||
def create(name, pretrained, channels, classes): | |||
def create(name, pretrained, channels, classes, autoshape): | |||
"""Creates a specified YOLOv5 model | |||
Arguments: | |||
@@ -41,7 +41,8 @@ def create(name, pretrained, channels, classes): | |||
model.load_state_dict(state_dict, strict=False) # load | |||
if len(ckpt['model'].names) == classes: | |||
model.names = ckpt['model'].names # set class names attribute | |||
# model = model.autoshape() # for PIL/cv2/np inputs and NMS | |||
if autoshape: | |||
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS | |||
return model | |||
except Exception as e: | |||
@@ -50,7 +51,7 @@ def create(name, pretrained, channels, classes): | |||
raise Exception(s) from e | |||
def yolov5s(pretrained=False, channels=3, classes=80): | |||
def yolov5s(pretrained=False, channels=3, classes=80, autoshape=True): | |||
"""YOLOv5-small model from https://github.com/ultralytics/yolov5 | |||
Arguments: | |||
@@ -61,10 +62,10 @@ def yolov5s(pretrained=False, channels=3, classes=80): | |||
Returns: | |||
pytorch model | |||
""" | |||
return create('yolov5s', pretrained, channels, classes) | |||
return create('yolov5s', pretrained, channels, classes, autoshape) | |||
def yolov5m(pretrained=False, channels=3, classes=80): | |||
def yolov5m(pretrained=False, channels=3, classes=80, autoshape=True): | |||
"""YOLOv5-medium model from https://github.com/ultralytics/yolov5 | |||
Arguments: | |||
@@ -75,10 +76,10 @@ def yolov5m(pretrained=False, channels=3, classes=80): | |||
Returns: | |||
pytorch model | |||
""" | |||
return create('yolov5m', pretrained, channels, classes) | |||
return create('yolov5m', pretrained, channels, classes, autoshape) | |||
def yolov5l(pretrained=False, channels=3, classes=80): | |||
def yolov5l(pretrained=False, channels=3, classes=80, autoshape=True): | |||
"""YOLOv5-large model from https://github.com/ultralytics/yolov5 | |||
Arguments: | |||
@@ -89,10 +90,10 @@ def yolov5l(pretrained=False, channels=3, classes=80): | |||
Returns: | |||
pytorch model | |||
""" | |||
return create('yolov5l', pretrained, channels, classes) | |||
return create('yolov5l', pretrained, channels, classes, autoshape) | |||
def yolov5x(pretrained=False, channels=3, classes=80): | |||
def yolov5x(pretrained=False, channels=3, classes=80, autoshape=True): | |||
"""YOLOv5-xlarge model from https://github.com/ultralytics/yolov5 | |||
Arguments: | |||
@@ -103,10 +104,10 @@ def yolov5x(pretrained=False, channels=3, classes=80): | |||
Returns: | |||
pytorch model | |||
""" | |||
return create('yolov5x', pretrained, channels, classes) | |||
return create('yolov5x', pretrained, channels, classes, autoshape) | |||
def custom(path_or_model='path/to/model.pt'): | |||
def custom(path_or_model='path/to/model.pt', autoshape=True): | |||
"""YOLOv5-custom model from https://github.com/ultralytics/yolov5 | |||
Arguments (3 options): | |||
@@ -124,13 +125,12 @@ def custom(path_or_model='path/to/model.pt'): | |||
hub_model = Model(model.yaml).to(next(model.parameters()).device) # create | |||
hub_model.load_state_dict(model.float().state_dict()) # load state_dict | |||
hub_model.names = model.names # class names | |||
return hub_model | |||
return hub_model.autoshape() if autoshape else hub_model | |||
if __name__ == '__main__': | |||
model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # pretrained example | |||
model = create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True) # pretrained example | |||
# model = custom(path_or_model='path/to/model.pt') # custom example | |||
model = model.autoshape() # for PIL/cv2/np inputs and NMS | |||
# Verify inference | |||
from PIL import Image |
@@ -2,6 +2,7 @@ | |||
import math | |||
import numpy as np | |||
import requests | |||
import torch | |||
import torch.nn as nn | |||
from PIL import Image, ImageDraw | |||
@@ -143,35 +144,42 @@ class autoShape(nn.Module): | |||
super(autoShape, self).__init__() | |||
self.model = model.eval() | |||
def autoshape(self): | |||
print('autoShape already enabled, skipping... ') # model already converted to model.autoshape() | |||
return self | |||
def forward(self, imgs, size=640, augment=False, profile=False): | |||
# supports inference from various sources. For height=720, width=1280, RGB images example inputs are: | |||
# opencv: imgs = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3) | |||
# PIL: imgs = Image.open('image.jpg') # HWC x(720,1280,3) | |||
# numpy: imgs = np.zeros((720,1280,3)) # HWC | |||
# torch: imgs = torch.zeros(16,3,720,1280) # BCHW | |||
# multiple: imgs = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images | |||
# Inference from various sources. For height=720, width=1280, RGB images example inputs are: | |||
# filename: imgs = 'data/samples/zidane.jpg' | |||
# URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg' | |||
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3) | |||
# PIL: = Image.open('image.jpg') # HWC x(720,1280,3) | |||
# numpy: = np.zeros((720,1280,3)) # HWC | |||
# torch: = torch.zeros(16,3,720,1280) # BCHW | |||
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images | |||
p = next(self.model.parameters()) # for device and type | |||
if isinstance(imgs, torch.Tensor): # torch | |||
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference | |||
# Pre-process | |||
if not isinstance(imgs, list): | |||
imgs = [imgs] | |||
n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images | |||
shape0, shape1 = [], [] # image and inference shapes | |||
batch = range(len(imgs)) # batch size | |||
for i in batch: | |||
imgs[i] = np.array(imgs[i]) # to numpy | |||
if imgs[i].shape[0] < 5: # image in CHW | |||
imgs[i] = imgs[i].transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1) | |||
imgs[i] = imgs[i][:, :, :3] if imgs[i].ndim == 3 else np.tile(imgs[i][:, :, None], 3) # enforce 3ch input | |||
s = imgs[i].shape[:2] # HWC | |||
for i, im in enumerate(imgs): | |||
if isinstance(im, str): # filename or uri | |||
im = Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im) # open | |||
im = np.array(im) # to numpy | |||
if im.shape[0] < 5: # image in CHW | |||
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1) | |||
im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input | |||
s = im.shape[:2] # HWC | |||
shape0.append(s) # image shape | |||
g = (size / max(s)) # gain | |||
shape1.append([y * g for y in s]) | |||
imgs[i] = im # update | |||
shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape | |||
x = [letterbox(imgs[i], new_shape=shape1, auto=False)[0] for i in batch] # pad | |||
x = np.stack(x, 0) if batch[-1] else x[0][None] # stack | |||
x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad | |||
x = np.stack(x, 0) if n > 1 else x[0][None] # stack | |||
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW | |||
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32 | |||
@@ -181,7 +189,7 @@ class autoShape(nn.Module): | |||
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS | |||
# Post-process | |||
for i in batch: | |||
for i in range(n): | |||
scale_coords(shape1, y[i][:, :4], shape0[i]) | |||
return Detections(imgs, y, self.names) |