Browse Source

Update 4 main ops for paths and .run() (#3715)

* Add yolov5/ to path

* rename functions to run()

* cleanup

* rename fix

* CI fix

* cleanup find models/export.py
modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
1f69d12591
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 130 additions and 100 deletions
  1. +1
    -1
      .github/workflows/ci-testing.yml
  2. +1
    -1
      .github/workflows/greetings.yml
  3. +35
    -25
      detect.py
  4. +16
    -16
      export.py
  5. +41
    -31
      test.py
  6. +34
    -24
      train.py
  7. +2
    -2
      tutorial.ipynb

+ 1
- 1
.github/workflows/ci-testing.yml View File

@@ -74,5 +74,5 @@ jobs:

python hubconf.py # hub
python models/yolo.py --cfg ${{ matrix.model }}.yaml # inspect
python models/export.py --img 128 --batch 1 --weights ${{ matrix.model }}.pt # export
python export.py --img 128 --batch 1 --weights ${{ matrix.model }}.pt # export
shell: bash

+ 1
- 1
.github/workflows/greetings.yml View File

@@ -52,5 +52,5 @@ jobs:

![CI CPU testing](https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg)

If this badge is green, all [YOLOv5 GitHub Actions](https://github.com/ultralytics/yolov5/actions) Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training ([train.py](https://github.com/ultralytics/yolov5/blob/master/train.py)), testing ([test.py](https://github.com/ultralytics/yolov5/blob/master/test.py)), inference ([detect.py](https://github.com/ultralytics/yolov5/blob/master/detect.py)) and export ([export.py](https://github.com/ultralytics/yolov5/blob/master/models/export.py)) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.
If this badge is green, all [YOLOv5 GitHub Actions](https://github.com/ultralytics/yolov5/actions) Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training ([train.py](https://github.com/ultralytics/yolov5/blob/master/train.py)), testing ([test.py](https://github.com/ultralytics/yolov5/blob/master/test.py)), inference ([detect.py](https://github.com/ultralytics/yolov5/blob/master/detect.py)) and export ([export.py](https://github.com/ultralytics/yolov5/blob/master/export.py)) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.

+ 35
- 25
detect.py View File

@@ -1,4 +1,11 @@
"""Run inference with a YOLOv5 model on images, videos, directories, streams

Usage:
$ python path/to/detect.py --source path/to/img.jpg --weights yolov5s.pt --img 640
"""

import argparse
import sys
import time
from pathlib import Path

@@ -6,6 +13,9 @@ import cv2
import torch
import torch.backends.cudnn as cudnn

FILE = Path(__file__).absolute()
sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path

from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, colorstr, non_max_suppression, \
@@ -15,30 +25,30 @@ from utils.torch_utils import select_device, load_classifier, time_synchronized


@torch.no_grad()
def detect(weights='yolov5s.pt', # model.pt path(s)
source='data/images', # file/dir/URL/glob, 0 for webcam
imgsz=640, # inference size (pixels)
conf_thres=0.25, # confidence threshold
iou_thres=0.45, # NMS IOU threshold
max_det=1000, # maximum detections per image
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
view_img=False, # show results
save_txt=False, # save results to *.txt
save_conf=False, # save confidences in --save-txt labels
save_crop=False, # save cropped prediction boxes
nosave=False, # do not save images/videos
classes=None, # filter by class: --class 0, or --class 0 2 3
agnostic_nms=False, # class-agnostic NMS
augment=False, # augmented inference
update=False, # update all models
project='runs/detect', # save results to project/name
name='exp', # save results to project/name
exist_ok=False, # existing project/name ok, do not increment
line_thickness=3, # bounding box thickness (pixels)
hide_labels=False, # hide labels
hide_conf=False, # hide confidences
half=False, # use FP16 half-precision inference
):
def run(weights='yolov5s.pt', # model.pt path(s)
source='data/images', # file/dir/URL/glob, 0 for webcam
imgsz=640, # inference size (pixels)
conf_thres=0.25, # confidence threshold
iou_thres=0.45, # NMS IOU threshold
max_det=1000, # maximum detections per image
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
view_img=False, # show results
save_txt=False, # save results to *.txt
save_conf=False, # save confidences in --save-txt labels
save_crop=False, # save cropped prediction boxes
nosave=False, # do not save images/videos
classes=None, # filter by class: --class 0, or --class 0 2 3
agnostic_nms=False, # class-agnostic NMS
augment=False, # augmented inference
update=False, # update all models
project='runs/detect', # save results to project/name
name='exp', # save results to project/name
exist_ok=False, # existing project/name ok, do not increment
line_thickness=3, # bounding box thickness (pixels)
hide_labels=False, # hide labels
hide_conf=False, # hide confidences
half=False, # use FP16 half-precision inference
):
save_img = not nosave and not source.endswith('.txt') # save inference images
webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
('rtsp://', 'rtmp://', 'http://', 'https://'))
@@ -204,7 +214,7 @@ def parse_opt():
def main(opt):
print(colorstr('detect: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
check_requirements(exclude=('tensorboard', 'thop'))
detect(**vars(opt))
run(**vars(opt))


if __name__ == "__main__":

models/export.py → export.py View File

@@ -1,7 +1,7 @@
"""Export a YOLOv5 *.pt model to TorchScript, ONNX, CoreML formats

Usage:
$ python path/to/models/export.py --weights yolov5s.pt --img 640 --batch 1
$ python path/to/export.py --weights yolov5s.pt --img 640 --batch 1
"""

import argparse
@@ -14,7 +14,7 @@ import torch.nn as nn
from torch.utils.mobile_optimizer import optimize_for_mobile

FILE = Path(__file__).absolute()
sys.path.append(FILE.parents[1].as_posix()) # add yolov5/ to path
sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path

from models.common import Conv
from models.yolo import Detect
@@ -24,19 +24,19 @@ from utils.general import colorstr, check_img_size, check_requirements, file_siz
from utils.torch_utils import select_device


def export(weights='./yolov5s.pt', # weights path
img_size=(640, 640), # image (height, width)
batch_size=1, # batch size
device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
include=('torchscript', 'onnx', 'coreml'), # include formats
half=False, # FP16 half-precision export
inplace=False, # set YOLOv5 Detect() inplace=True
train=False, # model.train() mode
optimize=False, # TorchScript: optimize for mobile
dynamic=False, # ONNX: dynamic axes
simplify=False, # ONNX: simplify model
opset_version=12, # ONNX: opset version
):
def run(weights='./yolov5s.pt', # weights path
img_size=(640, 640), # image (height, width)
batch_size=1, # batch size
device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
include=('torchscript', 'onnx', 'coreml'), # include formats
half=False, # FP16 half-precision export
inplace=False, # set YOLOv5 Detect() inplace=True
train=False, # model.train() mode
optimize=False, # TorchScript: optimize for mobile
dynamic=False, # ONNX: dynamic axes
simplify=False, # ONNX: simplify model
opset_version=12, # ONNX: opset version
):
t = time.time()
include = [x.lower() for x in include]
img_size *= 2 if len(img_size) == 1 else 1 # expand
@@ -165,7 +165,7 @@ def parse_opt():
def main(opt):
set_logging()
print(colorstr('export: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
export(**vars(opt))
run(**vars(opt))


if __name__ == "__main__":

+ 41
- 31
test.py View File

@@ -1,6 +1,13 @@
"""Test a trained YOLOv5 model accuracy on a custom dataset

Usage:
$ python path/to/test.py --data coco128.yaml --weights yolov5s.pt --img 640
"""

import argparse
import json
import os
import sys
from pathlib import Path
from threading import Thread

@@ -9,6 +16,9 @@ import torch
import yaml
from tqdm import tqdm

FILE = Path(__file__).absolute()
sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path

from models.experimental import attempt_load
from utils.datasets import create_dataloader
from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, check_requirements, \
@@ -19,32 +29,32 @@ from utils.torch_utils import select_device, time_synchronized


@torch.no_grad()
def test(data,
weights=None, # model.pt path(s)
batch_size=32, # batch size
imgsz=640, # inference size (pixels)
conf_thres=0.001, # confidence threshold
iou_thres=0.6, # NMS IoU threshold
task='val', # train, val, test, speed or study
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
single_cls=False, # treat as single-class dataset
augment=False, # augmented inference
verbose=False, # verbose output
save_txt=False, # save results to *.txt
save_hybrid=False, # save label+prediction hybrid results to *.txt
save_conf=False, # save confidences in --save-txt labels
save_json=False, # save a cocoapi-compatible JSON results file
project='runs/test', # save to project/name
name='exp', # save to project/name
exist_ok=False, # existing project/name ok, do not increment
half=True, # use FP16 half-precision inference
model=None,
dataloader=None,
save_dir=Path(''),
plots=True,
wandb_logger=None,
compute_loss=None,
):
def run(data,
weights=None, # model.pt path(s)
batch_size=32, # batch size
imgsz=640, # inference size (pixels)
conf_thres=0.001, # confidence threshold
iou_thres=0.6, # NMS IoU threshold
task='val', # train, val, test, speed or study
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
single_cls=False, # treat as single-class dataset
augment=False, # augmented inference
verbose=False, # verbose output
save_txt=False, # save results to *.txt
save_hybrid=False, # save label+prediction hybrid results to *.txt
save_conf=False, # save confidences in --save-txt labels
save_json=False, # save a cocoapi-compatible JSON results file
project='runs/test', # save to project/name
name='exp', # save to project/name
exist_ok=False, # existing project/name ok, do not increment
half=True, # use FP16 half-precision inference
model=None,
dataloader=None,
save_dir=Path(''),
plots=True,
wandb_logger=None,
compute_loss=None,
):
# Initialize/load model and set device
training = model is not None
if training: # called by train.py
@@ -327,12 +337,12 @@ def main(opt):
check_requirements(exclude=('tensorboard', 'thop'))

if opt.task in ('train', 'val', 'test'): # run normally
test(**vars(opt))
run(**vars(opt))

elif opt.task == 'speed': # speed benchmarks
for w in opt.weights if isinstance(opt.weights, list) else [opt.weights]:
test(opt.data, weights=w, batch_size=opt.batch_size, imgsz=opt.imgsz, conf_thres=.25, iou_thres=.45,
save_json=False, plots=False)
run(opt.data, weights=w, batch_size=opt.batch_size, imgsz=opt.imgsz, conf_thres=.25, iou_thres=.45,
save_json=False, plots=False)

elif opt.task == 'study': # run over a range of settings and save/plot
# python test.py --task study --data coco.yaml --iou 0.7 --weights yolov5s.pt yolov5m.pt yolov5l.pt yolov5x.pt
@@ -342,8 +352,8 @@ def main(opt):
y = [] # y axis
for i in x: # img-size
print(f'\nRunning {f} point {i}...')
r, _, t = test(opt.data, weights=w, batch_size=opt.batch_size, imgsz=i, conf_thres=opt.conf_thres,
iou_thres=opt.iou_thres, save_json=opt.save_json, plots=False)
r, _, t = run(opt.data, weights=w, batch_size=opt.batch_size, imgsz=i, conf_thres=opt.conf_thres,
iou_thres=opt.iou_thres, save_json=opt.save_json, plots=False)
y.append(r + t) # results and times
np.savetxt(f, y, fmt='%10.4g') # save
os.system('zip -r study.zip study_*.txt')

+ 34
- 24
train.py View File

@@ -1,8 +1,15 @@
"""Train a YOLOv5 model on a custom dataset

Usage:
$ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640
"""

import argparse
import logging
import math
import os
import random
import sys
import time
import warnings
from copy import deepcopy
@@ -22,6 +29,9 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

FILE = Path(__file__).absolute()
sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path

import test # for end-of-epoch mAP
from models.experimental import attempt_load
from models.yolo import Model
@@ -89,7 +99,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# W&B
opt.hyp = hyp # add hyperparameters
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
run_id = run_id if opt.resume else None # start fresh run if transfer learning
run_id = run_id if opt.resume else None # start fresh run if transfer learning
wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
loggers['wandb'] = wandb_logger.wandb
if loggers['wandb']:
@@ -375,18 +385,18 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
final_epoch = epoch + 1 == epochs
if not notest or final_epoch: # Calculate mAP
wandb_logger.current_epoch = epoch + 1
results, maps, _ = test.test(data_dict,
batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz_test,
model=ema.ema,
single_cls=single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=is_coco and final_epoch,
verbose=nc < 50 and final_epoch,
plots=plots and final_epoch,
wandb_logger=wandb_logger,
compute_loss=compute_loss)
results, maps, _ = test.run(data_dict,
batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz_test,
model=ema.ema,
single_cls=single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=is_coco and final_epoch,
verbose=nc < 50 and final_epoch,
plots=plots and final_epoch,
wandb_logger=wandb_logger,
compute_loss=compute_loss)

# Write
with open(results_file, 'a') as f:
@@ -443,17 +453,17 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if not evolve:
if is_coco: # COCO dataset
for m in [last, best] if best.exists() else [last]: # speed, mAP tests
results, _, _ = test.test(data,
batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz_test,
conf_thres=0.001,
iou_thres=0.7,
model=attempt_load(m, device).half(),
single_cls=single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=True,
plots=False)
results, _, _ = test.run(data,
batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz_test,
conf_thres=0.001,
iou_thres=0.7,
model=attempt_load(m, device).half(),
single_cls=single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=True,
plots=False)

# Strip optimizers
for f in last, best:

+ 2
- 2
tutorial.ipynb View File

@@ -1125,7 +1125,7 @@
"\n",
"![CI CPU testing](https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg)\n",
"\n",
"If this badge is green, all [YOLOv5 GitHub Actions](https://github.com/ultralytics/yolov5/actions) Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training ([train.py](https://github.com/ultralytics/yolov5/blob/master/train.py)), testing ([test.py](https://github.com/ultralytics/yolov5/blob/master/test.py)), inference ([detect.py](https://github.com/ultralytics/yolov5/blob/master/detect.py)) and export ([export.py](https://github.com/ultralytics/yolov5/blob/master/models/export.py)) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.\n"
"If this badge is green, all [YOLOv5 GitHub Actions](https://github.com/ultralytics/yolov5/actions) Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training ([train.py](https://github.com/ultralytics/yolov5/blob/master/train.py)), testing ([test.py](https://github.com/ultralytics/yolov5/blob/master/test.py)), inference ([detect.py](https://github.com/ultralytics/yolov5/blob/master/detect.py)) and export ([export.py](https://github.com/ultralytics/yolov5/blob/master/export.py)) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.\n"
]
},
{
@@ -1212,7 +1212,7 @@
" done\n",
" python hubconf.py # hub\n",
" python models/yolo.py --cfg $m.yaml # inspect\n",
" python models/export.py --weights $m.pt --img 640 --batch 1 # export\n",
" python export.py --weights $m.pt --img 640 --batch 1 # export\n",
"done"
],
"execution_count": null,

Loading…
Cancel
Save