* Update C3 module * Update C3 module * Update C3 module * Update C3 module * update * update * update * update * update * update * update * update * update * updates * updates * updates * updates * updates * updates * updates * updates * updates * updates * update * update * update * update * updates * updates * updates * updates * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update datasets * update * update * update * update attempt_downlaod() * merge * merge * update * update * update * update * update * update * update * update * update * update * parameterize eps * comments * gs-multiple * update * max_nms implemented * Create one_cycle() function * update * update * update * update * update * update * update * update study.png * update study.png * Update datasets.py5.0
@@ -4,28 +4,32 @@ | |||
![CI CPU testing](https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg) | |||
This repository represents Ultralytics open-source research into future object detection methods, and incorporates our lessons learned and best practices evolved over training thousands of models on custom client datasets with our previous YOLO repository https://github.com/ultralytics/yolov3. **All code and models are under active development, and are subject to modification or deletion without notice.** Use at your own risk. | |||
This repository represents Ultralytics open-source research into future object detection methods, and incorporates lessons learned and best practices evolved over thousands of hours of training and evolution on anonymized client datasets. **All code and models are under active development, and are subject to modification or deletion without notice.** Use at your own risk. | |||
<img src="https://user-images.githubusercontent.com/26833433/90187293-6773ba00-dd6e-11ea-8f90-cd94afc0427f.png" width="1000">** GPU Speed measures end-to-end time per image averaged over 5000 COCO val2017 images using a V100 GPU with batch size 32, and includes image preprocessing, PyTorch FP16 inference, postprocessing and NMS. EfficientDet data from [google/automl](https://github.com/google/automl) at batch size 8. | |||
<img src="https://user-images.githubusercontent.com/26833433/103594689-455e0e00-4eae-11eb-9cdf-7d753e2ceeeb.png" width="1000">** GPU Speed measures end-to-end time per image averaged over 5000 COCO val2017 images using a V100 GPU with batch size 32, and includes image preprocessing, PyTorch FP16 inference, postprocessing and NMS. EfficientDet data from [google/automl](https://github.com/google/automl) at batch size 8. | |||
- **January 5, 2021**: [v4.0 release](https://github.com/ultralytics/yolov5/releases/tag/v4.0): nn.SiLU() activations, [Weights & Biases](https://wandb.ai/) logging, [PyTorch Hub](https://pytorch.org/hub/ultralytics_yolov5/) integration. | |||
- **August 13, 2020**: [v3.0 release](https://github.com/ultralytics/yolov5/releases/tag/v3.0): nn.Hardswish() activations, data autodownload, native AMP. | |||
- **July 23, 2020**: [v2.0 release](https://github.com/ultralytics/yolov5/releases/tag/v2.0): improved model definition, training and mAP. | |||
- **June 22, 2020**: [PANet](https://arxiv.org/abs/1803.01534) updates: new heads, reduced parameters, improved speed and mAP [364fcfd](https://github.com/ultralytics/yolov5/commit/364fcfd7dba53f46edd4f04c037a039c0a287972). | |||
- **June 19, 2020**: [FP16](https://pytorch.org/docs/stable/nn.html#torch.nn.Module.half) as new default for smaller checkpoints and faster inference [d4c6674](https://github.com/ultralytics/yolov5/commit/d4c6674c98e19df4c40e33a777610a18d1961145). | |||
- **June 9, 2020**: [CSP](https://github.com/WongKinYiu/CrossStagePartialNetworks) updates: improved speed, size, and accuracy (credit to @WongKinYiu for CSP). | |||
- **May 27, 2020**: Public release. YOLOv5 models are SOTA among all known YOLO implementations. | |||
## Pretrained Checkpoints | |||
| Model | AP<sup>val</sup> | AP<sup>test</sup> | AP<sub>50</sub> | Speed<sub>GPU</sub> | FPS<sub>GPU</sub> || params | GFLOPS | | |||
|---------- |------ |------ |------ | -------- | ------| ------ |------ | :------: | | |||
| [YOLOv5s](https://github.com/ultralytics/yolov5/releases) | 37.0 | 37.0 | 56.2 | **2.4ms** | **416** || 7.5M | 17.5 | |||
| [YOLOv5m](https://github.com/ultralytics/yolov5/releases) | 44.3 | 44.3 | 63.2 | 3.4ms | 294 || 21.8M | 52.3 | |||
| [YOLOv5l](https://github.com/ultralytics/yolov5/releases) | 47.7 | 47.7 | 66.5 | 4.4ms | 227 || 47.8M | 117.2 | |||
| [YOLOv5x](https://github.com/ultralytics/yolov5/releases) | **49.2** | **49.2** | **67.7** | 6.9ms | 145 || 89.0M | 221.5 | |||
| | | | | | || | | |||
| [YOLOv5x](https://github.com/ultralytics/yolov5/releases) + TTA|**50.8**| **50.8** | **68.9** | 25.5ms | 39 || 89.0M | 801.0 | |||
| Model | size | AP<sup>val</sup> | AP<sup>test</sup> | AP<sub>50</sub> | Speed<sub>V100</sub> | FPS<sub>V100</sub> || params | GFLOPS | | |||
|---------- |------ |------ |------ |------ | -------- | ------| ------ |------ | :------: | | |||
| [YOLOv5s](https://github.com/ultralytics/yolov5/releases) |640 |36.8 |36.8 |55.6 |**2.2ms** |**455** ||7.3M |17.0 | |||
| [YOLOv5m](https://github.com/ultralytics/yolov5/releases) |640 |44.5 |44.5 |63.1 |2.9ms |345 ||21.4M |51.3 | |||
| [YOLOv5l](https://github.com/ultralytics/yolov5/releases) |640 |48.1 |48.1 |66.4 |3.8ms |264 ||47.0M |115.4 | |||
| [YOLOv5x](https://github.com/ultralytics/yolov5/releases) |640 |**50.1** |**50.1** |**68.7** |6.0ms |167 ||87.7M |218.8 | |||
| | | | | | | || | | |||
| [YOLOv5x](https://github.com/ultralytics/yolov5/releases) + TTA |832 |**51.9** |**51.9** |**69.6** |24.9ms |40 ||87.7M |1005.3 | |||
<!--- | |||
| [YOLOv5l6](https://github.com/ultralytics/yolov5/releases) |640 |49.0 |49.0 |67.4 |4.1ms |244 ||77.2M |117.7 | |||
| [YOLOv5l6](https://github.com/ultralytics/yolov5/releases) |1280 |53.0 |53.0 |70.8 |12.3ms |81 ||77.2M |117.7 | |||
---> | |||
** AP<sup>test</sup> denotes COCO [test-dev2017](http://cocodataset.org/#upload) server results, all other AP results denote val2017 accuracy. | |||
** All AP numbers are for single-model single-scale without ensemble or TTA. **Reproduce mAP** by `python test.py --data coco.yaml --img 640 --conf 0.001 --iou 0.65` | |||
@@ -33,6 +37,7 @@ This repository represents Ultralytics open-source research into future object d | |||
** All checkpoints are trained to 300 epochs with default settings and hyperparameters (no autoaugmentation). | |||
** Test Time Augmentation ([TTA](https://github.com/ultralytics/yolov5/issues/303)) runs at 3 image sizes. **Reproduce TTA** by `python test.py --data coco.yaml --img 832 --iou 0.65 --augment` | |||
## Requirements | |||
Python 3.8 or later with all [requirements.txt](https://github.com/ultralytics/yolov5/blob/master/requirements.txt) dependencies installed, including `torch>=1.7`. To install run: | |||
@@ -106,7 +111,7 @@ import torch | |||
from PIL import Image | |||
# Model | |||
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) # for PIL/cv2/np inputs and NMS | |||
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) | |||
# Images | |||
img1 = Image.open('zidane.jpg') | |||
@@ -114,13 +119,13 @@ img2 = Image.open('bus.jpg') | |||
imgs = [img1, img2] # batched list of images | |||
# Inference | |||
prediction = model(imgs, size=640) # includes NMS | |||
result = model(imgs) | |||
``` | |||
## Training | |||
Download [COCO](https://github.com/ultralytics/yolov5/blob/master/data/scripts/get_coco.sh) and run command below. Training times for YOLOv5s/m/l/x are 2/4/6/8 days on a single V100 (multi-GPU times faster). Use the largest `--batch-size` your GPU allows (batch sizes shown for 16 GB devices). | |||
Run commands below to reproduce results on [COCO](https://github.com/ultralytics/yolov5/blob/master/data/scripts/get_coco.sh) dataset (dataset auto-downloads on first use). Training times for YOLOv5s/m/l/x are 2/4/6/8 days on a single V100 (multi-GPU times faster). Use the largest `--batch-size` your GPU allows (batch sizes shown for 16 GB devices). | |||
```bash | |||
$ python train.py --data coco.yaml --cfg yolov5s.yaml --weights '' --batch-size 64 | |||
yolov5m 40 |
@@ -18,15 +18,15 @@ test: ../coco/test-dev2017.txt # 20288 of 40670 images, submit to https://compe | |||
nc: 80 | |||
# class names | |||
names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', | |||
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', | |||
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', | |||
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', | |||
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', | |||
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', | |||
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', | |||
'hair drier', 'toothbrush'] | |||
names: [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', | |||
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', | |||
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', | |||
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', | |||
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', | |||
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', | |||
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', | |||
'hair drier', 'toothbrush' ] | |||
# Print classes | |||
# with open('data/coco.yaml') as f: |
@@ -17,12 +17,12 @@ val: ../coco128/images/train2017/ # 128 images | |||
nc: 80 | |||
# class names | |||
names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', | |||
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', | |||
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', | |||
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', | |||
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', | |||
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', | |||
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', | |||
'hair drier', 'toothbrush'] | |||
names: [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', | |||
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', | |||
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', | |||
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', | |||
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', | |||
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', | |||
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', | |||
'hair drier', 'toothbrush' ] |
@@ -17,5 +17,5 @@ val: ../VOC/images/val/ # 4952 images | |||
nc: 20 | |||
# class names | |||
names: ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', | |||
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] | |||
names: [ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', | |||
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' ] |
@@ -30,7 +30,7 @@ class Conv(nn.Module): | |||
super(Conv, self).__init__() | |||
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) | |||
self.bn = nn.BatchNorm2d(c2) | |||
self.act = nn.Hardswish() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) | |||
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) | |||
def forward(self, x): | |||
return self.act(self.bn(self.conv(x))) | |||
@@ -105,9 +105,39 @@ class Focus(nn.Module): | |||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups | |||
super(Focus, self).__init__() | |||
self.conv = Conv(c1 * 4, c2, k, s, p, g, act) | |||
# self.contract = Contract(gain=2) | |||
def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2) | |||
return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)) | |||
# return self.conv(self.contract(x)) | |||
class Contract(nn.Module): | |||
# Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40) | |||
def __init__(self, gain=2): | |||
super().__init__() | |||
self.gain = gain | |||
def forward(self, x): | |||
N, C, H, W = x.size() # assert (H / s == 0) and (W / s == 0), 'Indivisible gain' | |||
s = self.gain | |||
x = x.view(N, C, H // s, s, W // s, s) # x(1,64,40,2,40,2) | |||
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40) | |||
return x.view(N, C * s * s, H // s, W // s) # x(1,256,40,40) | |||
class Expand(nn.Module): | |||
# Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160) | |||
def __init__(self, gain=2): | |||
super().__init__() | |||
self.gain = gain | |||
def forward(self, x): | |||
N, C, H, W = x.size() # assert C / s ** 2 == 0, 'Indivisible gain' | |||
s = self.gain | |||
x = x.view(N, s, s, C // s ** 2, H, W) # x(1,2,2,16,80,80) | |||
x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2) | |||
return x.view(N, C // s ** 2, H * s, W * s) # x(1,16,160,160) | |||
class Concat(nn.Module): | |||
@@ -253,20 +283,13 @@ class Detections: | |||
return x | |||
class Flatten(nn.Module): | |||
# Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions | |||
@staticmethod | |||
def forward(x): | |||
return x.view(x.size(0), -1) | |||
class Classify(nn.Module): | |||
# Classification head, i.e. x(b,c1,20,20) to x(b,c2) | |||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups | |||
super(Classify, self).__init__() | |||
self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1) | |||
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1) | |||
self.flat = Flatten() | |||
self.flat = nn.Flatten() | |||
def forward(self, x): | |||
z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list |
@@ -105,8 +105,8 @@ class Ensemble(nn.ModuleList): | |||
for module in self: | |||
y.append(module(x, augment)[0]) | |||
# y = torch.stack(y).max(0)[0] # max ensemble | |||
# y = torch.cat(y, 1) # nms ensemble | |||
y = torch.stack(y).mean(0) # mean ensemble | |||
# y = torch.stack(y).mean(0) # mean ensemble | |||
y = torch.cat(y, 1) # nms ensemble | |||
return y, None # inference, train output | |||
@@ -0,0 +1,58 @@ | |||
# Default YOLOv5 anchors for COCO data | |||
# P5 ------------------------------------------------------------------------------------------------------------------- | |||
# P5-640: | |||
anchors_p5_640: | |||
- [ 10,13, 16,30, 33,23 ] # P3/8 | |||
- [ 30,61, 62,45, 59,119 ] # P4/16 | |||
- [ 116,90, 156,198, 373,326 ] # P5/32 | |||
# P6 ------------------------------------------------------------------------------------------------------------------- | |||
# P6-640: thr=0.25: 0.9964 BPR, 5.54 anchors past thr, n=12, img_size=640, metric_all=0.281/0.716-mean/best, past_thr=0.469-mean: 9,11, 21,19, 17,41, 43,32, 39,70, 86,64, 65,131, 134,130, 120,265, 282,180, 247,354, 512,387 | |||
anchors_p6_640: | |||
- [ 9,11, 21,19, 17,41 ] # P3/8 | |||
- [ 43,32, 39,70, 86,64 ] # P4/16 | |||
- [ 65,131, 134,130, 120,265 ] # P5/32 | |||
- [ 282,180, 247,354, 512,387 ] # P6/64 | |||
# P6-1280: thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1280, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 19,27, 44,40, 38,94, 96,68, 86,152, 180,137, 140,301, 303,264, 238,542, 436,615, 739,380, 925,792 | |||
anchors_p6_1280: | |||
- [ 19,27, 44,40, 38,94 ] # P3/8 | |||
- [ 96,68, 86,152, 180,137 ] # P4/16 | |||
- [ 140,301, 303,264, 238,542 ] # P5/32 | |||
- [ 436,615, 739,380, 925,792 ] # P6/64 | |||
# P6-1920: thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1920, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 28,41, 67,59, 57,141, 144,103, 129,227, 270,205, 209,452, 455,396, 358,812, 653,922, 1109,570, 1387,1187 | |||
anchors_p6_1920: | |||
- [ 28,41, 67,59, 57,141 ] # P3/8 | |||
- [ 144,103, 129,227, 270,205 ] # P4/16 | |||
- [ 209,452, 455,396, 358,812 ] # P5/32 | |||
- [ 653,922, 1109,570, 1387,1187 ] # P6/64 | |||
# P7 ------------------------------------------------------------------------------------------------------------------- | |||
# P7-640: thr=0.25: 0.9962 BPR, 6.76 anchors past thr, n=15, img_size=640, metric_all=0.275/0.733-mean/best, past_thr=0.466-mean: 11,11, 13,30, 29,20, 30,46, 61,38, 39,92, 78,80, 146,66, 79,163, 149,150, 321,143, 157,303, 257,402, 359,290, 524,372 | |||
anchors_p7_640: | |||
- [ 11,11, 13,30, 29,20 ] # P3/8 | |||
- [ 30,46, 61,38, 39,92 ] # P4/16 | |||
- [ 78,80, 146,66, 79,163 ] # P5/32 | |||
- [ 149,150, 321,143, 157,303 ] # P6/64 | |||
- [ 257,402, 359,290, 524,372 ] # P7/128 | |||
# P7-1280: thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1280, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 19,22, 54,36, 32,77, 70,83, 138,71, 75,173, 165,159, 148,334, 375,151, 334,317, 251,626, 499,474, 750,326, 534,814, 1079,818 | |||
anchors_p7_1280: | |||
- [ 19,22, 54,36, 32,77 ] # P3/8 | |||
- [ 70,83, 138,71, 75,173 ] # P4/16 | |||
- [ 165,159, 148,334, 375,151 ] # P5/32 | |||
- [ 334,317, 251,626, 499,474 ] # P6/64 | |||
- [ 750,326, 534,814, 1079,818 ] # P7/128 | |||
# P7-1920: thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1920, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 29,34, 81,55, 47,115, 105,124, 207,107, 113,259, 247,238, 222,500, 563,227, 501,476, 376,939, 749,711, 1126,489, 801,1222, 1618,1227 | |||
anchors_p7_1920: | |||
- [ 29,34, 81,55, 47,115 ] # P3/8 | |||
- [ 105,124, 207,107, 113,259 ] # P4/16 | |||
- [ 247,238, 222,500, 563,227 ] # P5/32 | |||
- [ 501,476, 376,939, 749,711 ] # P6/64 | |||
- [ 1126,489, 801,1222, 1618,1227 ] # P7/128 |
@@ -0,0 +1,54 @@ | |||
# parameters | |||
nc: 80 # number of classes | |||
depth_multiple: 1.0 # model depth multiple | |||
width_multiple: 1.0 # layer channel multiple | |||
# anchors | |||
anchors: 3 | |||
# YOLOv5 backbone | |||
backbone: | |||
# [from, number, module, args] | |||
[ [ -1, 1, Focus, [ 64, 3 ] ], # 0-P1/2 | |||
[ -1, 1, Conv, [ 128, 3, 2 ] ], # 1-P2/4 | |||
[ -1, 3, C3, [ 128 ] ], | |||
[ -1, 1, Conv, [ 256, 3, 2 ] ], # 3-P3/8 | |||
[ -1, 9, C3, [ 256 ] ], | |||
[ -1, 1, Conv, [ 512, 3, 2 ] ], # 5-P4/16 | |||
[ -1, 9, C3, [ 512 ] ], | |||
[ -1, 1, Conv, [ 1024, 3, 2 ] ], # 7-P5/32 | |||
[ -1, 1, SPP, [ 1024, [ 5, 9, 13 ] ] ], | |||
[ -1, 3, C3, [ 1024, False ] ], # 9 | |||
] | |||
# YOLOv5 head | |||
head: | |||
[ [ -1, 1, Conv, [ 512, 1, 1 ] ], | |||
[ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], | |||
[ [ -1, 6 ], 1, Concat, [ 1 ] ], # cat backbone P4 | |||
[ -1, 3, C3, [ 512, False ] ], # 13 | |||
[ -1, 1, Conv, [ 256, 1, 1 ] ], | |||
[ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], | |||
[ [ -1, 4 ], 1, Concat, [ 1 ] ], # cat backbone P3 | |||
[ -1, 3, C3, [ 256, False ] ], # 17 (P3/8-small) | |||
[ -1, 1, Conv, [ 128, 1, 1 ] ], | |||
[ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], | |||
[ [ -1, 2 ], 1, Concat, [ 1 ] ], # cat backbone P2 | |||
[ -1, 1, C3, [ 128, False ] ], # 21 (P2/4-xsmall) | |||
[ -1, 1, Conv, [ 128, 3, 2 ] ], | |||
[ [ -1, 18 ], 1, Concat, [ 1 ] ], # cat head P3 | |||
[ -1, 3, C3, [ 256, False ] ], # 24 (P3/8-small) | |||
[ -1, 1, Conv, [ 256, 3, 2 ] ], | |||
[ [ -1, 14 ], 1, Concat, [ 1 ] ], # cat head P4 | |||
[ -1, 3, C3, [ 512, False ] ], # 27 (P4/16-medium) | |||
[ -1, 1, Conv, [ 512, 3, 2 ] ], | |||
[ [ -1, 10 ], 1, Concat, [ 1 ] ], # cat head P5 | |||
[ -1, 3, C3, [ 1024, False ] ], # 30 (P5/32-large) | |||
[ [ 24, 27, 30 ], 1, Detect, [ nc, anchors ] ], # Detect(P3, P4, P5) | |||
] |
@@ -0,0 +1,56 @@ | |||
# parameters | |||
nc: 80 # number of classes | |||
depth_multiple: 1.0 # model depth multiple | |||
width_multiple: 1.0 # layer channel multiple | |||
# anchors | |||
anchors: 3 | |||
# YOLOv5 backbone | |||
backbone: | |||
# [from, number, module, args] | |||
[ [ -1, 1, Focus, [ 64, 3 ] ], # 0-P1/2 | |||
[ -1, 1, Conv, [ 128, 3, 2 ] ], # 1-P2/4 | |||
[ -1, 3, C3, [ 128 ] ], | |||
[ -1, 1, Conv, [ 256, 3, 2 ] ], # 3-P3/8 | |||
[ -1, 9, C3, [ 256 ] ], | |||
[ -1, 1, Conv, [ 512, 3, 2 ] ], # 5-P4/16 | |||
[ -1, 9, C3, [ 512 ] ], | |||
[ -1, 1, Conv, [ 768, 3, 2 ] ], # 7-P5/32 | |||
[ -1, 3, C3, [ 768 ] ], | |||
[ -1, 1, Conv, [ 1024, 3, 2 ] ], # 9-P6/64 | |||
[ -1, 1, SPP, [ 1024, [ 3, 5, 7 ] ] ], | |||
[ -1, 3, C3, [ 1024, False ] ], # 11 | |||
] | |||
# YOLOv5 head | |||
head: | |||
[ [ -1, 1, Conv, [ 768, 1, 1 ] ], | |||
[ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], | |||
[ [ -1, 8 ], 1, Concat, [ 1 ] ], # cat backbone P5 | |||
[ -1, 3, C3, [ 768, False ] ], # 15 | |||
[ -1, 1, Conv, [ 512, 1, 1 ] ], | |||
[ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], | |||
[ [ -1, 6 ], 1, Concat, [ 1 ] ], # cat backbone P4 | |||
[ -1, 3, C3, [ 512, False ] ], # 19 | |||
[ -1, 1, Conv, [ 256, 1, 1 ] ], | |||
[ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], | |||
[ [ -1, 4 ], 1, Concat, [ 1 ] ], # cat backbone P3 | |||
[ -1, 3, C3, [ 256, False ] ], # 23 (P3/8-small) | |||
[ -1, 1, Conv, [ 256, 3, 2 ] ], | |||
[ [ -1, 20 ], 1, Concat, [ 1 ] ], # cat head P4 | |||
[ -1, 3, C3, [ 512, False ] ], # 26 (P4/16-medium) | |||
[ -1, 1, Conv, [ 512, 3, 2 ] ], | |||
[ [ -1, 16 ], 1, Concat, [ 1 ] ], # cat head P5 | |||
[ -1, 3, C3, [ 768, False ] ], # 29 (P5/32-large) | |||
[ -1, 1, Conv, [ 768, 3, 2 ] ], | |||
[ [ -1, 12 ], 1, Concat, [ 1 ] ], # cat head P6 | |||
[ -1, 3, C3, [ 1024, False ] ], # 32 (P5/64-xlarge) | |||
[ [ 23, 26, 29, 32 ], 1, Detect, [ nc, anchors ] ], # Detect(P3, P4, P5, P6) | |||
] |
@@ -0,0 +1,67 @@ | |||
# parameters | |||
nc: 80 # number of classes | |||
depth_multiple: 1.0 # model depth multiple | |||
width_multiple: 1.0 # layer channel multiple | |||
# anchors | |||
anchors: 3 | |||
# YOLOv5 backbone | |||
backbone: | |||
# [from, number, module, args] | |||
[ [ -1, 1, Focus, [ 64, 3 ] ], # 0-P1/2 | |||
[ -1, 1, Conv, [ 128, 3, 2 ] ], # 1-P2/4 | |||
[ -1, 3, C3, [ 128 ] ], | |||
[ -1, 1, Conv, [ 256, 3, 2 ] ], # 3-P3/8 | |||
[ -1, 9, C3, [ 256 ] ], | |||
[ -1, 1, Conv, [ 512, 3, 2 ] ], # 5-P4/16 | |||
[ -1, 9, C3, [ 512 ] ], | |||
[ -1, 1, Conv, [ 768, 3, 2 ] ], # 7-P5/32 | |||
[ -1, 3, C3, [ 768 ] ], | |||
[ -1, 1, Conv, [ 1024, 3, 2 ] ], # 9-P6/64 | |||
[ -1, 3, C3, [ 1024 ] ], | |||
[ -1, 1, Conv, [ 1280, 3, 2 ] ], # 11-P7/128 | |||
[ -1, 1, SPP, [ 1280, [ 3, 5 ] ] ], | |||
[ -1, 3, C3, [ 1280, False ] ], # 13 | |||
] | |||
# YOLOv5 head | |||
head: | |||
[ [ -1, 1, Conv, [ 1024, 1, 1 ] ], | |||
[ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], | |||
[ [ -1, 10 ], 1, Concat, [ 1 ] ], # cat backbone P6 | |||
[ -1, 3, C3, [ 1024, False ] ], # 17 | |||
[ -1, 1, Conv, [ 768, 1, 1 ] ], | |||
[ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], | |||
[ [ -1, 8 ], 1, Concat, [ 1 ] ], # cat backbone P5 | |||
[ -1, 3, C3, [ 768, False ] ], # 21 | |||
[ -1, 1, Conv, [ 512, 1, 1 ] ], | |||
[ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], | |||
[ [ -1, 6 ], 1, Concat, [ 1 ] ], # cat backbone P4 | |||
[ -1, 3, C3, [ 512, False ] ], # 25 | |||
[ -1, 1, Conv, [ 256, 1, 1 ] ], | |||
[ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], | |||
[ [ -1, 4 ], 1, Concat, [ 1 ] ], # cat backbone P3 | |||
[ -1, 3, C3, [ 256, False ] ], # 29 (P3/8-small) | |||
[ -1, 1, Conv, [ 256, 3, 2 ] ], | |||
[ [ -1, 26 ], 1, Concat, [ 1 ] ], # cat head P4 | |||
[ -1, 3, C3, [ 512, False ] ], # 32 (P4/16-medium) | |||
[ -1, 1, Conv, [ 512, 3, 2 ] ], | |||
[ [ -1, 22 ], 1, Concat, [ 1 ] ], # cat head P5 | |||
[ -1, 3, C3, [ 768, False ] ], # 35 (P5/32-large) | |||
[ -1, 1, Conv, [ 768, 3, 2 ] ], | |||
[ [ -1, 18 ], 1, Concat, [ 1 ] ], # cat head P6 | |||
[ -1, 3, C3, [ 1024, False ] ], # 38 (P6/64-xlarge) | |||
[ -1, 1, Conv, [ 1024, 3, 2 ] ], | |||
[ [ -1, 14 ], 1, Concat, [ 1 ] ], # cat head P7 | |||
[ -1, 3, C3, [ 1280, False ] ], # 41 (P7/128-xxlarge) | |||
[ [ 29, 32, 35, 38, 41 ], 1, Detect, [ nc, anchors ] ], # Detect(P3, P4, P5, P6, P7) | |||
] |
@@ -1,17 +1,13 @@ | |||
import argparse | |||
import logging | |||
import math | |||
import sys | |||
from copy import deepcopy | |||
from pathlib import Path | |||
import torch | |||
import torch.nn as nn | |||
sys.path.append('./') # to run '$ python *.py' files in subdirectories | |||
logger = logging.getLogger(__name__) | |||
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, C3, Concat, NMS, autoShape | |||
from models.common import * | |||
from models.experimental import MixConv2d, CrossConv | |||
from utils.autoanchor import check_anchor_order | |||
from utils.general import make_divisible, check_file, set_logging | |||
@@ -89,7 +85,7 @@ class Model(nn.Module): | |||
# Build strides, anchors | |||
m = self.model[-1] # Detect() | |||
if isinstance(m, Detect): | |||
s = 128 # 2x min stride | |||
s = 256 # 2x min stride | |||
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward | |||
m.anchors /= m.stride.view(-1, 1, 1) | |||
check_anchor_order(m) | |||
@@ -109,7 +105,7 @@ class Model(nn.Module): | |||
f = [None, 3, None] # flips (2-ud, 3-lr) | |||
y = [] # outputs | |||
for si, fi in zip(s, f): | |||
xi = scale_img(x.flip(fi) if fi else x, si) | |||
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max())) | |||
yi = self.forward_once(xi)[0] # forward | |||
# cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save | |||
yi[..., :4] /= si # de-scale | |||
@@ -242,13 +238,17 @@ def parse_model(d, ch): # model_dict, input_channels(3) | |||
elif m is nn.BatchNorm2d: | |||
args = [ch[f]] | |||
elif m is Concat: | |||
c2 = sum([ch[-1 if x == -1 else x + 1] for x in f]) | |||
c2 = sum([ch[x if x < 0 else x + 1] for x in f]) | |||
elif m is Detect: | |||
args.append([ch[x + 1] for x in f]) | |||
if isinstance(args[1], int): # number of anchors | |||
args[1] = [list(range(args[1] * 2))] * len(f) | |||
elif m is Contract: | |||
c2 = ch[f if f < 0 else f + 1] * args[0] ** 2 | |||
elif m is Expand: | |||
c2 = ch[f if f < 0 else f + 1] // args[0] ** 2 | |||
else: | |||
c2 = ch[f] | |||
c2 = ch[f if f < 0 else f + 1] | |||
m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module | |||
t = str(m)[8:-2].replace('__main__.', '') # module type |
@@ -14,14 +14,14 @@ backbone: | |||
# [from, number, module, args] | |||
[[-1, 1, Focus, [64, 3]], # 0-P1/2 | |||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4 | |||
[-1, 3, BottleneckCSP, [128]], | |||
[-1, 3, C3, [128]], | |||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8 | |||
[-1, 9, BottleneckCSP, [256]], | |||
[-1, 9, C3, [256]], | |||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16 | |||
[-1, 9, BottleneckCSP, [512]], | |||
[-1, 9, C3, [512]], | |||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 | |||
[-1, 1, SPP, [1024, [5, 9, 13]]], | |||
[-1, 3, BottleneckCSP, [1024, False]], # 9 | |||
[-1, 3, C3, [1024, False]], # 9 | |||
] | |||
# YOLOv5 head | |||
@@ -29,20 +29,20 @@ head: | |||
[[-1, 1, Conv, [512, 1, 1]], | |||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], | |||
[[-1, 6], 1, Concat, [1]], # cat backbone P4 | |||
[-1, 3, BottleneckCSP, [512, False]], # 13 | |||
[-1, 3, C3, [512, False]], # 13 | |||
[-1, 1, Conv, [256, 1, 1]], | |||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], | |||
[[-1, 4], 1, Concat, [1]], # cat backbone P3 | |||
[-1, 3, BottleneckCSP, [256, False]], # 17 (P3/8-small) | |||
[-1, 3, C3, [256, False]], # 17 (P3/8-small) | |||
[-1, 1, Conv, [256, 3, 2]], | |||
[[-1, 14], 1, Concat, [1]], # cat head P4 | |||
[-1, 3, BottleneckCSP, [512, False]], # 20 (P4/16-medium) | |||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium) | |||
[-1, 1, Conv, [512, 3, 2]], | |||
[[-1, 10], 1, Concat, [1]], # cat head P5 | |||
[-1, 3, BottleneckCSP, [1024, False]], # 23 (P5/32-large) | |||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large) | |||
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) | |||
] |
@@ -14,14 +14,14 @@ backbone: | |||
# [from, number, module, args] | |||
[[-1, 1, Focus, [64, 3]], # 0-P1/2 | |||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4 | |||
[-1, 3, BottleneckCSP, [128]], | |||
[-1, 3, C3, [128]], | |||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8 | |||
[-1, 9, BottleneckCSP, [256]], | |||
[-1, 9, C3, [256]], | |||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16 | |||
[-1, 9, BottleneckCSP, [512]], | |||
[-1, 9, C3, [512]], | |||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 | |||
[-1, 1, SPP, [1024, [5, 9, 13]]], | |||
[-1, 3, BottleneckCSP, [1024, False]], # 9 | |||
[-1, 3, C3, [1024, False]], # 9 | |||
] | |||
# YOLOv5 head | |||
@@ -29,20 +29,20 @@ head: | |||
[[-1, 1, Conv, [512, 1, 1]], | |||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], | |||
[[-1, 6], 1, Concat, [1]], # cat backbone P4 | |||
[-1, 3, BottleneckCSP, [512, False]], # 13 | |||
[-1, 3, C3, [512, False]], # 13 | |||
[-1, 1, Conv, [256, 1, 1]], | |||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], | |||
[[-1, 4], 1, Concat, [1]], # cat backbone P3 | |||
[-1, 3, BottleneckCSP, [256, False]], # 17 (P3/8-small) | |||
[-1, 3, C3, [256, False]], # 17 (P3/8-small) | |||
[-1, 1, Conv, [256, 3, 2]], | |||
[[-1, 14], 1, Concat, [1]], # cat head P4 | |||
[-1, 3, BottleneckCSP, [512, False]], # 20 (P4/16-medium) | |||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium) | |||
[-1, 1, Conv, [512, 3, 2]], | |||
[[-1, 10], 1, Concat, [1]], # cat head P5 | |||
[-1, 3, BottleneckCSP, [1024, False]], # 23 (P5/32-large) | |||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large) | |||
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) | |||
] |
@@ -14,14 +14,14 @@ backbone: | |||
# [from, number, module, args] | |||
[[-1, 1, Focus, [64, 3]], # 0-P1/2 | |||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4 | |||
[-1, 3, BottleneckCSP, [128]], | |||
[-1, 3, C3, [128]], | |||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8 | |||
[-1, 9, BottleneckCSP, [256]], | |||
[-1, 9, C3, [256]], | |||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16 | |||
[-1, 9, BottleneckCSP, [512]], | |||
[-1, 9, C3, [512]], | |||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 | |||
[-1, 1, SPP, [1024, [5, 9, 13]]], | |||
[-1, 3, BottleneckCSP, [1024, False]], # 9 | |||
[-1, 3, C3, [1024, False]], # 9 | |||
] | |||
# YOLOv5 head | |||
@@ -29,20 +29,20 @@ head: | |||
[[-1, 1, Conv, [512, 1, 1]], | |||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], | |||
[[-1, 6], 1, Concat, [1]], # cat backbone P4 | |||
[-1, 3, BottleneckCSP, [512, False]], # 13 | |||
[-1, 3, C3, [512, False]], # 13 | |||
[-1, 1, Conv, [256, 1, 1]], | |||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], | |||
[[-1, 4], 1, Concat, [1]], # cat backbone P3 | |||
[-1, 3, BottleneckCSP, [256, False]], # 17 (P3/8-small) | |||
[-1, 3, C3, [256, False]], # 17 (P3/8-small) | |||
[-1, 1, Conv, [256, 3, 2]], | |||
[[-1, 14], 1, Concat, [1]], # cat head P4 | |||
[-1, 3, BottleneckCSP, [512, False]], # 20 (P4/16-medium) | |||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium) | |||
[-1, 1, Conv, [512, 3, 2]], | |||
[[-1, 10], 1, Concat, [1]], # cat head P5 | |||
[-1, 3, BottleneckCSP, [1024, False]], # 23 (P5/32-large) | |||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large) | |||
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) | |||
] |
@@ -14,14 +14,14 @@ backbone: | |||
# [from, number, module, args] | |||
[[-1, 1, Focus, [64, 3]], # 0-P1/2 | |||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4 | |||
[-1, 3, BottleneckCSP, [128]], | |||
[-1, 3, C3, [128]], | |||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8 | |||
[-1, 9, BottleneckCSP, [256]], | |||
[-1, 9, C3, [256]], | |||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16 | |||
[-1, 9, BottleneckCSP, [512]], | |||
[-1, 9, C3, [512]], | |||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 | |||
[-1, 1, SPP, [1024, [5, 9, 13]]], | |||
[-1, 3, BottleneckCSP, [1024, False]], # 9 | |||
[-1, 3, C3, [1024, False]], # 9 | |||
] | |||
# YOLOv5 head | |||
@@ -29,20 +29,20 @@ head: | |||
[[-1, 1, Conv, [512, 1, 1]], | |||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], | |||
[[-1, 6], 1, Concat, [1]], # cat backbone P4 | |||
[-1, 3, BottleneckCSP, [512, False]], # 13 | |||
[-1, 3, C3, [512, False]], # 13 | |||
[-1, 1, Conv, [256, 1, 1]], | |||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], | |||
[[-1, 4], 1, Concat, [1]], # cat backbone P3 | |||
[-1, 3, BottleneckCSP, [256, False]], # 17 (P3/8-small) | |||
[-1, 3, C3, [256, False]], # 17 (P3/8-small) | |||
[-1, 1, Conv, [256, 3, 2]], | |||
[[-1, 14], 1, Concat, [1]], # cat head P4 | |||
[-1, 3, BottleneckCSP, [512, False]], # 20 (P4/16-medium) | |||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium) | |||
[-1, 1, Conv, [512, 3, 2]], | |||
[[-1, 10], 1, Concat, [1]], # cat head P5 | |||
[-1, 3, BottleneckCSP, [1024, False]], # 23 (P5/32-large) | |||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large) | |||
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) | |||
] |
@@ -104,6 +104,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |||
nbs = 64 # nominal batch size | |||
accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing | |||
hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay | |||
logger.info(f"Scaled weight_decay = {hyp['weight_decay']}") | |||
pg0, pg1, pg2 = [], [], [] # optimizer parameter groups | |||
for k, v in model.named_modules(): | |||
@@ -164,7 +165,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |||
del ckpt, state_dict | |||
# Image sizes | |||
gs = int(max(model.stride)) # grid size (max stride) | |||
gs = int(model.stride.max()) # grid size (max stride) | |||
nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj']) | |||
imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples | |||
# DP mode | |||
@@ -187,7 +189,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |||
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, | |||
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank, | |||
world_size=opt.world_size, workers=opt.workers, | |||
image_weights=opt.image_weights) | |||
image_weights=opt.image_weights, quad=opt.quad) | |||
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class | |||
nb = len(dataloader) # number of batches | |||
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1) | |||
@@ -214,7 +216,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |||
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) | |||
# Model parameters | |||
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset | |||
hyp['cls'] *= nc / 80. # scale hyp['cls'] to class count | |||
hyp['obj'] *= imgsz ** 2 / 640. ** 2 * 3. / nl # scale hyp['obj'] to image size and output layers | |||
model.nc = nc # attach number of classes to model | |||
model.hyp = hyp # attach hyperparameters to model | |||
model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou) | |||
@@ -290,6 +293,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |||
loss, loss_items = compute_loss(pred, targets.to(device), model) # loss scaled by batch_size | |||
if rank != -1: | |||
loss *= opt.world_size # gradient averaged between devices in DDP mode | |||
if opt.quad: | |||
loss *= 4. | |||
# Backward | |||
scaler.scale(loss).backward() | |||
@@ -458,10 +463,10 @@ if __name__ == '__main__': | |||
parser.add_argument('--project', default='runs/train', help='save to project/name') | |||
parser.add_argument('--name', default='exp', help='save to project/name') | |||
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') | |||
parser.add_argument('--quad', action='store_true', help='quad dataloader') | |||
opt = parser.parse_args() | |||
# Set DDP variables | |||
opt.total_batch_size = opt.batch_size | |||
opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 | |||
opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1 | |||
set_logging(opt.global_rank) | |||
@@ -486,6 +491,7 @@ if __name__ == '__main__': | |||
opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run | |||
# DDP mode | |||
opt.total_batch_size = opt.batch_size | |||
device = select_device(opt.device, batch_size=opt.batch_size) | |||
if opt.local_rank != -1: | |||
assert torch.cuda.device_count() > opt.local_rank |
@@ -110,6 +110,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10 | |||
print('WARNING: Extremely small objects found. ' | |||
'%g of %g labels are < 3 pixels in width or height.' % (i, len(wh0))) | |||
wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels | |||
# wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1 | |||
# Kmeans calculation | |||
print('Running kmeans for %g anchors on %g points...' % (n, len(wh))) |
@@ -15,6 +15,7 @@ from threading import Thread | |||
import cv2 | |||
import numpy as np | |||
import torch | |||
import torch.nn.functional as F | |||
from PIL import Image, ExifTags | |||
from torch.utils.data import Dataset | |||
from tqdm import tqdm | |||
@@ -55,7 +56,7 @@ def exif_size(img): | |||
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False, | |||
rank=-1, world_size=1, workers=8, image_weights=False): | |||
rank=-1, world_size=1, workers=8, image_weights=False, quad=False): | |||
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache | |||
with torch_distributed_zero_first(rank): | |||
dataset = LoadImagesAndLabels(path, imgsz, batch_size, | |||
@@ -79,7 +80,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa | |||
num_workers=nw, | |||
sampler=sampler, | |||
pin_memory=True, | |||
collate_fn=LoadImagesAndLabels.collate_fn) | |||
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn) | |||
return dataloader, dataset | |||
@@ -578,6 +579,32 @@ class LoadImagesAndLabels(Dataset): # for training/testing | |||
l[:, 0] = i # add target image index for build_targets() | |||
return torch.stack(img, 0), torch.cat(label, 0), path, shapes | |||
@staticmethod | |||
def collate_fn4(batch): | |||
img, label, path, shapes = zip(*batch) # transposed | |||
n = len(shapes) // 4 | |||
img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n] | |||
ho = torch.tensor([[0., 0, 0, 1, 0, 0]]) | |||
wo = torch.tensor([[0., 0, 1, 0, 0, 0]]) | |||
s = torch.tensor([[1, 1, .5, .5, .5, .5]]) # scale | |||
for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW | |||
i *= 4 | |||
if random.random() < 0.5: | |||
im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[ | |||
0].type(img[i].type()) | |||
l = label[i] | |||
else: | |||
im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2) | |||
l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s | |||
img4.append(im) | |||
label4.append(l) | |||
for i, l in enumerate(label4): | |||
l[:, 0] = i # add target image index for build_targets() | |||
return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4 | |||
# Ancillary functions -------------------------------------------------------------------------------------------------- | |||
def load_image(self, index): | |||
@@ -617,7 +644,7 @@ def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5): | |||
def load_mosaic(self, index): | |||
# loads images in a mosaic | |||
# loads images in a 4-mosaic | |||
labels4 = [] | |||
s = self.img_size | |||
@@ -674,6 +701,80 @@ def load_mosaic(self, index): | |||
return img4, labels4 | |||
def load_mosaic9(self, index): | |||
# loads images in a 9-mosaic | |||
labels9 = [] | |||
s = self.img_size | |||
indices = [index] + [self.indices[random.randint(0, self.n - 1)] for _ in range(8)] # 8 additional image indices | |||
for i, index in enumerate(indices): | |||
# Load image | |||
img, _, (h, w) = load_image(self, index) | |||
# place img in img9 | |||
if i == 0: # center | |||
img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles | |||
h0, w0 = h, w | |||
c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates | |||
elif i == 1: # top | |||
c = s, s - h, s + w, s | |||
elif i == 2: # top right | |||
c = s + wp, s - h, s + wp + w, s | |||
elif i == 3: # right | |||
c = s + w0, s, s + w0 + w, s + h | |||
elif i == 4: # bottom right | |||
c = s + w0, s + hp, s + w0 + w, s + hp + h | |||
elif i == 5: # bottom | |||
c = s + w0 - w, s + h0, s + w0, s + h0 + h | |||
elif i == 6: # bottom left | |||
c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h | |||
elif i == 7: # left | |||
c = s - w, s + h0 - h, s, s + h0 | |||
elif i == 8: # top left | |||
c = s - w, s + h0 - hp - h, s, s + h0 - hp | |||
padx, pady = c[:2] | |||
x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords | |||
# Labels | |||
x = self.labels[index] | |||
labels = x.copy() | |||
if x.size > 0: # Normalized xywh to pixel xyxy format | |||
labels[:, 1] = w * (x[:, 1] - x[:, 3] / 2) + padx | |||
labels[:, 2] = h * (x[:, 2] - x[:, 4] / 2) + pady | |||
labels[:, 3] = w * (x[:, 1] + x[:, 3] / 2) + padx | |||
labels[:, 4] = h * (x[:, 2] + x[:, 4] / 2) + pady | |||
labels9.append(labels) | |||
# Image | |||
img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax] | |||
hp, wp = h, w # height, width previous | |||
# Offset | |||
yc, xc = [int(random.uniform(0, s)) for x in self.mosaic_border] # mosaic center x, y | |||
img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s] | |||
# Concat/clip labels | |||
if len(labels9): | |||
labels9 = np.concatenate(labels9, 0) | |||
labels9[:, [1, 3]] -= xc | |||
labels9[:, [2, 4]] -= yc | |||
np.clip(labels9[:, 1:], 0, 2 * s, out=labels9[:, 1:]) # use with random_perspective | |||
# img9, labels9 = replicate(img9, labels9) # replicate | |||
# Augment | |||
img9, labels9 = random_perspective(img9, labels9, | |||
degrees=self.hyp['degrees'], | |||
translate=self.hyp['translate'], | |||
scale=self.hyp['scale'], | |||
shear=self.hyp['shear'], | |||
perspective=self.hyp['perspective'], | |||
border=self.mosaic_border) # border to remove | |||
return img9, labels9 | |||
def replicate(img, labels): | |||
# Replicate labels | |||
h, w = img.shape[:2] | |||
@@ -811,12 +912,12 @@ def random_perspective(img, targets=(), degrees=10, translate=.1, scale=.1, shea | |||
return img, targets | |||
def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1): # box1(4,n), box2(4,n) | |||
def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n) | |||
# Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio | |||
w1, h1 = box1[2] - box1[0], box1[3] - box1[1] | |||
w2, h2 = box2[2] - box2[0], box2[3] - box2[1] | |||
ar = np.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16)) # aspect ratio | |||
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + 1e-16) > area_thr) & (ar < ar_thr) # candidates | |||
ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio | |||
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates | |||
def cutout(image, labels): |
@@ -281,6 +281,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non | |||
# Settings | |||
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height | |||
max_det = 300 # maximum number of detections per image | |||
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() | |||
time_limit = 10.0 # seconds to quit after | |||
redundant = True # require redundant detections | |||
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) | |||
@@ -328,13 +329,12 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non | |||
# if not torch.isfinite(x).all(): | |||
# x = x[torch.isfinite(x).all(1)] | |||
# If none remain process next image | |||
# Check shape | |||
n = x.shape[0] # number of boxes | |||
if not n: | |||
if not n: # no boxes | |||
continue | |||
# Sort by confidence | |||
# x = x[x[:, 4].argsort(descending=True)] | |||
elif n > max_nms: # excess boxes | |||
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence | |||
# Batched NMS | |||
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes | |||
@@ -352,6 +352,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non | |||
output[xi] = x[i] | |||
if (time.time() - t) > time_limit: | |||
print(f'WARNING: NMS time limit {time_limit}s exceeded') | |||
break # time limit exceeded | |||
return output |
@@ -6,6 +6,7 @@ import subprocess | |||
import time | |||
from pathlib import Path | |||
import requests | |||
import torch | |||
@@ -21,21 +22,14 @@ def attempt_download(weights): | |||
file = Path(weights).name.lower() | |||
msg = weights + ' missing, try downloading from https://github.com/ultralytics/yolov5/releases/' | |||
models = ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt'] # available models | |||
redundant = False # offer second download option | |||
if file in models and not os.path.isfile(weights): | |||
# Google Drive | |||
# d = {'yolov5s.pt': '1R5T6rIyy3lLwgFXNms8whc-387H0tMQO', | |||
# 'yolov5m.pt': '1vobuEExpWQVpXExsJ2w-Mbf3HJjWkQJr', | |||
# 'yolov5l.pt': '1hrlqD1Wdei7UT4OgT785BEk1JwnSvNEV', | |||
# 'yolov5x.pt': '1mM8aZJlWTxOg7BZJvNUMrTnA2AbeCVzS'} | |||
# r = gdrive_download(id=d[file], name=weights) if file in d else 1 | |||
# if r == 0 and os.path.exists(weights) and os.path.getsize(weights) > 1E6: # check | |||
# return | |||
response = requests.get('https://api.github.com/repos/ultralytics/yolov5/releases/latest').json() # github api | |||
assets = [x['name'] for x in response['assets']] # release assets, i.e. ['yolov5s.pt', 'yolov5m.pt', ...] | |||
redundant = False # second download option | |||
if file in assets and not os.path.isfile(weights): | |||
try: # GitHub | |||
url = 'https://github.com/ultralytics/yolov5/releases/download/v3.1/' + file | |||
tag = response['tag_name'] # i.e. 'v1.0' | |||
url = f'https://github.com/ultralytics/yolov5/releases/download/{tag}/{file}' | |||
print('Downloading %s to %s...' % (url, weights)) | |||
torch.hub.download_url_to_file(url, weights) | |||
assert os.path.exists(weights) and os.path.getsize(weights) > 1E6 # check | |||
@@ -53,10 +47,9 @@ def attempt_download(weights): | |||
return | |||
def gdrive_download(id='1uH2BylpFxHKEGXKL6wJJlsgMU2YEjxuc', name='tmp.zip'): | |||
# Downloads a file from Google Drive. from utils.google_utils import *; gdrive_download() | |||
def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', name='tmp.zip'): | |||
# Downloads a file from Google Drive. from yolov5.utils.google_utils import *; gdrive_download() | |||
t = time.time() | |||
print('Downloading https://drive.google.com/uc?export=download&id=%s as %s... ' % (id, name), end='') | |||
os.remove(name) if os.path.exists(name) else None # remove existing | |||
os.remove('cookie') if os.path.exists('cookie') else None |
@@ -106,7 +106,7 @@ def compute_loss(p, targets, model): # predictions, targets, model | |||
# Losses | |||
nt = 0 # number of targets | |||
no = len(p) # number of outputs | |||
balance = [4.0, 1.0, 0.4] if no == 3 else [4.0, 1.0, 0.4, 0.1] # P3-5 or P3-6 | |||
balance = [4.0, 1.0, 0.3, 0.1, 0.03] # P3-P7 | |||
for i, pi in enumerate(p): # layer index, layer predictions | |||
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx | |||
tobj = torch.zeros_like(pi[..., 0], device=device) # target obj | |||
@@ -140,7 +140,7 @@ def compute_loss(p, targets, model): # predictions, targets, model | |||
s = 3 / no # output count scaling | |||
lbox *= h['box'] * s | |||
lobj *= h['obj'] * s * (1.4 if no == 4 else 1.) | |||
lobj *= h['obj'] | |||
lcls *= h['cls'] * s | |||
bs = tobj.shape[0] # batch size | |||
@@ -223,7 +223,7 @@ def plot_targets_txt(): # from utils.plots import *; plot_targets_txt() | |||
plt.savefig('targets.jpg', dpi=200) | |||
def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt() | |||
def plot_study_txt(path='study/', x=None): # from utils.plots import *; plot_study_txt() | |||
# Plot study.txt generated by test.py | |||
fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True) | |||
ax = ax.ravel() | |||
@@ -246,7 +246,7 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx | |||
ax2.grid() | |||
ax2.set_xlim(0, 30) | |||
ax2.set_ylim(28, 50) | |||
ax2.set_ylim(29, 51) | |||
ax2.set_yticks(np.arange(30, 55, 5)) | |||
ax2.set_xlabel('GPU Speed (ms/img)') | |||
ax2.set_ylabel('COCO AP val') |
@@ -225,8 +225,8 @@ def load_classifier(name='resnet101', n=2): | |||
return model | |||
def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio | |||
# scales img(bs,3,y,x) by ratio | |||
def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416) | |||
# scales img(bs,3,y,x) by ratio constrained to gs-multiple | |||
if ratio == 1.0: | |||
return img | |||
else: | |||
@@ -234,7 +234,6 @@ def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio | |||
s = (int(h * ratio), int(w * ratio)) # new size | |||
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize | |||
if not same_shape: # pad/crop img | |||
gs = 32 # (pixels) grid size | |||
h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)] | |||
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean | |||