Browse Source

Merge branch 'master' into advanced_logging

5.0
Alex Stoken GitHub 4 years ago
parent
commit
e18e6811dc
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 481 additions and 297 deletions
  1. +24
    -10
      .github/ISSUE_TEMPLATE/--bug-report.md
  2. +4
    -4
      Dockerfile
  3. +24
    -18
      README.md
  4. +14
    -6
      detect.py
  5. +2
    -1
      models/common.py
  6. +1
    -2
      models/hub/yolov3-spp.yaml
  7. +45
    -0
      models/hub/yolov5-fpn.yaml
  8. +52
    -0
      models/hub/yolov5-panet.yaml
  9. +2
    -2
      models/onnx_export.py
  10. +9
    -8
      models/yolo.py
  11. +28
    -21
      models/yolov5l.yaml
  12. +28
    -21
      models/yolov5m.yaml
  13. +28
    -21
      models/yolov5s.yaml
  14. +28
    -21
      models/yolov5x.yaml
  15. +2
    -2
      requirements.txt
  16. +24
    -33
      test.py
  17. +24
    -36
      train.py
  18. +26
    -5
      utils/datasets.py
  19. +12
    -7
      utils/torch_utils.py
  20. +104
    -79
      utils/utils.py

+ 24
- 10
.github/ISSUE_TEMPLATE/--bug-report.md View File

@@ -7,29 +7,43 @@ assignees: ''

---

Before submitting a bug report, please ensure that you are using the latest versions of:
- Python
- PyTorch
- This repository (run `git fetch && git status -uno` to check and `git pull` to update)
Before submitting a bug report, please be aware that your issue **must be reproducible** with all of the following, otherwise it is non-actionable, and we can not help you:
- **Current repo**: run `git fetch && git status -uno` to check and `git pull` to update repo
- **Common dataset**: coco.yaml or coco128.yaml
- **Common environment**: Colab, Google Cloud, or Docker image. See https://github.com/ultralytics/yolov5#reproduce-our-environment
**Your issue must be reproducible on a public dataset (i.e COCO) using the latest version of the repository, and you must supply code to reproduce, or we can not help you.**

If this is a custom training question we suggest you include your `train*.jpg`, `test*.jpg` and `results.png` figures.
If this is a custom dataset/training question you **must include** your `train*.jpg`, `test*.jpg` and `results.png` figures, or we can not help you. You can generate these with `utils.plot_results()`.


## 🐛 Bug
A clear and concise description of what the bug is.

## To Reproduce
**REQUIRED**: Code to reproduce your issue below

## To Reproduce (REQUIRED)

Input:
```
import torch

a = torch.tensor([5])
c = a / 0
```

Output:
```
python train.py ...
Traceback (most recent call last):
File "/Users/glennjocher/opt/anaconda3/envs/env1/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-5-be04c762b799>", line 5, in <module>
c = a / 0
RuntimeError: ZeroDivisionError
```


## Expected behavior
A clear and concise description of what you expected to happen.


## Environment
If applicable, add screenshots to help explain your problem.


+ 4
- 4
Dockerfile View File

@@ -1,9 +1,6 @@
# Start FROM Nvidia PyTorch image https://ngc.nvidia.com/catalog/containers/nvidia:pytorch
FROM nvcr.io/nvidia/pytorch:20.03-py3

# Install dependencies (pip or conda)
RUN pip install -U gsutil
# RUN pip install -U -r requirements.txt

# Create working directory
RUN mkdir -p /usr/src/app
@@ -12,6 +9,9 @@ WORKDIR /usr/src/app
# Copy contents
COPY . /usr/src/app

# Install dependencies (pip or conda)
#RUN pip install -r requirements.txt

# Copy weights
#RUN python3 -c "from models import *; \
#attempt_download('weights/yolov5s.pt'); \
@@ -41,7 +41,7 @@ COPY . /usr/src/app

# Bash into running container
# sudo docker container exec -it ba65811811ab bash
# python -c "from utils.utils import *; create_backbone('weights/last.pt')" && gsutil cp weights/backbone.pt gs://*
# python -c "from utils.utils import *; create_pretrained('weights/last.pt')" && gsutil cp weights/pretrained.pt gs://*

# Bash into stopped container
# sudo docker commit 6d525e299258 user/test_image && sudo docker run -it --gpus all --ipc=host -v "$(pwd)"/coco:/usr/src/coco --entrypoint=sh user/test_image

+ 24
- 18
README.md View File

@@ -4,26 +4,29 @@

This repository represents Ultralytics open-source research into future object detection methods, and incorporates our lessons learned and best practices evolved over training thousands of models on custom client datasets with our previous YOLO repository https://github.com/ultralytics/yolov3. **All code and models are under active development, and are subject to modification or deletion without notice.** Use at your own risk.

<img src="https://user-images.githubusercontent.com/26833433/84200349-729f2680-aa5b-11ea-8f9a-604c9e01a658.png" width="1000">** GPU Latency measures end-to-end latency per image averaged over 5000 COCO val2017 images using a V100 GPU with batch size 32, and includes image preprocessing, PyTorch FP32 inference, postprocessing and NMS.
<img src="https://user-images.githubusercontent.com/26833433/85340570-30360a80-b49b-11ea-87cf-bdf33d53ae15.png" width="1000">** GPU Speed measures end-to-end time per image averaged over 5000 COCO val2017 images using a V100 GPU with batch size 8, and includes image preprocessing, PyTorch FP16 inference, postprocessing and NMS.

- **June 9, 2020**: [CSP](https://github.com/WongKinYiu/CrossStagePartialNetworks) updates to all YOLOv5 models. New models are faster, smaller and more accurate. Credit to @WongKinYiu for his excellent work with CSP.
- **May 27, 2020**: Public release of repo. YOLOv5 models are SOTA among all known YOLO implementations, YOLOv5 family will be undergoing architecture research and development over Q2/Q3 2020 to increase performance. Updates may include [CSP](https://github.com/WongKinYiu/CrossStagePartialNetworks) bottlenecks, [YOLOv4](https://github.com/AlexeyAB/darknet) features, as well as PANet or BiFPN heads.
- **April 1, 2020**: Begin development of a 100% PyTorch, scaleable YOLOv3/4-based group of future models, in a range of compound-scaled sizes. Models will be defined by new user-friendly `*.yaml` files. New training methods will be simpler to start, faster to finish, and more robust to training a wider variety of custom dataset.
- **June 22, 2020**: [PANet](https://arxiv.org/abs/1803.01534) updates: new heads, reduced parameters, faster inference and improved mAP [364fcfd](https://github.com/ultralytics/yolov5/commit/364fcfd7dba53f46edd4f04c037a039c0a287972).
- **June 19, 2020**: [FP16](https://pytorch.org/docs/stable/nn.html#torch.nn.Module.half) as new default for smaller checkpoints and faster inference [d4c6674](https://github.com/ultralytics/yolov5/commit/d4c6674c98e19df4c40e33a777610a18d1961145).
- **June 9, 2020**: [CSP](https://github.com/WongKinYiu/CrossStagePartialNetworks) updates: improved speed, size, and accuracy (credit to @WongKinYiu for CSP).
- **May 27, 2020**: Public release of repo. YOLOv5 models are SOTA among all known YOLO implementations.
- **April 1, 2020**: Start development of future [YOLOv3](https://github.com/ultralytics/yolov3)/[YOLOv4](https://github.com/AlexeyAB/darknet)-based PyTorch models in a range of compound-scaled sizes.


## Pretrained Checkpoints

| Model | AP<sup>val</sup> | AP<sup>test</sup> | AP<sub>50</sub> | Latency<sub>GPU</sub> | FPS<sub>GPU</sub> || params | FLOPs |
| Model | AP<sup>val</sup> | AP<sup>test</sup> | AP<sub>50</sub> | Speed<sub>GPU</sub> | FPS<sub>GPU</sub> || params | FLOPS |
|---------- |------ |------ |------ | -------- | ------| ------ |------ | :------: |
| YOLOv5-s ([ckpt](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)) | 35.5 | 35.5 | 55.0 | **2.5ms** | **400** || 7.1M | 12.6B
| YOLOv5-m ([ckpt](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)) | 42.7 | 42.7 | 62.4 | 4.4ms | 227 || 22.0M | 39.0B
| YOLOv5-l ([ckpt](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)) | 45.7 | 45.9 | 65.1 | 6.8ms | 147 || 50.3M | 89.0B
| YOLOv5-x ([ckpt](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)) | **47.2** | **47.3** | **66.6** | 11.7ms | 85 || 95.9M | 170.3B
| YOLOv3-SPP ([ckpt](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)) | 45.6 | 45.5 | 65.2 | 7.9ms | 127 || 63.0M | 118.0B
| [YOLOv5s](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 36.6 | 36.6 | 55.8 | **2.1ms** | **476** || 7.5M | 13.2B
| [YOLOv5m](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 43.4 | 43.4 | 62.4 | 3.0ms | 333 || 21.8M | 39.4B
| [YOLOv5l](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 46.6 | 46.7 | 65.4 | 3.9ms | 256 || 47.8M | 88.1B
| [YOLOv5x](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | **48.4** | **48.4** | **66.9** | 6.1ms | 164 || 89.0M | 166.4B
| [YOLOv3-SPP](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J) | 45.6 | 45.5 | 65.2 | 4.5ms | 222 || 63.0M | 118.0B


** AP<sup>test</sup> denotes COCO [test-dev2017](http://cocodataset.org/#upload) server results, all other AP results in the table denote val2017 accuracy.
** All AP numbers are for single-model single-scale without ensemble or test-time augmentation. Reproduce by `python test.py --img 736 --conf 0.001`
** Latency<sub>GPU</sub> measures end-to-end latency per image averaged over 5000 COCO val2017 images using a GCP [n1-standard-16](https://cloud.google.com/compute/docs/machine-types#n1_standard_machine_types) instance with one V100 GPU, and includes image preprocessing, PyTorch FP32 inference at batch size 32, postprocessing and NMS. Average NMS time included in this chart is 1-2ms/img. Reproduce by `python test.py --img 640 --conf 0.1`
** Speed<sub>GPU</sub> measures end-to-end time per image averaged over 5000 COCO val2017 images using a GCP [n1-standard-16](https://cloud.google.com/compute/docs/machine-types#n1_standard_machine_types) instance with one V100 GPU, and includes image preprocessing, PyTorch FP16 image inference at --batch-size 32 --img-size 640, postprocessing and NMS. Average NMS time included in this chart is 1-2ms/img. Reproduce by `python test.py --img 640 --conf 0.1`
** All checkpoints are trained to 300 epochs with default settings and hyperparameters (no autoaugmentation).


@@ -37,10 +40,10 @@ $ pip install -U -r requirements.txt

## Tutorials

* <a href="https://colab.research.google.com/github/ultralytics/yolov5/blob/master/tutorial.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
* [Notebook](https://github.com/ultralytics/yolov5/blob/master/tutorial.ipynb) <a href="https://colab.research.google.com/github/ultralytics/yolov5/blob/master/tutorial.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
* [Train Custom Data](https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data)
* [Google Cloud Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/GCP-Quickstart)
* [Docker Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/Docker-Quickstart)
* [Docker Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/Docker-Quickstart) ![Docker Pulls](https://img.shields.io/docker/pulls/ultralytics/yolov5?logo=docker)


## Inference
@@ -74,9 +77,12 @@ Results saved to /content/yolov5/inference/output

## Reproduce Our Training

Run command below. Training times for yolov5s/m/l/x are 2/4/6/8 days on a single V100 (multi-GPU times faster).
Download [COCO](https://github.com/ultralytics/yolov5/blob/master/data/get_coco2017.sh), install [Apex](https://github.com/NVIDIA/apex) and run command below. Training times for YOLOv5s/m/l/x are 2/4/6/8 days on a single V100 (multi-GPU times faster). Use the largest `--batch-size` your GPU allows (batch sizes shown for 16 GB devices).
```bash
$ python train.py --data coco.yaml --cfg yolov5s.yaml --weights '' --batch-size 16
$ python train.py --data coco.yaml --cfg yolov5s.yaml --weights '' --batch-size 64
yolov5m 48
yolov5l 32
yolov5x 16
```
<img src="https://user-images.githubusercontent.com/26833433/84186698-c4d54d00-aa45-11ea-9bde-c632c1230ccd.png" width="900">

@@ -85,20 +91,20 @@ $ python train.py --data coco.yaml --cfg yolov5s.yaml --weights '' --batch-size

To access an up-to-date working environment (with all dependencies including CUDA/CUDNN, Python and PyTorch preinstalled), consider a:

- **GCP** Deep Learning VM with $300 free credit offer: See our [GCP Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/GCP-Quickstart)
- **Google Cloud** Deep Learning VM with $300 free credit offer: See our [GCP Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/GCP-Quickstart)
- **Google Colab Notebook** with 12 hours of free GPU time. <a href="https://colab.research.google.com/github/ultralytics/yolov5/blob/master/tutorial.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
- **Docker Image** https://hub.docker.com/r/ultralytics/yolov5. See [Docker Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/Docker-Quickstart) ![Docker Pulls](https://img.shields.io/docker/pulls/ultralytics/yolov5?logo=docker)


## Citation

[![DOI](https://zenodo.org/badge/146165888.svg)](https://zenodo.org/badge/latestdoi/146165888)
[![DOI](https://zenodo.org/badge/264818686.svg)](https://zenodo.org/badge/latestdoi/264818686)


## About Us

Ultralytics is a U.S.-based particle physics and AI startup with over 6 years of expertise supporting government, academic and business clients. We offer a wide range of vision AI services, spanning from simple expert advice up to delivery of fully customized, end-to-end production solutions, including:
- **Cloud-based AI** surveillance systems operating on **hundreds of HD video streams in realtime.**
- **Cloud-based AI** systems operating on **hundreds of HD video streams in realtime.**
- **Edge AI** integrated into custom iOS and Android apps for realtime **30 FPS video inference.**
- **Custom data training**, hyperparameter evolution, and model exportation to any destination.


+ 14
- 6
detect.py View File

@@ -1,5 +1,8 @@
import argparse

import torch.backends.cudnn as cudnn

from utils import google_utils
from utils.datasets import *
from utils.utils import *

@@ -36,14 +39,14 @@ def detect(save_img=False):
vid_path, vid_writer = None, None
if webcam:
view_img = True
torch.backends.cudnn.benchmark = True # set True to speed up constant image size inference
cudnn.benchmark = True # set True to speed up constant image size inference
dataset = LoadStreams(source, img_size=imgsz)
else:
save_img = True
dataset = LoadImages(source, img_size=imgsz)

# Get names and colors
names = model.names if hasattr(model, 'names') else model.modules.names
names = model.module.names if hasattr(model, 'module') else model.names
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]

# Run inference
@@ -62,8 +65,7 @@ def detect(save_img=False):
pred = model(img, augment=opt.augment)[0]

# Apply NMS
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres,
fast=True, classes=opt.classes, agnostic=opt.agnostic_nms)
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
t2 = torch_utils.time_synchronized()

# Apply Classifier
@@ -78,6 +80,7 @@ def detect(save_img=False):
p, s, im0 = path, '', im0s

save_path = str(Path(out) / Path(p).name)
txt_path = str(Path(out) / Path(p).stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '')
s += '%gx%g ' % img.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] #  normalization gain whwh
if det is not None and len(det):
@@ -93,8 +96,8 @@ def detect(save_img=False):
for *xyxy, conf, cls in det:
if save_txt: # Write to file
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
with open(save_path[:save_path.rfind('.')] + '.txt', 'a') as file:
file.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format
with open(txt_path + '.txt', 'a') as f:
f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format

if save_img or view_img: # Add bbox to image
label = '%s %.2f' % (names[int(cls)], conf)
@@ -154,3 +157,8 @@ if __name__ == '__main__':

with torch.no_grad():
detect()

# Update all models
# for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt', 'yolov3-spp.pt']:
# detect()
# create_pretrained(opt.weights, opt.weights)

+ 2
- 1
models/common.py View File

@@ -13,7 +13,8 @@ class Conv(nn.Module):
# Standard convolution
def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
super(Conv, self).__init__()
self.conv = nn.Conv2d(c1, c2, k, s, k // 2, groups=g, bias=False)
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # padding
self.conv = nn.Conv2d(c1, c2, k, s, p, groups=g, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = nn.LeakyReLU(0.1, inplace=True) if act else nn.Identity()


models/yolov3-spp.yaml → models/hub/yolov3-spp.yaml View File

@@ -25,8 +25,7 @@ backbone:
[-1, 4, Bottleneck, [1024]], # 10
]

# yolov3-spp head
# na = len(anchors[0])
# YOLOv3-SPP head
head:
[[-1, 1, Bottleneck, [1024, False]], # 11
[-1, 1, SPP, [512, [5, 9, 13]]],

+ 45
- 0
models/hub/yolov5-fpn.yaml View File

@@ -0,0 +1,45 @@
# parameters
nc: 80 # number of classes
depth_multiple: 1.0 # model depth multiple
width_multiple: 1.0 # layer channel multiple

# anchors
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32

# YOLOv5 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Focus, [64, 3]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, Bottleneck, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 9, BottleneckCSP, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, BottleneckCSP, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]],
[-1, 6, BottleneckCSP, [1024]], # 9
]

# YOLOv5 FPN head
head:
[[-1, 3, BottleneckCSP, [1024, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 11 (P5/32-large)

[-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 1, Conv, [512, 1, 1]],
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 16 (P4/16-medium)

[-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 1, Conv, [256, 1, 1]],
[-1, 3, BottleneckCSP, [256, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 21 (P3/8-small)

[[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]

+ 52
- 0
models/hub/yolov5-panet.yaml View File

@@ -0,0 +1,52 @@
# parameters
nc: 80 # number of classes
depth_multiple: 1.0 # model depth multiple
width_multiple: 1.0 # layer channel multiple

# anchors
anchors:
- [116,90, 156,198, 373,326] # P5/32
- [30,61, 62,45, 59,119] # P4/16
- [10,13, 16,30, 33,23] # P3/8

# YOLOv5 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Focus, [64, 3]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, BottleneckCSP, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 9, BottleneckCSP, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, BottleneckCSP, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]],
]

# YOLOv5 PANet head
head:
[[-1, 3, BottleneckCSP, [1024, False]],
[-1, 1, Conv, [512, 1, 1]], # 10

[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, Conv, [256, 1, 1]], # 14

[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, BottleneckCSP, [256, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 18 (P3/8-small)

[-2, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P4/16-medium)

[-2, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, BottleneckCSP, [1024, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 26 (P5/32-large)

[[], 1, Detect, [nc, anchors]], # Detect(P5, P4, P3)
]

+ 2
- 2
models/onnx_export.py View File

@@ -1,7 +1,6 @@
"""Exports a pytorch *.pt model to *.onnx format

Usage:
import torch
$ export PYTHONPATH="$PWD" && python models/onnx_export.py --weights ./weights/yolov5s.pt --img 640 --batch 1
"""

@@ -10,6 +9,7 @@ import argparse
import onnx

from models.common import *
from utils import google_utils

if __name__ == '__main__':
parser = argparse.ArgumentParser()
@@ -25,7 +25,7 @@ if __name__ == '__main__':

# Load pytorch model
google_utils.attempt_download(opt.weights)
model = torch.load(opt.weights, map_location=torch.device('cpu'))['model']
model = torch.load(opt.weights, map_location=torch.device('cpu'))['model'].float()
model.eval()
model.fuse()


+ 9
- 8
models/yolo.py View File

@@ -1,7 +1,5 @@
import argparse

import yaml

from models.experimental import *


@@ -61,8 +59,9 @@ class Model(nn.Module):

# Build strides, anchors
m = self.model[-1] # Detect()
m.stride = torch.tensor([64 / x.shape[-2] for x in self.forward(torch.zeros(1, ch, 64, 64))]) # forward
m.stride = torch.tensor([128 / x.shape[-2] for x in self.forward(torch.zeros(1, ch, 128, 128))]) # forward
m.anchors /= m.stride.view(-1, 1, 1)
check_anchor_order(m)
self.stride = m.stride

# Init weights, biases
@@ -97,8 +96,11 @@ class Model(nn.Module):
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers

if profile:
import thop
o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # FLOPS
try:
import thop
o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # FLOPS
except:
o = 0
t = torch_utils.time_synchronized()
for _ in range(10):
_ = m(x)
@@ -208,7 +210,7 @@ if __name__ == '__main__':
parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
opt = parser.parse_args()
opt.cfg = glob.glob('./**/' + opt.cfg, recursive=True)[0] # find file
opt.cfg = check_file(opt.cfg) # check file
device = torch_utils.select_device(opt.device)

# Create model
@@ -218,11 +220,10 @@ if __name__ == '__main__':
# Profile
# img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
# y = model(img, profile=True)
# print([y[0].shape] + [x.shape for x in y[1]])

# ONNX export
# model.model[-1].export = True
# torch.onnx.export(model, img, f.replace('.yaml', '.onnx'), verbose=True, opset_version=11)
# torch.onnx.export(model, img, opt.cfg.replace('.yaml', '.onnx'), verbose=True, opset_version=11)

# Tensorboard
# from torch.utils.tensorboard import SummaryWriter

+ 28
- 21
models/yolov5l.yaml View File

@@ -5,41 +5,48 @@ width_multiple: 1.0 # layer channel multiple

# anchors
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
- [30,61, 62,45, 59,119] # P4/16
- [10,13, 16,30, 33,23] # P3/8

# yolov5 backbone
# YOLOv5 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Focus, [64, 3]], # 1-P1/2
[-1, 1, Conv, [128, 3, 2]], # 2-P2/4
[-1, 3, Bottleneck, [128]],
[-1, 1, Conv, [256, 3, 2]], # 4-P3/8
[[-1, 1, Focus, [64, 3]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, BottleneckCSP, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 9, BottleneckCSP, [256]],
[-1, 1, Conv, [512, 3, 2]], # 6-P4/16
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, BottleneckCSP, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 8-P5/32
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]],
[-1, 6, BottleneckCSP, [1024]], # 10
]

# yolov5 head
# YOLOv5 head
head:
[[-1, 3, BottleneckCSP, [1024, False]], # 11
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 12 (P5/32-large)
[[-1, 3, BottleneckCSP, [1024, False]], # 9

[-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 1, Conv, [512, 1, 1]],
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 17 (P4/16-medium)
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, BottleneckCSP, [512, False]], # 13

[-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, BottleneckCSP, [256, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P3/8-small)
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 18 (P3/8-small)

[-2, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P4/16-medium)

[-2, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, BottleneckCSP, [1024, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 26 (P5/32-large)

[[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
[[], 1, Detect, [nc, anchors]], # Detect(P5, P4, P3)
]

+ 28
- 21
models/yolov5m.yaml View File

@@ -5,41 +5,48 @@ width_multiple: 0.75 # layer channel multiple

# anchors
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
- [30,61, 62,45, 59,119] # P4/16
- [10,13, 16,30, 33,23] # P3/8

# yolov5 backbone
# YOLOv5 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Focus, [64, 3]], # 1-P1/2
[-1, 1, Conv, [128, 3, 2]], # 2-P2/4
[-1, 3, Bottleneck, [128]],
[-1, 1, Conv, [256, 3, 2]], # 4-P3/8
[[-1, 1, Focus, [64, 3]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, BottleneckCSP, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 9, BottleneckCSP, [256]],
[-1, 1, Conv, [512, 3, 2]], # 6-P4/16
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, BottleneckCSP, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 8-P5/32
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]],
[-1, 6, BottleneckCSP, [1024]], # 10
]

# yolov5 head
# YOLOv5 head
head:
[[-1, 3, BottleneckCSP, [1024, False]], # 11
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 12 (P5/32-large)
[[-1, 3, BottleneckCSP, [1024, False]], # 9

[-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 1, Conv, [512, 1, 1]],
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 17 (P4/16-medium)
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, BottleneckCSP, [512, False]], # 13

[-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, BottleneckCSP, [256, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P3/8-small)
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 18 (P3/8-small)

[-2, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P4/16-medium)

[-2, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, BottleneckCSP, [1024, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 26 (P5/32-large)

[[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
[[], 1, Detect, [nc, anchors]], # Detect(P5, P4, P3)
]

+ 28
- 21
models/yolov5s.yaml View File

@@ -5,41 +5,48 @@ width_multiple: 0.50 # layer channel multiple

# anchors
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
- [30,61, 62,45, 59,119] # P4/16
- [10,13, 16,30, 33,23] # P3/8

# yolov5 backbone
# YOLOv5 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Focus, [64, 3]], # 1-P1/2
[-1, 1, Conv, [128, 3, 2]], # 2-P2/4
[-1, 3, Bottleneck, [128]],
[-1, 1, Conv, [256, 3, 2]], # 4-P3/8
[[-1, 1, Focus, [64, 3]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, BottleneckCSP, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 9, BottleneckCSP, [256]],
[-1, 1, Conv, [512, 3, 2]], # 6-P4/16
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, BottleneckCSP, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 8-P5/32
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]],
[-1, 6, BottleneckCSP, [1024]], # 10
]

# yolov5 head
# YOLOv5 head
head:
[[-1, 3, BottleneckCSP, [1024, False]], # 11
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 12 (P5/32-large)
[[-1, 3, BottleneckCSP, [1024, False]], # 9

[-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 1, Conv, [512, 1, 1]],
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 17 (P4/16-medium)
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, BottleneckCSP, [512, False]], # 13

[-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, BottleneckCSP, [256, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P3/8-small)
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 18 (P3/8-small)

[-2, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P4/16-medium)

[-2, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, BottleneckCSP, [1024, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 26 (P5/32-large)

[[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
[[], 1, Detect, [nc, anchors]], # Detect(P5, P4, P3)
]

+ 28
- 21
models/yolov5x.yaml View File

@@ -5,41 +5,48 @@ width_multiple: 1.25 # layer channel multiple

# anchors
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
- [30,61, 62,45, 59,119] # P4/16
- [10,13, 16,30, 33,23] # P3/8

# yolov5 backbone
# YOLOv5 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Focus, [64, 3]], # 1-P1/2
[-1, 1, Conv, [128, 3, 2]], # 2-P2/4
[-1, 3, Bottleneck, [128]],
[-1, 1, Conv, [256, 3, 2]], # 4-P3/8
[[-1, 1, Focus, [64, 3]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, BottleneckCSP, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 9, BottleneckCSP, [256]],
[-1, 1, Conv, [512, 3, 2]], # 6-P4/16
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, BottleneckCSP, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 8-P5/32
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]],
[-1, 6, BottleneckCSP, [1024]], # 10
]

# yolov5 head
# YOLOv5 head
head:
[[-1, 3, BottleneckCSP, [1024, False]], # 11
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 12 (P5/32-large)
[[-1, 3, BottleneckCSP, [1024, False]], # 9

[-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 1, Conv, [512, 1, 1]],
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 17 (P4/16-medium)
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, BottleneckCSP, [512, False]], # 13

[-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, BottleneckCSP, [256, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P3/8-small)
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 18 (P3/8-small)

[-2, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P4/16-medium)

[-2, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, BottleneckCSP, [1024, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 26 (P5/32-large)

[[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
[[], 1, Detect, [nc, anchors]], # Detect(P5, P4, P3)
]

+ 2
- 2
requirements.txt View File

@@ -2,7 +2,7 @@
Cython
numpy==1.17
opencv-python
torch>=1.5
torch>=1.4
matplotlib
pillow
tensorboard
@@ -21,4 +21,4 @@ git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI
# conda install -yc conda-forge scikit-image pycocotools tensorboard
# conda install -yc spyder-ide spyder-line-profiler
# conda install -yc pytorch pytorch torchvision
# conda install -yc conda-forge protobuf numpy && pip install onnx # https://github.com/onnx/onnx#linux-and-macos
# conda install -yc conda-forge protobuf numpy && pip install onnx==1.6.0 # https://github.com/onnx/onnx#linux-and-macos

+ 24
- 33
test.py View File

@@ -1,9 +1,7 @@
import argparse
import json

import yaml
from torch.utils.data import DataLoader

from utils import google_utils
from utils.datasets import *
from utils.utils import *

@@ -17,16 +15,18 @@ def test(data,
save_json=False,
single_cls=False,
augment=False,
verbose=False,
model=None,
dataloader=None,
fast=False,
verbose=False,
save_dir='.'):
save_dir='.',
merge=False):

# Initialize/load model and set device
if model is None:
training = False
device = torch_utils.select_device(opt.device, batch_size=batch_size)
half = device.type != 'cpu' # half precision only supported on CUDA

# Remove previous
for f in glob.glob(f'{save_dir}/test_batch*.jpg'):
@@ -38,18 +38,19 @@ def test(data,
torch_utils.model_info(model)
model.fuse()
model.to(device)
if half:
model.half() # to FP16

if device.type != 'cpu' and torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
# Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99
# if device.type != 'cpu' and torch.cuda.device_count() > 1:
# model = nn.DataParallel(model)

else: # called by train.py
training = True
device = next(model.parameters()).device # get model device
half = device.type != 'cpu' # half precision only supported on CUDA
if half:
model.half() # to FP16

# Half
half = device.type != 'cpu' and torch.cuda.device_count() == 1 # half precision only supported on single-GPU
if half:
model.half() # to FP16

# Configure
model.eval()
@@ -57,29 +58,16 @@ def test(data,
data = yaml.load(f, Loader=yaml.FullLoader) # model dict
nc = 1 if single_cls else int(data['nc']) # number of classes
iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for mAP@0.5:0.95
# iouv = iouv[0].view(1) # comment for mAP@0.5:0.95
niou = iouv.numel()

# Dataloader
if dataloader is None: # not training
merge = opt.merge # use Merge NMS
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once

fast |= conf_thres > 0.001 # enable fast mode
path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images
dataset = LoadImagesAndLabels(path,
imgsz,
batch_size,
rect=True, # rectangular inference
single_cls=opt.single_cls, # single class mode
pad=0.5) # padding
batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
dataloader = DataLoader(dataset,
batch_size=batch_size,
num_workers=nw,
pin_memory=True,
collate_fn=dataset.collate_fn)
dataloader = create_dataloader(path, imgsz, batch_size, int(max(model.stride)), opt,
hyp=None, augment=False, cache=False, pad=0.5, rect=True)[0]

seen = 0
names = model.names if hasattr(model, 'names') else model.module.names
@@ -109,7 +97,7 @@ def test(data,

# Run NMS
t = torch_utils.time_synchronized()
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, fast=fast)
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, merge=merge)
t1 += torch_utils.time_synchronized() - t

# Statistics per image
@@ -235,6 +223,7 @@ def test(data,
'See https://github.com/cocodataset/cocoapi/issues/356')

# Return results
model.float() # for training
maps = np.zeros(nc) + map
for i, c in enumerate(ap_class):
maps[c] = ap[i]
@@ -244,7 +233,7 @@ def test(data,
if __name__ == '__main__':
parser = argparse.ArgumentParser(prog='test.py')
parser.add_argument('--weights', type=str, default='weights/yolov5s.pt', help='model.pt path')
parser.add_argument('--data', type=str, default='data/coco.yaml', help='*.data path')
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path')
parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch')
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold')
@@ -254,6 +243,7 @@ if __name__ == '__main__':
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--merge', action='store_true', help='use Merge NMS')
parser.add_argument('--verbose', action='store_true', help='report mAP by class')
opt = parser.parse_args()
opt.img_size = check_img_size(opt.img_size)
@@ -271,12 +261,13 @@ if __name__ == '__main__':
opt.iou_thres,
opt.save_json,
opt.single_cls,
opt.augment)
opt.augment,
opt.verbose)

elif opt.task == 'study': # run over a range of settings and save/plot
for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt', 'yolov3-spp.pt']:
f = 'study_%s_%s.txt' % (Path(opt.data).stem, Path(weights).stem) # filename to save to
x = list(range(288, 896, 64)) # x axis
x = list(range(352, 832, 64)) # x axis
y = [] # y axis
for i in x: # img-size
print('\nRunning %s point %s...' % (f, i))

+ 24
- 36
train.py View File

@@ -4,10 +4,12 @@ import torch.distributed as dist
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter

import test # import test.py to get mAP after each epoch
from models.yolo import Model
from utils import google_utils
from utils.datasets import *
from utils.utils import *

@@ -72,6 +74,7 @@ def train(hyp):
# Create model
model = Model(opt.cfg).to(device)
assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
model.names = data_dict['names']

# Image sizes
gs = int(max(model.stride)) # grid size (max stride)
@@ -148,37 +151,17 @@ def train(hyp):
world_size=1, # number of nodes
rank=0) # node rank
model = torch.nn.parallel.DistributedDataParallel(model)
# pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html

# Dataset
dataset = LoadImagesAndLabels(train_path, imgsz, batch_size,
augment=True,
hyp=hyp, # augmentation hyperparameters
rect=opt.rect, # rectangular training
cache_images=opt.cache_images,
single_cls=opt.single_cls)
# Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect)
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)

# Dataloader
batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
num_workers=nw,
shuffle=not opt.rect, # Shuffle=True unless rectangular training is used
pin_memory=True,
collate_fn=dataset.collate_fn)

# Testloader
testloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path, imgsz_test, batch_size,
hyp=hyp,
rect=True,
cache_images=opt.cache_images,
single_cls=opt.single_cls),
batch_size=batch_size,
num_workers=nw,
pin_memory=True,
collate_fn=dataset.collate_fn)
testloader = create_dataloader(test_path, imgsz_test, batch_size, gs, opt,
hyp=hyp, augment=False, cache=opt.cache_images, rect=True)[0]

# Model parameters
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
@@ -186,7 +169,6 @@ def train(hyp):
model.hyp = hyp # attach hyperparameters to model
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
model.names = data_dict['names']

#save hyperparamter and training options in run folder
with open(os.path.join(log_dir, 'hyp.yaml'), 'w') as f:
@@ -200,11 +182,17 @@ def train(hyp):
c = torch.tensor(labels[:, 0]) # classes
# cf = torch.bincount(c.long(), minlength=nc) + 1.
# model._initialize_biases(cf.to(device))

#always plot labels to log_dir
plot_labels(labels, save_dir=log_dir)
tb_writer.add_histogram('classes', c, 0)

if tb_writer:
tb_writer.add_histogram('classes', c, 0)


# Check anchors
check_best_possible_recall(dataset, anchors=model.model[-1].anchor_grid, thr=hyp['anchor_t'], imgsz=imgsz)
if not opt.noautoanchor:
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)

# Exponential moving average
ema = torch_utils.ModelEMA(model)
@@ -216,7 +204,7 @@ def train(hyp):
maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
print('Using %g dataloader workers' % nw)
print('Using %g dataloader workers' % dataloader.num_workers)
print('Starting training for %g epochs...' % epochs)
# torch.autograd.set_detect_anomaly(True)
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
@@ -285,10 +273,10 @@ def train(hyp):

# Plot
if ni < 3:
f = os.path.join(log_dir, 'train_batch%g.jpg' % i) # filename
res = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
if tb_writer:
tb_writer.add_image(f, res, dataformats='HWC', global_step=epoch)
f = os.path.join(log_dir, 'train_batch%g.jpg' % ni) # filename
result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
if tb_writer and result is not None:
tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
# tb_writer.add_graph(model, imgs) # add model to tensorboard

# end batch ------------------------------------------------------------------------------------------------
@@ -307,7 +295,6 @@ def train(hyp):
model=ema.ema,
single_cls=opt.single_cls,
dataloader=testloader,
fast=epoch < epochs / 2,
save_dir=log_dir)

# Write
@@ -362,7 +349,7 @@ def train(hyp):
if not opt.evolve:
plot_results(save_dir = log_dir) # save as results.png
print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
dist.destroy_process_group() if torch.cuda.device_count() > 1 else None
dist.destroy_process_group() if device.type != 'cpu' and torch.cuda.device_count() > 1 else None
torch.cuda.empty_cache()
return results

@@ -379,6 +366,7 @@ if __name__ == '__main__':
parser.add_argument('--rect', action='store_true', help='rectangular training')
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
parser.add_argument('--notest', action='store_true', help='only test final epoch')
parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')

+ 26
- 5
utils/datasets.py View File

@@ -18,7 +18,7 @@ from utils.utils import xyxy2xywh, xywh2xyxy

help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.dng']
vid_formats = ['.mov', '.avi', '.mp4']
vid_formats = ['.mov', '.avi', '.mp4', '.mpg', '.mpeg', '.m4v', '.wmv', '.mkv']

# Get orientation exif tag
for orientation in ExifTags.TAGS.keys():
@@ -41,6 +41,26 @@ def exif_size(img):
return s


def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False):
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
augment=augment, # augment images
hyp=hyp, # augmentation hyperparameters
rect=rect, # rectangular training
cache_images=cache,
single_cls=opt.single_cls,
stride=stride,
pad=pad)

batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
num_workers=nw,
pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn)
return dataloader, dataset


class LoadImages: # for inference
def __init__(self, path, img_size=416):
path = str(Path(path)) # os-agnostic
@@ -63,7 +83,8 @@ class LoadImages: # for inference
self.new_video(videos[0]) # new video
else:
self.cap = None
assert self.nF > 0, 'No images or videos found in ' + path
assert self.nF > 0, 'No images or videos found in %s. Supported formats are:\nimages: %s\nvideos: %s' % \
(path, img_formats, vid_formats)

def __iter__(self):
self.count = 0
@@ -257,7 +278,7 @@ class LoadStreams: # multiple IP or RTSP cameras

class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
cache_images=False, single_cls=False, pad=0.0):
cache_images=False, single_cls=False, stride=32, pad=0.0):
try:
path = str(Path(path)) # os-agnostic
parent = str(Path(path).parent) + os.sep
@@ -324,7 +345,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
elif mini > 1:
shapes[i] = [1, 1 / mini]

self.batch_shapes = np.ceil(np.array(shapes) * img_size / 32. + pad).astype(np.int) * 32
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride

# Cache labels
self.imgs = [None] * n
@@ -711,7 +732,7 @@ def random_affine(img, targets=(), degrees=10, translate=.1, scale=.1, shear=10,
area = w * h
area0 = (targets[:, 3] - targets[:, 1]) * (targets[:, 4] - targets[:, 2])
ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16)) # aspect ratio
i = (w > 4) & (h > 4) & (area / (area0 * s + 1e-16) > 0.2) & (ar < 10)
i = (w > 2) & (h > 2) & (area / (area0 * s + 1e-16) > 0.2) & (ar < 20)

targets = targets[i]
targets[:, 1:5] = xy[i]

+ 12
- 7
utils/torch_utils.py View File

@@ -7,6 +7,7 @@ import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models


def init_seeds(seed=0):
@@ -120,18 +121,22 @@ def model_info(model, verbose=False):

def load_classifier(name='resnet101', n=2):
# Loads a pretrained model reshaped to n-class output
import pretrainedmodels # https://github.com/Cadene/pretrained-models.pytorch#torchvision
model = pretrainedmodels.__dict__[name](num_classes=1000, pretrained='imagenet')
model = models.__dict__[name](pretrained=True)

# Display model properties
for x in ['model.input_size', 'model.input_space', 'model.input_range', 'model.mean', 'model.std']:
input_size = [3, 224, 224]
input_space = 'RGB'
input_range = [0, 1]
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
for x in [input_size, input_space, input_range, mean, std]:
print(x + ' =', eval(x))

# Reshape output to n classes
filters = model.last_linear.weight.shape[1]
model.last_linear.bias = torch.nn.Parameter(torch.zeros(n))
model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters))
model.last_linear.out_features = n
filters = model.fc.weight.shape[1]
model.fc.bias = torch.nn.Parameter(torch.zeros(n), requires_grad=True)
model.fc.weight = torch.nn.Parameter(torch.zeros(n, filters), requires_grad=True)
model.fc.out_features = n
return model



+ 104
- 79
utils/utils.py View File

@@ -20,7 +20,7 @@ import yaml
from scipy.signal import butter, filtfilt
from tqdm import tqdm

from . import torch_utils, google_utils #  torch_utils, google_utils
from . import torch_utils #  torch_utils, google_utils

# Set printoptions
torch.set_printoptions(linewidth=320, precision=5, profile='long')
@@ -53,24 +53,52 @@ def check_git_status():

def check_img_size(img_size, s=32):
# Verify img_size is a multiple of stride s
if img_size % s != 0:
print('WARNING: --img-size %g must be multiple of max stride %g' % (img_size, s))
return make_divisible(img_size, s) # nearest gs-multiple
new_size = make_divisible(img_size, s) # ceil gs-multiple
if new_size != img_size:
print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
return new_size


def check_best_possible_recall(dataset, anchors, thr=4.0, imgsz=640):
# Check best possible recall of dataset with current anchors
def check_anchors(dataset, model, thr=4.0, imgsz=640):
# Check anchor fit to data, recompute if necessary
print('\nAnalyzing anchors... ', end='')
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh
ratio = wh[:, None] / anchors.view(-1, 2).cpu()[None] # ratio
m = torch.max(ratio, 1. / ratio).max(2)[0] # max ratio
bpr = (m.min(1)[0] < thr).float().mean() # best possible recall
mr = (m < thr).float().mean() # match ratio
print(('AutoAnchor labels:' + '%10s' * 6) % ('n', 'mean', 'min', 'max', 'matching', 'recall'))
print((' ' + '%10.4g' * 6) % (wh.shape[0], wh.mean(), wh.min(), wh.max(), mr, bpr))
scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh

def metric(k): # compute metric
r = wh[:, None] / k[None]
x = torch.min(r, 1. / r).min(2)[0] # ratio metric
best = x.max(1)[0] # best_x
return (best > 1. / thr).float().mean() #  best possible recall

bpr = metric(m.anchor_grid.clone().cpu().view(-1, 2))
print('Best Possible Recall (BPR) = %.4f' % bpr, end='')
if bpr < 0.99: # threshold to recompute
print('. Attempting to generate improved anchors, please wait...' % bpr)
na = m.anchor_grid.numel() // 2 # number of anchors
new_anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
new_bpr = metric(new_anchors.reshape(-1, 2))
if new_bpr > bpr: # replace anchors
new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors)
m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
check_anchor_order(m)
print('New anchors saved to model. Update model *.yaml to use these anchors in the future.')
else:
print('Original anchors better than new anchors. Proceeding with original anchors.')
print('') # newline


assert bpr > 0.9, 'Best possible recall %.3g (BPR) below 0.9 threshold. Training cancelled. ' \
'Compute new anchors with utils.utils.kmeans_anchors() and update model before training.' % bpr
def check_anchor_order(m):
# Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
a = m.anchor_grid.prod(-1).view(-1) # anchor area
da = a[-1] - a[0] # delta a
ds = m.stride[-1] - m.stride[0] # delta s
if da.sign() != ds.sign(): # same order
m.anchors[:] = m.anchors.flip(0)
m.anchor_grid[:] = m.anchor_grid.flip(0)


def check_file(file):
@@ -517,11 +545,11 @@ def build_targets(p, targets, model):
return tcls, tbox, indices, anch


def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, classes=None, agnostic=False):
"""
Performs Non-Maximum Suppression on inference results
Returns detections with shape:
nx6 (x1, y1, x2, y2, conf, cls)
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, classes=None, agnostic=False):
"""Performs Non-Maximum Suppression (NMS) on inference results
Returns:
detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
"""
if prediction.dtype is torch.float16:
prediction = prediction.float() # to FP32
@@ -534,12 +562,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c
max_det = 300 # maximum number of detections per image
time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections
fast |= conf_thres > 0.001 # fast mode
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
if fast:
merge = False
else:
merge = True # merge for best mAP (adds 0.5ms/img)

t = time.time()
output = [None] * prediction.shape[0]
@@ -610,24 +633,24 @@ def strip_optimizer(f='weights/best.pt'): # from utils.utils import *; strip_op
# Strip optimizer from *.pt files for lighter files (reduced by 1/2 size)
x = torch.load(f, map_location=torch.device('cpu'))
x['optimizer'] = None
x['model'].half() # to FP16
torch.save(x, f)
print('Optimizer stripped from %s' % f)


def create_backbone(f='weights/best.pt', s='weights/backbone.pt'): # from utils.utils import *; create_backbone()
# create backbone 's' from 'f'
def create_pretrained(f='weights/best.pt', s='weights/pretrained.pt'): # from utils.utils import *; create_pretrained()
# create pretrained checkpoint 's' from 'f' (create_pretrained(x, x) for x in glob.glob('./*.pt'))
device = torch.device('cpu')
x = torch.load(f, map_location=device)
torch.save(x, s) # update model if SourceChangeWarning
x = torch.load(s, map_location=device)

x['optimizer'] = None
x['training_results'] = None
x['epoch'] = -1
x['model'].half() # to FP16
for p in x['model'].parameters():
p.requires_grad = True
torch.save(x, s)
print('%s modified for backbone use and saved as %s' % (f, s))
print('%s saved as pretrained checkpoint %s' % (f, s))


def coco_class_count(path='../coco/labels/train2014/'):
@@ -695,14 +718,14 @@ def coco_single_class_labels(path='../coco/labels/train2014/', label_class=43):
shutil.copyfile(src=img_file, dst='new/images/' + Path(file).name.replace('txt', 'jpg')) # copy images


def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=(640, 640), thr=0.20, gen=1000):
def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
""" Creates kmeans-evolved anchors from training dataset

Arguments:
path: path to dataset *.yaml
path: path to dataset *.yaml, or a loaded dataset
n: number of anchors
img_size: (min, max) image size used for multi-scale training (can be same values)
thr: IoU threshold hyperparameter used for training (0.0 - 1.0)
img_size: image size used for training
thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
gen: generations to evolve anchors using genetic algorithm

Return:
@@ -711,52 +734,47 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=(640, 640), thr=0.20
Usage:
from utils.utils import *; _ = kmean_anchors()
"""
thr = 1. / thr

def metric(k, wh): # compute metrics
r = wh[:, None] / k[None]
x = torch.min(r, 1. / r).min(2)[0] # ratio metric
# x = wh_iou(wh, torch.tensor(k)) # iou metric
return x, x.max(1)[0] # x, best_x

from utils.datasets import LoadImagesAndLabels
def fitness(k): # mutation fitness
_, best = metric(torch.tensor(k, dtype=torch.float32), wh)
return (best * (best > thr).float()).mean() # fitness

def print_results(k):
k = k[np.argsort(k.prod(1))] # sort small to large
iou = wh_iou(wh, torch.Tensor(k))
max_iou = iou.max(1)[0]
bpr, aat = (max_iou > thr).float().mean(), (iou > thr).float().mean() * n # best possible recall, anch > thr

# thr = 5.0
# r = wh[:, None] / k[None]
# ar = torch.max(r, 1. / r).max(2)[0]
# max_ar = ar.min(1)[0]
# bpr, aat = (max_ar < thr).float().mean(), (ar < thr).float().mean() * n # best possible recall, anch > thr

print('%.2f iou_thr: %.3f best possible recall, %.2f anchors > thr' % (thr, bpr, aat))
print('n=%g, img_size=%s, IoU_all=%.3f/%.3f-mean/best, IoU>thr=%.3f-mean: ' %
(n, img_size, iou.mean(), max_iou.mean(), iou[iou > thr].mean()), end='')
x, best = metric(k, wh0)
bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
print('thr=%.2f: %.4f best possible recall, %.2f anchors past thr' % (thr, bpr, aat))
print('n=%g, img_size=%s, metric_all=%.3f/%.3f-mean/best, past_thr=%.3f-mean: ' %
(n, img_size, x.mean(), best.mean(), x[x > thr].mean()), end='')
for i, x in enumerate(k):
print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg
return k

def fitness(k): # mutation fitness
iou = wh_iou(wh, torch.Tensor(k)) # iou
max_iou = iou.max(1)[0]
return (max_iou * (max_iou > thr).float()).mean() # product

# def fitness_ratio(k): # mutation fitness
# # wh(5316,2), k(9,2)
# r = wh[:, None] / k[None]
# x = torch.max(r, 1. / r).max(2)[0]
# m = x.min(1)[0]
# return 1. / (m * (m < 5).float()).mean() # product
if isinstance(path, str): # *.yaml file
with open(path) as f:
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
from utils.datasets import LoadImagesAndLabels
dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
else:
dataset = path # dataset

# Get label wh
wh = []
with open(path) as f:
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
nr = 1 if img_size[0] == img_size[1] else 3 # number augmentation repetitions
for s, l in zip(dataset.shapes, dataset.labels):
# wh.append(l[:, 3:5] * (s / s.max())) # image normalized to letterbox normalized wh
wh.append(l[:, 3:5] * s) # image normalized to pixels
wh = np.concatenate(wh, 0).repeat(nr, axis=0) # augment 3x
# wh *= np.random.uniform(img_size[0], img_size[1], size=(wh.shape[0], 1)) # normalized to pixels (multi-scale)
wh = wh[(wh > 2.0).all(1)] # remove below threshold boxes (< 2 pixels wh)
shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh

# Filter
i = (wh0 < 4.0).any(1).sum()
if i:
print('WARNING: Extremely small objects found. '
'%g of %g labels are < 4 pixels in width or height.' % (i, len(wh0)))
wh = wh0[(wh0 >= 4.0).any(1)] # filter > 2 pixels

# Kmeans calculation
from scipy.cluster.vq import kmeans
@@ -764,10 +782,11 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=(640, 640), thr=0.20
s = wh.std(0) # sigmas for whitening
k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
k *= s
wh = torch.Tensor(wh)
wh = torch.tensor(wh, dtype=torch.float32) # filtered
wh0 = torch.tensor(wh0, dtype=torch.float32) # unflitered
k = print_results(k)

# # Plot
# Plot
# k, d = [None] * 20, [None] * 20
# for i in tqdm(range(1, 21)):
# k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
@@ -783,7 +802,8 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=(640, 640), thr=0.20
# Evolve
npr = np.random
f, sh, mp, s = fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
for _ in tqdm(range(gen), desc='Evolving anchors'):
pbar = tqdm(range(gen), desc='Evolving anchors with Genetic Algorithm') # progress bar
for _ in pbar:
v = np.ones(sh)
while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
@@ -791,9 +811,11 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=(640, 640), thr=0.20
fg = fitness(kg)
if fg > f:
f, k = fg, kg.copy()
print_results(k)
k = print_results(k)
return k
pbar.desc = 'Evolving anchors with Genetic Algorithm: fitness = %.4f' % f
if verbose:
print_results(k)

return print_results(k)


def print_mutation(hyp, results, bucket=''):
@@ -1078,12 +1100,14 @@ def plot_study_txt(f='study.txt', x=None): # from utils.utils import *; plot_st

ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [33.5, 39.1, 42.5, 45.9, 49., 50.5],
'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')

ax2.grid()
ax2.set_xlim(0, 30)
ax2.set_ylim(25, 50)
ax2.set_xlabel('GPU Latency (ms)')
ax2.set_ylim(28, 50)
ax2.set_yticks(np.arange(30, 55, 5))
ax2.set_xlabel('GPU Speed (ms/img)')
ax2.set_ylabel('COCO AP val')
ax2.legend(loc='lower right')
ax2.grid()
plt.savefig('study_mAP_latency.png', dpi=300)
plt.savefig(f.replace('.txt', '.png'), dpi=200)

@@ -1110,6 +1134,7 @@ def plot_labels(labels, save_dir= '.'):
ax[2].set_xlabel('width')
ax[2].set_ylabel('height')
plt.savefig(os.path.join(save_dir,'labels.png'), dpi=200)
plt.close()


def plot_evolution_results(hyp): # from utils.utils import *; plot_evolution_results(hyp)

Loading…
Cancel
Save