You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

586 lines
30KB

  1. """Train a YOLOv5 model on a custom dataset
  2. Usage:
  3. $ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640
  4. """
  5. import argparse
  6. import logging
  7. import os
  8. import random
  9. import sys
  10. import time
  11. from copy import deepcopy
  12. from pathlib import Path
  13. import math
  14. import numpy as np
  15. import torch
  16. import torch.distributed as dist
  17. import torch.nn as nn
  18. import yaml
  19. from torch.cuda import amp
  20. from torch.nn.parallel import DistributedDataParallel as DDP
  21. from torch.optim import Adam, SGD, lr_scheduler
  22. from tqdm import tqdm
  23. FILE = Path(__file__).absolute()
  24. sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path
  25. import val # for end-of-epoch mAP
  26. from models.experimental import attempt_load
  27. from models.yolo import Model
  28. from utils.autoanchor import check_anchors
  29. from utils.datasets import create_dataloader
  30. from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
  31. strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
  32. check_requirements, print_mutation, set_logging, one_cycle, colorstr
  33. from utils.google_utils import attempt_download
  34. from utils.loss import ComputeLoss
  35. from utils.plots import plot_labels, plot_evolution
  36. from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
  37. from utils.loggers.wandb.wandb_utils import check_wandb_resume
  38. from utils.metrics import fitness
  39. from utils.loggers import Loggers
  40. LOGGER = logging.getLogger(__name__)
  41. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  42. RANK = int(os.getenv('RANK', -1))
  43. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
  44. def train(hyp, # path/to/hyp.yaml or hyp dictionary
  45. opt,
  46. device,
  47. ):
  48. save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, = \
  49. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
  50. opt.resume, opt.noval, opt.nosave, opt.workers
  51. # Directories
  52. w = save_dir / 'weights' # weights dir
  53. w.mkdir(parents=True, exist_ok=True) # make dir
  54. last, best = w / 'last.pt', w / 'best.pt'
  55. # Hyperparameters
  56. if isinstance(hyp, str):
  57. with open(hyp) as f:
  58. hyp = yaml.safe_load(f) # load hyps dict
  59. LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  60. # Save run settings
  61. with open(save_dir / 'hyp.yaml', 'w') as f:
  62. yaml.safe_dump(hyp, f, sort_keys=False)
  63. with open(save_dir / 'opt.yaml', 'w') as f:
  64. yaml.safe_dump(vars(opt), f, sort_keys=False)
  65. # Config
  66. plots = not evolve # create plots
  67. cuda = device.type != 'cpu'
  68. init_seeds(1 + RANK)
  69. with open(data) as f:
  70. data_dict = yaml.safe_load(f) # data dict
  71. nc = 1 if single_cls else int(data_dict['nc']) # number of classes
  72. names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  73. assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
  74. is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset
  75. # Loggers
  76. if RANK in [-1, 0]:
  77. loggers = Loggers(save_dir, weights, opt, hyp, data_dict, LOGGER).start() # loggers dict
  78. if loggers.wandb and resume:
  79. weights, epochs, hyp, data_dict = opt.weights, opt.epochs, opt.hyp, loggers.wandb.data_dict
  80. # Model
  81. pretrained = weights.endswith('.pt')
  82. if pretrained:
  83. with torch_distributed_zero_first(RANK):
  84. weights = attempt_download(weights) # download if not found locally
  85. ckpt = torch.load(weights, map_location=device) # load checkpoint
  86. model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  87. exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
  88. csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
  89. csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
  90. model.load_state_dict(csd, strict=False) # load
  91. LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
  92. else:
  93. model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  94. with torch_distributed_zero_first(RANK):
  95. check_dataset(data_dict) # check
  96. train_path, val_path = data_dict['train'], data_dict['val']
  97. # Freeze
  98. freeze = [] # parameter names to freeze (full or partial)
  99. for k, v in model.named_parameters():
  100. v.requires_grad = True # train all layers
  101. if any(x in k for x in freeze):
  102. print(f'freezing {k}')
  103. v.requires_grad = False
  104. # Optimizer
  105. nbs = 64 # nominal batch size
  106. accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
  107. hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
  108. LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  109. g0, g1, g2 = [], [], [] # optimizer parameter groups
  110. for v in model.modules():
  111. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias
  112. g2.append(v.bias)
  113. if isinstance(v, nn.BatchNorm2d): # weight with decay
  114. g0.append(v.weight)
  115. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight without decay
  116. g1.append(v.weight)
  117. if opt.adam:
  118. optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  119. else:
  120. optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  121. optimizer.add_param_group({'params': g1, 'weight_decay': hyp['weight_decay']}) # add g1 with weight_decay
  122. optimizer.add_param_group({'params': g2}) # add g2 (biases)
  123. LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
  124. f"{len(g0)} weight, {len(g1)} weight (no decay), {len(g2)} bias")
  125. del g0, g1, g2
  126. # Scheduler
  127. if opt.linear_lr:
  128. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  129. else:
  130. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  131. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
  132. # EMA
  133. ema = ModelEMA(model) if RANK in [-1, 0] else None
  134. # Resume
  135. start_epoch, best_fitness = 0, 0.0
  136. if pretrained:
  137. # Optimizer
  138. if ckpt['optimizer'] is not None:
  139. optimizer.load_state_dict(ckpt['optimizer'])
  140. best_fitness = ckpt['best_fitness']
  141. # EMA
  142. if ema and ckpt.get('ema'):
  143. ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
  144. ema.updates = ckpt['updates']
  145. # Epochs
  146. start_epoch = ckpt['epoch'] + 1
  147. if resume:
  148. assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
  149. if epochs < start_epoch:
  150. LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
  151. epochs += ckpt['epoch'] # finetune additional epochs
  152. del ckpt, csd
  153. # Image sizes
  154. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  155. nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj'])
  156. imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
  157. # DP mode
  158. if cuda and RANK == -1 and torch.cuda.device_count() > 1:
  159. logging.warning('DP not recommended, instead use torch.distributed.run for best DDP Multi-GPU results.\n'
  160. 'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.')
  161. model = torch.nn.DataParallel(model)
  162. # SyncBatchNorm
  163. if opt.sync_bn and cuda and RANK != -1:
  164. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  165. LOGGER.info('Using SyncBatchNorm()')
  166. # Trainloader
  167. train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
  168. hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
  169. workers=workers, image_weights=opt.image_weights, quad=opt.quad,
  170. prefix=colorstr('train: '))
  171. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  172. nb = len(train_loader) # number of batches
  173. assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
  174. # Process 0
  175. if RANK in [-1, 0]:
  176. val_loader = create_dataloader(val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls,
  177. hyp=hyp, cache=opt.cache_images and not noval, rect=True, rank=-1,
  178. workers=workers, pad=0.5,
  179. prefix=colorstr('val: '))[0]
  180. if not resume:
  181. labels = np.concatenate(dataset.labels, 0)
  182. # c = torch.tensor(labels[:, 0]) # classes
  183. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  184. # model._initialize_biases(cf.to(device))
  185. if plots:
  186. plot_labels(labels, names, save_dir, loggers)
  187. # Anchors
  188. if not opt.noautoanchor:
  189. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  190. model.half().float() # pre-reduce anchor precision
  191. # DDP mode
  192. if cuda and RANK != -1:
  193. model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
  194. # Model parameters
  195. hyp['box'] *= 3. / nl # scale to layers
  196. hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
  197. hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
  198. hyp['label_smoothing'] = opt.label_smoothing
  199. model.nc = nc # attach number of classes to model
  200. model.hyp = hyp # attach hyperparameters to model
  201. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  202. model.names = names
  203. # Start training
  204. t0 = time.time()
  205. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  206. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  207. last_opt_step = -1
  208. maps = np.zeros(nc) # mAP per class
  209. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  210. scheduler.last_epoch = start_epoch - 1 # do not move
  211. scaler = amp.GradScaler(enabled=cuda)
  212. compute_loss = ComputeLoss(model) # init loss class
  213. LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
  214. f'Using {train_loader.num_workers} dataloader workers\n'
  215. f'Logging results to {save_dir}\n'
  216. f'Starting training for {epochs} epochs...')
  217. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  218. model.train()
  219. # Update image weights (optional)
  220. if opt.image_weights:
  221. # Generate indices
  222. if RANK in [-1, 0]:
  223. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  224. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  225. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  226. # Broadcast if DDP
  227. if RANK != -1:
  228. indices = (torch.tensor(dataset.indices) if RANK == 0 else torch.zeros(dataset.n)).int()
  229. dist.broadcast(indices, 0)
  230. if RANK != 0:
  231. dataset.indices = indices.cpu().numpy()
  232. # Update mosaic border
  233. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  234. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  235. mloss = torch.zeros(3, device=device) # mean losses
  236. if RANK != -1:
  237. train_loader.sampler.set_epoch(epoch)
  238. pbar = enumerate(train_loader)
  239. LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))
  240. if RANK in [-1, 0]:
  241. pbar = tqdm(pbar, total=nb) # progress bar
  242. optimizer.zero_grad()
  243. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  244. ni = i + nb * epoch # number integrated batches (since train start)
  245. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  246. # Warmup
  247. if ni <= nw:
  248. xi = [0, nw] # x interp
  249. # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  250. accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
  251. for j, x in enumerate(optimizer.param_groups):
  252. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  253. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  254. if 'momentum' in x:
  255. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  256. # Multi-scale
  257. if opt.multi_scale:
  258. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  259. sf = sz / max(imgs.shape[2:]) # scale factor
  260. if sf != 1:
  261. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  262. imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  263. # Forward
  264. with amp.autocast(enabled=cuda):
  265. pred = model(imgs) # forward
  266. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  267. if RANK != -1:
  268. loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
  269. if opt.quad:
  270. loss *= 4.
  271. # Backward
  272. scaler.scale(loss).backward()
  273. # Optimize
  274. if ni - last_opt_step >= accumulate:
  275. scaler.step(optimizer) # optimizer.step
  276. scaler.update()
  277. optimizer.zero_grad()
  278. if ema:
  279. ema.update(model)
  280. last_opt_step = ni
  281. # Log
  282. if RANK in [-1, 0]:
  283. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  284. mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
  285. pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
  286. f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
  287. loggers.on_train_batch_end(ni, model, imgs, targets, paths, plots)
  288. # end batch ------------------------------------------------------------------------------------------------
  289. # Scheduler
  290. lr = [x['lr'] for x in optimizer.param_groups] # for loggers
  291. scheduler.step()
  292. if RANK in [-1, 0]:
  293. # mAP
  294. loggers.on_train_epoch_end(epoch)
  295. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
  296. final_epoch = epoch + 1 == epochs
  297. if not noval or final_epoch: # Calculate mAP
  298. results, maps, _ = val.run(data_dict,
  299. batch_size=batch_size // WORLD_SIZE * 2,
  300. imgsz=imgsz,
  301. model=ema.ema,
  302. single_cls=single_cls,
  303. dataloader=val_loader,
  304. save_dir=save_dir,
  305. save_json=is_coco and final_epoch,
  306. verbose=nc < 50 and final_epoch,
  307. plots=plots and final_epoch,
  308. loggers=loggers,
  309. compute_loss=compute_loss)
  310. # Update best mAP
  311. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  312. if fi > best_fitness:
  313. best_fitness = fi
  314. loggers.on_train_val_end(mloss, results, lr, epoch, best_fitness, fi)
  315. # Save model
  316. if (not nosave) or (final_epoch and not evolve): # if save
  317. ckpt = {'epoch': epoch,
  318. 'best_fitness': best_fitness,
  319. 'model': deepcopy(de_parallel(model)).half(),
  320. 'ema': deepcopy(ema.ema).half(),
  321. 'updates': ema.updates,
  322. 'optimizer': optimizer.state_dict(),
  323. 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None}
  324. # Save last, best and delete
  325. torch.save(ckpt, last)
  326. if best_fitness == fi:
  327. torch.save(ckpt, best)
  328. del ckpt
  329. loggers.on_model_save(last, epoch, final_epoch, best_fitness, fi)
  330. # end epoch ----------------------------------------------------------------------------------------------------
  331. # end training -----------------------------------------------------------------------------------------------------
  332. if RANK in [-1, 0]:
  333. LOGGER.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
  334. if not evolve:
  335. if is_coco: # COCO dataset
  336. for m in [last, best] if best.exists() else [last]: # speed, mAP tests
  337. results, _, _ = val.run(data_dict,
  338. batch_size=batch_size // WORLD_SIZE * 2,
  339. imgsz=imgsz,
  340. model=attempt_load(m, device).half(),
  341. iou_thres=0.7, # NMS IoU threshold for best pycocotools results
  342. single_cls=single_cls,
  343. dataloader=val_loader,
  344. save_dir=save_dir,
  345. save_json=True,
  346. plots=False)
  347. # Strip optimizers
  348. for f in last, best:
  349. if f.exists():
  350. strip_optimizer(f) # strip optimizers
  351. loggers.on_train_end(last, best, plots)
  352. torch.cuda.empty_cache()
  353. return results
  354. def parse_opt(known=False):
  355. parser = argparse.ArgumentParser()
  356. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  357. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  358. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='dataset.yaml path')
  359. parser.add_argument('--hyp', type=str, default='data/hyps/hyp.scratch.yaml', help='hyperparameters path')
  360. parser.add_argument('--epochs', type=int, default=300)
  361. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  362. parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
  363. parser.add_argument('--rect', action='store_true', help='rectangular training')
  364. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  365. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  366. parser.add_argument('--noval', action='store_true', help='only validate final epoch')
  367. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  368. parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
  369. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  370. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  371. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  372. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  373. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  374. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  375. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  376. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  377. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  378. parser.add_argument('--project', default='runs/train', help='save to project/name')
  379. parser.add_argument('--entity', default=None, help='W&B entity')
  380. parser.add_argument('--name', default='exp', help='save to project/name')
  381. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  382. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  383. parser.add_argument('--linear-lr', action='store_true', help='linear LR')
  384. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
  385. parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
  386. parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
  387. parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
  388. parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
  389. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  390. opt = parser.parse_known_args()[0] if known else parser.parse_args()
  391. return opt
  392. def main(opt):
  393. set_logging(RANK)
  394. if RANK in [-1, 0]:
  395. print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
  396. check_git_status()
  397. check_requirements(exclude=['thop'])
  398. # Resume
  399. if opt.resume and not check_wandb_resume(opt): # resume an interrupted run
  400. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  401. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  402. with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
  403. opt = argparse.Namespace(**yaml.safe_load(f)) # replace
  404. opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
  405. LOGGER.info(f'Resuming training from {ckpt}')
  406. else:
  407. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  408. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  409. opt.name = 'evolve' if opt.evolve else opt.name
  410. opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok or opt.evolve))
  411. # DDP mode
  412. device = select_device(opt.device, batch_size=opt.batch_size)
  413. if LOCAL_RANK != -1:
  414. from datetime import timedelta
  415. assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
  416. assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
  417. assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
  418. assert not opt.evolve, '--evolve argument is not compatible with DDP training'
  419. assert not opt.sync_bn, '--sync-bn known training issue, see https://github.com/ultralytics/yolov5/issues/3998'
  420. torch.cuda.set_device(LOCAL_RANK)
  421. device = torch.device('cuda', LOCAL_RANK)
  422. dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo", timeout=timedelta(seconds=60))
  423. # Train
  424. if not opt.evolve:
  425. train(opt.hyp, opt, device)
  426. if WORLD_SIZE > 1 and RANK == 0:
  427. _ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')]
  428. # Evolve hyperparameters (optional)
  429. else:
  430. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  431. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  432. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  433. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  434. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  435. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  436. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  437. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  438. 'box': (1, 0.02, 0.2), # box loss gain
  439. 'cls': (1, 0.2, 4.0), # cls loss gain
  440. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  441. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  442. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  443. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  444. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  445. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  446. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  447. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  448. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  449. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  450. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  451. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  452. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  453. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  454. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  455. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  456. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  457. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  458. 'mixup': (1, 0.0, 1.0), # image mixup (probability)
  459. 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
  460. with open(opt.hyp) as f:
  461. hyp = yaml.safe_load(f) # load hyps dict
  462. if 'anchors' not in hyp: # anchors commented in hyp.yaml
  463. hyp['anchors'] = 3
  464. opt.noval, opt.nosave = True, True # only val/save final epoch
  465. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  466. yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
  467. if opt.bucket:
  468. os.system(f'gsutil cp gs://{opt.bucket}/evolve.txt .') # download evolve.txt if exists
  469. for _ in range(opt.evolve): # generations to evolve
  470. if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate
  471. # Select parent(s)
  472. parent = 'single' # parent selection method: 'single' or 'weighted'
  473. x = np.loadtxt('evolve.txt', ndmin=2)
  474. n = min(5, len(x)) # number of previous results to consider
  475. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  476. w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
  477. if parent == 'single' or len(x) == 1:
  478. # x = x[random.randint(0, n - 1)] # random selection
  479. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  480. elif parent == 'weighted':
  481. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  482. # Mutate
  483. mp, s = 0.8, 0.2 # mutation probability, sigma
  484. npr = np.random
  485. npr.seed(int(time.time()))
  486. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  487. ng = len(meta)
  488. v = np.ones(ng)
  489. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  490. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  491. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  492. hyp[k] = float(x[i + 7] * v[i]) # mutate
  493. # Constrain to limits
  494. for k, v in meta.items():
  495. hyp[k] = max(hyp[k], v[1]) # lower limit
  496. hyp[k] = min(hyp[k], v[2]) # upper limit
  497. hyp[k] = round(hyp[k], 5) # significant digits
  498. # Train mutation
  499. results = train(hyp.copy(), opt, device)
  500. # Write mutation results
  501. print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
  502. # Plot results
  503. plot_evolution(yaml_file)
  504. print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'
  505. f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')
  506. def run(**kwargs):
  507. # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
  508. opt = parse_opt(True)
  509. for k, v in kwargs.items():
  510. setattr(opt, k, v)
  511. main(opt)
  512. if __name__ == "__main__":
  513. opt = parse_opt()
  514. main(opt)