Quellcode durchsuchen

Merge `develop` branch into `master` (#3518)

* update ci-testing.yml (#3322)

* update ci-testing.yml

* update greetings.yml

* bring back os matrix

* update ci-testing.yml (#3322)

* update ci-testing.yml

* update greetings.yml

* bring back os matrix

* Enable direct `--weights URL` definition (#3373)

* Enable direct `--weights URL` definition

@KalenMike this PR will enable direct --weights URL definition. Example use case:
```
python train.py --weights https://storage.googleapis.com/bucket/dir/model.pt
```

* cleanup

* bug fixes

* weights = attempt_download(weights)

* Update experimental.py

* Update hubconf.py

* return bug fix

* comment mirror

* min_bytes

* Update tutorial.ipynb (#3368)

add Open in Kaggle badge

* `cv2.imread(img, -1)` for IMREAD_UNCHANGED (#3379)

* Update datasets.py

* comment

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* COCO evolution fix (#3388)

* COCO evolution fix

* cleanup

* update print

* print fix

* Create `is_pip()` function (#3391)

Returns `True` if file is part of pip package. Useful for contextual behavior modification.

```python
def is_pip():
    # Is file in a pip package?
    return 'site-packages' in Path(__file__).absolute().parts
```

* Revert "`cv2.imread(img, -1)` for IMREAD_UNCHANGED (#3379)" (#3395)

This reverts commit 21a9607e00.

* Update FLOPs description (#3422)

* Update README.md

* Changing FLOPS to FLOPs.

Co-authored-by: BuildTools <unconfigured@null.spigotmc.org>

* Parse URL authentication (#3424)

* Parse URL authentication

* urllib.parse.unquote()

* improved error handling

* improved error handling

* remove %3F

* update check_file()

* Add FLOPs title to table (#3453)

* Suppress jit trace warning + graph once (#3454)

* Suppress jit trace warning + graph once

Suppress harmless jit trace warning on TensorBoard add_graph call. Also fix multiple add_graph() calls bug, now only on batch 0.

* Update train.py

* Update MixUp augmentation `alpha=beta=32.0` (#3455)

Per VOC empirical results https://github.com/ultralytics/yolov5/issues/3380#issuecomment-853001307 by @developer0hye

* Add `timeout()` class (#3460)

* Add `timeout()` class

* rearrange order

* Faster HSV augmentation (#3462)

remove datatype conversion process that can be skipped

* Add `check_git_status()` 5 second timeout (#3464)

* Add check_git_status() 5 second timeout

This should prevent the SSH Git bug that we were discussing @KalenMike

* cleanup

* replace timeout with check_output built-in timeout

* Improved `check_requirements()` offline-handling (#3466)

Improve robustness of `check_requirements()` function to offline environments (do not attempt pip installs when offline).

* Add `output_names` argument for ONNX export with dynamic axes (#3456)

* Add output names & dynamic axes for onnx export

Add output_names and dynamic_axes names for all outputs in torch.onnx.export. The first four outputs of the model will have names output0, output1, output2, output3

* use first output only + cleanup

Co-authored-by: Samridha Shrestha <samridha.shrestha@g42.ai>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* Revert FP16 `test.py` and `detect.py` inference to FP32 default (#3423)

* fixed inference bug ,while use half precision

* replace --use-half with --half

* replace space and PEP8 in detect.py

* PEP8 detect.py

* update --half help comment

* Update test.py

* revert space

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* Add additional links/resources to stale.yml message (#3467)

* Update stale.yml

* cleanup

* Update stale.yml

* reformat

* Update stale.yml HUB URL (#3468)

* Stale `github.actor` bug fix (#3483)

* Explicit `model.eval()` call `if opt.train=False` (#3475)

* call model.eval() when opt.train is False

call model.eval() when opt.train is False

* single-line if statement

* cleanup

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* check_requirements() exclude `opencv-python` (#3495)

Fix for 3rd party or contrib versions of installed OpenCV as in https://github.com/ultralytics/yolov5/issues/3494.

* Earlier `assert` for cpu and half option (#3508)

* early assert for cpu and half option

early assert for cpu and half option

* Modified comment

Modified comment

* Update tutorial.ipynb (#3510)

* Reduce test.py results spacing (#3511)

* Update README.md (#3512)

* Update README.md

Minor modifications

* 850 width

* Update greetings.yml

revert greeting change as PRs will now merge to master.

Co-authored-by: Piotr Skalski <SkalskiP@users.noreply.github.com>
Co-authored-by: SkalskiP <piotr.skalski92@gmail.com>
Co-authored-by: Peretz Cohen <pizzaz93@users.noreply.github.com>
Co-authored-by: tudoulei <34886368+tudoulei@users.noreply.github.com>
Co-authored-by: chocosaj <chocosaj@users.noreply.github.com>
Co-authored-by: BuildTools <unconfigured@null.spigotmc.org>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
Co-authored-by: Sam_S <SamSamhuns@users.noreply.github.com>
Co-authored-by: Samridha Shrestha <samridha.shrestha@g42.ai>
Co-authored-by: edificewang <609552430@qq.com>
modifyDataloader
Glenn Jocher GitHub vor 3 Jahren
Ursprung
Commit
f3c3d2ce5d
Es konnte kein GPG-Schlüssel zu dieser Signatur gefunden werden GPG-Schlüssel-ID: 4AEE18F83AFDEB23
16 geänderte Dateien mit 187 neuen und 123 gelöschten Zeilen
  1. +2
    -4
      .github/workflows/ci-testing.yml
  2. +20
    -2
      .github/workflows/stale.yml
  3. +13
    -13
      README.md
  4. +2
    -1
      detect.py
  5. +1
    -2
      hubconf.py
  6. +1
    -2
      models/experimental.py
  7. +9
    -9
      models/export.py
  8. +3
    -3
      models/yolo.py
  9. +1
    -1
      requirements.txt
  10. +4
    -2
      test.py
  11. +36
    -36
      train.py
  12. +7
    -6
      tutorial.ipynb
  13. +3
    -3
      utils/datasets.py
  14. +42
    -12
      utils/general.py
  15. +37
    -21
      utils/google_utils.py
  16. +6
    -6
      utils/torch_utils.py

+ 2
- 4
.github/workflows/ci-testing.yml Datei anzeigen

@@ -2,12 +2,10 @@ name: CI CPU testing

on: # https://help.github.com/en/actions/reference/events-that-trigger-workflows
push:
branches: [ master ]
branches: [ master, develop ]
pull_request:
# The branches below must be a subset of the branches above
branches: [ master ]
schedule:
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
branches: [ master, develop ]

jobs:
cpu-tests:

+ 20
- 2
.github/workflows/stale.yml Datei anzeigen

@@ -10,8 +10,26 @@ jobs:
- uses: actions/stale@v3
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: 'This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.'
stale-pr-message: 'This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.'
stale-issue-message: |
👋 Hello, this issue has been automatically marked as stale because it has not had recent activity. Please note it will be closed if no further activity occurs.

Access additional [YOLOv5](https://ultralytics.com/yolov5) 🚀 resources:
- **Wiki** – https://github.com/ultralytics/yolov5/wiki
- **Tutorials** – https://github.com/ultralytics/yolov5#tutorials
- **Docs** – https://docs.ultralytics.com

Access additional [Ultralytics](https://ultralytics.com) ⚡ resources:
- **Ultralytics HUB** – https://ultralytics.com/pricing
- **Vision API** – https://ultralytics.com/yolov5
- **About Us** – https://ultralytics.com/about
- **Join Our Team** – https://ultralytics.com/work
- **Contact Us** – https://ultralytics.com/contact

Feel free to inform us of any other **issues** you discover or **feature requests** that come to mind in the future. Pull Requests (PRs) are also always welcomed!

Thank you for your contributions to YOLOv5 🚀 and Vision AI ⭐!

stale-pr-message: 'This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions YOLOv5 🚀 and Vision AI ⭐.'
days-before-stale: 30
days-before-close: 5
exempt-issue-labels: 'documentation,tutorial'

+ 13
- 13
README.md Datei anzeigen

@@ -1,5 +1,5 @@
<a align="left" href="https://apps.apple.com/app/id1452689527" target="_blank">
<img width="800" src="https://user-images.githubusercontent.com/26833433/98699617-a1595a00-2377-11eb-8145-fc674eb9b1a7.jpg"></a>
<img width="850" src="https://user-images.githubusercontent.com/26833433/121094150-72607500-c7ee-11eb-9f39-1d9e4ce89a9e.jpg"></a>
&nbsp

<a href="https://github.com/ultralytics/yolov5/actions"><img src="https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg" alt="CI CPU testing"></a>
@@ -30,19 +30,19 @@ This repository represents Ultralytics open-source research into future object d

[assets]: https://github.com/ultralytics/yolov5/releases

Model |size<br><sup>(pixels) |mAP<sup>val<br>0.5:0.95 |mAP<sup>test<br>0.5:0.95 |mAP<sup>val<br>0.5 |Speed<br><sup>V100 (ms) | |params<br><sup>(M) |FLOPS<br><sup>640 (B)
--- |--- |--- |--- |--- |--- |---|--- |---
[YOLOv5s][assets] |640 |36.7 |36.7 |55.4 |**2.0** | |7.3 |17.0
[YOLOv5m][assets] |640 |44.5 |44.5 |63.1 |2.7 | |21.4 |51.3
[YOLOv5l][assets] |640 |48.2 |48.2 |66.9 |3.8 | |47.0 |115.4
[YOLOv5x][assets] |640 |**50.4** |**50.4** |**68.8** |6.1 | |87.7 |218.8
|Model |size<br><sup>(pixels) |mAP<sup>val<br>0.5:0.95 |mAP<sup>test<br>0.5:0.95 |mAP<sup>val<br>0.5 |Speed<br><sup>V100 (ms) | |params<br><sup>(M) |FLOPs<br><sup>640 (B)
|--- |--- |--- |--- |--- |--- |---|--- |---
|[YOLOv5s][assets] |640 |36.7 |36.7 |55.4 |**2.0** | |7.3 |17.0
|[YOLOv5m][assets] |640 |44.5 |44.5 |63.1 |2.7 | |21.4 |51.3
|[YOLOv5l][assets] |640 |48.2 |48.2 |66.9 |3.8 | |47.0 |115.4
|[YOLOv5x][assets] |640 |**50.4** |**50.4** |**68.8** |6.1 | |87.7 |218.8
| | | | | | || |
[YOLOv5s6][assets] |1280 |43.3 |43.3 |61.9 |**4.3** | |12.7 |17.4
[YOLOv5m6][assets] |1280 |50.5 |50.5 |68.7 |8.4 | |35.9 |52.4
[YOLOv5l6][assets] |1280 |53.4 |53.4 |71.1 |12.3 | |77.2 |117.7
[YOLOv5x6][assets] |1280 |**54.4** |**54.4** |**72.0** |22.4 | |141.8 |222.9
|[YOLOv5s6][assets] |1280 |43.3 |43.3 |61.9 |**4.3** | |12.7 |17.4
|[YOLOv5m6][assets] |1280 |50.5 |50.5 |68.7 |8.4 | |35.9 |52.4
|[YOLOv5l6][assets] |1280 |53.4 |53.4 |71.1 |12.3 | |77.2 |117.7
|[YOLOv5x6][assets] |1280 |**54.4** |**54.4** |**72.0** |22.4 | |141.8 |222.9
| | | | | | || |
[YOLOv5x6][assets] TTA |1280 |**55.0** |**55.0** |**72.0** |70.8 | |- |-
|[YOLOv5x6][assets] TTA |1280 |**55.0** |**55.0** |**72.0** |70.8 | |- |-

<details>
<summary>Table Notes (click to expand)</summary>
@@ -112,7 +112,7 @@ Namespace(agnostic_nms=False, augment=False, classes=None, conf_thres=0.25, devi
YOLOv5 v4.0-96-g83dc1b4 torch 1.7.0+cu101 CUDA:0 (Tesla V100-SXM2-16GB, 16160.5MB)

Fusing layers...
Model Summary: 224 layers, 7266973 parameters, 0 gradients, 17.0 GFLOPS
Model Summary: 224 layers, 7266973 parameters, 0 gradients, 17.0 GFLOPs
image 1/2 /content/yolov5/data/images/bus.jpg: 640x480 4 persons, 1 bus, Done. (0.010s)
image 2/2 /content/yolov5/data/images/zidane.jpg: 384x640 2 persons, 1 tie, Done. (0.011s)
Results saved to runs/detect/exp2

+ 2
- 1
detect.py Datei anzeigen

@@ -28,7 +28,7 @@ def detect(opt):
# Initialize
set_logging()
device = select_device(opt.device)
half = device.type != 'cpu' # half precision only supported on CUDA
half = opt.half and device.type != 'cpu' # half precision only supported on CUDA

# Load model
model = attempt_load(weights, map_location=device) # load FP32 model
@@ -172,6 +172,7 @@ if __name__ == '__main__':
parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
parser.add_argument('--half', type=bool, default=False, help='use FP16 half-precision inference')
opt = parser.parse_args()
print(opt)
check_requirements(exclude=('tensorboard', 'pycocotools', 'thop'))

+ 1
- 2
hubconf.py Datei anzeigen

@@ -42,8 +42,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
model = Model(cfg, channels, classes) # create model
if pretrained:
attempt_download(fname) # download if not found locally
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
ckpt = torch.load(attempt_download(fname), map_location=torch.device('cpu')) # load
msd = model.state_dict() # model state_dict
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter

+ 1
- 2
models/experimental.py Datei anzeigen

@@ -116,8 +116,7 @@ def attempt_load(weights, map_location=None, inplace=True):
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
attempt_download(w)
ckpt = torch.load(w, map_location=map_location) # load
ckpt = torch.load(attempt_download(w), map_location=map_location) # load
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model

# Compatibility updates

+ 9
- 9
models/export.py Datei anzeigen

@@ -44,22 +44,19 @@ if __name__ == '__main__':

# Load PyTorch model
device = select_device(opt.device)
assert not (opt.device.lower() == 'cpu' and opt.half), '--half only compatible with GPU export, i.e. use --device 0'
model = attempt_load(opt.weights, map_location=device) # load FP32 model
labels = model.names

# Checks
# Input
gs = int(max(model.stride)) # grid size (max stride)
opt.img_size = [check_img_size(x, gs) for x in opt.img_size] # verify img_size are gs-multiples
assert not (opt.device.lower() == 'cpu' and opt.half), '--half only compatible with GPU export, i.e. use --device 0'

# Input
img = torch.zeros(opt.batch_size, 3, *opt.img_size).to(device) # image size(1,3,320,192) iDetection

# Update model
if opt.half:
img, model = img.half(), model.half() # to FP16
if opt.train:
model.train() # training mode (no grid construction in Detect layer)
model.train() if opt.train else model.eval() # training mode = no Detect() layer grid construction
for k, m in model.named_modules():
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
if isinstance(m, models.common.Conv): # assign export-friendly activations
@@ -96,11 +93,14 @@ if __name__ == '__main__':

print(f'{prefix} starting export with onnx {onnx.__version__}...')
f = opt.weights.replace('.pt', '.onnx') # filename
torch.onnx.export(model, img, f, verbose=False, opset_version=opt.opset_version, input_names=['images'],
torch.onnx.export(model, img, f, verbose=False, opset_version=opt.opset_version,
training=torch.onnx.TrainingMode.TRAINING if opt.train else torch.onnx.TrainingMode.EVAL,
do_constant_folding=not opt.train,
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None)
input_names=['images'],
output_names=['output'],
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
} if opt.dynamic else None)

# Checks
model_onnx = onnx.load(f) # load onnx model

+ 3
- 3
models/yolo.py Datei anzeigen

@@ -21,7 +21,7 @@ from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, s
select_device, copy_attr

try:
import thop # for FLOPS computation
import thop # for FLOPs computation
except ImportError:
thop = None

@@ -140,13 +140,13 @@ class Model(nn.Module):
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers

if profile:
o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPS
o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
t = time_synchronized()
for _ in range(10):
_ = m(x)
dt.append((time_synchronized() - t) * 100)
if m == self.model[0]:
logger.info(f"{'time (ms)':>10s} {'GFLOPS':>10s} {'params':>10s} {'module'}")
logger.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} {'module'}")
logger.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')

x = m(x) # run

+ 1
- 1
requirements.txt Datei anzeigen

@@ -27,4 +27,4 @@ pandas
# extras --------------------------------------
# Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172
pycocotools>=2.0 # COCO mAP
thop # FLOPS computation
thop # FLOPs computation

+ 4
- 2
test.py Datei anzeigen

@@ -95,7 +95,7 @@ def test(data,
confusion_matrix = ConfusionMatrix(nc=nc)
names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
coco91class = coco80_to_coco91_class()
s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
s = ('%20s' + '%11s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0.
loss = torch.zeros(3, device=device)
jdict, stats, ap, ap_class, wandb_images = [], [], [], [], []
@@ -228,7 +228,7 @@ def test(data,
nt = torch.zeros(1)

# Print results
pf = '%20s' + '%12i' * 2 + '%12.3g' * 4 # print format
pf = '%20s' + '%11i' * 2 + '%11.3g' * 4 # print format
print(pf % ('all', seen, nt.sum(), mp, mr, map50, map))

# Print results per class
@@ -306,6 +306,7 @@ if __name__ == '__main__':
parser.add_argument('--project', default='runs/test', help='save to project/name')
parser.add_argument('--name', default='exp', help='save to project/name')
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
parser.add_argument('--half', type=bool, default=False, help='use FP16 half-precision inference')
opt = parser.parse_args()
opt.save_json |= opt.data.endswith('coco.yaml')
opt.data = check_file(opt.data) # check file
@@ -326,6 +327,7 @@ if __name__ == '__main__':
save_txt=opt.save_txt | opt.save_hybrid,
save_hybrid=opt.save_hybrid,
save_conf=opt.save_conf,
half_precision=opt.half,
opt=opt
)


+ 36
- 36
train.py Datei anzeigen

@@ -4,6 +4,7 @@ import math
import os
import random
import time
import warnings
from copy import deepcopy
from pathlib import Path
from threading import Thread
@@ -62,7 +63,6 @@ def train(hyp, opt, device, tb_writer=None):
init_seeds(2 + rank)
with open(opt.data) as f:
data_dict = yaml.safe_load(f) # data dict
is_coco = opt.data.endswith('coco.yaml')

# Logging- Doing this before checking the dataset. Might update data_dict
loggers = {'wandb': None} # loggers dict
@@ -78,12 +78,13 @@ def train(hyp, opt, device, tb_writer=None):
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset

# Model
pretrained = weights.endswith('.pt')
if pretrained:
with torch_distributed_zero_first(rank):
attempt_download(weights) # download if not found locally
weights = attempt_download(weights) # download if not found locally
ckpt = torch.load(weights, map_location=device) # load checkpoint
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
@@ -323,18 +324,19 @@ def train(hyp, opt, device, tb_writer=None):
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
s = ('%10s' * 2 + '%10.4g' * 6) % (
'%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])
pbar.set_description(s)

# Plot
if plots and ni < 3:
f = save_dir / f'train_batch{ni}.jpg' # filename
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
if tb_writer:
tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), []) # model graph
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
if tb_writer and ni == 0:
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), []) # graph
elif plots and ni == 10 and wandb_logger.wandb:
wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
save_dir.glob('train*.jpg') if x.exists()]})

# end batch ------------------------------------------------------------------------------------------------
@@ -358,6 +360,7 @@ def train(hyp, opt, device, tb_writer=None):
single_cls=opt.single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=is_coco and final_epoch,
verbose=nc < 50 and final_epoch,
plots=plots and final_epoch,
wandb_logger=wandb_logger,
@@ -409,41 +412,38 @@ def train(hyp, opt, device, tb_writer=None):
# end epoch ----------------------------------------------------------------------------------------------------
# end training
if rank in [-1, 0]:
# Plots
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
if plots:
plot_results(save_dir=save_dir) # save as results.png
if wandb_logger.wandb:
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files
if (save_dir / f).exists()]})
# Test best.pt
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
for m in [last, best] if best.exists() else [last]: # speed, mAP tests
results, _, _ = test.test(opt.data,
batch_size=batch_size * 2,
imgsz=imgsz_test,
conf_thres=0.001,
iou_thres=0.7,
model=attempt_load(m, device).half(),
single_cls=opt.single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=True,
plots=False,
is_coco=is_coco)

# Strip optimizers
final = best if best.exists() else last # final model
for f in last, best:
if f.exists():
strip_optimizer(f) # strip optimizers
if opt.bucket:
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
if wandb_logger.wandb and not opt.evolve: # Log the stripped model
wandb_logger.wandb.log_artifact(str(final), type='model',
name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['latest', 'best', 'stripped'])

if not opt.evolve:
if is_coco: # COCO dataset
for m in [last, best] if best.exists() else [last]: # speed, mAP tests
results, _, _ = test.test(opt.data,
batch_size=batch_size * 2,
imgsz=imgsz_test,
conf_thres=0.001,
iou_thres=0.7,
model=attempt_load(m, device).half(),
single_cls=opt.single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=True,
plots=False,
is_coco=is_coco)

# Strip optimizers
for f in last, best:
if f.exists():
strip_optimizer(f) # strip optimizers
if wandb_logger.wandb: # Log the stripped model
wandb_logger.wandb.log_artifact(str(best if best.exists() else last), type='model',
name='run_' + wandb_logger.wandb_run.id + '_model',
aliases=['latest', 'best', 'stripped'])
wandb_logger.finish_run()
else:
dist.destroy_process_group()

+ 7
- 6
tutorial.ipynb Datei anzeigen

@@ -517,7 +517,8 @@
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/ultralytics/yolov5/blob/master/tutorial.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
"<a href=\"https://colab.research.google.com/github/ultralytics/yolov5/blob/master/tutorial.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>",
"<a href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ultralytics/yolov5/blob/master/tutorial.ipynb\" target=\"_parent\"><img alt=\"Kaggle\" title=\"Open in Kaggle\" src=\"https://kaggle.com/static/images/open-in-kaggle.svg\"></a>"
]
},
{
@@ -529,7 +530,7 @@
"<img src=\"https://user-images.githubusercontent.com/26833433/98702494-b71c4e80-237a-11eb-87ed-17fcd6b3f066.jpg\">\n",
"\n",
"This is the **official YOLOv5 🚀 notebook** authored by **Ultralytics**, and is freely available for redistribution under the [GPL-3.0 license](https://choosealicense.com/licenses/gpl-3.0/). \n",
"For more information please visit https://github.com/ultralytics/yolov5 and https://www.ultralytics.com. Thank you!"
"For more information please visit https://github.com/ultralytics/yolov5 and https://ultralytics.com. Thank you!"
]
},
{
@@ -610,7 +611,7 @@
"YOLOv5 🚀 v5.0-1-g0f395b3 torch 1.8.1+cu101 CUDA:0 (Tesla V100-SXM2-16GB, 16160.5MB)\n",
"\n",
"Fusing layers... \n",
"Model Summary: 224 layers, 7266973 parameters, 0 gradients, 17.0 GFLOPS\n",
"Model Summary: 224 layers, 7266973 parameters, 0 gradients, 17.0 GFLOPs\n",
"image 1/2 /content/yolov5/data/images/bus.jpg: 640x480 4 persons, 1 bus, Done. (0.008s)\n",
"image 2/2 /content/yolov5/data/images/zidane.jpg: 384x640 2 persons, 2 ties, Done. (0.008s)\n",
"Results saved to runs/detect/exp\n",
@@ -733,7 +734,7 @@
"100% 168M/168M [00:05<00:00, 32.3MB/s]\n",
"\n",
"Fusing layers... \n",
"Model Summary: 476 layers, 87730285 parameters, 0 gradients, 218.8 GFLOPS\n",
"Model Summary: 476 layers, 87730285 parameters, 0 gradients, 218.8 GFLOPs\n",
"\u001b[34m\u001b[1mval: \u001b[0mScanning '../coco/val2017' images and labels... 4952 found, 48 missing, 0 empty, 0 corrupted: 100% 5000/5000 [00:01<00:00, 3102.29it/s]\n",
"\u001b[34m\u001b[1mval: \u001b[0mNew cache created: ../coco/val2017.cache\n",
" Class Images Labels P R mAP@.5 mAP@.5:.95: 100% 157/157 [01:23<00:00, 1.87it/s]\n",
@@ -963,7 +964,7 @@
" 22 [-1, 10] 1 0 models.common.Concat [1] \n",
" 23 -1 1 1182720 models.common.C3 [512, 512, 1, False] \n",
" 24 [17, 20, 23] 1 229245 models.yolo.Detect [80, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]]\n",
"Model Summary: 283 layers, 7276605 parameters, 7276605 gradients, 17.1 GFLOPS\n",
"Model Summary: 283 layers, 7276605 parameters, 7276605 gradients, 17.1 GFLOPs\n",
"\n",
"Transferred 362/362 items from yolov5s.pt\n",
"Scaled weight_decay = 0.0005\n",
@@ -1260,4 +1261,4 @@
"outputs": []
}
]
}
}

+ 3
- 3
utils/datasets.py Datei anzeigen

@@ -535,7 +535,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
# MixUp https://arxiv.org/pdf/1710.09412.pdf
if random.random() < hyp['mixup']:
img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1))
r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
img = (img * r + img2 * (1 - r)).astype(np.uint8)
labels = np.concatenate((labels, labels2), 0)

@@ -655,12 +655,12 @@ def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
dtype = img.dtype # uint8

x = np.arange(0, 256, dtype=np.int16)
x = np.arange(0, 256, dtype=r.dtype)
lut_hue = ((x * r[0]) % 180).astype(dtype)
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)

img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed



+ 42
- 12
utils/general.py Datei anzeigen

@@ -1,5 +1,6 @@
# YOLOv5 general utils

import contextlib
import glob
import logging
import math
@@ -7,11 +8,13 @@ import os
import platform
import random
import re
import subprocess
import signal
import time
import urllib
from itertools import repeat
from multiprocessing.pool import ThreadPool
from pathlib import Path
from subprocess import check_output

import cv2
import numpy as np
@@ -33,6 +36,26 @@ cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with Py
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads


class timeout(contextlib.ContextDecorator):
# Usage: @timeout(seconds) decorator or 'with timeout(seconds):' context manager
def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
self.seconds = int(seconds)
self.timeout_message = timeout_msg
self.suppress = bool(suppress_timeout_errors)

def _timeout_handler(self, signum, frame):
raise TimeoutError(self.timeout_message)

def __enter__(self):
signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
signal.alarm(self.seconds) # start countdown for SIGALRM to be raised

def __exit__(self, exc_type, exc_val, exc_tb):
signal.alarm(0) # Cancel SIGALRM if it's scheduled
if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
return True


def set_logging(rank=-1, verbose=True):
logging.basicConfig(
format="%(message)s",
@@ -53,12 +76,12 @@ def get_latest_run(search_dir='.'):


def is_docker():
# Is environment a Docker container
# Is environment a Docker container?
return Path('/workspace').exists() # or Path('/.dockerenv').exists()


def is_colab():
# Is environment a Google Colab instance
# Is environment a Google Colab instance?
try:
import google.colab
return True
@@ -66,6 +89,11 @@ def is_colab():
return False


def is_pip():
# Is file in a pip package?
return 'site-packages' in Path(__file__).absolute().parts


def emojis(str=''):
# Return platform-dependent emoji-safe version of string
return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
@@ -80,13 +108,13 @@ def check_online():
# Check internet connectivity
import socket
try:
socket.create_connection(("1.1.1.1", 443), 5) # check host accesability
socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
return True
except OSError:
return False


def check_git_status():
def check_git_status(err_msg=', for updates see https://github.com/ultralytics/yolov5'):
# Recommend 'git pull' if code is out of date
print(colorstr('github: '), end='')
try:
@@ -95,9 +123,9 @@ def check_git_status():
assert check_online(), 'skipping check (offline)'

cmd = 'git fetch && git config --get remote.origin.url'
url = subprocess.check_output(cmd, shell=True).decode().strip().rstrip('.git') # github repo url
branch = subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
n = int(subprocess.check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
if n > 0:
s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
f"Use 'git pull' to update or 'git clone {url}' to download latest."
@@ -105,7 +133,7 @@ def check_git_status():
s = f'up to date with {url} ✅'
print(emojis(s)) # emoji-safe
except Exception as e:
print(e)
print(f'{e}{err_msg}')


def check_python(minimum='3.7.0', required=True):
@@ -135,10 +163,11 @@ def check_requirements(requirements='requirements.txt', exclude=()):
try:
pkg.require(r)
except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
n += 1
print(f"{prefix} {r} not found and is required by YOLOv5, attempting auto-update...")
try:
print(subprocess.check_output(f"pip install '{r}'", shell=True).decode())
assert check_online(), f"'pip install {r}' skipped (offline)"
print(check_output(f"pip install '{r}'", shell=True).decode())
n += 1
except Exception as e:
print(f'{prefix} {e}')

@@ -178,7 +207,8 @@ def check_file(file):
if Path(file).is_file() or file == '': # exists
return file
elif file.startswith(('http://', 'https://')): # download
url, file = file, Path(file).name
url, file = file, Path(urllib.parse.unquote(str(file))).name # url, file (decode '%2F' to '/' etc.)
file = file.split('?')[0] # parse authentication https://url.com/file.txt?auth...
print(f'Downloading {url} to {file}...')
torch.hub.download_url_to_file(url, file)
assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check

+ 37
- 21
utils/google_utils.py Datei anzeigen

@@ -4,6 +4,7 @@ import os
import platform
import subprocess
import time
import urllib
from pathlib import Path

import requests
@@ -16,11 +17,39 @@ def gsutil_getsize(url=''):
return eval(s.split(' ')[0]) if len(s) else 0 # bytes


def attempt_download(file, repo='ultralytics/yolov5'):
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
file = Path(file)
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
try: # url1
print(f'Downloading {url} to {file}...')
torch.hub.download_url_to_file(url, str(file))
assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
except Exception as e: # url2
file.unlink(missing_ok=True) # remove partial downloads
print(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
finally:
if not file.exists() or file.stat().st_size < min_bytes: # check
file.unlink(missing_ok=True) # remove partial downloads
print(f"ERROR: {assert_msg}\n{error_msg}")
print('')


def attempt_download(file, repo='ultralytics/yolov5'): # from utils.google_utils import *; attempt_download()
# Attempt file download if does not exist
file = Path(str(file).strip().replace("'", ''))

if not file.exists():
# URL specified
name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc.
if str(file).startswith(('http:/', 'https:/')): # download
url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
name = name.split('?')[0] # parse authentication https://url.com/file.txt?auth...
safe_download(file=name, url=url, min_bytes=1E5)
return name

# GitHub assets
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
try:
response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api
@@ -34,27 +63,14 @@ def attempt_download(file, repo='ultralytics/yolov5'):
except:
tag = 'v5.0' # current release

name = file.name
if name in assets:
msg = f'{file} missing, try downloading from https://github.com/{repo}/releases/'
redundant = False # second download option
try: # GitHub
url = f'https://github.com/{repo}/releases/download/{tag}/{name}'
print(f'Downloading {url} to {file}...')
torch.hub.download_url_to_file(url, file)
assert file.exists() and file.stat().st_size > 1E6 # check
except Exception as e: # GCP
print(f'Download error: {e}')
assert redundant, 'No secondary mirror'
url = f'https://storage.googleapis.com/{repo}/ckpt/{name}'
print(f'Downloading {url} to {file}...')
os.system(f"curl -L '{url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
finally:
if not file.exists() or file.stat().st_size < 1E6: # check
file.unlink(missing_ok=True) # remove partial downloads
print(f'ERROR: Download failure: {msg}')
print('')
return
safe_download(file,
url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
# url2=f'https://storage.googleapis.com/{repo}/ckpt/{name}', # backup url (optional)
min_bytes=1E5,
error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/')

return str(file)


def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):

+ 6
- 6
utils/torch_utils.py Datei anzeigen

@@ -18,7 +18,7 @@ import torch.nn.functional as F
import torchvision

try:
import thop # for FLOPS computation
import thop # for FLOPs computation
except ImportError:
thop = None
logger = logging.getLogger(__name__)
@@ -105,13 +105,13 @@ def profile(x, ops, n=100, device=None):
x = x.to(device)
x.requires_grad = True
print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
print(f"\n{'Params':>12s}{'GFLOPs':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
for m in ops if isinstance(ops, list) else [ops]:
m = m.to(device) if hasattr(m, 'to') else m # device
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type
dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
try:
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
except:
flops = 0

@@ -219,13 +219,13 @@ def model_info(model, verbose=False, img_size=640):
print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))

try: # FLOPS
try: # FLOPs
from thop import profile
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPS
fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPs
except (ImportError, Exception):
fs = ''


Laden…
Abbrechen
Speichern