Browse Source

Rename `test.py` to `val.py` (#4000)

modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
720aaa65c8
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 77 additions and 78 deletions
  1. +1
    -1
      .github/ISSUE_TEMPLATE/bug-report.md
  2. +3
    -3
      .github/workflows/ci-testing.yml
  3. +1
    -1
      .github/workflows/greetings.yml
  4. +4
    -4
      README.md
  5. +0
    -1
      models/yolo.py
  6. +34
    -34
      train.py
  7. +18
    -18
      tutorial.ipynb
  8. +1
    -1
      utils/augmentations.py
  9. +1
    -1
      utils/general.py
  10. +4
    -4
      utils/plots.py
  11. +10
    -10
      val.py

+ 1
- 1
.github/ISSUE_TEMPLATE/bug-report.md View File

- **Common dataset**: coco.yaml or coco128.yaml - **Common dataset**: coco.yaml or coco128.yaml
- **Common environment**: Colab, Google Cloud, or Docker image. See https://github.com/ultralytics/yolov5#environments - **Common environment**: Colab, Google Cloud, or Docker image. See https://github.com/ultralytics/yolov5#environments
If this is a custom dataset/training question you **must include** your `train*.jpg`, `test*.jpg` and `results.png` figures, or we can not help you. You can generate these with `utils.plot_results()`.
If this is a custom dataset/training question you **must include** your `train*.jpg`, `val*.jpg` and `results.png` figures, or we can not help you. You can generate these with `utils.plot_results()`.




## 🐛 Bug ## 🐛 Bug

+ 3
- 3
.github/workflows/ci-testing.yml View File

# detect # detect
python detect.py --weights ${{ matrix.model }}.pt --device $di python detect.py --weights ${{ matrix.model }}.pt --device $di
python detect.py --weights runs/train/exp/weights/last.pt --device $di python detect.py --weights runs/train/exp/weights/last.pt --device $di
# test
python test.py --img 128 --batch 16 --weights ${{ matrix.model }}.pt --device $di
python test.py --img 128 --batch 16 --weights runs/train/exp/weights/last.pt --device $di
# val
python val.py --img 128 --batch 16 --weights ${{ matrix.model }}.pt --device $di
python val.py --img 128 --batch 16 --weights runs/train/exp/weights/last.pt --device $di


python hubconf.py # hub python hubconf.py # hub
python models/yolo.py --cfg ${{ matrix.model }}.yaml # inspect python models/yolo.py --cfg ${{ matrix.model }}.yaml # inspect

+ 1
- 1
.github/workflows/greetings.yml View File



![CI CPU testing](https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg) ![CI CPU testing](https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg)


If this badge is green, all [YOLOv5 GitHub Actions](https://github.com/ultralytics/yolov5/actions) Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training ([train.py](https://github.com/ultralytics/yolov5/blob/master/train.py)), testing ([test.py](https://github.com/ultralytics/yolov5/blob/master/test.py)), inference ([detect.py](https://github.com/ultralytics/yolov5/blob/master/detect.py)) and export ([export.py](https://github.com/ultralytics/yolov5/blob/master/export.py)) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.
If this badge is green, all [YOLOv5 GitHub Actions](https://github.com/ultralytics/yolov5/actions) Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training ([train.py](https://github.com/ultralytics/yolov5/blob/master/train.py)), testing ([val.py](https://github.com/ultralytics/yolov5/blob/master/val.py)), inference ([detect.py](https://github.com/ultralytics/yolov5/blob/master/detect.py)) and export ([export.py](https://github.com/ultralytics/yolov5/blob/master/export.py)) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.

+ 4
- 4
README.md View File

* GPU Speed measures end-to-end time per image averaged over 5000 COCO val2017 images using a V100 GPU with batch size 32, and includes image preprocessing, PyTorch FP16 inference, postprocessing and NMS. * GPU Speed measures end-to-end time per image averaged over 5000 COCO val2017 images using a V100 GPU with batch size 32, and includes image preprocessing, PyTorch FP16 inference, postprocessing and NMS.
* EfficientDet data from [google/automl](https://github.com/google/automl) at batch size 8. * EfficientDet data from [google/automl](https://github.com/google/automl) at batch size 8.
* **Reproduce** by `python test.py --task study --data coco.yaml --iou 0.7 --weights yolov5s6.pt yolov5m6.pt yolov5l6.pt yolov5x6.pt`
* **Reproduce** by `python val.py --task study --data coco.yaml --iou 0.7 --weights yolov5s6.pt yolov5m6.pt yolov5l6.pt yolov5x6.pt`
</details> </details>




<summary>Table Notes (click to expand)</summary> <summary>Table Notes (click to expand)</summary>
* AP<sup>test</sup> denotes COCO [test-dev2017](http://cocodataset.org/#upload) server results, all other AP results denote val2017 accuracy. * AP<sup>test</sup> denotes COCO [test-dev2017](http://cocodataset.org/#upload) server results, all other AP results denote val2017 accuracy.
* AP values are for single-model single-scale unless otherwise noted. **Reproduce mAP** by `python test.py --data coco.yaml --img 640 --conf 0.001 --iou 0.65`
* Speed<sub>GPU</sub> averaged over 5000 COCO val2017 images using a GCP [n1-standard-16](https://cloud.google.com/compute/docs/machine-types#n1_standard_machine_types) V100 instance, and includes FP16 inference, postprocessing and NMS. **Reproduce speed** by `python test.py --data coco.yaml --img 640 --conf 0.25 --iou 0.45`
* AP values are for single-model single-scale unless otherwise noted. **Reproduce mAP** by `python val.py --data coco.yaml --img 640 --conf 0.001 --iou 0.65`
* Speed<sub>GPU</sub> averaged over 5000 COCO val2017 images using a GCP [n1-standard-16](https://cloud.google.com/compute/docs/machine-types#n1_standard_machine_types) V100 instance, and includes FP16 inference, postprocessing and NMS. **Reproduce speed** by `python val.py --data coco.yaml --img 640 --conf 0.25 --iou 0.45`
* All checkpoints are trained to 300 epochs with default settings and hyperparameters (no autoaugmentation). * All checkpoints are trained to 300 epochs with default settings and hyperparameters (no autoaugmentation).
* Test Time Augmentation ([TTA](https://github.com/ultralytics/yolov5/issues/303)) includes reflection and scale augmentation. **Reproduce TTA** by `python test.py --data coco.yaml --img 1536 --iou 0.7 --augment`
* Test Time Augmentation ([TTA](https://github.com/ultralytics/yolov5/issues/303)) includes reflection and scale augmentation. **Reproduce TTA** by `python val.py --data coco.yaml --img 1536 --iou 0.7 --augment`
</details> </details>





+ 0
- 1
models/yolo.py View File

# tb_writer = SummaryWriter('.') # tb_writer = SummaryWriter('.')
# logger.info("Run 'tensorboard --logdir=models' to view tensorboard at http://localhost:6006/") # logger.info("Run 'tensorboard --logdir=models' to view tensorboard at http://localhost:6006/")
# tb_writer.add_graph(torch.jit.trace(model, img, strict=False), []) # add model graph # tb_writer.add_graph(torch.jit.trace(model, img, strict=False), []) # add model graph
# tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard

+ 34
- 34
train.py View File

FILE = Path(__file__).absolute() FILE = Path(__file__).absolute()
sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path


import test # for end-of-epoch mAP
import val # for end-of-epoch mAP
from models.experimental import attempt_load from models.experimental import attempt_load
from models.yolo import Model from models.yolo import Model
from utils.autoanchor import check_anchors from utils.autoanchor import check_anchors
opt, opt,
device, device,
): ):
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, notest, nosave, workers, = \
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, = \
opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \ opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
opt.resume, opt.notest, opt.nosave, opt.workers
opt.resume, opt.noval, opt.nosave, opt.workers


# Directories # Directories
save_dir = Path(save_dir) save_dir = Path(save_dir)
with torch_distributed_zero_first(RANK): with torch_distributed_zero_first(RANK):
check_dataset(data_dict) # check check_dataset(data_dict) # check
train_path = data_dict['train'] train_path = data_dict['train']
test_path = data_dict['val']
val_path = data_dict['val']


# Freeze # Freeze
freeze = [] # parameter names to freeze (full or partial) freeze = [] # parameter names to freeze (full or partial)
# Image sizes # Image sizes
gs = max(int(model.stride.max()), 32) # grid size (max stride) gs = max(int(model.stride.max()), 32) # grid size (max stride)
nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj']) nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj'])
imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
imgsz, imgsz_val = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples


# DP mode # DP mode
if cuda and RANK == -1 and torch.cuda.device_count() > 1: if cuda and RANK == -1 and torch.cuda.device_count() > 1:


# Process 0 # Process 0
if RANK in [-1, 0]: if RANK in [-1, 0]:
testloader = create_dataloader(test_path, imgsz_test, batch_size // WORLD_SIZE * 2, gs, single_cls,
hyp=hyp, cache=opt.cache_images and not notest, rect=True, rank=-1,
valloader = create_dataloader(val_path, imgsz_val, batch_size // WORLD_SIZE * 2, gs, single_cls,
hyp=hyp, cache=opt.cache_images and not noval, rect=True, rank=-1,
workers=workers, workers=workers,
pad=0.5, prefix=colorstr('val: '))[0] pad=0.5, prefix=colorstr('val: '))[0]


scheduler.last_epoch = start_epoch - 1 # do not move scheduler.last_epoch = start_epoch - 1 # do not move
scaler = amp.GradScaler(enabled=cuda) scaler = amp.GradScaler(enabled=cuda)
compute_loss = ComputeLoss(model) # init loss class compute_loss = ComputeLoss(model) # init loss class
logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
logger.info(f'Image sizes {imgsz} train, {imgsz_val} val\n'
f'Using {dataloader.num_workers} dataloader workers\n' f'Using {dataloader.num_workers} dataloader workers\n'
f'Logging results to {save_dir}\n' f'Logging results to {save_dir}\n'
f'Starting training for {epochs} epochs...') f'Starting training for {epochs} epochs...')
# mAP # mAP
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights']) ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs final_epoch = epoch + 1 == epochs
if not notest or final_epoch: # Calculate mAP
if not noval or final_epoch: # Calculate mAP
wandb_logger.current_epoch = epoch + 1 wandb_logger.current_epoch = epoch + 1
results, maps, _ = test.run(data_dict,
batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz_test,
model=ema.ema,
single_cls=single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=is_coco and final_epoch,
verbose=nc < 50 and final_epoch,
plots=plots and final_epoch,
wandb_logger=wandb_logger,
compute_loss=compute_loss)
results, maps, _ = val.run(data_dict,
batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz_val,
model=ema.ema,
single_cls=single_cls,
dataloader=valloader,
save_dir=save_dir,
save_json=is_coco and final_epoch,
verbose=nc < 50 and final_epoch,
plots=plots and final_epoch,
wandb_logger=wandb_logger,
compute_loss=compute_loss)


# Write # Write
with open(results_file, 'a') as f: with open(results_file, 'a') as f:
if not evolve: if not evolve:
if is_coco: # COCO dataset if is_coco: # COCO dataset
for m in [last, best] if best.exists() else [last]: # speed, mAP tests for m in [last, best] if best.exists() else [last]: # speed, mAP tests
results, _, _ = test.run(data_dict,
batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz_test,
model=attempt_load(m, device).half(),
single_cls=single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=True,
plots=False)
results, _, _ = val.run(data_dict,
batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz_val,
model=attempt_load(m, device).half(),
single_cls=single_cls,
dataloader=valloader,
save_dir=save_dir,
save_json=True,
plots=False)


# Strip optimizers # Strip optimizers
for f in last, best: for f in last, best:
parser.add_argument('--hyp', type=str, default='data/hyps/hyp.scratch.yaml', help='hyperparameters path') parser.add_argument('--hyp', type=str, default='data/hyps/hyp.scratch.yaml', help='hyperparameters path')
parser.add_argument('--epochs', type=int, default=300) parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs') parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, val] image sizes')
parser.add_argument('--rect', action='store_true', help='rectangular training') parser.add_argument('--rect', action='store_true', help='rectangular training')
parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training') parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
parser.add_argument('--notest', action='store_true', help='only test final epoch')
parser.add_argument('--noval', action='store_true', help='only validate final epoch')
parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check') parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations') parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
# opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml') # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified' assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, val)
opt.name = 'evolve' if opt.evolve else opt.name opt.name = 'evolve' if opt.evolve else opt.name
opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok or opt.evolve)) opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok or opt.evolve))


if 'anchors' not in hyp: # anchors commented in hyp.yaml if 'anchors' not in hyp: # anchors commented in hyp.yaml
hyp['anchors'] = 3 hyp['anchors'] = 3
assert LOCAL_RANK == -1, 'DDP mode not implemented for --evolve' assert LOCAL_RANK == -1, 'DDP mode not implemented for --evolve'
opt.notest, opt.nosave = True, True # only test/save final epoch
opt.noval, opt.nosave = True, True # only val/save final epoch
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
if opt.bucket: if opt.bucket:

+ 18
- 18
tutorial.ipynb View File

"id": "0eq1SMWl6Sfn" "id": "0eq1SMWl6Sfn"
}, },
"source": [ "source": [
"# 2. Test\n",
"Test a model's accuracy on [COCO](https://cocodataset.org/#home) val or test-dev datasets. Models are downloaded automatically from the [latest YOLOv5 release](https://github.com/ultralytics/yolov5/releases). To show results by class use the `--verbose` flag. Note that `pycocotools` metrics may be ~1% better than the equivalent repo metrics, as is visible below, due to slight differences in mAP computation."
"# 2. Validate\n",
"Validate a model's accuracy on [COCO](https://cocodataset.org/#home) val or test-dev datasets. Models are downloaded automatically from the [latest YOLOv5 release](https://github.com/ultralytics/yolov5/releases). To show results by class use the `--verbose` flag. Note that `pycocotools` metrics may be ~1% better than the equivalent repo metrics, as is visible below, due to slight differences in mAP computation."
] ]
}, },
{ {
}, },
"source": [ "source": [
"# Run YOLOv5x on COCO val2017\n", "# Run YOLOv5x on COCO val2017\n",
"!python test.py --weights yolov5x.pt --data coco.yaml --img 640 --iou 0.65 --half"
"!python val.py --weights yolov5x.pt --data coco.yaml --img 640 --iou 0.65 --half"
], ],
"execution_count": null, "execution_count": null,
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Namespace(augment=False, batch_size=32, conf_thres=0.001, data='./data/coco.yaml', device='', exist_ok=False, half=True, img_size=640, iou_thres=0.65, name='exp', project='runs/test', save_conf=False, save_hybrid=False, save_json=True, save_txt=False, single_cls=False, task='val', verbose=False, weights=['yolov5x.pt'])\n",
"Namespace(augment=False, batch_size=32, conf_thres=0.001, data='./data/coco.yaml', device='', exist_ok=False, half=True, img_size=640, iou_thres=0.65, name='exp', project='runs/val', save_conf=False, save_hybrid=False, save_json=True, save_txt=False, single_cls=False, task='val', verbose=False, weights=['yolov5x.pt'])\n",
"YOLOv5 🚀 v5.0-157-gc6b51f4 torch 1.8.1+cu101 CUDA:0 (Tesla V100-SXM2-16GB, 16160.5MB)\n", "YOLOv5 🚀 v5.0-157-gc6b51f4 torch 1.8.1+cu101 CUDA:0 (Tesla V100-SXM2-16GB, 16160.5MB)\n",
"\n", "\n",
"Downloading https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5x.pt to yolov5x.pt...\n", "Downloading https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5x.pt to yolov5x.pt...\n",
" all 5000 36335 0.746 0.626 0.68 0.49\n", " all 5000 36335 0.746 0.626 0.68 0.49\n",
"Speed: 5.3/1.5/6.8 ms inference/NMS/total per 640x640 image at batch-size 32\n", "Speed: 5.3/1.5/6.8 ms inference/NMS/total per 640x640 image at batch-size 32\n",
"\n", "\n",
"Evaluating pycocotools mAP... saving runs/test/exp/yolov5x_predictions.json...\n",
"Evaluating pycocotools mAP... saving runs/val/exp/yolov5x_predictions.json...\n",
"loading annotations into memory...\n", "loading annotations into memory...\n",
"Done (t=0.44s)\n", "Done (t=0.44s)\n",
"creating index...\n", "creating index...\n",
" Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.524\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.524\n",
" Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.735\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.735\n",
" Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.827\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.827\n",
"Results saved to runs/test/exp\n"
"Results saved to runs/val/exp\n"
], ],
"name": "stdout" "name": "stdout"
} }
}, },
"source": [ "source": [
"# Run YOLOv5s on COCO test-dev2017 using --task test\n", "# Run YOLOv5s on COCO test-dev2017 using --task test\n",
"!python test.py --weights yolov5s.pt --data coco.yaml --task test"
"!python val.py --weights yolov5s.pt --data coco.yaml --task test"
], ],
"execution_count": null, "execution_count": null,
"outputs": [] "outputs": []
"Plotting labels... \n", "Plotting labels... \n",
"\n", "\n",
"\u001b[34m\u001b[1mautoanchor: \u001b[0mAnalyzing anchors... anchors/target = 4.26, Best Possible Recall (BPR) = 0.9946\n", "\u001b[34m\u001b[1mautoanchor: \u001b[0mAnalyzing anchors... anchors/target = 4.26, Best Possible Recall (BPR) = 0.9946\n",
"Image sizes 640 train, 640 test\n",
"Image sizes 640 train, 640 val\n",
"Using 2 dataloader workers\n", "Using 2 dataloader workers\n",
"Logging results to runs/train/exp\n", "Logging results to runs/train/exp\n",
"Starting training for 3 epochs...\n", "Starting training for 3 epochs...\n",
"source": [ "source": [
"## Local Logging\n", "## Local Logging\n",
"\n", "\n",
"All results are logged by default to `runs/train`, with a new experiment directory created for each new training as `runs/train/exp2`, `runs/train/exp3`, etc. View train and test jpgs to see mosaics, labels, predictions and augmentation effects. Note a **Mosaic Dataloader** is used for training (shown below), a new concept developed by Ultralytics and first featured in [YOLOv4](https://arxiv.org/abs/2004.10934)."
"All results are logged by default to `runs/train`, with a new experiment directory created for each new training as `runs/train/exp2`, `runs/train/exp3`, etc. View train and val jpgs to see mosaics, labels, predictions and augmentation effects. Note a **Mosaic Dataloader** is used for training (shown below), a new concept developed by Ultralytics and first featured in [YOLOv4](https://arxiv.org/abs/2004.10934)."
] ]
}, },
{ {
}, },
"source": [ "source": [
"Image(filename='runs/train/exp/train_batch0.jpg', width=800) # train batch 0 mosaics and labels\n", "Image(filename='runs/train/exp/train_batch0.jpg', width=800) # train batch 0 mosaics and labels\n",
"Image(filename='runs/train/exp/test_batch0_labels.jpg', width=800) # test batch 0 labels\n",
"Image(filename='runs/train/exp/test_batch0_pred.jpg', width=800) # test batch 0 predictions"
"Image(filename='runs/train/exp/test_batch0_labels.jpg', width=800) # val batch 0 labels\n",
"Image(filename='runs/train/exp/test_batch0_pred.jpg', width=800) # val batch 0 predictions"
], ],
"execution_count": null, "execution_count": null,
"outputs": [] "outputs": []
"`train_batch0.jpg` shows train batch 0 mosaics and labels\n", "`train_batch0.jpg` shows train batch 0 mosaics and labels\n",
"\n", "\n",
"> <img src=\"https://user-images.githubusercontent.com/26833433/124931217-4826f080-e002-11eb-87b9-ae0925a8c94b.jpg\" width=\"700\"> \n", "> <img src=\"https://user-images.githubusercontent.com/26833433/124931217-4826f080-e002-11eb-87b9-ae0925a8c94b.jpg\" width=\"700\"> \n",
"`test_batch0_labels.jpg` shows test batch 0 labels\n",
"`test_batch0_labels.jpg` shows val batch 0 labels\n",
"\n", "\n",
"> <img src=\"https://user-images.githubusercontent.com/26833433/124931209-46f5c380-e002-11eb-9bd5-7a3de2be9851.jpg\" width=\"700\"> \n", "> <img src=\"https://user-images.githubusercontent.com/26833433/124931209-46f5c380-e002-11eb-9bd5-7a3de2be9851.jpg\" width=\"700\"> \n",
"`test_batch0_pred.jpg` shows test batch 0 _predictions_"
"`test_batch0_pred.jpg` shows val batch 0 _predictions_"
] ]
}, },
{ {
"\n", "\n",
"![CI CPU testing](https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg)\n", "![CI CPU testing](https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg)\n",
"\n", "\n",
"If this badge is green, all [YOLOv5 GitHub Actions](https://github.com/ultralytics/yolov5/actions) Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training ([train.py](https://github.com/ultralytics/yolov5/blob/master/train.py)), testing ([test.py](https://github.com/ultralytics/yolov5/blob/master/test.py)), inference ([detect.py](https://github.com/ultralytics/yolov5/blob/master/detect.py)) and export ([export.py](https://github.com/ultralytics/yolov5/blob/master/export.py)) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.\n"
"If this badge is green, all [YOLOv5 GitHub Actions](https://github.com/ultralytics/yolov5/actions) Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training ([train.py](https://github.com/ultralytics/yolov5/blob/master/train.py)), testing ([val.py](https://github.com/ultralytics/yolov5/blob/master/val.py)), inference ([detect.py](https://github.com/ultralytics/yolov5/blob/master/detect.py)) and export ([export.py](https://github.com/ultralytics/yolov5/blob/master/export.py)) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.\n"
] ]
}, },
{ {
"source": [ "source": [
"# Reproduce\n", "# Reproduce\n",
"for x in 'yolov5s', 'yolov5m', 'yolov5l', 'yolov5x':\n", "for x in 'yolov5s', 'yolov5m', 'yolov5l', 'yolov5x':\n",
" !python test.py --weights {x}.pt --data coco.yaml --img 640 --conf 0.25 --iou 0.45 # speed\n",
" !python test.py --weights {x}.pt --data coco.yaml --img 640 --conf 0.001 --iou 0.65 # mAP"
" !python val.py --weights {x}.pt --data coco.yaml --img 640 --conf 0.25 --iou 0.45 # speed\n",
" !python val.py --weights {x}.pt --data coco.yaml --img 640 --conf 0.001 --iou 0.65 # mAP"
], ],
"execution_count": null, "execution_count": null,
"outputs": [] "outputs": []
" for d in 0 cpu; do # devices\n", " for d in 0 cpu; do # devices\n",
" python detect.py --weights $m.pt --device $d # detect official\n", " python detect.py --weights $m.pt --device $d # detect official\n",
" python detect.py --weights runs/train/exp/weights/best.pt --device $d # detect custom\n", " python detect.py --weights runs/train/exp/weights/best.pt --device $d # detect custom\n",
" python test.py --weights $m.pt --device $d # test official\n",
" python test.py --weights runs/train/exp/weights/best.pt --device $d # test custom\n",
" python val.py --weights $m.pt --device $d # val official\n",
" python val.py --weights runs/train/exp/weights/best.pt --device $d # val custom\n",
" done\n", " done\n",
" python hubconf.py # hub\n", " python hubconf.py # hub\n",
" python models/yolo.py --cfg $m.yaml # inspect\n", " python models/yolo.py --cfg $m.yaml # inspect\n",

+ 1
- 1
utils/augmentations.py View File



# Scale ratio (new / old) # Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup: # only scale down, do not scale up (for better test mAP)
if not scaleup: # only scale down, do not scale up (for better val mAP)
r = min(r, 1.0) r = min(r, 1.0)


# Compute padding # Compute padding

+ 1
- 1
utils/general.py View File

for j, a in enumerate(d): # per item for j, a in enumerate(d): # per item
cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])] cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
im = cv2.resize(cutout, (224, 224)) # BGR im = cv2.resize(cutout, (224, 224)) # BGR
# cv2.imwrite('test%i.jpg' % j, cutout)
# cv2.imwrite('example%i.jpg' % j, cutout)


im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32 im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32

+ 4
- 4
utils/plots.py View File

plt.close() plt.close()




def plot_test_txt(): # from utils.plots import *; plot_test()
# Plot test.txt histograms
x = np.loadtxt('test.txt', dtype=np.float32)
def plot_val_txt(): # from utils.plots import *; plot_val()
# Plot val.txt histograms
x = np.loadtxt('val.txt', dtype=np.float32)
box = xyxy2xywh(x[:, :4]) box = xyxy2xywh(x[:, :4])
cx, cy = box[:, 0], box[:, 1] cx, cy = box[:, 0], box[:, 1]






def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt() def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt()
# Plot study.txt generated by test.py
# Plot study.txt generated by val.py
plot2 = False # plot additional results plot2 = False # plot additional results
if plot2: if plot2:
ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)[1].ravel() ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)[1].ravel()

test.py → val.py View File

"""Test a trained YOLOv5 model accuracy on a custom dataset
"""Validate a trained YOLOv5 model accuracy on a custom dataset


Usage: Usage:
$ python path/to/test.py --data coco128.yaml --weights yolov5s.pt --img 640
$ python path/to/val.py --data coco128.yaml --weights yolov5s.pt --img 640
""" """


import argparse import argparse
save_hybrid=False, # save label+prediction hybrid results to *.txt save_hybrid=False, # save label+prediction hybrid results to *.txt
save_conf=False, # save confidences in --save-txt labels save_conf=False, # save confidences in --save-txt labels
save_json=False, # save a cocoapi-compatible JSON results file save_json=False, # save a cocoapi-compatible JSON results file
project='runs/test', # save to project/name
project='runs/val', # save to project/name
name='exp', # save to project/name name='exp', # save to project/name
exist_ok=False, # existing project/name ok, do not increment exist_ok=False, # existing project/name ok, do not increment
half=True, # use FP16 half-precision inference half=True, # use FP16 half-precision inference


# Plot images # Plot images
if plots and batch_i < 3: if plots and batch_i < 3:
f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels
f = save_dir / f'val_batch{batch_i}_labels.jpg' # labels
Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start() Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start()
f = save_dir / f'test_batch{batch_i}_pred.jpg' # predictions
f = save_dir / f'val_batch{batch_i}_pred.jpg' # predictions
Thread(target=plot_images, args=(img, output_to_target(out), paths, f, names), daemon=True).start() Thread(target=plot_images, args=(img, output_to_target(out), paths, f, names), daemon=True).start()


# Compute statistics # Compute statistics
if plots: if plots:
confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
if wandb_logger and wandb_logger.wandb: if wandb_logger and wandb_logger.wandb:
val_batches = [wandb_logger.wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]
val_batches = [wandb_logger.wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('val*.jpg'))]
wandb_logger.log({"Validation": val_batches}) wandb_logger.log({"Validation": val_batches})
if wandb_images: if wandb_images:
wandb_logger.log({"Bounding Box Debugger/Images": wandb_images}) wandb_logger.log({"Bounding Box Debugger/Images": wandb_images})




def parse_opt(): def parse_opt():
parser = argparse.ArgumentParser(prog='test.py')
parser = argparse.ArgumentParser(prog='val.py')
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='dataset.yaml path') parser.add_argument('--data', type=str, default='data/coco128.yaml', help='dataset.yaml path')
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)') parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
parser.add_argument('--batch-size', type=int, default=32, help='batch size') parser.add_argument('--batch-size', type=int, default=32, help='batch size')
parser.add_argument('--save-hybrid', action='store_true', help='save label+prediction hybrid results to *.txt') parser.add_argument('--save-hybrid', action='store_true', help='save label+prediction hybrid results to *.txt')
parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
parser.add_argument('--save-json', action='store_true', help='save a cocoapi-compatible JSON results file') parser.add_argument('--save-json', action='store_true', help='save a cocoapi-compatible JSON results file')
parser.add_argument('--project', default='runs/test', help='save to project/name')
parser.add_argument('--project', default='runs/val', help='save to project/name')
parser.add_argument('--name', default='exp', help='save to project/name') parser.add_argument('--name', default='exp', help='save to project/name')
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')


def main(opt): def main(opt):
set_logging() set_logging()
print(colorstr('test: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
print(colorstr('val: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
check_requirements(exclude=('tensorboard', 'thop')) check_requirements(exclude=('tensorboard', 'thop'))


if opt.task in ('train', 'val', 'test'): # run normally if opt.task in ('train', 'val', 'test'): # run normally
save_json=False, plots=False) save_json=False, plots=False)


elif opt.task == 'study': # run over a range of settings and save/plot elif opt.task == 'study': # run over a range of settings and save/plot
# python test.py --task study --data coco.yaml --iou 0.7 --weights yolov5s.pt yolov5m.pt yolov5l.pt yolov5x.pt
# python val.py --task study --data coco.yaml --iou 0.7 --weights yolov5s.pt yolov5m.pt yolov5l.pt yolov5x.pt
x = list(range(256, 1536 + 128, 128)) # x axis (image sizes) x = list(range(256, 1536 + 128, 128)) # x axis (image sizes)
for w in opt.weights if isinstance(opt.weights, list) else [opt.weights]: for w in opt.weights if isinstance(opt.weights, list) else [opt.weights]:
f = f'study_{Path(opt.data).stem}_{Path(w).stem}.txt' # filename to save to f = f'study_{Path(opt.data).stem}_{Path(w).stem}.txt' # filename to save to

Loading…
Cancel
Save