You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

628 line
32KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Train a YOLOv5 model on a custom dataset
  4. Usage:
  5. $ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640
  6. """
  7. import argparse
  8. import math
  9. import os
  10. import random
  11. import sys
  12. import time
  13. from copy import deepcopy
  14. from datetime import datetime
  15. from pathlib import Path
  16. import numpy as np
  17. import torch
  18. import torch.distributed as dist
  19. import torch.nn as nn
  20. import yaml
  21. from torch.cuda import amp
  22. from torch.nn.parallel import DistributedDataParallel as DDP
  23. from torch.optim import SGD, Adam, lr_scheduler
  24. from tqdm import tqdm
  25. FILE = Path(__file__).resolve()
  26. ROOT = FILE.parents[0] # YOLOv5 root directory
  27. if str(ROOT) not in sys.path:
  28. sys.path.append(str(ROOT)) # add ROOT to PATH
  29. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  30. import val # for end-of-epoch mAP
  31. from models.experimental import attempt_load
  32. from models.yolo import Model
  33. from utils.autoanchor import check_anchors
  34. from utils.autobatch import check_train_batch_size
  35. from utils.callbacks import Callbacks
  36. from utils.datasets import create_dataloader
  37. from utils.downloads import attempt_download
  38. from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
  39. check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
  40. intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle,
  41. print_args, print_mutation, strip_optimizer)
  42. from utils.loggers import Loggers
  43. from utils.loggers.wandb.wandb_utils import check_wandb_resume
  44. from utils.loss import ComputeLoss
  45. from utils.metrics import fitness
  46. from utils.plots import plot_evolve, plot_labels
  47. from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
  48. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  49. RANK = int(os.getenv('RANK', -1))
  50. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
  51. def train(hyp, # path/to/hyp.yaml or hyp dictionary
  52. opt,
  53. device,
  54. callbacks
  55. ):
  56. save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = \
  57. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
  58. opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
  59. # Directories
  60. w = save_dir / 'weights' # weights dir
  61. (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
  62. last, best = w / 'last.pt', w / 'best.pt'
  63. # Hyperparameters
  64. if isinstance(hyp, str):
  65. with open(hyp, errors='ignore') as f:
  66. hyp = yaml.safe_load(f) # load hyps dict
  67. LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  68. # Save run settings
  69. if not evolve:
  70. with open(save_dir / 'hyp.yaml', 'w') as f:
  71. yaml.safe_dump(hyp, f, sort_keys=False)
  72. with open(save_dir / 'opt.yaml', 'w') as f:
  73. yaml.safe_dump(vars(opt), f, sort_keys=False)
  74. # Loggers
  75. data_dict = None
  76. if RANK in [-1, 0]:
  77. loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
  78. if loggers.wandb:
  79. data_dict = loggers.wandb.data_dict
  80. if resume:
  81. weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp
  82. # Register actions
  83. for k in methods(loggers):
  84. callbacks.register_action(k, callback=getattr(loggers, k))
  85. # Config
  86. plots = not evolve # create plots
  87. cuda = device.type != 'cpu'
  88. init_seeds(1 + RANK)
  89. with torch_distributed_zero_first(LOCAL_RANK):
  90. data_dict = data_dict or check_dataset(data) # check if None
  91. train_path, val_path = data_dict['train'], data_dict['val']
  92. nc = 1 if single_cls else int(data_dict['nc']) # number of classes
  93. names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  94. assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
  95. is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
  96. # Model
  97. check_suffix(weights, '.pt') # check weights
  98. pretrained = weights.endswith('.pt')
  99. if pretrained:
  100. with torch_distributed_zero_first(LOCAL_RANK):
  101. weights = attempt_download(weights) # download if not found locally
  102. ckpt = torch.load(weights, map_location=device) # load checkpoint
  103. model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  104. exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
  105. csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
  106. csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
  107. model.load_state_dict(csd, strict=False) # load
  108. LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
  109. else:
  110. model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  111. # Freeze
  112. freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
  113. for k, v in model.named_parameters():
  114. v.requires_grad = True # train all layers
  115. if any(x in k for x in freeze):
  116. LOGGER.info(f'freezing {k}')
  117. v.requires_grad = False
  118. # Image size
  119. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  120. imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
  121. # Batch size
  122. if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
  123. batch_size = check_train_batch_size(model, imgsz)
  124. loggers.on_params_update({"batch_size": batch_size})
  125. # Optimizer
  126. nbs = 64 # nominal batch size
  127. accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
  128. hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
  129. LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  130. g0, g1, g2 = [], [], [] # optimizer parameter groups
  131. for v in model.modules():
  132. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias
  133. g2.append(v.bias)
  134. if isinstance(v, nn.BatchNorm2d): # weight (no decay)
  135. g0.append(v.weight)
  136. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
  137. g1.append(v.weight)
  138. if opt.adam:
  139. optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  140. else:
  141. optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  142. optimizer.add_param_group({'params': g1, 'weight_decay': hyp['weight_decay']}) # add g1 with weight_decay
  143. optimizer.add_param_group({'params': g2}) # add g2 (biases)
  144. LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
  145. f"{len(g0)} weight, {len(g1)} weight (no decay), {len(g2)} bias")
  146. del g0, g1, g2
  147. # Scheduler
  148. if opt.linear_lr:
  149. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  150. else:
  151. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  152. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
  153. # EMA
  154. ema = ModelEMA(model) if RANK in [-1, 0] else None
  155. # Resume
  156. start_epoch, best_fitness = 0, 0.0
  157. if pretrained:
  158. # Optimizer
  159. if ckpt['optimizer'] is not None:
  160. optimizer.load_state_dict(ckpt['optimizer'])
  161. best_fitness = ckpt['best_fitness']
  162. # EMA
  163. if ema and ckpt.get('ema'):
  164. ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
  165. ema.updates = ckpt['updates']
  166. # Epochs
  167. start_epoch = ckpt['epoch'] + 1
  168. if resume:
  169. assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
  170. if epochs < start_epoch:
  171. LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
  172. epochs += ckpt['epoch'] # finetune additional epochs
  173. del ckpt, csd
  174. # DP mode
  175. if cuda and RANK == -1 and torch.cuda.device_count() > 1:
  176. LOGGER.warning('WARNING: DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.\n'
  177. 'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.')
  178. model = torch.nn.DataParallel(model)
  179. # SyncBatchNorm
  180. if opt.sync_bn and cuda and RANK != -1:
  181. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  182. LOGGER.info('Using SyncBatchNorm()')
  183. # Trainloader
  184. train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
  185. hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
  186. workers=workers, image_weights=opt.image_weights, quad=opt.quad,
  187. prefix=colorstr('train: '), shuffle=True)
  188. mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
  189. nb = len(train_loader) # number of batches
  190. assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
  191. # Process 0
  192. if RANK in [-1, 0]:
  193. val_loader = create_dataloader(val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls,
  194. hyp=hyp, cache=None if noval else opt.cache, rect=True, rank=-1,
  195. workers=workers, pad=0.5,
  196. prefix=colorstr('val: '))[0]
  197. if not resume:
  198. labels = np.concatenate(dataset.labels, 0)
  199. # c = torch.tensor(labels[:, 0]) # classes
  200. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  201. # model._initialize_biases(cf.to(device))
  202. if plots:
  203. plot_labels(labels, names, save_dir)
  204. # Anchors
  205. if not opt.noautoanchor:
  206. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  207. model.half().float() # pre-reduce anchor precision
  208. callbacks.run('on_pretrain_routine_end')
  209. # DDP mode
  210. if cuda and RANK != -1:
  211. model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
  212. # Model attributes
  213. nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
  214. hyp['box'] *= 3 / nl # scale to layers
  215. hyp['cls'] *= nc / 80 * 3 / nl # scale to classes and layers
  216. hyp['obj'] *= (imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
  217. hyp['label_smoothing'] = opt.label_smoothing
  218. model.nc = nc # attach number of classes to model
  219. model.hyp = hyp # attach hyperparameters to model
  220. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  221. model.names = names
  222. # Start training
  223. t0 = time.time()
  224. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  225. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  226. last_opt_step = -1
  227. maps = np.zeros(nc) # mAP per class
  228. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  229. scheduler.last_epoch = start_epoch - 1 # do not move
  230. scaler = amp.GradScaler(enabled=cuda)
  231. stopper = EarlyStopping(patience=opt.patience)
  232. compute_loss = ComputeLoss(model) # init loss class
  233. LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
  234. f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n'
  235. f"Logging results to {colorstr('bold', save_dir)}\n"
  236. f'Starting training for {epochs} epochs...')
  237. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  238. model.train()
  239. # Update image weights (optional, single-GPU only)
  240. if opt.image_weights:
  241. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  242. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  243. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  244. # Update mosaic border (optional)
  245. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  246. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  247. mloss = torch.zeros(3, device=device) # mean losses
  248. if RANK != -1:
  249. train_loader.sampler.set_epoch(epoch)
  250. pbar = enumerate(train_loader)
  251. LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))
  252. if RANK in [-1, 0]:
  253. pbar = tqdm(pbar, total=nb, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
  254. optimizer.zero_grad()
  255. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  256. ni = i + nb * epoch # number integrated batches (since train start)
  257. imgs = imgs.to(device, non_blocking=True).float() / 255 # uint8 to float32, 0-255 to 0.0-1.0
  258. # Warmup
  259. if ni <= nw:
  260. xi = [0, nw] # x interp
  261. # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  262. accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
  263. for j, x in enumerate(optimizer.param_groups):
  264. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  265. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  266. if 'momentum' in x:
  267. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  268. # Multi-scale
  269. if opt.multi_scale:
  270. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  271. sf = sz / max(imgs.shape[2:]) # scale factor
  272. if sf != 1:
  273. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  274. imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  275. # Forward
  276. with amp.autocast(enabled=cuda):
  277. pred = model(imgs) # forward
  278. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  279. if RANK != -1:
  280. loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
  281. if opt.quad:
  282. loss *= 4.
  283. # Backward
  284. scaler.scale(loss).backward()
  285. # Optimize
  286. if ni - last_opt_step >= accumulate:
  287. scaler.step(optimizer) # optimizer.step
  288. scaler.update()
  289. optimizer.zero_grad()
  290. if ema:
  291. ema.update(model)
  292. last_opt_step = ni
  293. # Log
  294. if RANK in [-1, 0]:
  295. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  296. mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
  297. pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
  298. f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
  299. callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots, opt.sync_bn)
  300. # end batch ------------------------------------------------------------------------------------------------
  301. # Scheduler
  302. lr = [x['lr'] for x in optimizer.param_groups] # for loggers
  303. scheduler.step()
  304. if RANK in [-1, 0]:
  305. # mAP
  306. callbacks.run('on_train_epoch_end', epoch=epoch)
  307. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
  308. final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
  309. if not noval or final_epoch: # Calculate mAP
  310. results, maps, _ = val.run(data_dict,
  311. batch_size=batch_size // WORLD_SIZE * 2,
  312. imgsz=imgsz,
  313. model=ema.ema,
  314. single_cls=single_cls,
  315. dataloader=val_loader,
  316. save_dir=save_dir,
  317. plots=False,
  318. callbacks=callbacks,
  319. compute_loss=compute_loss)
  320. # Update best mAP
  321. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  322. if fi > best_fitness:
  323. best_fitness = fi
  324. log_vals = list(mloss) + list(results) + lr
  325. callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
  326. # Save model
  327. if (not nosave) or (final_epoch and not evolve): # if save
  328. ckpt = {'epoch': epoch,
  329. 'best_fitness': best_fitness,
  330. 'model': deepcopy(de_parallel(model)).half(),
  331. 'ema': deepcopy(ema.ema).half(),
  332. 'updates': ema.updates,
  333. 'optimizer': optimizer.state_dict(),
  334. 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None,
  335. 'date': datetime.now().isoformat()}
  336. # Save last, best and delete
  337. torch.save(ckpt, last)
  338. if best_fitness == fi:
  339. torch.save(ckpt, best)
  340. if (epoch > 0) and (opt.save_period > 0) and (epoch % opt.save_period == 0):
  341. torch.save(ckpt, w / f'epoch{epoch}.pt')
  342. del ckpt
  343. callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
  344. # Stop Single-GPU
  345. if RANK == -1 and stopper(epoch=epoch, fitness=fi):
  346. break
  347. # Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576
  348. # stop = stopper(epoch=epoch, fitness=fi)
  349. # if RANK == 0:
  350. # dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks
  351. # Stop DPP
  352. # with torch_distributed_zero_first(RANK):
  353. # if stop:
  354. # break # must break all DDP ranks
  355. # end epoch ----------------------------------------------------------------------------------------------------
  356. # end training -----------------------------------------------------------------------------------------------------
  357. if RANK in [-1, 0]:
  358. LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
  359. for f in last, best:
  360. if f.exists():
  361. strip_optimizer(f) # strip optimizers
  362. if f is best:
  363. LOGGER.info(f'\nValidating {f}...')
  364. results, _, _ = val.run(data_dict,
  365. batch_size=batch_size // WORLD_SIZE * 2,
  366. imgsz=imgsz,
  367. model=attempt_load(f, device).half(),
  368. iou_thres=0.65 if is_coco else 0.60, # best pycocotools results at 0.65
  369. single_cls=single_cls,
  370. dataloader=val_loader,
  371. save_dir=save_dir,
  372. save_json=is_coco,
  373. verbose=True,
  374. plots=True,
  375. callbacks=callbacks,
  376. compute_loss=compute_loss) # val best model with plots
  377. if is_coco:
  378. callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
  379. callbacks.run('on_train_end', last, best, plots, epoch, results)
  380. LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
  381. torch.cuda.empty_cache()
  382. return results
  383. def parse_opt(known=False):
  384. parser = argparse.ArgumentParser()
  385. parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='initial weights path')
  386. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  387. parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
  388. parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch.yaml', help='hyperparameters path')
  389. parser.add_argument('--epochs', type=int, default=300)
  390. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
  391. parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
  392. parser.add_argument('--rect', action='store_true', help='rectangular training')
  393. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  394. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  395. parser.add_argument('--noval', action='store_true', help='only validate final epoch')
  396. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  397. parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
  398. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  399. parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
  400. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  401. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  402. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  403. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  404. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  405. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  406. parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
  407. parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
  408. parser.add_argument('--name', default='exp', help='save to project/name')
  409. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  410. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  411. parser.add_argument('--linear-lr', action='store_true', help='linear LR')
  412. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
  413. parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
  414. parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
  415. parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
  416. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  417. # Weights & Biases arguments
  418. parser.add_argument('--entity', default=None, help='W&B: Entity')
  419. parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='W&B: Upload data, "val" option')
  420. parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval')
  421. parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')
  422. opt = parser.parse_known_args()[0] if known else parser.parse_args()
  423. return opt
  424. def main(opt, callbacks=Callbacks()):
  425. # Checks
  426. if RANK in [-1, 0]:
  427. print_args(FILE.stem, opt)
  428. check_git_status()
  429. check_requirements(exclude=['thop'])
  430. # Resume
  431. if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run
  432. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  433. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  434. with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f:
  435. opt = argparse.Namespace(**yaml.safe_load(f)) # replace
  436. opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
  437. LOGGER.info(f'Resuming training from {ckpt}')
  438. else:
  439. opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
  440. check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
  441. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  442. if opt.evolve:
  443. opt.project = str(ROOT / 'runs/evolve')
  444. opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
  445. opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
  446. # DDP mode
  447. device = select_device(opt.device, batch_size=opt.batch_size)
  448. if LOCAL_RANK != -1:
  449. assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
  450. assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
  451. assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
  452. assert not opt.evolve, '--evolve argument is not compatible with DDP training'
  453. torch.cuda.set_device(LOCAL_RANK)
  454. device = torch.device('cuda', LOCAL_RANK)
  455. dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
  456. # Train
  457. if not opt.evolve:
  458. train(opt.hyp, opt, device, callbacks)
  459. if WORLD_SIZE > 1 and RANK == 0:
  460. LOGGER.info('Destroying process group... ')
  461. dist.destroy_process_group()
  462. # Evolve hyperparameters (optional)
  463. else:
  464. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  465. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  466. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  467. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  468. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  469. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  470. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  471. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  472. 'box': (1, 0.02, 0.2), # box loss gain
  473. 'cls': (1, 0.2, 4.0), # cls loss gain
  474. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  475. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  476. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  477. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  478. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  479. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  480. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  481. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  482. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  483. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  484. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  485. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  486. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  487. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  488. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  489. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  490. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  491. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  492. 'mixup': (1, 0.0, 1.0), # image mixup (probability)
  493. 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
  494. with open(opt.hyp, errors='ignore') as f:
  495. hyp = yaml.safe_load(f) # load hyps dict
  496. if 'anchors' not in hyp: # anchors commented in hyp.yaml
  497. hyp['anchors'] = 3
  498. opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
  499. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  500. evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
  501. if opt.bucket:
  502. os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {save_dir}') # download evolve.csv if exists
  503. for _ in range(opt.evolve): # generations to evolve
  504. if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
  505. # Select parent(s)
  506. parent = 'single' # parent selection method: 'single' or 'weighted'
  507. x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
  508. n = min(5, len(x)) # number of previous results to consider
  509. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  510. w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
  511. if parent == 'single' or len(x) == 1:
  512. # x = x[random.randint(0, n - 1)] # random selection
  513. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  514. elif parent == 'weighted':
  515. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  516. # Mutate
  517. mp, s = 0.8, 0.2 # mutation probability, sigma
  518. npr = np.random
  519. npr.seed(int(time.time()))
  520. g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1
  521. ng = len(meta)
  522. v = np.ones(ng)
  523. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  524. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  525. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  526. hyp[k] = float(x[i + 7] * v[i]) # mutate
  527. # Constrain to limits
  528. for k, v in meta.items():
  529. hyp[k] = max(hyp[k], v[1]) # lower limit
  530. hyp[k] = min(hyp[k], v[2]) # upper limit
  531. hyp[k] = round(hyp[k], 5) # significant digits
  532. # Train mutation
  533. results = train(hyp.copy(), opt, device, callbacks)
  534. # Write mutation results
  535. print_mutation(results, hyp.copy(), save_dir, opt.bucket)
  536. # Plot results
  537. plot_evolve(evolve_csv)
  538. LOGGER.info(f'Hyperparameter evolution finished\n'
  539. f"Results saved to {colorstr('bold', save_dir)}\n"
  540. f'Use best hyperparameters example: $ python train.py --hyp {evolve_yaml}')
  541. def run(**kwargs):
  542. # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
  543. opt = parse_opt(True)
  544. for k, v in kwargs.items():
  545. setattr(opt, k, v)
  546. main(opt)
  547. if __name__ == "__main__":
  548. opt = parse_opt()
  549. main(opt)