Browse Source

PyTorch Hub updates

5.0
Glenn Jocher 4 years ago
parent
commit
a814720403
6 changed files with 134 additions and 99 deletions
  1. +1
    -1
      README.md
  2. +90
    -0
      hubconf.py
  3. +0
    -37
      models/common.py
  4. +38
    -1
      models/experimental.py
  5. +5
    -5
      models/yolo.py
  6. +0
    -55
      models/yolov3-spp_csp.yaml

+ 1
- 1
README.md View File

@@ -87,7 +87,7 @@ To access an up-to-date working environment (with all dependencies including CUD

- **GCP** Deep Learning VM with $300 free credit offer: See our [GCP Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/GCP-Quickstart)
- **Google Colab Notebook** with 12 hours of free GPU time. <a href="https://colab.research.google.com/github/ultralytics/yolov5/blob/master/tutorial.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
- **Docker Image** https://hub.docker.com/r/ultralytics/yolov5. See [Docker Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/Docker-Quickstart)
- **Docker Image** https://hub.docker.com/r/ultralytics/yolov5. See [Docker Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/Docker-Quickstart) ![Docker Pulls](https://img.shields.io/docker/pulls/ultralytics/yolov5?logo=docker)


## Citation

+ 90
- 0
hubconf.py View File

@@ -0,0 +1,90 @@
"""File for accessing YOLOv5 via PyTorch Hub https://pytorch.org/hub/

Usage:
import torch
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, channels=3, classes=80)
"""

dependencies = ['torch', 'pyyaml']
import torch

from models.yolo import Model
from utils import google_utils


def create(name, pretrained, channels, classes):
"""Creates a specified YOLOv5 model

Arguments:
name (str): name of model, i.e. 'yolov5s'
pretrained (bool): load pretrained weights into the model
channels (int): number of input channels
classes (int): number of model classes

Returns:
pytorch model
"""
model = Model('models/%s.yaml' % name, channels, classes)
if pretrained:
ckpt = '%s.pt' % name # checkpoint filename
google_utils.attempt_download(ckpt) # download if not found locally
state_dict = torch.load(ckpt)['model'].state_dict()
state_dict = {k: v for k, v in state_dict if model.state_dict()[k].numel() == v.numel()} # filter
model.load_state_dict(state_dict, strict=False) # load
return model


def yolov5s(pretrained=False, channels=3, classes=80):
"""YOLOv5-small model from https://github.com/ultralytics/yolov5

Arguments:
pretrained (bool): load pretrained weights into the model, default=False
channels (int): number of input channels, default=3
classes (int): number of model classes, default=80

Returns:
pytorch model
"""
return create('yolov5s', pretrained, channels, classes)


def yolov5m(pretrained=False, channels=3, classes=80):
"""YOLOv5-medium model from https://github.com/ultralytics/yolov5

Arguments:
pretrained (bool): load pretrained weights into the model, default=False
channels (int): number of input channels, default=3
classes (int): number of model classes, default=80

Returns:
pytorch model
"""
return create('yolov5m', pretrained, channels, classes)


def yolov5l(pretrained=False, channels=3, classes=80):
"""YOLOv5-large model from https://github.com/ultralytics/yolov5

Arguments:
pretrained (bool): load pretrained weights into the model, default=False
channels (int): number of input channels, default=3
classes (int): number of model classes, default=80

Returns:
pytorch model
"""
return create('yolov5l', pretrained, channels, classes)


def yolov5x(pretrained=False, channels=3, classes=80):
"""YOLOv5-xlarge model from https://github.com/ultralytics/yolov5

Arguments:
pretrained (bool): load pretrained weights into the model, default=False
channels (int): number of input channels, default=3
classes (int): number of model classes, default=80

Returns:
pytorch model
"""
return create('yolov5x', pretrained, channels, classes)

+ 0
- 37
models/common.py View File

@@ -1,8 +1,6 @@
# This file contains modules common to various models


import torch.nn.functional as F

from utils.utils import *


@@ -58,17 +56,6 @@ class BottleneckCSP(nn.Module):
return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))


class ConvPlus(nn.Module):
# Plus-shaped convolution
def __init__(self, c1, c2, k=3, s=1, g=1, bias=True): # ch_in, ch_out, kernel, stride, groups
super(ConvPlus, self).__init__()
self.cv1 = nn.Conv2d(c1, c2, (k, 1), s, (k // 2, 0), groups=g, bias=bias)
self.cv2 = nn.Conv2d(c1, c2, (1, k), s, (0, k // 2), groups=g, bias=bias)

def forward(self, x):
return self.cv1(x) + self.cv2(x)


class SPP(nn.Module):
# Spatial pyramid pooling layer used in YOLOv3-SPP
def __init__(self, c1, c2, k=(5, 9, 13)):
@@ -107,27 +94,3 @@ class Concat(nn.Module):

def forward(self, x):
return torch.cat(x, self.d)


class MixConv2d(nn.Module):
# Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
super(MixConv2d, self).__init__()
groups = len(k)
if equal_ch: # equal c_ per group
i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices
c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
else: # equal weight.numel() per group
b = [c2] + [0] * groups
a = np.eye(groups + 1, groups, k=-1)
a -= np.roll(a, 1, axis=1)
a *= np.array(k) ** 2
a[0] = 1
c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b

self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
self.bn = nn.BatchNorm2d(c2)
self.act = nn.LeakyReLU(0.1, inplace=True)

def forward(self, x):
return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))

+ 38
- 1
models/experimental.py View File

@@ -2,7 +2,7 @@ from models.common import *


class Sum(nn.Module):
# weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
# Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
def __init__(self, n, weight=False): # n: number of inputs
super(Sum, self).__init__()
self.weight = weight # apply weights boolean
@@ -23,6 +23,7 @@ class Sum(nn.Module):


class GhostConv(nn.Module):
# Ghost Convolution https://github.com/huawei-noah/ghostnet
def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
super(GhostConv, self).__init__()
c_ = c2 // 2 # hidden channels
@@ -35,6 +36,7 @@ class GhostConv(nn.Module):


class GhostBottleneck(nn.Module):
# Ghost Bottleneck https://github.com/huawei-noah/ghostnet
def __init__(self, c1, c2, k, s):
super(GhostBottleneck, self).__init__()
c_ = c2 // 2
@@ -46,3 +48,38 @@ class GhostBottleneck(nn.Module):

def forward(self, x):
return self.conv(x) + self.shortcut(x)


class ConvPlus(nn.Module):
# Plus-shaped convolution
def __init__(self, c1, c2, k=3, s=1, g=1, bias=True): # ch_in, ch_out, kernel, stride, groups
super(ConvPlus, self).__init__()
self.cv1 = nn.Conv2d(c1, c2, (k, 1), s, (k // 2, 0), groups=g, bias=bias)
self.cv2 = nn.Conv2d(c1, c2, (1, k), s, (0, k // 2), groups=g, bias=bias)

def forward(self, x):
return self.cv1(x) + self.cv2(x)


class MixConv2d(nn.Module):
# Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
super(MixConv2d, self).__init__()
groups = len(k)
if equal_ch: # equal c_ per group
i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices
c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
else: # equal weight.numel() per group
b = [c2] + [0] * groups
a = np.eye(groups + 1, groups, k=-1)
a -= np.roll(a, 1, axis=1)
a *= np.array(k) ** 2
a[0] = 1
c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b

self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
self.bn = nn.BatchNorm2d(c2)
self.act = nn.LeakyReLU(0.1, inplace=True)

def forward(self, x):
return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))

+ 5
- 5
models/yolo.py View File

@@ -2,7 +2,7 @@ import argparse

import yaml

from models.common import *
from models.experimental import *


class Detect(nn.Module):
@@ -56,12 +56,12 @@ class Model(nn.Module):
# Define model
if nc:
self.md['nc'] = nc # override yaml value
self.model, self.save, ch = parse_model(self.md, ch=[ch]) # model, savelist, ch_out
# print([x.shape for x in self.forward(torch.zeros(1, 3, 64, 64))])
self.model, self.save = parse_model(self.md, ch=[ch]) # model, savelist, ch_out
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])

# Build strides, anchors
m = self.model[-1] # Detect()
m.stride = torch.tensor([64 / x.shape[-2] for x in self.forward(torch.zeros(1, 3, 64, 64))]) # forward
m.stride = torch.tensor([64 / x.shape[-2] for x in self.forward(torch.zeros(1, ch, 64, 64))]) # forward
m.anchors /= m.stride.view(-1, 1, 1)
self.stride = m.stride

@@ -200,7 +200,7 @@ def parse_model(md, ch): # model_dict, input_channels(3)
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
layers.append(m_)
ch.append(c2)
return nn.Sequential(*layers), sorted(save), ch
return nn.Sequential(*layers), sorted(save)


if __name__ == '__main__':

+ 0
- 55
models/yolov3-spp_csp.yaml View File

@@ -1,55 +0,0 @@
# parameters
nc: 80 # number of classes
depth_multiple: 1.0 # expand model depth
width_multiple: 1.0 # expand layer channels

# anchors
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32

# darknet53 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [32, 3, 1]], # 0
[-1, 1, Conv, [64, 3, 2]], # 1-P1/2
[-1, 1, BottleneckCSP, [64]],
[-1, 1, Conv, [128, 3, 2]], # 3-P2/4
[-1, 2, BottleneckCSP, [128]],
[-1, 1, Conv, [256, 3, 2]], # 5-P3/8
[-1, 8, BottleneckCSP, [256]],
[-1, 1, Conv, [512, 3, 2]], # 7-P4/16
[-1, 8, BottleneckCSP, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
[-1, 4, BottleneckCSP, [1024]], # 10
]

# yolov3-spp head
# na = len(anchors[0])
head:
[[-1, 1, Bottleneck, [1024, False]], # 11
[-1, 1, SPP, [512, [5, 9, 13]]],
[-1, 1, Conv, [1024, 3, 1]],
[-1, 1, Conv, [512, 1, 1]],
[-1, 1, Conv, [1024, 3, 1]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 16 (P5/32-large)

[-3, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 8], 1, Concat, [1]], # cat backbone P4
[-1, 1, Bottleneck, [512, False]],
[-1, 1, Bottleneck, [512, False]],
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, Conv, [512, 3, 1]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 24 (P4/16-medium)

[-3, 1, Conv, [128, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P3
[-1, 1, Bottleneck, [256, False]],
[-1, 2, Bottleneck, [256, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 30 (P3/8-small)

[[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]

Loading…
Cancel
Save