Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.

644 lines
33KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Train a YOLOv5 model on a custom dataset.
  4. Models and datasets download automatically from the latest YOLOv5 release.
  5. Models: https://github.com/ultralytics/yolov5/tree/master/models
  6. Datasets: https://github.com/ultralytics/yolov5/tree/master/data
  7. Tutorial: https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data
  8. Usage:
  9. $ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640 # from pretrained (RECOMMENDED)
  10. $ python path/to/train.py --data coco128.yaml --weights '' --cfg yolov5s.yaml --img 640 # from scratch
  11. """
  12. import argparse
  13. import math
  14. import os
  15. import random
  16. import sys
  17. import time
  18. from copy import deepcopy
  19. from datetime import datetime
  20. from pathlib import Path
  21. import numpy as np
  22. import torch
  23. import torch.distributed as dist
  24. import torch.nn as nn
  25. import yaml
  26. from torch.cuda import amp
  27. from torch.nn.parallel import DistributedDataParallel as DDP
  28. from torch.optim import SGD, Adam, AdamW, lr_scheduler
  29. from tqdm import tqdm
  30. FILE = Path(__file__).resolve()
  31. ROOT = FILE.parents[0] # YOLOv5 root directory
  32. if str(ROOT) not in sys.path:
  33. sys.path.append(str(ROOT)) # add ROOT to PATH
  34. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  35. import val # for end-of-epoch mAP
  36. from models.experimental import attempt_load
  37. from models.yolo import Model
  38. from utils.autoanchor import check_anchors
  39. from utils.autobatch import check_train_batch_size
  40. from utils.callbacks import Callbacks
  41. from utils.datasets import create_dataloader
  42. from utils.downloads import attempt_download
  43. from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
  44. check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
  45. intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle,
  46. print_args, print_mutation, strip_optimizer)
  47. from utils.loggers import Loggers
  48. from utils.loggers.wandb.wandb_utils import check_wandb_resume
  49. from utils.loss import ComputeLoss
  50. from utils.metrics import fitness
  51. from utils.plots import plot_evolve, plot_labels
  52. from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
  53. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  54. RANK = int(os.getenv('RANK', -1))
  55. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
  56. def train(hyp, # path/to/hyp.yaml or hyp dictionary
  57. opt,
  58. device,
  59. callbacks
  60. ):
  61. save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = \
  62. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
  63. opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
  64. # Directories
  65. w = save_dir / 'weights' # weights dir
  66. (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
  67. last, best = w / 'last.pt', w / 'best.pt'
  68. # Hyperparameters
  69. if isinstance(hyp, str):
  70. with open(hyp, errors='ignore') as f:
  71. hyp = yaml.safe_load(f) # load hyps dict
  72. LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  73. # Save run settings
  74. if not evolve:
  75. with open(save_dir / 'hyp.yaml', 'w') as f:
  76. yaml.safe_dump(hyp, f, sort_keys=False)
  77. with open(save_dir / 'opt.yaml', 'w') as f:
  78. yaml.safe_dump(vars(opt), f, sort_keys=False)
  79. # Loggers
  80. data_dict = None
  81. if RANK in [-1, 0]:
  82. loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
  83. if loggers.wandb:
  84. data_dict = loggers.wandb.data_dict
  85. if resume:
  86. weights, epochs, hyp, batch_size = opt.weights, opt.epochs, opt.hyp, opt.batch_size
  87. # Register actions
  88. for k in methods(loggers):
  89. callbacks.register_action(k, callback=getattr(loggers, k))
  90. # Config
  91. plots = not evolve # create plots
  92. cuda = device.type != 'cpu'
  93. init_seeds(1 + RANK)
  94. with torch_distributed_zero_first(LOCAL_RANK):
  95. data_dict = data_dict or check_dataset(data) # check if None
  96. train_path, val_path = data_dict['train'], data_dict['val']
  97. nc = 1 if single_cls else int(data_dict['nc']) # number of classes
  98. names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  99. assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
  100. is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
  101. # Model
  102. check_suffix(weights, '.pt') # check weights
  103. pretrained = weights.endswith('.pt')
  104. if pretrained:
  105. with torch_distributed_zero_first(LOCAL_RANK):
  106. weights = attempt_download(weights) # download if not found locally
  107. ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
  108. model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  109. exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
  110. csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
  111. csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
  112. model.load_state_dict(csd, strict=False) # load
  113. LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
  114. else:
  115. model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  116. # Freeze
  117. freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
  118. for k, v in model.named_parameters():
  119. v.requires_grad = True # train all layers
  120. if any(x in k for x in freeze):
  121. LOGGER.info(f'freezing {k}')
  122. v.requires_grad = False
  123. # Image size
  124. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  125. imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
  126. # Batch size
  127. if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
  128. batch_size = check_train_batch_size(model, imgsz)
  129. loggers.on_params_update({"batch_size": batch_size})
  130. # Optimizer
  131. nbs = 64 # nominal batch size
  132. accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
  133. hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
  134. LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  135. g0, g1, g2 = [], [], [] # optimizer parameter groups
  136. for v in model.modules():
  137. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias
  138. g2.append(v.bias)
  139. if isinstance(v, nn.BatchNorm2d): # weight (no decay)
  140. g0.append(v.weight)
  141. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
  142. g1.append(v.weight)
  143. if opt.optimizer == 'Adam':
  144. optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  145. elif opt.optimizer == 'AdamW':
  146. optimizer = AdamW(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  147. else:
  148. optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  149. optimizer.add_param_group({'params': g1, 'weight_decay': hyp['weight_decay']}) # add g1 with weight_decay
  150. optimizer.add_param_group({'params': g2}) # add g2 (biases)
  151. LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
  152. f"{len(g0)} weight (no decay), {len(g1)} weight, {len(g2)} bias")
  153. del g0, g1, g2
  154. # Scheduler
  155. if opt.cos_lr:
  156. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  157. else:
  158. lf = lambda x: (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  159. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
  160. # EMA
  161. ema = ModelEMA(model) if RANK in [-1, 0] else None
  162. # Resume
  163. start_epoch, best_fitness = 0, 0.0
  164. if pretrained:
  165. # Optimizer
  166. if ckpt['optimizer'] is not None:
  167. optimizer.load_state_dict(ckpt['optimizer'])
  168. best_fitness = ckpt['best_fitness']
  169. # EMA
  170. if ema and ckpt.get('ema'):
  171. ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
  172. ema.updates = ckpt['updates']
  173. # Epochs
  174. start_epoch = ckpt['epoch'] + 1
  175. if resume:
  176. assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
  177. if epochs < start_epoch:
  178. LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
  179. epochs += ckpt['epoch'] # finetune additional epochs
  180. del ckpt, csd
  181. # DP mode
  182. if cuda and RANK == -1 and torch.cuda.device_count() > 1:
  183. LOGGER.warning('WARNING: DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.\n'
  184. 'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.')
  185. model = torch.nn.DataParallel(model)
  186. # SyncBatchNorm
  187. if opt.sync_bn and cuda and RANK != -1:
  188. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  189. LOGGER.info('Using SyncBatchNorm()')
  190. # Trainloader
  191. train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
  192. hyp=hyp, augment=True, cache=None if opt.cache == 'val' else opt.cache,
  193. rect=opt.rect, rank=LOCAL_RANK, workers=workers,
  194. image_weights=opt.image_weights, quad=opt.quad,
  195. prefix=colorstr('train: '), shuffle=True)
  196. mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
  197. nb = len(train_loader) # number of batches
  198. assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
  199. # Process 0
  200. if RANK in [-1, 0]:
  201. val_loader = create_dataloader(val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls,
  202. hyp=hyp, cache=None if noval else opt.cache,
  203. rect=True, rank=-1, workers=workers * 2, pad=0.5,
  204. prefix=colorstr('val: '))[0]
  205. if not resume:
  206. labels = np.concatenate(dataset.labels, 0)
  207. # c = torch.tensor(labels[:, 0]) # classes
  208. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  209. # model._initialize_biases(cf.to(device))
  210. if plots:
  211. plot_labels(labels, names, save_dir)
  212. # Anchors
  213. if not opt.noautoanchor:
  214. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  215. model.half().float() # pre-reduce anchor precision
  216. callbacks.run('on_pretrain_routine_end')
  217. # DDP mode
  218. if cuda and RANK != -1:
  219. model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
  220. # Model attributes
  221. nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
  222. hyp['box'] *= 3 / nl # scale to layers
  223. hyp['cls'] *= nc / 80 * 3 / nl # scale to classes and layers
  224. hyp['obj'] *= (imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
  225. hyp['label_smoothing'] = opt.label_smoothing
  226. model.nc = nc # attach number of classes to model
  227. model.hyp = hyp # attach hyperparameters to model
  228. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  229. model.names = names
  230. # Start training
  231. t0 = time.time()
  232. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  233. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  234. last_opt_step = -1
  235. maps = np.zeros(nc) # mAP per class
  236. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  237. scheduler.last_epoch = start_epoch - 1 # do not move
  238. scaler = amp.GradScaler(enabled=cuda)
  239. stopper = EarlyStopping(patience=opt.patience)
  240. compute_loss = ComputeLoss(model) # init loss class
  241. LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
  242. f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n'
  243. f"Logging results to {colorstr('bold', save_dir)}\n"
  244. f'Starting training for {epochs} epochs...')
  245. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  246. model.train()
  247. # Update image weights (optional, single-GPU only)
  248. if opt.image_weights:
  249. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  250. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  251. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  252. # Update mosaic border (optional)
  253. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  254. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  255. mloss = torch.zeros(3, device=device) # mean losses
  256. if RANK != -1:
  257. train_loader.sampler.set_epoch(epoch)
  258. pbar = enumerate(train_loader)
  259. LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))
  260. if RANK in [-1, 0]:
  261. pbar = tqdm(pbar, total=nb, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
  262. optimizer.zero_grad()
  263. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  264. ni = i + nb * epoch # number integrated batches (since train start)
  265. imgs = imgs.to(device, non_blocking=True).float() / 255 # uint8 to float32, 0-255 to 0.0-1.0
  266. # Warmup
  267. if ni <= nw:
  268. xi = [0, nw] # x interp
  269. # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  270. accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
  271. for j, x in enumerate(optimizer.param_groups):
  272. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  273. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  274. if 'momentum' in x:
  275. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  276. # Multi-scale
  277. if opt.multi_scale:
  278. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  279. sf = sz / max(imgs.shape[2:]) # scale factor
  280. if sf != 1:
  281. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  282. imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  283. # Forward
  284. with amp.autocast(enabled=cuda):
  285. pred = model(imgs) # forward
  286. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  287. if RANK != -1:
  288. loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
  289. if opt.quad:
  290. loss *= 4.
  291. # Backward
  292. scaler.scale(loss).backward()
  293. # Optimize
  294. if ni - last_opt_step >= accumulate:
  295. scaler.step(optimizer) # optimizer.step
  296. scaler.update()
  297. optimizer.zero_grad()
  298. if ema:
  299. ema.update(model)
  300. last_opt_step = ni
  301. # Log
  302. if RANK in [-1, 0]:
  303. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  304. mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
  305. pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
  306. f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
  307. callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots, opt.sync_bn)
  308. if callbacks.stop_training:
  309. return
  310. # end batch ------------------------------------------------------------------------------------------------
  311. # Scheduler
  312. lr = [x['lr'] for x in optimizer.param_groups] # for loggers
  313. scheduler.step()
  314. if RANK in [-1, 0]:
  315. # mAP
  316. callbacks.run('on_train_epoch_end', epoch=epoch)
  317. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
  318. final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
  319. if not noval or final_epoch: # Calculate mAP
  320. results, maps, _ = val.run(data_dict,
  321. batch_size=batch_size // WORLD_SIZE * 2,
  322. imgsz=imgsz,
  323. model=ema.ema,
  324. single_cls=single_cls,
  325. dataloader=val_loader,
  326. save_dir=save_dir,
  327. plots=False,
  328. callbacks=callbacks,
  329. compute_loss=compute_loss)
  330. # Update best mAP
  331. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  332. if fi > best_fitness:
  333. best_fitness = fi
  334. log_vals = list(mloss) + list(results) + lr
  335. callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
  336. # Save model
  337. if (not nosave) or (final_epoch and not evolve): # if save
  338. ckpt = {'epoch': epoch,
  339. 'best_fitness': best_fitness,
  340. 'model': deepcopy(de_parallel(model)).half(),
  341. 'ema': deepcopy(ema.ema).half(),
  342. 'updates': ema.updates,
  343. 'optimizer': optimizer.state_dict(),
  344. 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None,
  345. 'date': datetime.now().isoformat()}
  346. # Save last, best and delete
  347. torch.save(ckpt, last)
  348. if best_fitness == fi:
  349. torch.save(ckpt, best)
  350. if (epoch > 0) and (opt.save_period > 0) and (epoch % opt.save_period == 0):
  351. torch.save(ckpt, w / f'epoch{epoch}.pt')
  352. del ckpt
  353. callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
  354. # Stop Single-GPU
  355. if RANK == -1 and stopper(epoch=epoch, fitness=fi):
  356. break
  357. # Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576
  358. # stop = stopper(epoch=epoch, fitness=fi)
  359. # if RANK == 0:
  360. # dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks
  361. # Stop DPP
  362. # with torch_distributed_zero_first(RANK):
  363. # if stop:
  364. # break # must break all DDP ranks
  365. # end epoch ----------------------------------------------------------------------------------------------------
  366. # end training -----------------------------------------------------------------------------------------------------
  367. if RANK in [-1, 0]:
  368. LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
  369. for f in last, best:
  370. if f.exists():
  371. strip_optimizer(f) # strip optimizers
  372. if f is best:
  373. LOGGER.info(f'\nValidating {f}...')
  374. results, _, _ = val.run(data_dict,
  375. batch_size=batch_size // WORLD_SIZE * 2,
  376. imgsz=imgsz,
  377. model=attempt_load(f, device).half(),
  378. iou_thres=0.65 if is_coco else 0.60, # best pycocotools results at 0.65
  379. single_cls=single_cls,
  380. dataloader=val_loader,
  381. save_dir=save_dir,
  382. save_json=is_coco,
  383. verbose=True,
  384. plots=True,
  385. callbacks=callbacks,
  386. compute_loss=compute_loss) # val best model with plots
  387. if is_coco:
  388. callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
  389. callbacks.run('on_train_end', last, best, plots, epoch, results)
  390. LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
  391. torch.cuda.empty_cache()
  392. return results
  393. def parse_opt(known=False):
  394. parser = argparse.ArgumentParser()
  395. parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='initial weights path')
  396. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  397. parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
  398. parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-low.yaml', help='hyperparameters path')
  399. parser.add_argument('--epochs', type=int, default=300)
  400. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
  401. parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
  402. parser.add_argument('--rect', action='store_true', help='rectangular training')
  403. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  404. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  405. parser.add_argument('--noval', action='store_true', help='only validate final epoch')
  406. parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor')
  407. parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
  408. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  409. parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
  410. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  411. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  412. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  413. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  414. parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='SGD', help='optimizer')
  415. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  416. parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
  417. parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
  418. parser.add_argument('--name', default='exp', help='save to project/name')
  419. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  420. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  421. parser.add_argument('--cos-lr', action='store_true', help='cosine LR scheduler')
  422. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
  423. parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
  424. parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
  425. parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
  426. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  427. # Weights & Biases arguments
  428. parser.add_argument('--entity', default=None, help='W&B: Entity')
  429. parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='W&B: Upload data, "val" option')
  430. parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval')
  431. parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')
  432. opt = parser.parse_known_args()[0] if known else parser.parse_args()
  433. return opt
  434. def main(opt, callbacks=Callbacks()):
  435. # Checks
  436. if RANK in [-1, 0]:
  437. print_args(FILE.stem, opt)
  438. check_git_status()
  439. check_requirements(exclude=['thop'])
  440. # Resume
  441. if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run
  442. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  443. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  444. with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f:
  445. opt = argparse.Namespace(**yaml.safe_load(f)) # replace
  446. opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
  447. LOGGER.info(f'Resuming training from {ckpt}')
  448. else:
  449. opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
  450. check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
  451. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  452. if opt.evolve:
  453. if opt.project == str(ROOT / 'runs/train'): # if default project name, rename to runs/evolve
  454. opt.project = str(ROOT / 'runs/evolve')
  455. opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
  456. opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
  457. # DDP mode
  458. device = select_device(opt.device, batch_size=opt.batch_size)
  459. if LOCAL_RANK != -1:
  460. msg = 'is not compatible with YOLOv5 Multi-GPU DDP training'
  461. assert not opt.image_weights, f'--image-weights {msg}'
  462. assert not opt.evolve, f'--evolve {msg}'
  463. assert opt.batch_size != -1, f'AutoBatch with --batch-size -1 {msg}, please pass a valid --batch-size'
  464. assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multiple of WORLD_SIZE'
  465. assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
  466. torch.cuda.set_device(LOCAL_RANK)
  467. device = torch.device('cuda', LOCAL_RANK)
  468. dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
  469. # Train
  470. if not opt.evolve:
  471. train(opt.hyp, opt, device, callbacks)
  472. if WORLD_SIZE > 1 and RANK == 0:
  473. LOGGER.info('Destroying process group... ')
  474. dist.destroy_process_group()
  475. # Evolve hyperparameters (optional)
  476. else:
  477. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  478. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  479. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  480. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  481. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  482. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  483. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  484. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  485. 'box': (1, 0.02, 0.2), # box loss gain
  486. 'cls': (1, 0.2, 4.0), # cls loss gain
  487. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  488. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  489. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  490. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  491. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  492. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  493. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  494. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  495. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  496. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  497. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  498. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  499. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  500. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  501. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  502. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  503. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  504. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  505. 'mixup': (1, 0.0, 1.0), # image mixup (probability)
  506. 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
  507. with open(opt.hyp, errors='ignore') as f:
  508. hyp = yaml.safe_load(f) # load hyps dict
  509. if 'anchors' not in hyp: # anchors commented in hyp.yaml
  510. hyp['anchors'] = 3
  511. opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
  512. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  513. evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
  514. if opt.bucket:
  515. os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {evolve_csv}') # download evolve.csv if exists
  516. for _ in range(opt.evolve): # generations to evolve
  517. if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
  518. # Select parent(s)
  519. parent = 'single' # parent selection method: 'single' or 'weighted'
  520. x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
  521. n = min(5, len(x)) # number of previous results to consider
  522. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  523. w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
  524. if parent == 'single' or len(x) == 1:
  525. # x = x[random.randint(0, n - 1)] # random selection
  526. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  527. elif parent == 'weighted':
  528. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  529. # Mutate
  530. mp, s = 0.8, 0.2 # mutation probability, sigma
  531. npr = np.random
  532. npr.seed(int(time.time()))
  533. g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1
  534. ng = len(meta)
  535. v = np.ones(ng)
  536. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  537. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  538. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  539. hyp[k] = float(x[i + 7] * v[i]) # mutate
  540. # Constrain to limits
  541. for k, v in meta.items():
  542. hyp[k] = max(hyp[k], v[1]) # lower limit
  543. hyp[k] = min(hyp[k], v[2]) # upper limit
  544. hyp[k] = round(hyp[k], 5) # significant digits
  545. # Train mutation
  546. results = train(hyp.copy(), opt, device, callbacks)
  547. callbacks = Callbacks()
  548. # Write mutation results
  549. print_mutation(results, hyp.copy(), save_dir, opt.bucket)
  550. # Plot results
  551. plot_evolve(evolve_csv)
  552. LOGGER.info(f'Hyperparameter evolution finished {opt.evolve} generations\n'
  553. f"Results saved to {colorstr('bold', save_dir)}\n"
  554. f'Usage example: $ python train.py --hyp {evolve_yaml}')
  555. def run(**kwargs):
  556. # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
  557. opt = parse_opt(True)
  558. for k, v in kwargs.items():
  559. setattr(opt, k, v)
  560. main(opt)
  561. return opt
  562. if __name__ == "__main__":
  563. opt = parse_opt()
  564. main(opt)