@@ -87,7 +87,7 @@ To access an up-to-date working environment (with all dependencies including CUD | |||
- **GCP** Deep Learning VM with $300 free credit offer: See our [GCP Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/GCP-Quickstart) | |||
- **Google Colab Notebook** with 12 hours of free GPU time. <a href="https://colab.research.google.com/github/ultralytics/yolov5/blob/master/tutorial.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a> | |||
- **Docker Image** https://hub.docker.com/r/ultralytics/yolov5. See [Docker Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/Docker-Quickstart) | |||
- **Docker Image** https://hub.docker.com/r/ultralytics/yolov5. See [Docker Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/Docker-Quickstart) ![Docker Pulls](https://img.shields.io/docker/pulls/ultralytics/yolov5?logo=docker) | |||
## Citation |
@@ -0,0 +1,90 @@ | |||
"""File for accessing YOLOv5 via PyTorch Hub https://pytorch.org/hub/ | |||
Usage: | |||
import torch | |||
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, channels=3, classes=80) | |||
""" | |||
dependencies = ['torch', 'pyyaml'] | |||
import torch | |||
from models.yolo import Model | |||
from utils import google_utils | |||
def create(name, pretrained, channels, classes): | |||
"""Creates a specified YOLOv5 model | |||
Arguments: | |||
name (str): name of model, i.e. 'yolov5s' | |||
pretrained (bool): load pretrained weights into the model | |||
channels (int): number of input channels | |||
classes (int): number of model classes | |||
Returns: | |||
pytorch model | |||
""" | |||
model = Model('models/%s.yaml' % name, channels, classes) | |||
if pretrained: | |||
ckpt = '%s.pt' % name # checkpoint filename | |||
google_utils.attempt_download(ckpt) # download if not found locally | |||
state_dict = torch.load(ckpt)['model'].state_dict() | |||
state_dict = {k: v for k, v in state_dict if model.state_dict()[k].numel() == v.numel()} # filter | |||
model.load_state_dict(state_dict, strict=False) # load | |||
return model | |||
def yolov5s(pretrained=False, channels=3, classes=80): | |||
"""YOLOv5-small model from https://github.com/ultralytics/yolov5 | |||
Arguments: | |||
pretrained (bool): load pretrained weights into the model, default=False | |||
channels (int): number of input channels, default=3 | |||
classes (int): number of model classes, default=80 | |||
Returns: | |||
pytorch model | |||
""" | |||
return create('yolov5s', pretrained, channels, classes) | |||
def yolov5m(pretrained=False, channels=3, classes=80): | |||
"""YOLOv5-medium model from https://github.com/ultralytics/yolov5 | |||
Arguments: | |||
pretrained (bool): load pretrained weights into the model, default=False | |||
channels (int): number of input channels, default=3 | |||
classes (int): number of model classes, default=80 | |||
Returns: | |||
pytorch model | |||
""" | |||
return create('yolov5m', pretrained, channels, classes) | |||
def yolov5l(pretrained=False, channels=3, classes=80): | |||
"""YOLOv5-large model from https://github.com/ultralytics/yolov5 | |||
Arguments: | |||
pretrained (bool): load pretrained weights into the model, default=False | |||
channels (int): number of input channels, default=3 | |||
classes (int): number of model classes, default=80 | |||
Returns: | |||
pytorch model | |||
""" | |||
return create('yolov5l', pretrained, channels, classes) | |||
def yolov5x(pretrained=False, channels=3, classes=80): | |||
"""YOLOv5-xlarge model from https://github.com/ultralytics/yolov5 | |||
Arguments: | |||
pretrained (bool): load pretrained weights into the model, default=False | |||
channels (int): number of input channels, default=3 | |||
classes (int): number of model classes, default=80 | |||
Returns: | |||
pytorch model | |||
""" | |||
return create('yolov5x', pretrained, channels, classes) |
@@ -1,8 +1,6 @@ | |||
# This file contains modules common to various models | |||
import torch.nn.functional as F | |||
from utils.utils import * | |||
@@ -58,17 +56,6 @@ class BottleneckCSP(nn.Module): | |||
return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1)))) | |||
class ConvPlus(nn.Module): | |||
# Plus-shaped convolution | |||
def __init__(self, c1, c2, k=3, s=1, g=1, bias=True): # ch_in, ch_out, kernel, stride, groups | |||
super(ConvPlus, self).__init__() | |||
self.cv1 = nn.Conv2d(c1, c2, (k, 1), s, (k // 2, 0), groups=g, bias=bias) | |||
self.cv2 = nn.Conv2d(c1, c2, (1, k), s, (0, k // 2), groups=g, bias=bias) | |||
def forward(self, x): | |||
return self.cv1(x) + self.cv2(x) | |||
class SPP(nn.Module): | |||
# Spatial pyramid pooling layer used in YOLOv3-SPP | |||
def __init__(self, c1, c2, k=(5, 9, 13)): | |||
@@ -107,27 +94,3 @@ class Concat(nn.Module): | |||
def forward(self, x): | |||
return torch.cat(x, self.d) | |||
class MixConv2d(nn.Module): | |||
# Mixed Depthwise Conv https://arxiv.org/abs/1907.09595 | |||
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): | |||
super(MixConv2d, self).__init__() | |||
groups = len(k) | |||
if equal_ch: # equal c_ per group | |||
i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices | |||
c_ = [(i == g).sum() for g in range(groups)] # intermediate channels | |||
else: # equal weight.numel() per group | |||
b = [c2] + [0] * groups | |||
a = np.eye(groups + 1, groups, k=-1) | |||
a -= np.roll(a, 1, axis=1) | |||
a *= np.array(k) ** 2 | |||
a[0] = 1 | |||
c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b | |||
self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)]) | |||
self.bn = nn.BatchNorm2d(c2) | |||
self.act = nn.LeakyReLU(0.1, inplace=True) | |||
def forward(self, x): | |||
return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) |
@@ -2,7 +2,7 @@ from models.common import * | |||
class Sum(nn.Module): | |||
# weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 | |||
# Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 | |||
def __init__(self, n, weight=False): # n: number of inputs | |||
super(Sum, self).__init__() | |||
self.weight = weight # apply weights boolean | |||
@@ -23,6 +23,7 @@ class Sum(nn.Module): | |||
class GhostConv(nn.Module): | |||
# Ghost Convolution https://github.com/huawei-noah/ghostnet | |||
def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups | |||
super(GhostConv, self).__init__() | |||
c_ = c2 // 2 # hidden channels | |||
@@ -35,6 +36,7 @@ class GhostConv(nn.Module): | |||
class GhostBottleneck(nn.Module): | |||
# Ghost Bottleneck https://github.com/huawei-noah/ghostnet | |||
def __init__(self, c1, c2, k, s): | |||
super(GhostBottleneck, self).__init__() | |||
c_ = c2 // 2 | |||
@@ -46,3 +48,38 @@ class GhostBottleneck(nn.Module): | |||
def forward(self, x): | |||
return self.conv(x) + self.shortcut(x) | |||
class ConvPlus(nn.Module): | |||
# Plus-shaped convolution | |||
def __init__(self, c1, c2, k=3, s=1, g=1, bias=True): # ch_in, ch_out, kernel, stride, groups | |||
super(ConvPlus, self).__init__() | |||
self.cv1 = nn.Conv2d(c1, c2, (k, 1), s, (k // 2, 0), groups=g, bias=bias) | |||
self.cv2 = nn.Conv2d(c1, c2, (1, k), s, (0, k // 2), groups=g, bias=bias) | |||
def forward(self, x): | |||
return self.cv1(x) + self.cv2(x) | |||
class MixConv2d(nn.Module): | |||
# Mixed Depthwise Conv https://arxiv.org/abs/1907.09595 | |||
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): | |||
super(MixConv2d, self).__init__() | |||
groups = len(k) | |||
if equal_ch: # equal c_ per group | |||
i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices | |||
c_ = [(i == g).sum() for g in range(groups)] # intermediate channels | |||
else: # equal weight.numel() per group | |||
b = [c2] + [0] * groups | |||
a = np.eye(groups + 1, groups, k=-1) | |||
a -= np.roll(a, 1, axis=1) | |||
a *= np.array(k) ** 2 | |||
a[0] = 1 | |||
c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b | |||
self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)]) | |||
self.bn = nn.BatchNorm2d(c2) | |||
self.act = nn.LeakyReLU(0.1, inplace=True) | |||
def forward(self, x): | |||
return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) |
@@ -2,7 +2,7 @@ import argparse | |||
import yaml | |||
from models.common import * | |||
from models.experimental import * | |||
class Detect(nn.Module): | |||
@@ -56,12 +56,12 @@ class Model(nn.Module): | |||
# Define model | |||
if nc: | |||
self.md['nc'] = nc # override yaml value | |||
self.model, self.save, ch = parse_model(self.md, ch=[ch]) # model, savelist, ch_out | |||
# print([x.shape for x in self.forward(torch.zeros(1, 3, 64, 64))]) | |||
self.model, self.save = parse_model(self.md, ch=[ch]) # model, savelist, ch_out | |||
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))]) | |||
# Build strides, anchors | |||
m = self.model[-1] # Detect() | |||
m.stride = torch.tensor([64 / x.shape[-2] for x in self.forward(torch.zeros(1, 3, 64, 64))]) # forward | |||
m.stride = torch.tensor([64 / x.shape[-2] for x in self.forward(torch.zeros(1, ch, 64, 64))]) # forward | |||
m.anchors /= m.stride.view(-1, 1, 1) | |||
self.stride = m.stride | |||
@@ -200,7 +200,7 @@ def parse_model(md, ch): # model_dict, input_channels(3) | |||
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist | |||
layers.append(m_) | |||
ch.append(c2) | |||
return nn.Sequential(*layers), sorted(save), ch | |||
return nn.Sequential(*layers), sorted(save) | |||
if __name__ == '__main__': |
@@ -1,55 +0,0 @@ | |||
# parameters | |||
nc: 80 # number of classes | |||
depth_multiple: 1.0 # expand model depth | |||
width_multiple: 1.0 # expand layer channels | |||
# anchors | |||
anchors: | |||
- [10,13, 16,30, 33,23] # P3/8 | |||
- [30,61, 62,45, 59,119] # P4/16 | |||
- [116,90, 156,198, 373,326] # P5/32 | |||
# darknet53 backbone | |||
backbone: | |||
# [from, number, module, args] | |||
[[-1, 1, Conv, [32, 3, 1]], # 0 | |||
[-1, 1, Conv, [64, 3, 2]], # 1-P1/2 | |||
[-1, 1, BottleneckCSP, [64]], | |||
[-1, 1, Conv, [128, 3, 2]], # 3-P2/4 | |||
[-1, 2, BottleneckCSP, [128]], | |||
[-1, 1, Conv, [256, 3, 2]], # 5-P3/8 | |||
[-1, 8, BottleneckCSP, [256]], | |||
[-1, 1, Conv, [512, 3, 2]], # 7-P4/16 | |||
[-1, 8, BottleneckCSP, [512]], | |||
[-1, 1, Conv, [1024, 3, 2]], # 9-P5/32 | |||
[-1, 4, BottleneckCSP, [1024]], # 10 | |||
] | |||
# yolov3-spp head | |||
# na = len(anchors[0]) | |||
head: | |||
[[-1, 1, Bottleneck, [1024, False]], # 11 | |||
[-1, 1, SPP, [512, [5, 9, 13]]], | |||
[-1, 1, Conv, [1024, 3, 1]], | |||
[-1, 1, Conv, [512, 1, 1]], | |||
[-1, 1, Conv, [1024, 3, 1]], | |||
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 16 (P5/32-large) | |||
[-3, 1, Conv, [256, 1, 1]], | |||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], | |||
[[-1, 8], 1, Concat, [1]], # cat backbone P4 | |||
[-1, 1, Bottleneck, [512, False]], | |||
[-1, 1, Bottleneck, [512, False]], | |||
[-1, 1, Conv, [256, 1, 1]], | |||
[-1, 1, Conv, [512, 3, 1]], | |||
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 24 (P4/16-medium) | |||
[-3, 1, Conv, [128, 1, 1]], | |||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], | |||
[[-1, 6], 1, Concat, [1]], # cat backbone P3 | |||
[-1, 1, Bottleneck, [256, False]], | |||
[-1, 2, Bottleneck, [256, False]], | |||
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 30 (P3/8-small) | |||
[[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) | |||
] |