You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

experimental.py 5.2KB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
Merge `develop` branch into `master` (#3518) * update ci-testing.yml (#3322) * update ci-testing.yml * update greetings.yml * bring back os matrix * update ci-testing.yml (#3322) * update ci-testing.yml * update greetings.yml * bring back os matrix * Enable direct `--weights URL` definition (#3373) * Enable direct `--weights URL` definition @KalenMike this PR will enable direct --weights URL definition. Example use case: ``` python train.py --weights https://storage.googleapis.com/bucket/dir/model.pt ``` * cleanup * bug fixes * weights = attempt_download(weights) * Update experimental.py * Update hubconf.py * return bug fix * comment mirror * min_bytes * Update tutorial.ipynb (#3368) add Open in Kaggle badge * `cv2.imread(img, -1)` for IMREAD_UNCHANGED (#3379) * Update datasets.py * comment Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> * COCO evolution fix (#3388) * COCO evolution fix * cleanup * update print * print fix * Create `is_pip()` function (#3391) Returns `True` if file is part of pip package. Useful for contextual behavior modification. ```python def is_pip(): # Is file in a pip package? return 'site-packages' in Path(__file__).absolute().parts ``` * Revert "`cv2.imread(img, -1)` for IMREAD_UNCHANGED (#3379)" (#3395) This reverts commit 21a9607e00f1365b21d8c4bd81bdbf5fc0efea24. * Update FLOPs description (#3422) * Update README.md * Changing FLOPS to FLOPs. Co-authored-by: BuildTools <unconfigured@null.spigotmc.org> * Parse URL authentication (#3424) * Parse URL authentication * urllib.parse.unquote() * improved error handling * improved error handling * remove %3F * update check_file() * Add FLOPs title to table (#3453) * Suppress jit trace warning + graph once (#3454) * Suppress jit trace warning + graph once Suppress harmless jit trace warning on TensorBoard add_graph call. Also fix multiple add_graph() calls bug, now only on batch 0. * Update train.py * Update MixUp augmentation `alpha=beta=32.0` (#3455) Per VOC empirical results https://github.com/ultralytics/yolov5/issues/3380#issuecomment-853001307 by @developer0hye * Add `timeout()` class (#3460) * Add `timeout()` class * rearrange order * Faster HSV augmentation (#3462) remove datatype conversion process that can be skipped * Add `check_git_status()` 5 second timeout (#3464) * Add check_git_status() 5 second timeout This should prevent the SSH Git bug that we were discussing @KalenMike * cleanup * replace timeout with check_output built-in timeout * Improved `check_requirements()` offline-handling (#3466) Improve robustness of `check_requirements()` function to offline environments (do not attempt pip installs when offline). * Add `output_names` argument for ONNX export with dynamic axes (#3456) * Add output names & dynamic axes for onnx export Add output_names and dynamic_axes names for all outputs in torch.onnx.export. The first four outputs of the model will have names output0, output1, output2, output3 * use first output only + cleanup Co-authored-by: Samridha Shrestha <samridha.shrestha@g42.ai> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Revert FP16 `test.py` and `detect.py` inference to FP32 default (#3423) * fixed inference bug ,while use half precision * replace --use-half with --half * replace space and PEP8 in detect.py * PEP8 detect.py * update --half help comment * Update test.py * revert space Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Add additional links/resources to stale.yml message (#3467) * Update stale.yml * cleanup * Update stale.yml * reformat * Update stale.yml HUB URL (#3468) * Stale `github.actor` bug fix (#3483) * Explicit `model.eval()` call `if opt.train=False` (#3475) * call model.eval() when opt.train is False call model.eval() when opt.train is False * single-line if statement * cleanup Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> * check_requirements() exclude `opencv-python` (#3495) Fix for 3rd party or contrib versions of installed OpenCV as in https://github.com/ultralytics/yolov5/issues/3494. * Earlier `assert` for cpu and half option (#3508) * early assert for cpu and half option early assert for cpu and half option * Modified comment Modified comment * Update tutorial.ipynb (#3510) * Reduce test.py results spacing (#3511) * Update README.md (#3512) * Update README.md Minor modifications * 850 width * Update greetings.yml revert greeting change as PRs will now merge to master. Co-authored-by: Piotr Skalski <SkalskiP@users.noreply.github.com> Co-authored-by: SkalskiP <piotr.skalski92@gmail.com> Co-authored-by: Peretz Cohen <pizzaz93@users.noreply.github.com> Co-authored-by: tudoulei <34886368+tudoulei@users.noreply.github.com> Co-authored-by: chocosaj <chocosaj@users.noreply.github.com> Co-authored-by: BuildTools <unconfigured@null.spigotmc.org> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: Sam_S <SamSamhuns@users.noreply.github.com> Co-authored-by: Samridha Shrestha <samridha.shrestha@g42.ai> Co-authored-by: edificewang <609552430@qq.com>
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # YOLOv5 experimental modules
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. from models.common import Conv, DWConv
  6. from utils.google_utils import attempt_download
  7. class CrossConv(nn.Module):
  8. # Cross Convolution Downsample
  9. def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
  10. # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
  11. super(CrossConv, self).__init__()
  12. c_ = int(c2 * e) # hidden channels
  13. self.cv1 = Conv(c1, c_, (1, k), (1, s))
  14. self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
  15. self.add = shortcut and c1 == c2
  16. def forward(self, x):
  17. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  18. class Sum(nn.Module):
  19. # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
  20. def __init__(self, n, weight=False): # n: number of inputs
  21. super(Sum, self).__init__()
  22. self.weight = weight # apply weights boolean
  23. self.iter = range(n - 1) # iter object
  24. if weight:
  25. self.w = nn.Parameter(-torch.arange(1., n) / 2, requires_grad=True) # layer weights
  26. def forward(self, x):
  27. y = x[0] # no weight
  28. if self.weight:
  29. w = torch.sigmoid(self.w) * 2
  30. for i in self.iter:
  31. y = y + x[i + 1] * w[i]
  32. else:
  33. for i in self.iter:
  34. y = y + x[i + 1]
  35. return y
  36. class GhostConv(nn.Module):
  37. # Ghost Convolution https://github.com/huawei-noah/ghostnet
  38. def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
  39. super(GhostConv, self).__init__()
  40. c_ = c2 // 2 # hidden channels
  41. self.cv1 = Conv(c1, c_, k, s, None, g, act)
  42. self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
  43. def forward(self, x):
  44. y = self.cv1(x)
  45. return torch.cat([y, self.cv2(y)], 1)
  46. class GhostBottleneck(nn.Module):
  47. # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
  48. def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
  49. super(GhostBottleneck, self).__init__()
  50. c_ = c2 // 2
  51. self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
  52. DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
  53. GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
  54. self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
  55. Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
  56. def forward(self, x):
  57. return self.conv(x) + self.shortcut(x)
  58. class MixConv2d(nn.Module):
  59. # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
  60. def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
  61. super(MixConv2d, self).__init__()
  62. groups = len(k)
  63. if equal_ch: # equal c_ per group
  64. i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices
  65. c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
  66. else: # equal weight.numel() per group
  67. b = [c2] + [0] * groups
  68. a = np.eye(groups + 1, groups, k=-1)
  69. a -= np.roll(a, 1, axis=1)
  70. a *= np.array(k) ** 2
  71. a[0] = 1
  72. c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
  73. self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
  74. self.bn = nn.BatchNorm2d(c2)
  75. self.act = nn.LeakyReLU(0.1, inplace=True)
  76. def forward(self, x):
  77. return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
  78. class Ensemble(nn.ModuleList):
  79. # Ensemble of models
  80. def __init__(self):
  81. super(Ensemble, self).__init__()
  82. def forward(self, x, augment=False):
  83. y = []
  84. for module in self:
  85. y.append(module(x, augment)[0])
  86. # y = torch.stack(y).max(0)[0] # max ensemble
  87. # y = torch.stack(y).mean(0) # mean ensemble
  88. y = torch.cat(y, 1) # nms ensemble
  89. return y, None # inference, train output
  90. def attempt_load(weights, map_location=None, inplace=True):
  91. from models.yolo import Detect, Model
  92. # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
  93. model = Ensemble()
  94. for w in weights if isinstance(weights, list) else [weights]:
  95. ckpt = torch.load(attempt_download(w), map_location=map_location) # load
  96. model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
  97. # Compatibility updates
  98. for m in model.modules():
  99. if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
  100. m.inplace = inplace # pytorch 1.7.0 compatibility
  101. elif type(m) is Conv:
  102. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
  103. if len(model) == 1:
  104. return model[-1] # return model
  105. else:
  106. print(f'Ensemble created with {weights}\n')
  107. for k in ['names']:
  108. setattr(model, k, getattr(model[-1], k))
  109. model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
  110. return model # return ensemble