You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

601 lines
30KB

  1. """Train a YOLOv5 model on a custom dataset
  2. Usage:
  3. $ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640
  4. """
  5. import argparse
  6. import logging
  7. import os
  8. import random
  9. import sys
  10. import time
  11. from copy import deepcopy
  12. from pathlib import Path
  13. import math
  14. import numpy as np
  15. import torch
  16. import torch.distributed as dist
  17. import torch.nn as nn
  18. import yaml
  19. from torch.cuda import amp
  20. from torch.nn.parallel import DistributedDataParallel as DDP
  21. from torch.optim import Adam, SGD, lr_scheduler
  22. from tqdm import tqdm
  23. FILE = Path(__file__).absolute()
  24. sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path
  25. import val # for end-of-epoch mAP
  26. from models.experimental import attempt_load
  27. from models.yolo import Model
  28. from utils.autoanchor import check_anchors
  29. from utils.datasets import create_dataloader
  30. from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
  31. strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
  32. check_requirements, print_mutation, set_logging, one_cycle, colorstr, methods
  33. from utils.downloads import attempt_download
  34. from utils.loss import ComputeLoss
  35. from utils.plots import plot_labels, plot_evolve
  36. from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
  37. from utils.loggers.wandb.wandb_utils import check_wandb_resume
  38. from utils.metrics import fitness
  39. from utils.loggers import Loggers
  40. from utils.callbacks import Callbacks
  41. LOGGER = logging.getLogger(__name__)
  42. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  43. RANK = int(os.getenv('RANK', -1))
  44. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
  45. def train(hyp, # path/to/hyp.yaml or hyp dictionary
  46. opt,
  47. device,
  48. callbacks=Callbacks()
  49. ):
  50. save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
  51. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
  52. opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
  53. # Directories
  54. w = save_dir / 'weights' # weights dir
  55. w.mkdir(parents=True, exist_ok=True) # make dir
  56. last, best = w / 'last.pt', w / 'best.pt'
  57. # Hyperparameters
  58. if isinstance(hyp, str):
  59. with open(hyp) as f:
  60. hyp = yaml.safe_load(f) # load hyps dict
  61. LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  62. # Save run settings
  63. with open(save_dir / 'hyp.yaml', 'w') as f:
  64. yaml.safe_dump(hyp, f, sort_keys=False)
  65. with open(save_dir / 'opt.yaml', 'w') as f:
  66. yaml.safe_dump(vars(opt), f, sort_keys=False)
  67. data_dict = None
  68. # Loggers
  69. if RANK in [-1, 0]:
  70. loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
  71. if loggers.wandb:
  72. data_dict = loggers.wandb.data_dict
  73. if resume:
  74. weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp
  75. # Register actions
  76. for k in methods(loggers):
  77. callbacks.register_action(k, callback=getattr(loggers, k))
  78. # Config
  79. plots = not evolve # create plots
  80. cuda = device.type != 'cpu'
  81. init_seeds(1 + RANK)
  82. with torch_distributed_zero_first(RANK):
  83. data_dict = data_dict or check_dataset(data) # check if None
  84. train_path, val_path = data_dict['train'], data_dict['val']
  85. nc = 1 if single_cls else int(data_dict['nc']) # number of classes
  86. names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  87. assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
  88. is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset
  89. # Model
  90. pretrained = weights.endswith('.pt')
  91. if pretrained:
  92. with torch_distributed_zero_first(RANK):
  93. weights = attempt_download(weights) # download if not found locally
  94. ckpt = torch.load(weights, map_location=device) # load checkpoint
  95. model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  96. exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
  97. csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
  98. csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
  99. model.load_state_dict(csd, strict=False) # load
  100. LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
  101. else:
  102. model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  103. # Freeze
  104. freeze = [f'model.{x}.' for x in range(freeze)] # layers to freeze
  105. for k, v in model.named_parameters():
  106. v.requires_grad = True # train all layers
  107. if any(x in k for x in freeze):
  108. print(f'freezing {k}')
  109. v.requires_grad = False
  110. # Optimizer
  111. nbs = 64 # nominal batch size
  112. accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
  113. hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
  114. LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  115. g0, g1, g2 = [], [], [] # optimizer parameter groups
  116. for v in model.modules():
  117. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias
  118. g2.append(v.bias)
  119. if isinstance(v, nn.BatchNorm2d): # weight (no decay)
  120. g0.append(v.weight)
  121. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
  122. g1.append(v.weight)
  123. if opt.adam:
  124. optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  125. else:
  126. optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  127. optimizer.add_param_group({'params': g1, 'weight_decay': hyp['weight_decay']}) # add g1 with weight_decay
  128. optimizer.add_param_group({'params': g2}) # add g2 (biases)
  129. LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
  130. f"{len(g0)} weight, {len(g1)} weight (no decay), {len(g2)} bias")
  131. del g0, g1, g2
  132. # Scheduler
  133. if opt.linear_lr:
  134. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  135. else:
  136. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  137. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
  138. # EMA
  139. ema = ModelEMA(model) if RANK in [-1, 0] else None
  140. # Resume
  141. start_epoch, best_fitness = 0, 0.0
  142. if pretrained:
  143. # Optimizer
  144. if ckpt['optimizer'] is not None:
  145. optimizer.load_state_dict(ckpt['optimizer'])
  146. best_fitness = ckpt['best_fitness']
  147. # EMA
  148. if ema and ckpt.get('ema'):
  149. ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
  150. ema.updates = ckpt['updates']
  151. # Epochs
  152. start_epoch = ckpt['epoch'] + 1
  153. if resume:
  154. assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
  155. if epochs < start_epoch:
  156. LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
  157. epochs += ckpt['epoch'] # finetune additional epochs
  158. del ckpt, csd
  159. # Image sizes
  160. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  161. nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj'])
  162. imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
  163. # DP mode
  164. if cuda and RANK == -1 and torch.cuda.device_count() > 1:
  165. logging.warning('DP not recommended, instead use torch.distributed.run for best DDP Multi-GPU results.\n'
  166. 'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.')
  167. model = torch.nn.DataParallel(model)
  168. # SyncBatchNorm
  169. if opt.sync_bn and cuda and RANK != -1:
  170. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  171. LOGGER.info('Using SyncBatchNorm()')
  172. # Trainloader
  173. train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
  174. hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=RANK,
  175. workers=workers, image_weights=opt.image_weights, quad=opt.quad,
  176. prefix=colorstr('train: '))
  177. mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
  178. nb = len(train_loader) # number of batches
  179. assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
  180. # Process 0
  181. if RANK in [-1, 0]:
  182. val_loader = create_dataloader(val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls,
  183. hyp=hyp, cache=None if noval else opt.cache, rect=True, rank=-1,
  184. workers=workers, pad=0.5,
  185. prefix=colorstr('val: '))[0]
  186. if not resume:
  187. labels = np.concatenate(dataset.labels, 0)
  188. # c = torch.tensor(labels[:, 0]) # classes
  189. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  190. # model._initialize_biases(cf.to(device))
  191. if plots:
  192. plot_labels(labels, names, save_dir)
  193. # Anchors
  194. if not opt.noautoanchor:
  195. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  196. model.half().float() # pre-reduce anchor precision
  197. callbacks.on_pretrain_routine_end()
  198. # DDP mode
  199. if cuda and RANK != -1:
  200. model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
  201. # Model parameters
  202. hyp['box'] *= 3. / nl # scale to layers
  203. hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
  204. hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
  205. hyp['label_smoothing'] = opt.label_smoothing
  206. model.nc = nc # attach number of classes to model
  207. model.hyp = hyp # attach hyperparameters to model
  208. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  209. model.names = names
  210. # Start training
  211. t0 = time.time()
  212. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  213. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  214. last_opt_step = -1
  215. maps = np.zeros(nc) # mAP per class
  216. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  217. scheduler.last_epoch = start_epoch - 1 # do not move
  218. scaler = amp.GradScaler(enabled=cuda)
  219. compute_loss = ComputeLoss(model) # init loss class
  220. LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
  221. f'Using {train_loader.num_workers} dataloader workers\n'
  222. f'Logging results to {save_dir}\n'
  223. f'Starting training for {epochs} epochs...')
  224. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  225. model.train()
  226. # Update image weights (optional)
  227. if opt.image_weights:
  228. # Generate indices
  229. if RANK in [-1, 0]:
  230. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  231. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  232. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  233. # Broadcast if DDP
  234. if RANK != -1:
  235. indices = (torch.tensor(dataset.indices) if RANK == 0 else torch.zeros(dataset.n)).int()
  236. dist.broadcast(indices, 0)
  237. if RANK != 0:
  238. dataset.indices = indices.cpu().numpy()
  239. # Update mosaic border
  240. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  241. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  242. mloss = torch.zeros(3, device=device) # mean losses
  243. if RANK != -1:
  244. train_loader.sampler.set_epoch(epoch)
  245. pbar = enumerate(train_loader)
  246. LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))
  247. if RANK in [-1, 0]:
  248. pbar = tqdm(pbar, total=nb) # progress bar
  249. optimizer.zero_grad()
  250. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  251. ni = i + nb * epoch # number integrated batches (since train start)
  252. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  253. # Warmup
  254. if ni <= nw:
  255. xi = [0, nw] # x interp
  256. # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  257. accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
  258. for j, x in enumerate(optimizer.param_groups):
  259. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  260. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  261. if 'momentum' in x:
  262. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  263. # Multi-scale
  264. if opt.multi_scale:
  265. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  266. sf = sz / max(imgs.shape[2:]) # scale factor
  267. if sf != 1:
  268. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  269. imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  270. # Forward
  271. with amp.autocast(enabled=cuda):
  272. pred = model(imgs) # forward
  273. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  274. if RANK != -1:
  275. loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
  276. if opt.quad:
  277. loss *= 4.
  278. # Backward
  279. scaler.scale(loss).backward()
  280. # Optimize
  281. if ni - last_opt_step >= accumulate:
  282. scaler.step(optimizer) # optimizer.step
  283. scaler.update()
  284. optimizer.zero_grad()
  285. if ema:
  286. ema.update(model)
  287. last_opt_step = ni
  288. # Log
  289. if RANK in [-1, 0]:
  290. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  291. mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
  292. pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
  293. f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
  294. callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots)
  295. # end batch ------------------------------------------------------------------------------------------------
  296. # Scheduler
  297. lr = [x['lr'] for x in optimizer.param_groups] # for loggers
  298. scheduler.step()
  299. if RANK in [-1, 0]:
  300. # mAP
  301. callbacks.on_train_epoch_end(epoch=epoch)
  302. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
  303. final_epoch = epoch + 1 == epochs
  304. if not noval or final_epoch: # Calculate mAP
  305. results, maps, _ = val.run(data_dict,
  306. batch_size=batch_size // WORLD_SIZE * 2,
  307. imgsz=imgsz,
  308. model=ema.ema,
  309. single_cls=single_cls,
  310. dataloader=val_loader,
  311. save_dir=save_dir,
  312. save_json=is_coco and final_epoch,
  313. verbose=nc < 50 and final_epoch,
  314. plots=plots and final_epoch,
  315. callbacks=callbacks,
  316. compute_loss=compute_loss)
  317. # Update best mAP
  318. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  319. if fi > best_fitness:
  320. best_fitness = fi
  321. log_vals = list(mloss) + list(results) + lr
  322. callbacks.on_fit_epoch_end(log_vals, epoch, best_fitness, fi)
  323. # Save model
  324. if (not nosave) or (final_epoch and not evolve): # if save
  325. ckpt = {'epoch': epoch,
  326. 'best_fitness': best_fitness,
  327. 'model': deepcopy(de_parallel(model)).half(),
  328. 'ema': deepcopy(ema.ema).half(),
  329. 'updates': ema.updates,
  330. 'optimizer': optimizer.state_dict(),
  331. 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None}
  332. # Save last, best and delete
  333. torch.save(ckpt, last)
  334. if best_fitness == fi:
  335. torch.save(ckpt, best)
  336. del ckpt
  337. callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi)
  338. # end epoch ----------------------------------------------------------------------------------------------------
  339. # end training -----------------------------------------------------------------------------------------------------
  340. if RANK in [-1, 0]:
  341. LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
  342. if not evolve:
  343. if is_coco: # COCO dataset
  344. for m in [last, best] if best.exists() else [last]: # speed, mAP tests
  345. results, _, _ = val.run(data_dict,
  346. batch_size=batch_size // WORLD_SIZE * 2,
  347. imgsz=imgsz,
  348. model=attempt_load(m, device).half(),
  349. iou_thres=0.7, # NMS IoU threshold for best pycocotools results
  350. single_cls=single_cls,
  351. dataloader=val_loader,
  352. save_dir=save_dir,
  353. save_json=True,
  354. plots=False)
  355. # Strip optimizers
  356. for f in last, best:
  357. if f.exists():
  358. strip_optimizer(f) # strip optimizers
  359. callbacks.on_train_end(last, best, plots, epoch)
  360. LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
  361. torch.cuda.empty_cache()
  362. return results
  363. def parse_opt(known=False):
  364. parser = argparse.ArgumentParser()
  365. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  366. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  367. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='dataset.yaml path')
  368. parser.add_argument('--hyp', type=str, default='data/hyps/hyp.scratch.yaml', help='hyperparameters path')
  369. parser.add_argument('--epochs', type=int, default=300)
  370. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  371. parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
  372. parser.add_argument('--rect', action='store_true', help='rectangular training')
  373. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  374. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  375. parser.add_argument('--noval', action='store_true', help='only validate final epoch')
  376. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  377. parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
  378. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  379. parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
  380. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  381. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  382. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  383. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  384. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  385. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  386. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  387. parser.add_argument('--project', default='runs/train', help='save to project/name')
  388. parser.add_argument('--entity', default=None, help='W&B entity')
  389. parser.add_argument('--name', default='exp', help='save to project/name')
  390. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  391. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  392. parser.add_argument('--linear-lr', action='store_true', help='linear LR')
  393. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
  394. parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
  395. parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
  396. parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
  397. parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
  398. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  399. parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
  400. opt = parser.parse_known_args()[0] if known else parser.parse_args()
  401. return opt
  402. def main(opt):
  403. # Checks
  404. set_logging(RANK)
  405. if RANK in [-1, 0]:
  406. print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
  407. check_git_status()
  408. check_requirements(requirements=FILE.parent / 'requirements.txt', exclude=['thop'])
  409. # Resume
  410. if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run
  411. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  412. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  413. with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
  414. opt = argparse.Namespace(**yaml.safe_load(f)) # replace
  415. opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
  416. LOGGER.info(f'Resuming training from {ckpt}')
  417. else:
  418. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  419. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  420. if opt.evolve:
  421. opt.project = 'runs/evolve'
  422. opt.exist_ok = opt.resume
  423. opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
  424. # DDP mode
  425. device = select_device(opt.device, batch_size=opt.batch_size)
  426. if LOCAL_RANK != -1:
  427. from datetime import timedelta
  428. assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
  429. assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
  430. assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
  431. assert not opt.evolve, '--evolve argument is not compatible with DDP training'
  432. assert not opt.sync_bn, '--sync-bn known training issue, see https://github.com/ultralytics/yolov5/issues/3998'
  433. torch.cuda.set_device(LOCAL_RANK)
  434. device = torch.device('cuda', LOCAL_RANK)
  435. dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo", timeout=timedelta(seconds=60))
  436. # Train
  437. if not opt.evolve:
  438. train(opt.hyp, opt, device)
  439. if WORLD_SIZE > 1 and RANK == 0:
  440. _ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')]
  441. # Evolve hyperparameters (optional)
  442. else:
  443. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  444. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  445. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  446. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  447. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  448. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  449. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  450. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  451. 'box': (1, 0.02, 0.2), # box loss gain
  452. 'cls': (1, 0.2, 4.0), # cls loss gain
  453. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  454. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  455. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  456. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  457. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  458. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  459. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  460. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  461. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  462. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  463. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  464. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  465. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  466. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  467. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  468. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  469. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  470. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  471. 'mixup': (1, 0.0, 1.0), # image mixup (probability)
  472. 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
  473. with open(opt.hyp) as f:
  474. hyp = yaml.safe_load(f) # load hyps dict
  475. if 'anchors' not in hyp: # anchors commented in hyp.yaml
  476. hyp['anchors'] = 3
  477. opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
  478. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  479. evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
  480. if opt.bucket:
  481. os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {save_dir}') # download evolve.csv if exists
  482. for _ in range(opt.evolve): # generations to evolve
  483. if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
  484. # Select parent(s)
  485. parent = 'single' # parent selection method: 'single' or 'weighted'
  486. x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
  487. n = min(5, len(x)) # number of previous results to consider
  488. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  489. w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
  490. if parent == 'single' or len(x) == 1:
  491. # x = x[random.randint(0, n - 1)] # random selection
  492. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  493. elif parent == 'weighted':
  494. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  495. # Mutate
  496. mp, s = 0.8, 0.2 # mutation probability, sigma
  497. npr = np.random
  498. npr.seed(int(time.time()))
  499. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  500. ng = len(meta)
  501. v = np.ones(ng)
  502. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  503. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  504. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  505. hyp[k] = float(x[i + 7] * v[i]) # mutate
  506. # Constrain to limits
  507. for k, v in meta.items():
  508. hyp[k] = max(hyp[k], v[1]) # lower limit
  509. hyp[k] = min(hyp[k], v[2]) # upper limit
  510. hyp[k] = round(hyp[k], 5) # significant digits
  511. # Train mutation
  512. results = train(hyp.copy(), opt, device)
  513. # Write mutation results
  514. print_mutation(results, hyp.copy(), save_dir, opt.bucket)
  515. # Plot results
  516. plot_evolve(evolve_csv)
  517. print(f'Hyperparameter evolution finished\n'
  518. f"Results saved to {colorstr('bold', save_dir)}\n"
  519. f'Use best hyperparameters example: $ python train.py --hyp {evolve_yaml}')
  520. def run(**kwargs):
  521. # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
  522. opt = parse_opt(True)
  523. for k, v in kwargs.items():
  524. setattr(opt, k, v)
  525. main(opt)
  526. if __name__ == "__main__":
  527. opt = parse_opt()
  528. main(opt)