You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

621 lines
31KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Train a YOLOv5 model on a custom dataset
  4. Usage:
  5. $ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640
  6. """
  7. import argparse
  8. import logging
  9. import math
  10. import os
  11. import random
  12. import sys
  13. import time
  14. from copy import deepcopy
  15. from pathlib import Path
  16. import numpy as np
  17. import torch
  18. import torch.distributed as dist
  19. import torch.nn as nn
  20. import yaml
  21. from torch.cuda import amp
  22. from torch.nn.parallel import DistributedDataParallel as DDP
  23. from torch.optim import Adam, SGD, lr_scheduler
  24. from tqdm import tqdm
  25. FILE = Path(__file__).resolve()
  26. ROOT = FILE.parents[0] # YOLOv5 root directory
  27. if str(ROOT) not in sys.path:
  28. sys.path.append(str(ROOT)) # add ROOT to PATH
  29. ROOT = ROOT.relative_to(Path.cwd()) # relative
  30. import val # for end-of-epoch mAP
  31. from models.experimental import attempt_load
  32. from models.yolo import Model
  33. from utils.autoanchor import check_anchors
  34. from utils.datasets import create_dataloader
  35. from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
  36. strip_optimizer, get_latest_run, check_dataset, check_git_status, check_img_size, check_requirements, \
  37. check_file, check_yaml, check_suffix, print_args, print_mutation, set_logging, one_cycle, colorstr, methods
  38. from utils.downloads import attempt_download
  39. from utils.loss import ComputeLoss
  40. from utils.plots import plot_labels, plot_evolve
  41. from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, intersect_dicts, select_device, \
  42. torch_distributed_zero_first
  43. from utils.loggers.wandb.wandb_utils import check_wandb_resume
  44. from utils.metrics import fitness
  45. from utils.loggers import Loggers
  46. from utils.callbacks import Callbacks
  47. LOGGER = logging.getLogger(__name__)
  48. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  49. RANK = int(os.getenv('RANK', -1))
  50. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
  51. def train(hyp, # path/to/hyp.yaml or hyp dictionary
  52. opt,
  53. device,
  54. callbacks
  55. ):
  56. save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
  57. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
  58. opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
  59. # Directories
  60. w = save_dir / 'weights' # weights dir
  61. (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
  62. last, best = w / 'last.pt', w / 'best.pt'
  63. # Hyperparameters
  64. if isinstance(hyp, str):
  65. with open(hyp) as f:
  66. hyp = yaml.safe_load(f) # load hyps dict
  67. LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  68. # Save run settings
  69. with open(save_dir / 'hyp.yaml', 'w') as f:
  70. yaml.safe_dump(hyp, f, sort_keys=False)
  71. with open(save_dir / 'opt.yaml', 'w') as f:
  72. yaml.safe_dump(vars(opt), f, sort_keys=False)
  73. data_dict = None
  74. # Loggers
  75. if RANK in [-1, 0]:
  76. loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
  77. if loggers.wandb:
  78. data_dict = loggers.wandb.data_dict
  79. if resume:
  80. weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp
  81. # Register actions
  82. for k in methods(loggers):
  83. callbacks.register_action(k, callback=getattr(loggers, k))
  84. # Config
  85. plots = not evolve # create plots
  86. cuda = device.type != 'cpu'
  87. init_seeds(1 + RANK)
  88. with torch_distributed_zero_first(RANK):
  89. data_dict = data_dict or check_dataset(data) # check if None
  90. train_path, val_path = data_dict['train'], data_dict['val']
  91. nc = 1 if single_cls else int(data_dict['nc']) # number of classes
  92. names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  93. assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
  94. is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset
  95. # Model
  96. check_suffix(weights, '.pt') # check weights
  97. pretrained = weights.endswith('.pt')
  98. if pretrained:
  99. with torch_distributed_zero_first(RANK):
  100. weights = attempt_download(weights) # download if not found locally
  101. ckpt = torch.load(weights, map_location=device) # load checkpoint
  102. model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  103. exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
  104. csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
  105. csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
  106. model.load_state_dict(csd, strict=False) # load
  107. LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
  108. else:
  109. model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  110. # Freeze
  111. freeze = [f'model.{x}.' for x in range(freeze)] # layers to freeze
  112. for k, v in model.named_parameters():
  113. v.requires_grad = True # train all layers
  114. if any(x in k for x in freeze):
  115. print(f'freezing {k}')
  116. v.requires_grad = False
  117. # Optimizer
  118. nbs = 64 # nominal batch size
  119. accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
  120. hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
  121. LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  122. g0, g1, g2 = [], [], [] # optimizer parameter groups
  123. for v in model.modules():
  124. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias
  125. g2.append(v.bias)
  126. if isinstance(v, nn.BatchNorm2d): # weight (no decay)
  127. g0.append(v.weight)
  128. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
  129. g1.append(v.weight)
  130. if opt.adam:
  131. optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  132. else:
  133. optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  134. optimizer.add_param_group({'params': g1, 'weight_decay': hyp['weight_decay']}) # add g1 with weight_decay
  135. optimizer.add_param_group({'params': g2}) # add g2 (biases)
  136. LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
  137. f"{len(g0)} weight, {len(g1)} weight (no decay), {len(g2)} bias")
  138. del g0, g1, g2
  139. # Scheduler
  140. if opt.linear_lr:
  141. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  142. else:
  143. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  144. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
  145. # EMA
  146. ema = ModelEMA(model) if RANK in [-1, 0] else None
  147. # Resume
  148. start_epoch, best_fitness = 0, 0.0
  149. if pretrained:
  150. # Optimizer
  151. if ckpt['optimizer'] is not None:
  152. optimizer.load_state_dict(ckpt['optimizer'])
  153. best_fitness = ckpt['best_fitness']
  154. # EMA
  155. if ema and ckpt.get('ema'):
  156. ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
  157. ema.updates = ckpt['updates']
  158. # Epochs
  159. start_epoch = ckpt['epoch'] + 1
  160. if resume:
  161. assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
  162. if epochs < start_epoch:
  163. LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
  164. epochs += ckpt['epoch'] # finetune additional epochs
  165. del ckpt, csd
  166. # Image sizes
  167. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  168. nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj'])
  169. imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
  170. # DP mode
  171. if cuda and RANK == -1 and torch.cuda.device_count() > 1:
  172. logging.warning('DP not recommended, instead use torch.distributed.run for best DDP Multi-GPU results.\n'
  173. 'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.')
  174. model = torch.nn.DataParallel(model)
  175. # SyncBatchNorm
  176. if opt.sync_bn and cuda and RANK != -1:
  177. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  178. LOGGER.info('Using SyncBatchNorm()')
  179. # Trainloader
  180. train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
  181. hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=RANK,
  182. workers=workers, image_weights=opt.image_weights, quad=opt.quad,
  183. prefix=colorstr('train: '))
  184. mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
  185. nb = len(train_loader) # number of batches
  186. assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
  187. # Process 0
  188. if RANK in [-1, 0]:
  189. val_loader = create_dataloader(val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls,
  190. hyp=hyp, cache=None if noval else opt.cache, rect=True, rank=-1,
  191. workers=workers, pad=0.5,
  192. prefix=colorstr('val: '))[0]
  193. if not resume:
  194. labels = np.concatenate(dataset.labels, 0)
  195. # c = torch.tensor(labels[:, 0]) # classes
  196. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  197. # model._initialize_biases(cf.to(device))
  198. if plots:
  199. plot_labels(labels, names, save_dir)
  200. # Anchors
  201. if not opt.noautoanchor:
  202. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  203. model.half().float() # pre-reduce anchor precision
  204. callbacks.run('on_pretrain_routine_end')
  205. # DDP mode
  206. if cuda and RANK != -1:
  207. model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
  208. # Model parameters
  209. hyp['box'] *= 3. / nl # scale to layers
  210. hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
  211. hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
  212. hyp['label_smoothing'] = opt.label_smoothing
  213. model.nc = nc # attach number of classes to model
  214. model.hyp = hyp # attach hyperparameters to model
  215. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  216. model.names = names
  217. # Start training
  218. t0 = time.time()
  219. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  220. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  221. last_opt_step = -1
  222. maps = np.zeros(nc) # mAP per class
  223. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  224. scheduler.last_epoch = start_epoch - 1 # do not move
  225. scaler = amp.GradScaler(enabled=cuda)
  226. stopper = EarlyStopping(patience=opt.patience)
  227. compute_loss = ComputeLoss(model) # init loss class
  228. LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
  229. f'Using {train_loader.num_workers} dataloader workers\n'
  230. f"Logging results to {colorstr('bold', save_dir)}\n"
  231. f'Starting training for {epochs} epochs...')
  232. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  233. model.train()
  234. # Update image weights (optional, single-GPU only)
  235. if opt.image_weights:
  236. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  237. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  238. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  239. # Update mosaic border (optional)
  240. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  241. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  242. mloss = torch.zeros(3, device=device) # mean losses
  243. if RANK != -1:
  244. train_loader.sampler.set_epoch(epoch)
  245. pbar = enumerate(train_loader)
  246. LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))
  247. if RANK in [-1, 0]:
  248. pbar = tqdm(pbar, total=nb) # progress bar
  249. optimizer.zero_grad()
  250. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  251. ni = i + nb * epoch # number integrated batches (since train start)
  252. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  253. # Warmup
  254. if ni <= nw:
  255. xi = [0, nw] # x interp
  256. # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  257. accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
  258. for j, x in enumerate(optimizer.param_groups):
  259. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  260. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  261. if 'momentum' in x:
  262. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  263. # Multi-scale
  264. if opt.multi_scale:
  265. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  266. sf = sz / max(imgs.shape[2:]) # scale factor
  267. if sf != 1:
  268. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  269. imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  270. # Forward
  271. with amp.autocast(enabled=cuda):
  272. pred = model(imgs) # forward
  273. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  274. if RANK != -1:
  275. loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
  276. if opt.quad:
  277. loss *= 4.
  278. # Backward
  279. scaler.scale(loss).backward()
  280. # Optimize
  281. if ni - last_opt_step >= accumulate:
  282. scaler.step(optimizer) # optimizer.step
  283. scaler.update()
  284. optimizer.zero_grad()
  285. if ema:
  286. ema.update(model)
  287. last_opt_step = ni
  288. # Log
  289. if RANK in [-1, 0]:
  290. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  291. mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
  292. pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
  293. f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
  294. callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots, opt.sync_bn)
  295. # end batch ------------------------------------------------------------------------------------------------
  296. # Scheduler
  297. lr = [x['lr'] for x in optimizer.param_groups] # for loggers
  298. scheduler.step()
  299. if RANK in [-1, 0]:
  300. # mAP
  301. callbacks.run('on_train_epoch_end', epoch=epoch)
  302. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
  303. final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
  304. if not noval or final_epoch: # Calculate mAP
  305. results, maps, _ = val.run(data_dict,
  306. batch_size=batch_size // WORLD_SIZE * 2,
  307. imgsz=imgsz,
  308. model=ema.ema,
  309. single_cls=single_cls,
  310. dataloader=val_loader,
  311. save_dir=save_dir,
  312. plots=False,
  313. callbacks=callbacks,
  314. compute_loss=compute_loss)
  315. # Update best mAP
  316. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  317. if fi > best_fitness:
  318. best_fitness = fi
  319. log_vals = list(mloss) + list(results) + lr
  320. callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
  321. # Save model
  322. if (not nosave) or (final_epoch and not evolve): # if save
  323. ckpt = {'epoch': epoch,
  324. 'best_fitness': best_fitness,
  325. 'model': deepcopy(de_parallel(model)).half(),
  326. 'ema': deepcopy(ema.ema).half(),
  327. 'updates': ema.updates,
  328. 'optimizer': optimizer.state_dict(),
  329. 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None}
  330. # Save last, best and delete
  331. torch.save(ckpt, last)
  332. if best_fitness == fi:
  333. torch.save(ckpt, best)
  334. if (epoch > 0) and (opt.save_period > 0) and (epoch % opt.save_period == 0):
  335. torch.save(ckpt, w / f'epoch{epoch}.pt')
  336. del ckpt
  337. callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
  338. # Stop Single-GPU
  339. if RANK == -1 and stopper(epoch=epoch, fitness=fi):
  340. break
  341. # Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576
  342. # stop = stopper(epoch=epoch, fitness=fi)
  343. # if RANK == 0:
  344. # dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks
  345. # Stop DPP
  346. # with torch_distributed_zero_first(RANK):
  347. # if stop:
  348. # break # must break all DDP ranks
  349. # end epoch ----------------------------------------------------------------------------------------------------
  350. # end training -----------------------------------------------------------------------------------------------------
  351. if RANK in [-1, 0]:
  352. LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
  353. for f in last, best:
  354. if f.exists():
  355. strip_optimizer(f) # strip optimizers
  356. if f is best:
  357. LOGGER.info(f'\nValidating {f}...')
  358. results, _, _ = val.run(data_dict,
  359. batch_size=batch_size // WORLD_SIZE * 2,
  360. imgsz=imgsz,
  361. model=attempt_load(f, device).half(),
  362. iou_thres=0.65 if is_coco else 0.60, # best pycocotools results at 0.65
  363. single_cls=single_cls,
  364. dataloader=val_loader,
  365. save_dir=save_dir,
  366. save_json=is_coco,
  367. verbose=True,
  368. plots=True,
  369. callbacks=callbacks,
  370. compute_loss=compute_loss) # val best model with plots
  371. callbacks.run('on_train_end', last, best, plots, epoch)
  372. LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
  373. torch.cuda.empty_cache()
  374. return results
  375. def parse_opt(known=False):
  376. parser = argparse.ArgumentParser()
  377. parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='initial weights path')
  378. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  379. parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
  380. parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch.yaml', help='hyperparameters path')
  381. parser.add_argument('--epochs', type=int, default=300)
  382. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  383. parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
  384. parser.add_argument('--rect', action='store_true', help='rectangular training')
  385. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  386. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  387. parser.add_argument('--noval', action='store_true', help='only validate final epoch')
  388. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  389. parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
  390. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  391. parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
  392. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  393. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  394. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  395. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  396. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  397. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  398. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  399. parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
  400. parser.add_argument('--name', default='exp', help='save to project/name')
  401. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  402. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  403. parser.add_argument('--linear-lr', action='store_true', help='linear LR')
  404. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
  405. parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
  406. parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
  407. parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
  408. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  409. # Weights & Biases arguments
  410. parser.add_argument('--entity', default=None, help='W&B: Entity')
  411. parser.add_argument('--upload_dataset', action='store_true', help='W&B: Upload dataset as artifact table')
  412. parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval')
  413. parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')
  414. opt = parser.parse_known_args()[0] if known else parser.parse_args()
  415. return opt
  416. def main(opt, callbacks=Callbacks()):
  417. # Checks
  418. set_logging(RANK)
  419. if RANK in [-1, 0]:
  420. print_args(FILE.stem, opt)
  421. check_git_status()
  422. check_requirements(exclude=['thop'])
  423. # Resume
  424. if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run
  425. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  426. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  427. with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
  428. opt = argparse.Namespace(**yaml.safe_load(f)) # replace
  429. opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
  430. LOGGER.info(f'Resuming training from {ckpt}')
  431. else:
  432. opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
  433. check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
  434. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  435. if opt.evolve:
  436. opt.project = str(ROOT / 'runs/evolve')
  437. opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
  438. opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
  439. # DDP mode
  440. device = select_device(opt.device, batch_size=opt.batch_size)
  441. if LOCAL_RANK != -1:
  442. assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
  443. assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
  444. assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
  445. assert not opt.evolve, '--evolve argument is not compatible with DDP training'
  446. torch.cuda.set_device(LOCAL_RANK)
  447. device = torch.device('cuda', LOCAL_RANK)
  448. dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
  449. # Train
  450. if not opt.evolve:
  451. train(opt.hyp, opt, device, callbacks)
  452. if WORLD_SIZE > 1 and RANK == 0:
  453. LOGGER.info('Destroying process group... ')
  454. dist.destroy_process_group()
  455. # Evolve hyperparameters (optional)
  456. else:
  457. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  458. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  459. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  460. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  461. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  462. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  463. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  464. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  465. 'box': (1, 0.02, 0.2), # box loss gain
  466. 'cls': (1, 0.2, 4.0), # cls loss gain
  467. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  468. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  469. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  470. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  471. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  472. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  473. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  474. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  475. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  476. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  477. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  478. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  479. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  480. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  481. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  482. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  483. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  484. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  485. 'mixup': (1, 0.0, 1.0), # image mixup (probability)
  486. 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
  487. with open(opt.hyp) as f:
  488. hyp = yaml.safe_load(f) # load hyps dict
  489. if 'anchors' not in hyp: # anchors commented in hyp.yaml
  490. hyp['anchors'] = 3
  491. opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
  492. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  493. evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
  494. if opt.bucket:
  495. os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {save_dir}') # download evolve.csv if exists
  496. for _ in range(opt.evolve): # generations to evolve
  497. if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
  498. # Select parent(s)
  499. parent = 'single' # parent selection method: 'single' or 'weighted'
  500. x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
  501. n = min(5, len(x)) # number of previous results to consider
  502. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  503. w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
  504. if parent == 'single' or len(x) == 1:
  505. # x = x[random.randint(0, n - 1)] # random selection
  506. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  507. elif parent == 'weighted':
  508. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  509. # Mutate
  510. mp, s = 0.8, 0.2 # mutation probability, sigma
  511. npr = np.random
  512. npr.seed(int(time.time()))
  513. g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1
  514. ng = len(meta)
  515. v = np.ones(ng)
  516. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  517. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  518. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  519. hyp[k] = float(x[i + 7] * v[i]) # mutate
  520. # Constrain to limits
  521. for k, v in meta.items():
  522. hyp[k] = max(hyp[k], v[1]) # lower limit
  523. hyp[k] = min(hyp[k], v[2]) # upper limit
  524. hyp[k] = round(hyp[k], 5) # significant digits
  525. # Train mutation
  526. results = train(hyp.copy(), opt, device, callbacks)
  527. # Write mutation results
  528. print_mutation(results, hyp.copy(), save_dir, opt.bucket)
  529. # Plot results
  530. plot_evolve(evolve_csv)
  531. print(f'Hyperparameter evolution finished\n'
  532. f"Results saved to {colorstr('bold', save_dir)}\n"
  533. f'Use best hyperparameters example: $ python train.py --hyp {evolve_yaml}')
  534. def run(**kwargs):
  535. # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
  536. opt = parse_opt(True)
  537. for k, v in kwargs.items():
  538. setattr(opt, k, v)
  539. main(opt)
  540. if __name__ == "__main__":
  541. opt = parse_opt()
  542. main(opt)