@@ -0,0 +1,13 @@ | |||
--- | |||
name: "❓Question" | |||
about: Ask a general question | |||
title: '' | |||
labels: question | |||
assignees: '' | |||
--- | |||
## ❔Question | |||
## Additional context |
@@ -41,9 +41,13 @@ $ pip install -U -r requirements.txt | |||
## Tutorials | |||
* [Notebook](https://github.com/ultralytics/yolov5/blob/master/tutorial.ipynb) <a href="https://colab.research.google.com/github/ultralytics/yolov5/blob/master/tutorial.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a> | |||
* [Kaggle](https://www.kaggle.com/ultralytics/yolov5-tutorial) | |||
* [Train Custom Data](https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data) | |||
* [Google Cloud Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/GCP-Quickstart) | |||
* [Docker Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/Docker-Quickstart) ![Docker Pulls](https://img.shields.io/docker/pulls/ultralytics/yolov5?logo=docker) | |||
* [PyTorch Hub](https://github.com/ultralytics/yolov5/issues/36) | |||
* [ONNX and TorchScript Export](https://github.com/ultralytics/yolov5/issues/251) | |||
* [Test-Time Augmentation (TTA)](https://github.com/ultralytics/yolov5/issues/303) | |||
* [Google Cloud Quickstart](https://github.com/ultralytics/yolov5/wiki/GCP-Quickstart) | |||
* [Docker Quickstart](https://github.com/ultralytics/yolov5/wiki/Docker-Quickstart) ![Docker Pulls](https://img.shields.io/docker/pulls/ultralytics/yolov5?logo=docker) | |||
## Inference |
@@ -2,7 +2,7 @@ import argparse | |||
import torch.backends.cudnn as cudnn | |||
from utils import google_utils | |||
from models.experimental import * | |||
from utils.datasets import * | |||
from utils.utils import * | |||
@@ -20,12 +20,8 @@ def detect(save_img=False): | |||
half = device.type != 'cpu' # half precision only supported on CUDA | |||
# Load model | |||
google_utils.attempt_download(weights) | |||
model = torch.load(weights, map_location=device)['model'].float() # load to FP32 | |||
# torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning | |||
# model.fuse() | |||
model.to(device).eval() | |||
imgsz = check_img_size(imgsz, s=model.model[-1].stride.max()) # check img_size | |||
model = attempt_load(weights, map_location=device) # load FP32 model | |||
imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size | |||
if half: | |||
model.half() # to FP16 | |||
@@ -123,10 +119,11 @@ def detect(save_img=False): | |||
if isinstance(vid_writer, cv2.VideoWriter): | |||
vid_writer.release() # release previous video writer | |||
fourcc = 'mp4v' # output video codec | |||
fps = vid_cap.get(cv2.CAP_PROP_FPS) | |||
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |||
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |||
vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*opt.fourcc), fps, (w, h)) | |||
vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h)) | |||
vid_writer.write(im0) | |||
if save_txt or save_img: | |||
@@ -139,26 +136,26 @@ def detect(save_img=False): | |||
if __name__ == '__main__': | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--weights', type=str, default='weights/yolov5s.pt', help='model.pt path') | |||
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)') | |||
parser.add_argument('--source', type=str, default='inference/images', help='source') # file/folder, 0 for webcam | |||
parser.add_argument('--output', type=str, default='inference/output', help='output folder') # output folder | |||
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') | |||
parser.add_argument('--conf-thres', type=float, default=0.4, help='object confidence threshold') | |||
parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS') | |||
parser.add_argument('--fourcc', type=str, default='mp4v', help='output video codec (verify ffmpeg support)') | |||
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') | |||
parser.add_argument('--view-img', action='store_true', help='display results') | |||
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') | |||
parser.add_argument('--classes', nargs='+', type=int, help='filter by class') | |||
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') | |||
parser.add_argument('--augment', action='store_true', help='augmented inference') | |||
parser.add_argument('--update', action='store_true', help='update all models') | |||
opt = parser.parse_args() | |||
print(opt) | |||
with torch.no_grad(): | |||
detect() | |||
# # Update all models | |||
# for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt', 'yolov3-spp.pt']: | |||
# detect() | |||
# create_pretrained(opt.weights, opt.weights) | |||
if opt.update: # update all models (to fix SourceChangeWarning) | |||
for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt', 'yolov3-spp.pt']: | |||
detect() | |||
create_pretrained(opt.weights, opt.weights) | |||
else: | |||
detect() |
@@ -1,6 +1,7 @@ | |||
# This file contains experimental modules | |||
from models.common import * | |||
from utils import google_utils | |||
class CrossConv(nn.Module): | |||
@@ -107,3 +108,34 @@ class MixConv2d(nn.Module): | |||
def forward(self, x): | |||
return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) | |||
class Ensemble(nn.ModuleList): | |||
# Ensemble of models | |||
def __init__(self): | |||
super(Ensemble, self).__init__() | |||
def forward(self, x, augment=False): | |||
y = [] | |||
for module in self: | |||
y.append(module(x, augment)[0]) | |||
# y = torch.stack(y).max(0)[0] # max ensemble | |||
# y = torch.cat(y, 1) # nms ensemble | |||
y = torch.stack(y).mean(0) # mean ensemble | |||
return y, None # inference, train output | |||
def attempt_load(weights, map_location=None): | |||
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a | |||
model = Ensemble() | |||
for w in weights if isinstance(weights, list) else [weights]: | |||
google_utils.attempt_download(w) | |||
model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model | |||
if len(model) == 1: | |||
return model[-1] # return model | |||
else: | |||
print('Ensemble created with %s\n' % weights) | |||
for k in ['names', 'stride']: | |||
setattr(model, k, getattr(model[-1], k)) | |||
return model # return ensemble |
@@ -61,7 +61,8 @@ if __name__ == '__main__': | |||
import coremltools as ct | |||
print('\nStarting CoreML export with coremltools %s...' % ct.__version__) | |||
model = ct.convert(ts, inputs=[ct.ImageType(name='images', shape=img.shape)]) # convert | |||
# convert model from torchscript and apply pixel scaling as per detect.py | |||
model = ct.convert(ts, inputs=[ct.ImageType(name='images', shape=img.shape, scale=1/255.0, bias=[0, 0, 0])]) | |||
f = opt.weights.replace('.pt', '.mlmodel') # filename | |||
model.save(f) | |||
print('CoreML export success, saved as %s' % f) |
@@ -48,6 +48,7 @@ class Model(nn.Module): | |||
if type(model_cfg) is dict: | |||
self.md = model_cfg # model dict | |||
else: # is *.yaml | |||
import yaml # for torch hub | |||
with open(model_cfg) as f: | |||
self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict | |||
@@ -141,14 +142,14 @@ class Model(nn.Module): | |||
# print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights | |||
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers | |||
print('Fusing layers...') | |||
print('Fusing layers... ', end='') | |||
for m in self.model.modules(): | |||
if type(m) is Conv: | |||
m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv | |||
m.bn = None # remove batchnorm | |||
m.forward = m.fuseforward # update forward | |||
torch_utils.model_info(self) | |||
return self | |||
def parse_model(md, ch): # model_dict, input_channels(3) | |||
print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments')) |
@@ -2,7 +2,7 @@ | |||
Cython | |||
numpy==1.17 | |||
opencv-python | |||
torch>=1.4 | |||
torch>=1.5.1 | |||
matplotlib | |||
pillow | |||
tensorboard |
@@ -1,9 +1,8 @@ | |||
import argparse | |||
import json | |||
from utils import google_utils | |||
from models.experimental import * | |||
from utils.datasets import * | |||
from utils.utils import * | |||
def test(data, | |||
@@ -18,32 +17,29 @@ def test(data, | |||
verbose=False, | |||
model=None, | |||
dataloader=None, | |||
save_dir='', | |||
merge=False): | |||
# Initialize/load model and set device | |||
if model is None: | |||
training = False | |||
training = model is not None | |||
if training: # called by train.py | |||
device = next(model.parameters()).device # get model device | |||
else: # called directly | |||
device = torch_utils.select_device(opt.device, batch_size=batch_size) | |||
merge = opt.merge # use Merge NMS | |||
# Remove previous | |||
for f in glob.glob('test_batch*.jpg'): | |||
for f in glob.glob(str(Path(save_dir) / 'test_batch*.jpg')): | |||
os.remove(f) | |||
# Load model | |||
google_utils.attempt_download(weights) | |||
model = torch.load(weights, map_location=device)['model'].float() # load to FP32 | |||
torch_utils.model_info(model) | |||
model.fuse() | |||
model.to(device) | |||
imgsz = check_img_size(imgsz, s=model.model[-1].stride.max()) # check img_size | |||
model = attempt_load(weights, map_location=device) # load FP32 model | |||
imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size | |||
# Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99 | |||
# if device.type != 'cpu' and torch.cuda.device_count() > 1: | |||
# model = nn.DataParallel(model) | |||
else: # called by train.py | |||
training = True | |||
device = next(model.parameters()).device # get model device | |||
# Half | |||
half = device.type != 'cpu' and torch.cuda.device_count() == 1 # half precision only supported on single-GPU | |||
if half: | |||
@@ -58,12 +54,11 @@ def test(data, | |||
niou = iouv.numel() | |||
# Dataloader | |||
if dataloader is None: # not training | |||
merge = opt.merge # use Merge NMS | |||
if not training: | |||
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img | |||
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once | |||
path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images | |||
dataloader = create_dataloader(path, imgsz, batch_size, int(max(model.stride)), opt, | |||
dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, | |||
hyp=None, augment=False, cache=False, pad=0.5, rect=True)[0] | |||
seen = 0 | |||
@@ -163,10 +158,10 @@ def test(data, | |||
# Plot images | |||
if batch_i < 1: | |||
f = 'test_batch%g_gt.jpg' % batch_i # filename | |||
plot_images(img, targets, paths, f, names) # ground truth | |||
f = 'test_batch%g_pred.jpg' % batch_i | |||
plot_images(img, output_to_target(output, width, height), paths, f, names) # predictions | |||
f = Path(save_dir) / ('test_batch%g_gt.jpg' % batch_i) # filename | |||
plot_images(img, targets, paths, str(f), names) # ground truth | |||
f = Path(save_dir) / ('test_batch%g_pred.jpg' % batch_i) | |||
plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions | |||
# Compute statistics | |||
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy | |||
@@ -196,7 +191,7 @@ def test(data, | |||
if save_json and map50 and len(jdict): | |||
imgIds = [int(Path(x).stem.split('_')[-1]) for x in dataloader.dataset.img_files] | |||
f = 'detections_val2017_%s_results.json' % \ | |||
(weights.split(os.sep)[-1].replace('.pt', '') if weights else '') # filename | |||
(weights.split(os.sep)[-1].replace('.pt', '') if isinstance(weights, str) else '') # filename | |||
print('\nCOCO mAP with pycocotools... saving %s...' % f) | |||
with open(f, 'w') as file: | |||
json.dump(jdict, file) | |||
@@ -229,7 +224,7 @@ def test(data, | |||
if __name__ == '__main__': | |||
parser = argparse.ArgumentParser(prog='test.py') | |||
parser.add_argument('--weights', type=str, default='weights/yolov5s.pt', help='model.pt path') | |||
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)') | |||
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path') | |||
parser.add_argument('--batch-size', type=int, default=32, help='size of each image batch') | |||
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') |
@@ -20,15 +20,10 @@ except: | |||
print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex') | |||
mixed_precision = False # not installed | |||
wdir = 'weights' + os.sep # weights dir | |||
os.makedirs(wdir, exist_ok=True) | |||
last = wdir + 'last.pt' | |||
best = wdir + 'best.pt' | |||
results_file = 'results.txt' | |||
# Hyperparameters | |||
hyp = {'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3) | |||
'momentum': 0.937, # SGD momentum | |||
hyp = {'optimizer': 'SGD', # ['adam', 'SGD', None] if none, default is SGD | |||
'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3) | |||
'momentum': 0.937, # SGD momentum/Adam beta1 | |||
'weight_decay': 5e-4, # optimizer weight decay | |||
'giou': 0.05, # giou loss gain | |||
'cls': 0.58, # cls loss gain | |||
@@ -45,21 +40,24 @@ hyp = {'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3) | |||
'translate': 0.0, # image translation (+/- fraction) | |||
'scale': 0.5, # image scale (+/- gain) | |||
'shear': 0.0} # image shear (+/- deg) | |||
print(hyp) | |||
# Overwrite hyp with hyp*.txt (optional) | |||
f = glob.glob('hyp*.txt') | |||
if f: | |||
print('Using %s' % f[0]) | |||
for k, v in zip(hyp.keys(), np.loadtxt(f[0])): | |||
hyp[k] = v | |||
# Print focal loss if gamma > 0 | |||
if hyp['fl_gamma']: | |||
print('Using FocalLoss(gamma=%g)' % hyp['fl_gamma']) | |||
def train(hyp): | |||
print(f'Hyperparameters {hyp}') | |||
log_dir = tb_writer.log_dir # run directory | |||
wdir = str(Path(log_dir) / 'weights') + os.sep # weights directory | |||
os.makedirs(wdir, exist_ok=True) | |||
last = wdir + 'last.pt' | |||
best = wdir + 'best.pt' | |||
results_file = log_dir + os.sep + 'results.txt' | |||
# Save run settings | |||
with open(Path(log_dir) / 'hyp.yaml', 'w') as f: | |||
yaml.dump(hyp, f, sort_keys=False) | |||
with open(Path(log_dir) / 'opt.yaml', 'w') as f: | |||
yaml.dump(vars(opt), f, sort_keys=False) | |||
def train(hyp): | |||
epochs = opt.epochs # 300 | |||
batch_size = opt.batch_size # 64 | |||
weights = opt.weights # initial training weights | |||
@@ -70,14 +68,15 @@ def train(hyp): | |||
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict | |||
train_path = data_dict['train'] | |||
test_path = data_dict['val'] | |||
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes | |||
nc, names = (1, ['item']) if opt.single_cls else (int(data_dict['nc']), data_dict['names']) # number classes, names | |||
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check | |||
# Remove previous results | |||
for f in glob.glob('*_batch*.jpg') + glob.glob(results_file): | |||
os.remove(f) | |||
# Create model | |||
model = Model(opt.cfg, nc=data_dict['nc']).to(device) | |||
model = Model(opt.cfg, nc=nc).to(device) | |||
# Image sizes | |||
gs = int(max(model.stride)) # grid size (max stride) | |||
@@ -97,15 +96,20 @@ def train(hyp): | |||
else: | |||
pg0.append(v) # all else | |||
optimizer = optim.Adam(pg0, lr=hyp['lr0']) if opt.adam else \ | |||
optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True) | |||
if hyp['optimizer'] == 'adam': # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR | |||
optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum | |||
else: | |||
optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True) | |||
optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay | |||
optimizer.add_param_group({'params': pg2}) # add pg2 (biases) | |||
print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0))) | |||
del pg0, pg1, pg2 | |||
# Scheduler https://arxiv.org/pdf/1812.01187.pdf | |||
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine | |||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) | |||
print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0))) | |||
del pg0, pg1, pg2 | |||
# plot_lr_scheduler(optimizer, scheduler, epochs, save_dir=log_dir) | |||
# Load Model | |||
google_utils.attempt_download(weights) | |||
@@ -147,12 +151,7 @@ def train(hyp): | |||
if mixed_precision: | |||
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0) | |||
scheduler.last_epoch = start_epoch - 1 # do not move | |||
# https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822 | |||
# plot_lr_scheduler(optimizer, scheduler, epochs) | |||
# Initialize distributed training | |||
# Distributed training | |||
if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available(): | |||
dist.init_process_group(backend='nccl', # distributed backend | |||
init_method='tcp://127.0.0.1:9999', # init method | |||
@@ -165,6 +164,7 @@ def train(hyp): | |||
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, | |||
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect) | |||
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class | |||
nb = len(dataloader) # number of batches | |||
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg) | |||
# Testloader | |||
@@ -177,15 +177,15 @@ def train(hyp): | |||
model.hyp = hyp # attach hyperparameters to model | |||
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou) | |||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights | |||
model.names = data_dict['names'] | |||
model.names = names | |||
# Class frequency | |||
labels = np.concatenate(dataset.labels, 0) | |||
c = torch.tensor(labels[:, 0]) # classes | |||
# cf = torch.bincount(c.long(), minlength=nc) + 1. | |||
# model._initialize_biases(cf.to(device)) | |||
plot_labels(labels, save_dir=log_dir) | |||
if tb_writer: | |||
plot_labels(labels) | |||
tb_writer.add_histogram('classes', c, 0) | |||
# Check anchors | |||
@@ -193,14 +193,14 @@ def train(hyp): | |||
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) | |||
# Exponential moving average | |||
ema = torch_utils.ModelEMA(model) | |||
ema = torch_utils.ModelEMA(model, updates=start_epoch * nb / accumulate) | |||
# Start training | |||
t0 = time.time() | |||
nb = len(dataloader) # number of batches | |||
n_burn = max(3 * nb, 1e3) # burn-in iterations, max(3 epochs, 1k iterations) | |||
nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations) | |||
maps = np.zeros(nc) # mAP per class | |||
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification' | |||
scheduler.last_epoch = start_epoch - 1 # do not move | |||
print('Image sizes %g train, %g test' % (imgsz, imgsz_test)) | |||
print('Using %g dataloader workers' % dataloader.num_workers) | |||
print('Starting training for %g epochs...' % epochs) | |||
@@ -225,9 +225,9 @@ def train(hyp): | |||
ni = i + nb * epoch # number integrated batches (since train start) | |||
imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0 | |||
# Burn-in | |||
if ni <= n_burn: | |||
xi = [0, n_burn] # x interp | |||
# Warmup | |||
if ni <= nw: | |||
xi = [0, nw] # x interp | |||
# model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou) | |||
accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round()) | |||
for j, x in enumerate(optimizer.param_groups): | |||
@@ -275,7 +275,7 @@ def train(hyp): | |||
# Plot | |||
if ni < 3: | |||
f = 'train_batch%g.jpg' % ni # filename | |||
f = str(Path(log_dir) / ('train_batch%g.jpg' % ni)) # filename | |||
result = plot_images(images=imgs, targets=targets, paths=paths, fname=f) | |||
if tb_writer and result is not None: | |||
tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) | |||
@@ -296,7 +296,8 @@ def train(hyp): | |||
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'), | |||
model=ema.ema, | |||
single_cls=opt.single_cls, | |||
dataloader=testloader) | |||
dataloader=testloader, | |||
save_dir=log_dir) | |||
# Write | |||
with open(results_file, 'a') as f: | |||
@@ -348,7 +349,7 @@ def train(hyp): | |||
# Finish | |||
if not opt.evolve: | |||
plot_results() # save as results.png | |||
plot_results(save_dir=log_dir) # save as results.png | |||
print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) | |||
dist.destroy_process_group() if device.type != 'cpu' and torch.cuda.device_count() > 1 else None | |||
torch.cuda.empty_cache() | |||
@@ -358,13 +359,15 @@ def train(hyp): | |||
if __name__ == '__main__': | |||
check_git_status() | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml', help='model.yaml path') | |||
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path') | |||
parser.add_argument('--hyp', type=str, default='', help='hyp.yaml path (optional)') | |||
parser.add_argument('--epochs', type=int, default=300) | |||
parser.add_argument('--batch-size', type=int, default=16) | |||
parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml', help='*.cfg path') | |||
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path') | |||
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes') | |||
parser.add_argument('--rect', action='store_true', help='rectangular training') | |||
parser.add_argument('--resume', action='store_true', help='resume training from last.pt') | |||
parser.add_argument('--resume', nargs='?', const='get_last', default=False, | |||
help='resume from given path/to/last.pt, or most recent run if blank.') | |||
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') | |||
parser.add_argument('--notest', action='store_true', help='only test final epoch') | |||
parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check') | |||
@@ -374,13 +377,17 @@ if __name__ == '__main__': | |||
parser.add_argument('--weights', type=str, default='', help='initial weights path') | |||
parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied') | |||
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') | |||
parser.add_argument('--adam', action='store_true', help='use adam optimizer') | |||
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%') | |||
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') | |||
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') | |||
opt = parser.parse_args() | |||
last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run | |||
if last and not opt.weights: | |||
print(f'Resuming training from {last}') | |||
opt.weights = last if opt.resume and not opt.weights else opt.weights | |||
opt.cfg = check_file(opt.cfg) # check file | |||
opt.data = check_file(opt.data) # check file | |||
opt.hyp = check_file(opt.hyp) if opt.hyp else '' # check file | |||
print(opt) | |||
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) | |||
device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size) | |||
@@ -389,8 +396,12 @@ if __name__ == '__main__': | |||
# Train | |||
if not opt.evolve: | |||
tb_writer = SummaryWriter(comment=opt.name) | |||
print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/') | |||
tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name)) | |||
if opt.hyp: # update hyps | |||
with open(opt.hyp) as f: | |||
hyp.update(yaml.load(f, Loader=yaml.FullLoader)) | |||
train(hyp) | |||
# Evolve hyperparameters (optional) |
@@ -26,6 +26,11 @@ for orientation in ExifTags.TAGS.keys(): | |||
break | |||
def get_hash(files): | |||
# Returns a single hash value of a list of files | |||
return sum(os.path.getsize(f) for f in files) | |||
def exif_size(img): | |||
# Returns exif-corrected PIL size | |||
s = img.size # (width, height) | |||
@@ -48,7 +53,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa | |||
rect=rect, # rectangular training | |||
cache_images=cache, | |||
single_cls=opt.single_cls, | |||
stride=stride, | |||
stride=int(stride), | |||
pad=pad) | |||
batch_size = min(batch_size, len(dataset)) | |||
@@ -280,19 +285,21 @@ class LoadImagesAndLabels(Dataset): # for training/testing | |||
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False, | |||
cache_images=False, single_cls=False, stride=32, pad=0.0): | |||
try: | |||
path = str(Path(path)) # os-agnostic | |||
parent = str(Path(path).parent) + os.sep | |||
if os.path.isfile(path): # file | |||
with open(path, 'r') as f: | |||
f = f.read().splitlines() | |||
f = [x.replace('./', parent) if x.startswith('./') else x for x in f] # local to global path | |||
elif os.path.isdir(path): # folder | |||
f = glob.iglob(path + os.sep + '*.*') | |||
else: | |||
raise Exception('%s does not exist' % path) | |||
f = [] # image files | |||
for p in path if isinstance(path, list) else [path]: | |||
p = str(Path(p)) # os-agnostic | |||
parent = str(Path(p).parent) + os.sep | |||
if os.path.isfile(p): # file | |||
with open(p, 'r') as t: | |||
t = t.read().splitlines() | |||
f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path | |||
elif os.path.isdir(p): # folder | |||
f += glob.iglob(p + os.sep + '*.*') | |||
else: | |||
raise Exception('%s does not exist' % p) | |||
self.img_files = [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats] | |||
except: | |||
raise Exception('Error loading data from %s. See %s' % (path, help_url)) | |||
except Exception as e: | |||
raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url)) | |||
n = len(self.img_files) | |||
assert n > 0, 'No images found in %s. See %s' % (path, help_url) | |||
@@ -311,20 +318,22 @@ class LoadImagesAndLabels(Dataset): # for training/testing | |||
self.stride = stride | |||
# Define labels | |||
self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt') | |||
for x in self.img_files] | |||
# Read image shapes (wh) | |||
sp = path.replace('.txt', '') + '.shapes' # shapefile path | |||
try: | |||
with open(sp, 'r') as f: # read existing shapefile | |||
s = [x.split() for x in f.read().splitlines()] | |||
assert len(s) == n, 'Shapefile out of sync' | |||
except: | |||
s = [exif_size(Image.open(f)) for f in tqdm(self.img_files, desc='Reading image shapes')] | |||
np.savetxt(sp, s, fmt='%g') # overwrites existing (if any) | |||
self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt') for x in | |||
self.img_files] | |||
# Check cache | |||
cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels | |||
if os.path.isfile(cache_path): | |||
cache = torch.load(cache_path) # load | |||
if cache['hash'] != get_hash(self.label_files + self.img_files): # dataset changed | |||
cache = self.cache_labels(cache_path) # re-cache | |||
else: | |||
cache = self.cache_labels(cache_path) # cache | |||
self.shapes = np.array(s, dtype=np.float64) | |||
# Get labels | |||
labels, shapes = zip(*[cache[x] for x in self.img_files]) | |||
self.shapes = np.array(shapes, dtype=np.float64) | |||
self.labels = list(labels) | |||
# Rectangular Training https://github.com/ultralytics/yolov3/issues/232 | |||
if self.rect: | |||
@@ -350,33 +359,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing | |||
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride | |||
# Cache labels | |||
self.imgs = [None] * n | |||
self.labels = [np.zeros((0, 5), dtype=np.float32)] * n | |||
create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False | |||
nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate | |||
np_labels_path = str(Path(self.label_files[0]).parent) + '.npy' # saved labels in *.npy file | |||
if os.path.isfile(np_labels_path): | |||
s = np_labels_path # print string | |||
x = np.load(np_labels_path, allow_pickle=True) | |||
if len(x) == n: | |||
self.labels = x | |||
labels_loaded = True | |||
else: | |||
s = path.replace('images', 'labels') | |||
pbar = tqdm(self.label_files) | |||
for i, file in enumerate(pbar): | |||
if labels_loaded: | |||
l = self.labels[i] | |||
# np.savetxt(file, l, '%g') # save *.txt from *.npy file | |||
else: | |||
try: | |||
with open(file, 'r') as f: | |||
l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) | |||
except: | |||
nm += 1 # print('missing labels for image %s' % self.img_files[i]) # file missing | |||
continue | |||
l = self.labels[i] # label | |||
if l.shape[0]: | |||
assert l.shape[1] == 5, '> 5 label columns: %s' % file | |||
assert (l >= 0).all(), 'negative labels: %s' % file | |||
@@ -422,15 +409,13 @@ class LoadImagesAndLabels(Dataset): # for training/testing | |||
ne += 1 # print('empty labels for image %s' % self.img_files[i]) # file empty | |||
# os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove | |||
pbar.desc = 'Caching labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % ( | |||
s, nf, nm, ne, nd, n) | |||
assert nf > 0 or n == 20288, 'No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url) | |||
if not labels_loaded and n > 1000: | |||
print('Saving labels to %s for faster future loading' % np_labels_path) | |||
np.save(np_labels_path, self.labels) # save for next time | |||
pbar.desc = 'Scanning labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % ( | |||
cache_path, nf, nm, ne, nd, n) | |||
assert nf > 0, 'No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url) | |||
# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM) | |||
if cache_images: # if training | |||
self.imgs = [None] * n | |||
if cache_images: | |||
gb = 0 # Gigabytes of cached images | |||
pbar = tqdm(range(len(self.img_files)), desc='Caching images') | |||
self.img_hw0, self.img_hw = [None] * n, [None] * n | |||
@@ -439,15 +424,30 @@ class LoadImagesAndLabels(Dataset): # for training/testing | |||
gb += self.imgs[i].nbytes | |||
pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9) | |||
# Detect corrupted images https://medium.com/joelthchao/programmatically-detect-corrupted-image-8c1b2006c3d3 | |||
detect_corrupted_images = False | |||
if detect_corrupted_images: | |||
from skimage import io # conda install -c conda-forge scikit-image | |||
for file in tqdm(self.img_files, desc='Detecting corrupted images'): | |||
try: | |||
_ = io.imread(file) | |||
except: | |||
print('Corrupted image detected: %s' % file) | |||
def cache_labels(self, path='labels.cache'): | |||
# Cache dataset labels, check images and read shapes | |||
x = {} # dict | |||
pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files)) | |||
for (img, label) in pbar: | |||
try: | |||
l = [] | |||
image = Image.open(img) | |||
image.verify() # PIL verify | |||
# _ = io.imread(img) # skimage verify (from skimage import io) | |||
shape = exif_size(image) # image size | |||
if os.path.isfile(label): | |||
with open(label, 'r') as f: | |||
l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) # labels | |||
if len(l) == 0: | |||
l = np.zeros((0, 5), dtype=np.float32) | |||
x[img] = [l, shape] | |||
except Exception as e: | |||
x[img] = None | |||
print('WARNING: %s: %s' % (img, e)) | |||
x['hash'] = get_hash(self.label_files + self.img_files) | |||
torch.save(x, path) # save for next time | |||
return x | |||
def __len__(self): | |||
return len(self.img_files) | |||
@@ -679,8 +679,8 @@ def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale | |||
dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding | |||
elif scaleFill: # stretch | |||
dw, dh = 0.0, 0.0 | |||
new_unpad = new_shape | |||
ratio = new_shape[0] / shape[1], new_shape[1] / shape[0] # width, height ratios | |||
new_unpad = (new_shape[1], new_shape[0]) | |||
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios | |||
dw /= 2 # divide padding into 2 sides | |||
dh /= 2 |
@@ -76,16 +76,36 @@ def find_modules(model, mclass=nn.Conv2d): | |||
return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)] | |||
def sparsity(model): | |||
# Return global model sparsity | |||
a, b = 0., 0. | |||
for p in model.parameters(): | |||
a += p.numel() | |||
b += (p == 0).sum() | |||
return b / a | |||
def prune(model, amount=0.3): | |||
# Prune model to requested global sparsity | |||
import torch.nn.utils.prune as prune | |||
print('Pruning model... ', end='') | |||
for name, m in model.named_modules(): | |||
if isinstance(m, nn.Conv2d): | |||
prune.l1_unstructured(m, name='weight', amount=amount) # prune | |||
prune.remove(m, 'weight') # make permanent | |||
print(' %.3g global sparsity' % sparsity(model)) | |||
def fuse_conv_and_bn(conv, bn): | |||
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/ | |||
with torch.no_grad(): | |||
# init | |||
fusedconv = torch.nn.Conv2d(conv.in_channels, | |||
conv.out_channels, | |||
kernel_size=conv.kernel_size, | |||
stride=conv.stride, | |||
padding=conv.padding, | |||
bias=True) | |||
fusedconv = nn.Conv2d(conv.in_channels, | |||
conv.out_channels, | |||
kernel_size=conv.kernel_size, | |||
stride=conv.stride, | |||
padding=conv.padding, | |||
bias=True).to(conv.weight.device) | |||
# prepare filters | |||
w_conv = conv.weight.clone().view(conv.out_channels, -1) | |||
@@ -93,10 +113,7 @@ def fuse_conv_and_bn(conv, bn): | |||
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) | |||
# prepare spatial bias | |||
if conv.bias is not None: | |||
b_conv = conv.bias | |||
else: | |||
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) | |||
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias | |||
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) | |||
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) | |||
@@ -139,8 +156,8 @@ def load_classifier(name='resnet101', n=2): | |||
# Reshape output to n classes | |||
filters = model.fc.weight.shape[1] | |||
model.fc.bias = torch.nn.Parameter(torch.zeros(n), requires_grad=True) | |||
model.fc.weight = torch.nn.Parameter(torch.zeros(n, filters), requires_grad=True) | |||
model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True) | |||
model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True) | |||
model.fc.out_features = n | |||
return model | |||
@@ -174,15 +191,11 @@ class ModelEMA: | |||
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU. | |||
""" | |||
def __init__(self, model, decay=0.9999, device=''): | |||
def __init__(self, model, decay=0.9999, updates=0): | |||
# Create EMA | |||
self.ema = deepcopy(model.module if is_parallel(model) else model) # FP32 EMA | |||
self.ema.eval() | |||
self.updates = 0 # number of EMA updates | |||
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA | |||
self.updates = updates # number of EMA updates | |||
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs) | |||
self.device = device # perform ema on different device from model if set | |||
if device: | |||
self.ema.to(device) | |||
for p in self.ema.parameters(): | |||
p.requires_grad_(False) | |||
@@ -37,6 +37,12 @@ def init_seeds(seed=0): | |||
torch_utils.init_seeds(seed=seed) | |||
def get_latest_run(search_dir='./runs'): | |||
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from) | |||
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) | |||
return max(last_list, key=os.path.getctime) | |||
def check_git_status(): | |||
# Suggest 'git pull' if repo is out of date | |||
if platform in ['linux', 'darwin']: | |||
@@ -173,7 +179,7 @@ def xywh2xyxy(x): | |||
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): | |||
# Rescale coords (xyxy) from img1_shape to img0_shape | |||
if ratio_pad is None: # calculate from img0_shape | |||
gain = max(img1_shape) / max(img0_shape) # gain = old / new | |||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new | |||
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding | |||
else: | |||
gain = ratio_pad[0][0] | |||
@@ -898,6 +904,16 @@ def output_to_target(output, width, height): | |||
return np.array(targets) | |||
def increment_dir(dir, comment=''): | |||
# Increments a directory runs/exp1 --> runs/exp2_comment | |||
n = 0 # number | |||
d = sorted(glob.glob(dir + '*')) # directories | |||
if len(d): | |||
d = d[-1].replace(dir, '') | |||
n = int(d[:d.find('_')] if '_' in d else d) + 1 # increment | |||
return dir + str(n) + ('_' + comment if comment else '') | |||
# Plotting functions --------------------------------------------------------------------------------------------------- | |||
def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5): | |||
# https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy | |||
@@ -1028,7 +1044,7 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max | |||
return mosaic | |||
def plot_lr_scheduler(optimizer, scheduler, epochs=300): | |||
def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''): | |||
# Plot LR simulating training for full epochs | |||
optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals | |||
y = [] | |||
@@ -1042,7 +1058,7 @@ def plot_lr_scheduler(optimizer, scheduler, epochs=300): | |||
plt.xlim(0, epochs) | |||
plt.ylim(0) | |||
plt.tight_layout() | |||
plt.savefig('LR.png', dpi=200) | |||
plt.savefig(Path(save_dir) / 'LR.png', dpi=200) | |||
def plot_test_txt(): # from utils.utils import *; plot_test() | |||
@@ -1107,7 +1123,7 @@ def plot_study_txt(f='study.txt', x=None): # from utils.utils import *; plot_st | |||
plt.savefig(f.replace('.txt', '.png'), dpi=200) | |||
def plot_labels(labels): | |||
def plot_labels(labels, save_dir=''): | |||
# plot dataset labels | |||
c, b = labels[:, 0], labels[:, 1:].transpose() # classees, boxes | |||
@@ -1128,7 +1144,7 @@ def plot_labels(labels): | |||
ax[2].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet') | |||
ax[2].set_xlabel('width') | |||
ax[2].set_ylabel('height') | |||
plt.savefig('labels.png', dpi=200) | |||
plt.savefig(Path(save_dir) / 'labels.png', dpi=200) | |||
plt.close() | |||
@@ -1174,7 +1190,8 @@ def plot_results_overlay(start=0, stop=0): # from utils.utils import *; plot_re | |||
fig.savefig(f.replace('.txt', '.png'), dpi=200) | |||
def plot_results(start=0, stop=0, bucket='', id=(), labels=()): # from utils.utils import *; plot_results() | |||
def plot_results(start=0, stop=0, bucket='', id=(), labels=(), | |||
save_dir=''): # from utils.utils import *; plot_results() | |||
# Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov5#reproduce-our-training | |||
fig, ax = plt.subplots(2, 5, figsize=(12, 6)) | |||
ax = ax.ravel() | |||
@@ -1184,7 +1201,7 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=()): # from utils.ut | |||
os.system('rm -rf storage.googleapis.com') | |||
files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id] | |||
else: | |||
files = glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt') | |||
files = glob.glob(str(Path(save_dir) / 'results*.txt')) + glob.glob('../../Downloads/results*.txt') | |||
for fi, f in enumerate(files): | |||
try: | |||
results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T | |||
@@ -1205,4 +1222,4 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=()): # from utils.ut | |||
fig.tight_layout() | |||
ax[1].legend() | |||
fig.savefig('results.png', dpi=200) | |||
fig.savefig(Path(save_dir) / 'results.png', dpi=200) |