Browse Source

Module `super().__init__()` (#4065)

* Module `super().__init__()`

* remove NMS
modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
b1be685005
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 47 deletions
  1. +18
    -24
      models/common.py
  2. +6
    -6
      models/experimental.py
  3. +3
    -17
      models/yolo.py

+ 18
- 24
models/common.py View File

class Conv(nn.Module): class Conv(nn.Module):
# Standard convolution # Standard convolution
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super(Conv, self).__init__()
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
self.bn = nn.BatchNorm2d(c2) self.bn = nn.BatchNorm2d(c2)
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
class Bottleneck(nn.Module): class Bottleneck(nn.Module):
# Standard bottleneck # Standard bottleneck
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
super(Bottleneck, self).__init__()
super().__init__()
c_ = int(c2 * e) # hidden channels c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1) self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_, c2, 3, 1, g=g) self.cv2 = Conv(c_, c2, 3, 1, g=g)
class BottleneckCSP(nn.Module): class BottleneckCSP(nn.Module):
# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super(BottleneckCSP, self).__init__()
super().__init__()
c_ = int(c2 * e) # hidden channels c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1) self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False) self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
class C3(nn.Module): class C3(nn.Module):
# CSP Bottleneck with 3 convolutions # CSP Bottleneck with 3 convolutions
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super(C3, self).__init__()
super().__init__()
c_ = int(c2 * e) # hidden channels c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1) self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c1, c_, 1, 1)
self.m = TransformerBlock(c_, c_, 4, n) self.m = TransformerBlock(c_, c_, 4, n)




class C3SPP(C3):
# C3 module with SPP()
def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5):
super().__init__(c1, c2, n, shortcut, g, e)
c_ = int(c2 * e)
self.m = SPP(c_, c_, k)


class SPP(nn.Module): class SPP(nn.Module):
# Spatial pyramid pooling layer used in YOLOv3-SPP # Spatial pyramid pooling layer used in YOLOv3-SPP
def __init__(self, c1, c2, k=(5, 9, 13)): def __init__(self, c1, c2, k=(5, 9, 13)):
super(SPP, self).__init__()
super().__init__()
c_ = c1 // 2 # hidden channels c_ = c1 // 2 # hidden channels
self.cv1 = Conv(c1, c_, 1, 1) self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
class Focus(nn.Module): class Focus(nn.Module):
# Focus wh information into c-space # Focus wh information into c-space
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super(Focus, self).__init__()
super().__init__()
self.conv = Conv(c1 * 4, c2, k, s, p, g, act) self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
# self.contract = Contract(gain=2) # self.contract = Contract(gain=2)


class Concat(nn.Module): class Concat(nn.Module):
# Concatenate a list of tensors along dimension # Concatenate a list of tensors along dimension
def __init__(self, dimension=1): def __init__(self, dimension=1):
super(Concat, self).__init__()
super().__init__()
self.d = dimension self.d = dimension


def forward(self, x): def forward(self, x):
return torch.cat(x, self.d) return torch.cat(x, self.d)




class NMS(nn.Module):
# Non-Maximum Suppression (NMS) module
conf = 0.25 # confidence threshold
iou = 0.45 # IoU threshold
classes = None # (optional list) filter by class
max_det = 1000 # maximum number of detections per image

def __init__(self):
super(NMS, self).__init__()

def forward(self, x):
return non_max_suppression(x[0], self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det)


class AutoShape(nn.Module): class AutoShape(nn.Module):
# YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS # YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
conf = 0.25 # NMS confidence threshold conf = 0.25 # NMS confidence threshold
max_det = 1000 # maximum number of detections per image max_det = 1000 # maximum number of detections per image


def __init__(self, model): def __init__(self, model):
super(AutoShape, self).__init__()
super().__init__()
self.model = model.eval() self.model = model.eval()


def autoshape(self): def autoshape(self):
class Detections: class Detections:
# YOLOv5 detections class for inference results # YOLOv5 detections class for inference results
def __init__(self, imgs, pred, files, times=None, names=None, shape=None): def __init__(self, imgs, pred, files, times=None, names=None, shape=None):
super(Detections, self).__init__()
super().__init__()
d = pred[0].device # device d = pred[0].device # device
gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
self.imgs = imgs # list of images as numpy arrays self.imgs = imgs # list of images as numpy arrays
class Classify(nn.Module): class Classify(nn.Module):
# Classification head, i.e. x(b,c1,20,20) to x(b,c2) # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
super(Classify, self).__init__()
super().__init__()
self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1) self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1) self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
self.flat = nn.Flatten() self.flat = nn.Flatten()

+ 6
- 6
models/experimental.py View File

# Cross Convolution Downsample # Cross Convolution Downsample
def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
# ch_in, ch_out, kernel, stride, groups, expansion, shortcut # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
super(CrossConv, self).__init__()
super().__init__()
c_ = int(c2 * e) # hidden channels c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, (1, k), (1, s)) self.cv1 = Conv(c1, c_, (1, k), (1, s))
self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g) self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
class Sum(nn.Module): class Sum(nn.Module):
# Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
def __init__(self, n, weight=False): # n: number of inputs def __init__(self, n, weight=False): # n: number of inputs
super(Sum, self).__init__()
super().__init__()
self.weight = weight # apply weights boolean self.weight = weight # apply weights boolean
self.iter = range(n - 1) # iter object self.iter = range(n - 1) # iter object
if weight: if weight:
class GhostConv(nn.Module): class GhostConv(nn.Module):
# Ghost Convolution https://github.com/huawei-noah/ghostnet # Ghost Convolution https://github.com/huawei-noah/ghostnet
def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
super(GhostConv, self).__init__()
super().__init__()
c_ = c2 // 2 # hidden channels c_ = c2 // 2 # hidden channels
self.cv1 = Conv(c1, c_, k, s, None, g, act) self.cv1 = Conv(c1, c_, k, s, None, g, act)
self.cv2 = Conv(c_, c_, 5, 1, None, c_, act) self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
class GhostBottleneck(nn.Module): class GhostBottleneck(nn.Module):
# Ghost Bottleneck https://github.com/huawei-noah/ghostnet # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
super(GhostBottleneck, self).__init__()
super().__init__()
c_ = c2 // 2 c_ = c2 // 2
self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
class MixConv2d(nn.Module): class MixConv2d(nn.Module):
# Mixed Depthwise Conv https://arxiv.org/abs/1907.09595 # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
super(MixConv2d, self).__init__()
super().__init__()
groups = len(k) groups = len(k)
if equal_ch: # equal c_ per group if equal_ch: # equal c_ per group
i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices
class Ensemble(nn.ModuleList): class Ensemble(nn.ModuleList):
# Ensemble of models # Ensemble of models
def __init__(self): def __init__(self):
super(Ensemble, self).__init__()
super().__init__()


def forward(self, x, augment=False, profile=False, visualize=False): def forward(self, x, augment=False, profile=False, visualize=False):
y = [] y = []

+ 3
- 17
models/yolo.py View File

onnx_dynamic = False # ONNX export parameter onnx_dynamic = False # ONNX export parameter


def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
super(Detect, self).__init__()
super().__init__()
self.nc = nc # number of classes self.nc = nc # number of classes
self.no = nc + 5 # number of outputs per anchor self.no = nc + 5 # number of outputs per anchor
self.nl = len(anchors) # number of detection layers self.nl = len(anchors) # number of detection layers


class Model(nn.Module): class Model(nn.Module):
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
super(Model, self).__init__()
super().__init__()
if isinstance(cfg, dict): if isinstance(cfg, dict):
self.yaml = cfg # model dict self.yaml = cfg # model dict
else: # is *.yaml else: # is *.yaml
self.info() self.info()
return self return self


def nms(self, mode=True): # add or remove NMS module
present = type(self.model[-1]) is NMS # last layer is NMS
if mode and not present:
LOGGER.info('Adding NMS... ')
m = NMS() # module
m.f = -1 # from
m.i = self.model[-1].i + 1 # index
self.model.add_module(name='%s' % m.i, module=m) # add
self.eval()
elif not mode and present:
LOGGER.info('Removing NMS... ')
self.model = self.model[:-1] # remove
return self

def autoshape(self): # add AutoShape module def autoshape(self): # add AutoShape module
LOGGER.info('Adding AutoShape... ') LOGGER.info('Adding AutoShape... ')
m = AutoShape(self) # wrap model m = AutoShape(self) # wrap model


n = max(round(n * gd), 1) if n > 1 else n # depth gain n = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,
C3, C3TR]:
C3, C3TR, C3SPP]:
c1, c2 = ch[f], args[0] c1, c2 = ch[f], args[0]
if c2 != no: # if not output if c2 != no: # if not output
c2 = make_divisible(c2 * gw, 8) c2 = make_divisible(c2 * gw, 8)

Loading…
Cancel
Save