Browse Source

new CSP model release

5.0
Glenn Jocher 4 years ago
parent
commit
c14368d768
8 changed files with 54 additions and 53 deletions
  1. +13
    -13
      README.md
  2. +0
    -2
      detect.py
  3. +8
    -5
      models/yolo.py
  4. +6
    -6
      models/yolov5l.yaml
  5. +6
    -6
      models/yolov5m.yaml
  6. +6
    -6
      models/yolov5s.yaml
  7. +6
    -6
      models/yolov5x.yaml
  8. +9
    -9
      utils/utils.py

+ 13
- 13
README.md View File



This repository represents Ultralytics open-source research into future object detection methods, and incorporates our lessons learned and best practices evolved over training thousands of models on custom client datasets with our previous YOLO repository https://github.com/ultralytics/yolov3. **All code and models are under active development, and are subject to modification or deletion without notice.** Use at your own risk. This repository represents Ultralytics open-source research into future object detection methods, and incorporates our lessons learned and best practices evolved over training thousands of models on custom client datasets with our previous YOLO repository https://github.com/ultralytics/yolov3. **All code and models are under active development, and are subject to modification or deletion without notice.** Use at your own risk.


<img src="https://user-images.githubusercontent.com/26833433/83359175-63b6c680-a32d-11ea-970a-9f602e022468.png" width="1000">** GPU Latency measures end-to-end latency per image averaged over 5000 COCO val2017 images using a V100 GPU with batch size 16, and includes image preprocessing, FP32 inference, postprocessing and NMS.
<img src="https://user-images.githubusercontent.com/26833433/84200349-729f2680-aa5b-11ea-8f9a-604c9e01a658.png" width="1000">** GPU Latency measures end-to-end latency per image averaged over 5000 COCO val2017 images using a V100 GPU with batch size 32, and includes image preprocessing, FP32 inference, postprocessing and NMS.


- **May 27, 2020**: Public release of repo. yolov3-spp implementation (this repo) is SOTA at 45.5 mAP among all known yolo implementations, yolov5 family will be undergoing architecture research and development over Q2/Q3 2020 to increase performance. Updates may include [CSP](https://github.com/WongKinYiu/CrossStagePartialNetworks) bottlenecks from [yolov4](https://github.com/AlexeyAB/darknet), as well as PANet or BiFPN head features.
- **May 24, 2020**: Training yolov5s/x and yolov3-spp. yolov5m/l suffered early overfitting and also code 137 early docker terminations, cause unknown. yolov5l underperforms yolov3-spp due to earlier overfitting, cause unknown.
- **April 1, 2020**: Begin development of a 100% pytorch scaleable yolov3/4-based group of future models, in small, medium, large and extra large sizes, collectively known as yolov5. Models will be defined by new user-friendly yaml-based configuration files for ease of construction and modification. Datasets will likewise use yaml configuration files. New training platform will be simpler use, harder to break, and more robust to training a wider variety of custom dataset.
- **June 9, 2020**: [CSP](https://github.com/WongKinYiu/CrossStagePartialNetworks) updates to all YOLOv5 models. New models are faster, smaller and more accurate. Credit to @WongKinYiu for his excellent work with CSP.
- **May 27, 2020**: Public release of repo. YOLOv5 models are SOTA among all known YOLO implementations, YOLOv5 family will be undergoing architecture research and development over Q2/Q3 2020 to increase performance. Updates may include [CSP](https://github.com/WongKinYiu/CrossStagePartialNetworks) bottlenecks, [YOLOv4](https://github.com/AlexeyAB/darknet) features, as well as PANet or BiFPN heads.
- **April 1, 2020**: Begin development of a 100% PyTorch, scaleable YOLOv3/4-based group of future models, in a range of compound-scaled sizes, collectively known as YOLOv5. Models will be defined by new user-friendly *.yaml files. New training platform will be simpler use, harder to break, and more robust to training a wider variety of custom dataset.




## Ultralytics Professional Support ## Ultralytics Professional Support


| Model | AP<sup>val</sup> | AP<sup>test</sup> | AP<sub>50</sub> | Latency<sub>GPU</sub> | FPS<sub>GPU</sub> || params | FLOPs | | Model | AP<sup>val</sup> | AP<sup>test</sup> | AP<sub>50</sub> | Latency<sub>GPU</sub> | FPS<sub>GPU</sub> || params | FLOPs |
|---------- |------ |------ |------ | -------- | ------| ------ |------ | :------: | |---------- |------ |------ |------ | -------- | ------| ------ |------ | :------: |
| YOLOv5-s ([ckpt](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)) | 33.0 | 33.0 | 53.2 | **2.9ms** | **345** || 7.0M | 14.0B
| YOLOv5-m ([ckpt](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)) | 41.4 | 41.4 | 61.5 | 5.0ms | 200 || 25.2M | 50.2B
| YOLOv5-l ([ckpt](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)) | 44.3 | 44.5 | 64.3 | 8.9ms | 112 || 61.8M | 123.1B
| YOLOv5-x ([ckpt](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)) | **47.1** | **47.2** | **66.7** | 15.2ms | 66 || 123.1M | 245.7B
| YOLOv3-SPP ([ckpt](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)) | 45.6 | 45.5 | 65.2 | 8.3ms | 120 || 63.0M | 118.0B
| YOLOv5-s ([ckpt](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)) | 35.5 | 35.5 | 55.0 | **2.5ms** | **400** || 7.1M | 12.6B
| YOLOv5-m ([ckpt](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)) | 42.7 | 42.7 | 62.4 | 4.4ms | 227 || 22.0M | 39.0B
| YOLOv5-l ([ckpt](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)) | 45.7 | 45.9 | 65.1 | 6.8ms | 147 || 50.3M | 89.0B
| YOLOv5-x ([ckpt](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)) | **47.2** | **47.3** | **66.6** | 11.7ms | 85 || 95.9M | 170.3B
| YOLOv3-SPP ([ckpt](https://drive.google.com/open?id=1Drs_Aiu7xx6S-ix95f9kNsA6ueKRpN2J)) | 45.6 | 45.5 | 65.2 | 7.9ms | 127 || 63.0M | 118.0B


** AP<sup>test</sup> denotes COCO [test-dev2017](http://cocodataset.org/#upload) server results, all other AP results in the table denote val2017 accuracy. ** AP<sup>test</sup> denotes COCO [test-dev2017](http://cocodataset.org/#upload) server results, all other AP results in the table denote val2017 accuracy.
** All AP numbers are for single-model single-scale without ensemble or test-time augmentation. Reproduce by `python test.py --img 736 --conf 0.001` ** All AP numbers are for single-model single-scale without ensemble or test-time augmentation. Reproduce by `python test.py --img 736 --conf 0.001`
** Latency<sub>GPU</sub> measures end-to-end latency per image averaged over 5000 COCO val2017 images using a GCP [n1-standard-16](https://cloud.google.com/compute/docs/machine-types#n1_standard_machine_types) instance with one V100 GPU, and includes image preprocessing, pytorch FP32 inference at batch size 16, postprocessing and NMS. Average NMS time included in this chart is 1-2ms/img. Reproduce by `python test.py --img 640 --conf 0.1`
** Latency<sub>GPU</sub> measures end-to-end latency per image averaged over 5000 COCO val2017 images using a GCP [n1-standard-16](https://cloud.google.com/compute/docs/machine-types#n1_standard_machine_types) instance with one V100 GPU, and includes image preprocessing, PyTorch FP32 inference at batch size 32, postprocessing and NMS. Average NMS time included in this chart is 1-2ms/img. Reproduce by `python test.py --img 640 --conf 0.1`
** All checkpoints are trained to 300 epochs with default settings and hyperparameters (no autoaugmentation). ** All checkpoints are trained to 300 epochs with default settings and hyperparameters (no autoaugmentation).






## Reproduce Our Training ## Reproduce Our Training


Run commands below. Training takes a few days for yolov5s, to a few weeks for yolov5x on a 2080Ti GPU.
Run command below. Training times for yolov5s/m/l/x are 2/4/6/8 days on a single V100 (multi-GPU times faster).
```bash ```bash
$ python train.py --data coco.yaml --cfg yolov5s.yaml --weights '' --batch-size 16
$ python train.py --data coco.yaml --cfg yolov5s.yaml --weights '' --batch-size 16
``` ```
<img src="https://user-images.githubusercontent.com/26833433/82960433-5a191180-9f6f-11ea-85cc-c49dbd1555e1.png" width="900">
<img src="https://user-images.githubusercontent.com/26833433/84186698-c4d54d00-aa45-11ea-9bde-c632c1230ccd.png" width="900">




## Reproduce Our Environment ## Reproduce Our Environment

+ 0
- 2
detect.py View File

from utils.datasets import * from utils.datasets import *
from utils.utils import * from utils.utils import *


ONNX_EXPORT = False



def detect(save_img=False): def detect(save_img=False):
out, source, weights, half, view_img, save_txt, imgsz = \ out, source, weights, half, view_img, save_txt, imgsz = \

+ 8
- 5
models/yolo.py View File





class Model(nn.Module): class Model(nn.Module):
def __init__(self, model_yaml='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
def __init__(self, model_cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
super(Model, self).__init__() super(Model, self).__init__()
with open(model_yaml) as f:
self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict
if nc:
self.md['nc'] = nc # override yaml value
if type(model_cfg) is dict:
self.md = model_cfg # model dict
else: # is *.yaml
with open(model_cfg) as f:
self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict


# Define model # Define model
if nc:
self.md['nc'] = nc # override yaml value
self.model, self.save, ch = parse_model(self.md, ch=[ch]) # model, savelist, ch_out self.model, self.save, ch = parse_model(self.md, ch=[ch]) # model, savelist, ch_out
# print([x.shape for x in self.forward(torch.zeros(1, 3, 64, 64))]) # print([x.shape for x in self.forward(torch.zeros(1, 3, 64, 64))])



+ 6
- 6
models/yolov5l.yaml View File

[-1, 1, Conv, [128, 3, 2]], # 2-P2/4 [-1, 1, Conv, [128, 3, 2]], # 2-P2/4
[-1, 3, Bottleneck, [128]], [-1, 3, Bottleneck, [128]],
[-1, 1, Conv, [256, 3, 2]], # 4-P3/8 [-1, 1, Conv, [256, 3, 2]], # 4-P3/8
[-1, 9, Bottleneck, [256]],
[-1, 9, BottleneckCSP, [256]],
[-1, 1, Conv, [512, 3, 2]], # 6-P4/16 [-1, 1, Conv, [512, 3, 2]], # 6-P4/16
[-1, 9, Bottleneck, [512]],
[-1, 9, BottleneckCSP, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 8-P5/32 [-1, 1, Conv, [1024, 3, 2]], # 8-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]], [-1, 1, SPP, [1024, [5, 9, 13]]],
[-1, 3, Bottleneck, [1024]], # 10
[-1, 6, BottleneckCSP, [1024]], # 10
] ]


# yolov5 head # yolov5 head
head: head:
[[-1, 3, Bottleneck, [1024, False]], # 11
[[-1, 3, BottleneckCSP, [1024, False]], # 11
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 12 (P5/32-large) [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 12 (P5/32-large)


[-2, 1, nn.Upsample, [None, 2, 'nearest']], [-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4 [[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 1, Conv, [512, 1, 1]], [-1, 1, Conv, [512, 1, 1]],
[-1, 3, Bottleneck, [512, False]],
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 17 (P4/16-medium) [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 17 (P4/16-medium)


[-2, 1, nn.Upsample, [None, 2, 'nearest']], [-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3 [[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 1, Conv, [256, 1, 1]], [-1, 1, Conv, [256, 1, 1]],
[-1, 3, Bottleneck, [256, False]],
[-1, 3, BottleneckCSP, [256, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 22 (P3/8-small) [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 22 (P3/8-small)


[[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) [[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)

+ 6
- 6
models/yolov5m.yaml View File

[-1, 1, Conv, [128, 3, 2]], # 2-P2/4 [-1, 1, Conv, [128, 3, 2]], # 2-P2/4
[-1, 3, Bottleneck, [128]], [-1, 3, Bottleneck, [128]],
[-1, 1, Conv, [256, 3, 2]], # 4-P3/8 [-1, 1, Conv, [256, 3, 2]], # 4-P3/8
[-1, 9, Bottleneck, [256]],
[-1, 9, BottleneckCSP, [256]],
[-1, 1, Conv, [512, 3, 2]], # 6-P4/16 [-1, 1, Conv, [512, 3, 2]], # 6-P4/16
[-1, 9, Bottleneck, [512]],
[-1, 9, BottleneckCSP, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 8-P5/32 [-1, 1, Conv, [1024, 3, 2]], # 8-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]], [-1, 1, SPP, [1024, [5, 9, 13]]],
[-1, 3, Bottleneck, [1024]], # 10
[-1, 6, BottleneckCSP, [1024]], # 10
] ]


# yolov5 head # yolov5 head
head: head:
[[-1, 3, Bottleneck, [1024, False]], # 11
[[-1, 3, BottleneckCSP, [1024, False]], # 11
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 12 (P5/32-large) [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 12 (P5/32-large)


[-2, 1, nn.Upsample, [None, 2, 'nearest']], [-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4 [[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 1, Conv, [512, 1, 1]], [-1, 1, Conv, [512, 1, 1]],
[-1, 3, Bottleneck, [512, False]],
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 17 (P4/16-medium) [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 17 (P4/16-medium)


[-2, 1, nn.Upsample, [None, 2, 'nearest']], [-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3 [[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 1, Conv, [256, 1, 1]], [-1, 1, Conv, [256, 1, 1]],
[-1, 3, Bottleneck, [256, False]],
[-1, 3, BottleneckCSP, [256, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 22 (P3/8-small) [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 22 (P3/8-small)


[[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) [[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)

+ 6
- 6
models/yolov5s.yaml View File

[-1, 1, Conv, [128, 3, 2]], # 2-P2/4 [-1, 1, Conv, [128, 3, 2]], # 2-P2/4
[-1, 3, Bottleneck, [128]], [-1, 3, Bottleneck, [128]],
[-1, 1, Conv, [256, 3, 2]], # 4-P3/8 [-1, 1, Conv, [256, 3, 2]], # 4-P3/8
[-1, 9, Bottleneck, [256]],
[-1, 9, BottleneckCSP, [256]],
[-1, 1, Conv, [512, 3, 2]], # 6-P4/16 [-1, 1, Conv, [512, 3, 2]], # 6-P4/16
[-1, 9, Bottleneck, [512]],
[-1, 9, BottleneckCSP, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 8-P5/32 [-1, 1, Conv, [1024, 3, 2]], # 8-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]], [-1, 1, SPP, [1024, [5, 9, 13]]],
[-1, 3, Bottleneck, [1024]], # 10
[-1, 6, BottleneckCSP, [1024]], # 10
] ]


# yolov5 head # yolov5 head
head: head:
[[-1, 3, Bottleneck, [1024, False]], # 11
[[-1, 3, BottleneckCSP, [1024, False]], # 11
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 12 (P5/32-large) [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 12 (P5/32-large)


[-2, 1, nn.Upsample, [None, 2, 'nearest']], [-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4 [[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 1, Conv, [512, 1, 1]], [-1, 1, Conv, [512, 1, 1]],
[-1, 3, Bottleneck, [512, False]],
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 17 (P4/16-medium) [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 17 (P4/16-medium)


[-2, 1, nn.Upsample, [None, 2, 'nearest']], [-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3 [[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 1, Conv, [256, 1, 1]], [-1, 1, Conv, [256, 1, 1]],
[-1, 3, Bottleneck, [256, False]],
[-1, 3, BottleneckCSP, [256, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 22 (P3/8-small) [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 22 (P3/8-small)


[[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) [[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)

+ 6
- 6
models/yolov5x.yaml View File

[-1, 1, Conv, [128, 3, 2]], # 2-P2/4 [-1, 1, Conv, [128, 3, 2]], # 2-P2/4
[-1, 3, Bottleneck, [128]], [-1, 3, Bottleneck, [128]],
[-1, 1, Conv, [256, 3, 2]], # 4-P3/8 [-1, 1, Conv, [256, 3, 2]], # 4-P3/8
[-1, 9, Bottleneck, [256]],
[-1, 9, BottleneckCSP, [256]],
[-1, 1, Conv, [512, 3, 2]], # 6-P4/16 [-1, 1, Conv, [512, 3, 2]], # 6-P4/16
[-1, 9, Bottleneck, [512]],
[-1, 9, BottleneckCSP, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 8-P5/32 [-1, 1, Conv, [1024, 3, 2]], # 8-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]], [-1, 1, SPP, [1024, [5, 9, 13]]],
[-1, 3, Bottleneck, [1024]], # 10
[-1, 6, BottleneckCSP, [1024]], # 10
] ]


# yolov5 head # yolov5 head
head: head:
[[-1, 3, Bottleneck, [1024, False]], # 11
[[-1, 3, BottleneckCSP, [1024, False]], # 11
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 12 (P5/32-large) [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 12 (P5/32-large)


[-2, 1, nn.Upsample, [None, 2, 'nearest']], [-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4 [[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 1, Conv, [512, 1, 1]], [-1, 1, Conv, [512, 1, 1]],
[-1, 3, Bottleneck, [512, False]],
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 17 (P4/16-medium) [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 17 (P4/16-medium)


[-2, 1, nn.Upsample, [None, 2, 'nearest']], [-2, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3 [[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 1, Conv, [256, 1, 1]], [-1, 1, Conv, [256, 1, 1]],
[-1, 3, Bottleneck, [256, False]],
[-1, 3, BottleneckCSP, [256, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 22 (P3/8-small) [-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1, 0]], # 22 (P3/8-small)


[[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) [[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)

+ 9
- 9
utils/utils.py View File

ax = ax.ravel() ax = ax.ravel()


fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True) fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18][:-1]), [33.5, 39.1, 42.5, 45.9, 49., 50.5][:-1],
'.-', linewidth=2, markersize=8, alpha=0.3, label='EfficientDet')

for f in sorted(glob.glob('study*.txt')):
for f in ['coco_study/study_coco_yolov5%s.txt' % x for x in ['s', 'm', 'l', 'x']]:
y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
x = np.arange(y.shape[1]) if x is None else np.array(x) x = np.arange(y.shape[1]) if x is None else np.array(x)
s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)'] s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)']
ax2.plot(y[6, :j], y[3, :j] * 1E2, '.-', linewidth=2, markersize=8, ax2.plot(y[6, :j], y[3, :j] * 1E2, '.-', linewidth=2, markersize=8,
label=Path(f).stem.replace('study_coco_', '').replace('yolo', 'YOLO')) label=Path(f).stem.replace('study_coco_', '').replace('yolo', 'YOLO'))


ax2.set_xlim(0)
ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [33.5, 39.1, 42.5, 45.9, 49., 50.5],
'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')
ax2.set_xlim(0, 30)
ax2.set_ylim(23, 50) ax2.set_ylim(23, 50)
ax2.set_xlabel('GPU Latency (ms)') ax2.set_xlabel('GPU Latency (ms)')
ax2.set_ylabel('COCO AP val') ax2.set_ylabel('COCO AP val')
fig.savefig(f.replace('.txt', '.png'), dpi=200) fig.savefig(f.replace('.txt', '.png'), dpi=200)




def plot_results(start=0, stop=0, bucket='', id=()): # from utils.utils import *; plot_results()
# Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov3#training
def plot_results(start=0, stop=0, bucket='', id=(), labels=()): # from utils.utils import *; plot_results()
# Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov5#reproduce-our-training
fig, ax = plt.subplots(2, 5, figsize=(12, 6)) fig, ax = plt.subplots(2, 5, figsize=(12, 6))
ax = ax.ravel() ax = ax.ravel()
s = ['GIoU', 'Objectness', 'Classification', 'Precision', 'Recall', s = ['GIoU', 'Objectness', 'Classification', 'Precision', 'Recall',
files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id] files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id]
else: else:
files = glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt') files = glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')
for f in sorted(files):
for fi, f in enumerate(files):
try: try:
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T
n = results.shape[1] # number of rows n = results.shape[1] # number of rows
if i in [0, 1, 2, 5, 6, 7]: if i in [0, 1, 2, 5, 6, 7]:
y[y == 0] = np.nan # dont show zero loss values y[y == 0] = np.nan # dont show zero loss values
# y /= y[0] # normalize # y /= y[0] # normalize
ax[i].plot(x, y, marker='.', label=Path(f).stem, linewidth=2, markersize=8)
label = labels[fi] if len(labels) else Path(f).stem
ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8)
ax[i].set_title(s[i]) ax[i].set_title(s[i])
# if i in [5, 6, 7]: # share train and val loss y axes # if i in [5, 6, 7]: # share train and val loss y axes
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5]) # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])

Loading…
Cancel
Save