Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

experimental.py 4.4KB

před 4 roky
před 4 roky
před 4 roky
před 4 roky
před 4 roky
před 4 roky
před 4 roky
před 4 roky
před 4 roky
před 4 roky
Add TensorFlow and TFLite export (#1127) * Add models/tf.py for TensorFlow and TFLite export * Set auto=False for int8 calibration * Update requirements.txt for TensorFlow and TFLite export * Read anchors directly from PyTorch weights * Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export * Remove check_anchor_order, check_file, set_logging from import * Reformat code and optimize imports * Autodownload model and check cfg * update --source path, img-size to 320, single output * Adjust representative_dataset * Put representative dataset in tfl_int8 block * detect.py TF inference * weights to string * weights to string * cleanup tf.py * Add --dynamic-batch-size * Add xywh normalization to reduce calibration error * Update requirements.txt TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error * Fix imports Move C3 from models.experimental to models.common * Add models/tf.py for TensorFlow and TFLite export * Set auto=False for int8 calibration * Update requirements.txt for TensorFlow and TFLite export * Read anchors directly from PyTorch weights * Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export * Remove check_anchor_order, check_file, set_logging from import * Reformat code and optimize imports * Autodownload model and check cfg * update --source path, img-size to 320, single output * Adjust representative_dataset * detect.py TF inference * Put representative dataset in tfl_int8 block * weights to string * weights to string * cleanup tf.py * Add --dynamic-batch-size * Add xywh normalization to reduce calibration error * Update requirements.txt TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error * Fix imports Move C3 from models.experimental to models.common * implement C3() and SiLU() * Fix reshape dim to support dynamic batching * Add epsilon argument in tf_BN, which is different between TF and PT * Set stride to None if not using PyTorch, and do not warmup without PyTorch * Add list support in check_img_size() * Add list input support in detect.py * sys.path.append('./') to run from yolov5/ * Add int8 quantization support for TensorFlow 2.5 * Add get_coco128.sh * Remove --no-tfl-detect in models/tf.py (Use tf-android-tfl-detect branch for EdgeTPU) * Update requirements.txt * Replace torch.load() with attempt_load() * Update requirements.txt * Add --tf-raw-resize to set half_pixel_centers=False * Add --agnostic-nms for TF class-agnostic NMS * Cleanup after merge * Cleanup2 after merge * Cleanup3 after merge * Add tf.py docstring with credit and usage * pb saved_model and tflite use only one model in detect.py * Add use cases in docstring of tf.py * Remove redundant `stride` definition * Remove keras direct import * Fix `check_requirements(('tensorflow>=2.4.1',))` Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
před 3 roky
Merge `develop` branch into `master` (#3518) * update ci-testing.yml (#3322) * update ci-testing.yml * update greetings.yml * bring back os matrix * update ci-testing.yml (#3322) * update ci-testing.yml * update greetings.yml * bring back os matrix * Enable direct `--weights URL` definition (#3373) * Enable direct `--weights URL` definition @KalenMike this PR will enable direct --weights URL definition. Example use case: ``` python train.py --weights https://storage.googleapis.com/bucket/dir/model.pt ``` * cleanup * bug fixes * weights = attempt_download(weights) * Update experimental.py * Update hubconf.py * return bug fix * comment mirror * min_bytes * Update tutorial.ipynb (#3368) add Open in Kaggle badge * `cv2.imread(img, -1)` for IMREAD_UNCHANGED (#3379) * Update datasets.py * comment Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> * COCO evolution fix (#3388) * COCO evolution fix * cleanup * update print * print fix * Create `is_pip()` function (#3391) Returns `True` if file is part of pip package. Useful for contextual behavior modification. ```python def is_pip(): # Is file in a pip package? return 'site-packages' in Path(__file__).absolute().parts ``` * Revert "`cv2.imread(img, -1)` for IMREAD_UNCHANGED (#3379)" (#3395) This reverts commit 21a9607e00f1365b21d8c4bd81bdbf5fc0efea24. * Update FLOPs description (#3422) * Update README.md * Changing FLOPS to FLOPs. Co-authored-by: BuildTools <unconfigured@null.spigotmc.org> * Parse URL authentication (#3424) * Parse URL authentication * urllib.parse.unquote() * improved error handling * improved error handling * remove %3F * update check_file() * Add FLOPs title to table (#3453) * Suppress jit trace warning + graph once (#3454) * Suppress jit trace warning + graph once Suppress harmless jit trace warning on TensorBoard add_graph call. Also fix multiple add_graph() calls bug, now only on batch 0. * Update train.py * Update MixUp augmentation `alpha=beta=32.0` (#3455) Per VOC empirical results https://github.com/ultralytics/yolov5/issues/3380#issuecomment-853001307 by @developer0hye * Add `timeout()` class (#3460) * Add `timeout()` class * rearrange order * Faster HSV augmentation (#3462) remove datatype conversion process that can be skipped * Add `check_git_status()` 5 second timeout (#3464) * Add check_git_status() 5 second timeout This should prevent the SSH Git bug that we were discussing @KalenMike * cleanup * replace timeout with check_output built-in timeout * Improved `check_requirements()` offline-handling (#3466) Improve robustness of `check_requirements()` function to offline environments (do not attempt pip installs when offline). * Add `output_names` argument for ONNX export with dynamic axes (#3456) * Add output names & dynamic axes for onnx export Add output_names and dynamic_axes names for all outputs in torch.onnx.export. The first four outputs of the model will have names output0, output1, output2, output3 * use first output only + cleanup Co-authored-by: Samridha Shrestha <samridha.shrestha@g42.ai> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Revert FP16 `test.py` and `detect.py` inference to FP32 default (#3423) * fixed inference bug ,while use half precision * replace --use-half with --half * replace space and PEP8 in detect.py * PEP8 detect.py * update --half help comment * Update test.py * revert space Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Add additional links/resources to stale.yml message (#3467) * Update stale.yml * cleanup * Update stale.yml * reformat * Update stale.yml HUB URL (#3468) * Stale `github.actor` bug fix (#3483) * Explicit `model.eval()` call `if opt.train=False` (#3475) * call model.eval() when opt.train is False call model.eval() when opt.train is False * single-line if statement * cleanup Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> * check_requirements() exclude `opencv-python` (#3495) Fix for 3rd party or contrib versions of installed OpenCV as in https://github.com/ultralytics/yolov5/issues/3494. * Earlier `assert` for cpu and half option (#3508) * early assert for cpu and half option early assert for cpu and half option * Modified comment Modified comment * Update tutorial.ipynb (#3510) * Reduce test.py results spacing (#3511) * Update README.md (#3512) * Update README.md Minor modifications * 850 width * Update greetings.yml revert greeting change as PRs will now merge to master. Co-authored-by: Piotr Skalski <SkalskiP@users.noreply.github.com> Co-authored-by: SkalskiP <piotr.skalski92@gmail.com> Co-authored-by: Peretz Cohen <pizzaz93@users.noreply.github.com> Co-authored-by: tudoulei <34886368+tudoulei@users.noreply.github.com> Co-authored-by: chocosaj <chocosaj@users.noreply.github.com> Co-authored-by: BuildTools <unconfigured@null.spigotmc.org> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: Sam_S <SamSamhuns@users.noreply.github.com> Co-authored-by: Samridha Shrestha <samridha.shrestha@g42.ai> Co-authored-by: edificewang <609552430@qq.com>
před 3 roky
Add TensorFlow and TFLite export (#1127) * Add models/tf.py for TensorFlow and TFLite export * Set auto=False for int8 calibration * Update requirements.txt for TensorFlow and TFLite export * Read anchors directly from PyTorch weights * Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export * Remove check_anchor_order, check_file, set_logging from import * Reformat code and optimize imports * Autodownload model and check cfg * update --source path, img-size to 320, single output * Adjust representative_dataset * Put representative dataset in tfl_int8 block * detect.py TF inference * weights to string * weights to string * cleanup tf.py * Add --dynamic-batch-size * Add xywh normalization to reduce calibration error * Update requirements.txt TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error * Fix imports Move C3 from models.experimental to models.common * Add models/tf.py for TensorFlow and TFLite export * Set auto=False for int8 calibration * Update requirements.txt for TensorFlow and TFLite export * Read anchors directly from PyTorch weights * Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export * Remove check_anchor_order, check_file, set_logging from import * Reformat code and optimize imports * Autodownload model and check cfg * update --source path, img-size to 320, single output * Adjust representative_dataset * detect.py TF inference * Put representative dataset in tfl_int8 block * weights to string * weights to string * cleanup tf.py * Add --dynamic-batch-size * Add xywh normalization to reduce calibration error * Update requirements.txt TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error * Fix imports Move C3 from models.experimental to models.common * implement C3() and SiLU() * Fix reshape dim to support dynamic batching * Add epsilon argument in tf_BN, which is different between TF and PT * Set stride to None if not using PyTorch, and do not warmup without PyTorch * Add list support in check_img_size() * Add list input support in detect.py * sys.path.append('./') to run from yolov5/ * Add int8 quantization support for TensorFlow 2.5 * Add get_coco128.sh * Remove --no-tfl-detect in models/tf.py (Use tf-android-tfl-detect branch for EdgeTPU) * Update requirements.txt * Replace torch.load() with attempt_load() * Update requirements.txt * Add --tf-raw-resize to set half_pixel_centers=False * Add --agnostic-nms for TF class-agnostic NMS * Cleanup after merge * Cleanup2 after merge * Cleanup3 after merge * Add tf.py docstring with credit and usage * pb saved_model and tflite use only one model in detect.py * Add use cases in docstring of tf.py * Remove redundant `stride` definition * Remove keras direct import * Fix `check_requirements(('tensorflow>=2.4.1',))` Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
před 3 roky
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Experimental modules
  4. """
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. from models.common import Conv
  9. from utils.downloads import attempt_download
  10. class CrossConv(nn.Module):
  11. # Cross Convolution Downsample
  12. def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
  13. # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
  14. super().__init__()
  15. c_ = int(c2 * e) # hidden channels
  16. self.cv1 = Conv(c1, c_, (1, k), (1, s))
  17. self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
  18. self.add = shortcut and c1 == c2
  19. def forward(self, x):
  20. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  21. class Sum(nn.Module):
  22. # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
  23. def __init__(self, n, weight=False): # n: number of inputs
  24. super().__init__()
  25. self.weight = weight # apply weights boolean
  26. self.iter = range(n - 1) # iter object
  27. if weight:
  28. self.w = nn.Parameter(-torch.arange(1., n) / 2, requires_grad=True) # layer weights
  29. def forward(self, x):
  30. y = x[0] # no weight
  31. if self.weight:
  32. w = torch.sigmoid(self.w) * 2
  33. for i in self.iter:
  34. y = y + x[i + 1] * w[i]
  35. else:
  36. for i in self.iter:
  37. y = y + x[i + 1]
  38. return y
  39. class MixConv2d(nn.Module):
  40. # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
  41. def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
  42. super().__init__()
  43. groups = len(k)
  44. if equal_ch: # equal c_ per group
  45. i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices
  46. c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
  47. else: # equal weight.numel() per group
  48. b = [c2] + [0] * groups
  49. a = np.eye(groups + 1, groups, k=-1)
  50. a -= np.roll(a, 1, axis=1)
  51. a *= np.array(k) ** 2
  52. a[0] = 1
  53. c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
  54. self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
  55. self.bn = nn.BatchNorm2d(c2)
  56. self.act = nn.LeakyReLU(0.1, inplace=True)
  57. def forward(self, x):
  58. return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
  59. class Ensemble(nn.ModuleList):
  60. # Ensemble of models
  61. def __init__(self):
  62. super().__init__()
  63. def forward(self, x, augment=False, profile=False, visualize=False):
  64. y = []
  65. for module in self:
  66. y.append(module(x, augment, profile, visualize)[0])
  67. # y = torch.stack(y).max(0)[0] # max ensemble
  68. # y = torch.stack(y).mean(0) # mean ensemble
  69. y = torch.cat(y, 1) # nms ensemble
  70. return y, None # inference, train output
  71. def attempt_load(weights, map_location=None, inplace=True, fuse=True):
  72. from models.yolo import Detect, Model
  73. # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
  74. model = Ensemble()
  75. for w in weights if isinstance(weights, list) else [weights]:
  76. ckpt = torch.load(attempt_download(w), map_location=map_location) # load
  77. if fuse:
  78. model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
  79. else:
  80. model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # without layer fuse
  81. # Compatibility updates
  82. for m in model.modules():
  83. if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
  84. m.inplace = inplace # pytorch 1.7.0 compatibility
  85. if type(m) is Detect:
  86. if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
  87. delattr(m, 'anchor_grid')
  88. setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
  89. elif type(m) is Conv:
  90. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
  91. if len(model) == 1:
  92. return model[-1] # return model
  93. else:
  94. print(f'Ensemble created with {weights}\n')
  95. for k in ['names']:
  96. setattr(model, k, getattr(model[-1], k))
  97. model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
  98. return model # return ensemble