You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

612 lines
31KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Train a YOLOv5 model on a custom dataset
  4. Usage:
  5. $ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640
  6. """
  7. import argparse
  8. import logging
  9. import math
  10. import os
  11. import random
  12. import sys
  13. import time
  14. from copy import deepcopy
  15. from pathlib import Path
  16. import numpy as np
  17. import torch
  18. import torch.distributed as dist
  19. import torch.nn as nn
  20. import yaml
  21. from torch.cuda import amp
  22. from torch.nn.parallel import DistributedDataParallel as DDP
  23. from torch.optim import Adam, SGD, lr_scheduler
  24. from tqdm import tqdm
  25. FILE = Path(__file__).resolve()
  26. sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path
  27. import val # for end-of-epoch mAP
  28. from models.experimental import attempt_load
  29. from models.yolo import Model
  30. from utils.autoanchor import check_anchors
  31. from utils.datasets import create_dataloader
  32. from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
  33. strip_optimizer, get_latest_run, check_dataset, check_git_status, check_img_size, check_requirements, \
  34. check_file, check_yaml, check_suffix, print_mutation, set_logging, one_cycle, colorstr, methods
  35. from utils.downloads import attempt_download
  36. from utils.loss import ComputeLoss
  37. from utils.plots import plot_labels, plot_evolve
  38. from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, intersect_dicts, select_device, \
  39. torch_distributed_zero_first
  40. from utils.loggers.wandb.wandb_utils import check_wandb_resume
  41. from utils.metrics import fitness
  42. from utils.loggers import Loggers
  43. from utils.callbacks import Callbacks
  44. LOGGER = logging.getLogger(__name__)
  45. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  46. RANK = int(os.getenv('RANK', -1))
  47. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
  48. def train(hyp, # path/to/hyp.yaml or hyp dictionary
  49. opt,
  50. device,
  51. callbacks
  52. ):
  53. save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
  54. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
  55. opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
  56. # Directories
  57. w = save_dir / 'weights' # weights dir
  58. (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
  59. last, best = w / 'last.pt', w / 'best.pt'
  60. # Hyperparameters
  61. if isinstance(hyp, str):
  62. with open(hyp) as f:
  63. hyp = yaml.safe_load(f) # load hyps dict
  64. LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  65. # Save run settings
  66. with open(save_dir / 'hyp.yaml', 'w') as f:
  67. yaml.safe_dump(hyp, f, sort_keys=False)
  68. with open(save_dir / 'opt.yaml', 'w') as f:
  69. yaml.safe_dump(vars(opt), f, sort_keys=False)
  70. data_dict = None
  71. # Loggers
  72. if RANK in [-1, 0]:
  73. loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
  74. if loggers.wandb:
  75. data_dict = loggers.wandb.data_dict
  76. if resume:
  77. weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp
  78. # Register actions
  79. for k in methods(loggers):
  80. callbacks.register_action(k, callback=getattr(loggers, k))
  81. # Config
  82. plots = not evolve # create plots
  83. cuda = device.type != 'cpu'
  84. init_seeds(1 + RANK)
  85. with torch_distributed_zero_first(RANK):
  86. data_dict = data_dict or check_dataset(data) # check if None
  87. train_path, val_path = data_dict['train'], data_dict['val']
  88. nc = 1 if single_cls else int(data_dict['nc']) # number of classes
  89. names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  90. assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
  91. is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset
  92. # Model
  93. check_suffix(weights, '.pt') # check weights
  94. pretrained = weights.endswith('.pt')
  95. if pretrained:
  96. with torch_distributed_zero_first(RANK):
  97. weights = attempt_download(weights) # download if not found locally
  98. ckpt = torch.load(weights, map_location=device) # load checkpoint
  99. model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  100. exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
  101. csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
  102. csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
  103. model.load_state_dict(csd, strict=False) # load
  104. LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
  105. else:
  106. model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  107. # Freeze
  108. freeze = [f'model.{x}.' for x in range(freeze)] # layers to freeze
  109. for k, v in model.named_parameters():
  110. v.requires_grad = True # train all layers
  111. if any(x in k for x in freeze):
  112. print(f'freezing {k}')
  113. v.requires_grad = False
  114. # Optimizer
  115. nbs = 64 # nominal batch size
  116. accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
  117. hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
  118. LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  119. g0, g1, g2 = [], [], [] # optimizer parameter groups
  120. for v in model.modules():
  121. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias
  122. g2.append(v.bias)
  123. if isinstance(v, nn.BatchNorm2d): # weight (no decay)
  124. g0.append(v.weight)
  125. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
  126. g1.append(v.weight)
  127. if opt.adam:
  128. optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  129. else:
  130. optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  131. optimizer.add_param_group({'params': g1, 'weight_decay': hyp['weight_decay']}) # add g1 with weight_decay
  132. optimizer.add_param_group({'params': g2}) # add g2 (biases)
  133. LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
  134. f"{len(g0)} weight, {len(g1)} weight (no decay), {len(g2)} bias")
  135. del g0, g1, g2
  136. # Scheduler
  137. if opt.linear_lr:
  138. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  139. else:
  140. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  141. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
  142. # EMA
  143. ema = ModelEMA(model) if RANK in [-1, 0] else None
  144. # Resume
  145. start_epoch, best_fitness = 0, 0.0
  146. if pretrained:
  147. # Optimizer
  148. if ckpt['optimizer'] is not None:
  149. optimizer.load_state_dict(ckpt['optimizer'])
  150. best_fitness = ckpt['best_fitness']
  151. # EMA
  152. if ema and ckpt.get('ema'):
  153. ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
  154. ema.updates = ckpt['updates']
  155. # Epochs
  156. start_epoch = ckpt['epoch'] + 1
  157. if resume:
  158. assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
  159. if epochs < start_epoch:
  160. LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
  161. epochs += ckpt['epoch'] # finetune additional epochs
  162. del ckpt, csd
  163. # Image sizes
  164. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  165. nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj'])
  166. imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
  167. # DP mode
  168. if cuda and RANK == -1 and torch.cuda.device_count() > 1:
  169. logging.warning('DP not recommended, instead use torch.distributed.run for best DDP Multi-GPU results.\n'
  170. 'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.')
  171. model = torch.nn.DataParallel(model)
  172. # SyncBatchNorm
  173. if opt.sync_bn and cuda and RANK != -1:
  174. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  175. LOGGER.info('Using SyncBatchNorm()')
  176. # Trainloader
  177. train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
  178. hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=RANK,
  179. workers=workers, image_weights=opt.image_weights, quad=opt.quad,
  180. prefix=colorstr('train: '))
  181. mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
  182. nb = len(train_loader) # number of batches
  183. assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
  184. # Process 0
  185. if RANK in [-1, 0]:
  186. val_loader = create_dataloader(val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls,
  187. hyp=hyp, cache=None if noval else opt.cache, rect=True, rank=-1,
  188. workers=workers, pad=0.5,
  189. prefix=colorstr('val: '))[0]
  190. if not resume:
  191. labels = np.concatenate(dataset.labels, 0)
  192. # c = torch.tensor(labels[:, 0]) # classes
  193. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  194. # model._initialize_biases(cf.to(device))
  195. if plots:
  196. plot_labels(labels, names, save_dir)
  197. # Anchors
  198. if not opt.noautoanchor:
  199. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  200. model.half().float() # pre-reduce anchor precision
  201. callbacks.run('on_pretrain_routine_end')
  202. # DDP mode
  203. if cuda and RANK != -1:
  204. model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
  205. # Model parameters
  206. hyp['box'] *= 3. / nl # scale to layers
  207. hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
  208. hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
  209. hyp['label_smoothing'] = opt.label_smoothing
  210. model.nc = nc # attach number of classes to model
  211. model.hyp = hyp # attach hyperparameters to model
  212. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  213. model.names = names
  214. # Start training
  215. t0 = time.time()
  216. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  217. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  218. last_opt_step = -1
  219. maps = np.zeros(nc) # mAP per class
  220. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  221. scheduler.last_epoch = start_epoch - 1 # do not move
  222. scaler = amp.GradScaler(enabled=cuda)
  223. stopper = EarlyStopping(patience=opt.patience)
  224. compute_loss = ComputeLoss(model) # init loss class
  225. LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
  226. f'Using {train_loader.num_workers} dataloader workers\n'
  227. f"Logging results to {colorstr('bold', save_dir)}\n"
  228. f'Starting training for {epochs} epochs...')
  229. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  230. model.train()
  231. # Update image weights (optional, single-GPU only)
  232. if opt.image_weights:
  233. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  234. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  235. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  236. # Update mosaic border (optional)
  237. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  238. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  239. mloss = torch.zeros(3, device=device) # mean losses
  240. if RANK != -1:
  241. train_loader.sampler.set_epoch(epoch)
  242. pbar = enumerate(train_loader)
  243. LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))
  244. if RANK in [-1, 0]:
  245. pbar = tqdm(pbar, total=nb) # progress bar
  246. optimizer.zero_grad()
  247. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  248. ni = i + nb * epoch # number integrated batches (since train start)
  249. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  250. # Warmup
  251. if ni <= nw:
  252. xi = [0, nw] # x interp
  253. # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  254. accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
  255. for j, x in enumerate(optimizer.param_groups):
  256. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  257. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  258. if 'momentum' in x:
  259. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  260. # Multi-scale
  261. if opt.multi_scale:
  262. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  263. sf = sz / max(imgs.shape[2:]) # scale factor
  264. if sf != 1:
  265. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  266. imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  267. # Forward
  268. with amp.autocast(enabled=cuda):
  269. pred = model(imgs) # forward
  270. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  271. if RANK != -1:
  272. loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
  273. if opt.quad:
  274. loss *= 4.
  275. # Backward
  276. scaler.scale(loss).backward()
  277. # Optimize
  278. if ni - last_opt_step >= accumulate:
  279. scaler.step(optimizer) # optimizer.step
  280. scaler.update()
  281. optimizer.zero_grad()
  282. if ema:
  283. ema.update(model)
  284. last_opt_step = ni
  285. # Log
  286. if RANK in [-1, 0]:
  287. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  288. mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
  289. pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
  290. f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
  291. callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots, opt.sync_bn)
  292. # end batch ------------------------------------------------------------------------------------------------
  293. # Scheduler
  294. lr = [x['lr'] for x in optimizer.param_groups] # for loggers
  295. scheduler.step()
  296. if RANK in [-1, 0]:
  297. # mAP
  298. callbacks.run('on_train_epoch_end', epoch=epoch)
  299. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
  300. final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
  301. if not noval or final_epoch: # Calculate mAP
  302. results, maps, _ = val.run(data_dict,
  303. batch_size=batch_size // WORLD_SIZE * 2,
  304. imgsz=imgsz,
  305. model=ema.ema,
  306. single_cls=single_cls,
  307. dataloader=val_loader,
  308. save_dir=save_dir,
  309. save_json=is_coco and final_epoch,
  310. verbose=nc < 50 and final_epoch,
  311. plots=plots and final_epoch,
  312. callbacks=callbacks,
  313. compute_loss=compute_loss)
  314. # Update best mAP
  315. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  316. if fi > best_fitness:
  317. best_fitness = fi
  318. log_vals = list(mloss) + list(results) + lr
  319. callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
  320. # Save model
  321. if (not nosave) or (final_epoch and not evolve): # if save
  322. ckpt = {'epoch': epoch,
  323. 'best_fitness': best_fitness,
  324. 'model': deepcopy(de_parallel(model)).half(),
  325. 'ema': deepcopy(ema.ema).half(),
  326. 'updates': ema.updates,
  327. 'optimizer': optimizer.state_dict(),
  328. 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None}
  329. # Save last, best and delete
  330. torch.save(ckpt, last)
  331. if best_fitness == fi:
  332. torch.save(ckpt, best)
  333. del ckpt
  334. callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
  335. # Stop Single-GPU
  336. if RANK == -1 and stopper(epoch=epoch, fitness=fi):
  337. break
  338. # Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576
  339. # stop = stopper(epoch=epoch, fitness=fi)
  340. # if RANK == 0:
  341. # dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks
  342. # Stop DPP
  343. # with torch_distributed_zero_first(RANK):
  344. # if stop:
  345. # break # must break all DDP ranks
  346. # end epoch ----------------------------------------------------------------------------------------------------
  347. # end training -----------------------------------------------------------------------------------------------------
  348. if RANK in [-1, 0]:
  349. LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
  350. if not evolve:
  351. if is_coco: # COCO dataset
  352. for m in [last, best] if best.exists() else [last]: # speed, mAP tests
  353. results, _, _ = val.run(data_dict,
  354. batch_size=batch_size // WORLD_SIZE * 2,
  355. imgsz=imgsz,
  356. model=attempt_load(m, device).half(),
  357. iou_thres=0.7, # NMS IoU threshold for best pycocotools results
  358. single_cls=single_cls,
  359. dataloader=val_loader,
  360. save_dir=save_dir,
  361. save_json=True,
  362. plots=False)
  363. # Strip optimizers
  364. for f in last, best:
  365. if f.exists():
  366. strip_optimizer(f) # strip optimizers
  367. callbacks.run('on_train_end', last, best, plots, epoch)
  368. LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
  369. torch.cuda.empty_cache()
  370. return results
  371. def parse_opt(known=False):
  372. parser = argparse.ArgumentParser()
  373. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  374. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  375. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='dataset.yaml path')
  376. parser.add_argument('--hyp', type=str, default='data/hyps/hyp.scratch.yaml', help='hyperparameters path')
  377. parser.add_argument('--epochs', type=int, default=300)
  378. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  379. parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
  380. parser.add_argument('--rect', action='store_true', help='rectangular training')
  381. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  382. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  383. parser.add_argument('--noval', action='store_true', help='only validate final epoch')
  384. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  385. parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
  386. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  387. parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
  388. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  389. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  390. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  391. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  392. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  393. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  394. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  395. parser.add_argument('--project', default='runs/train', help='save to project/name')
  396. parser.add_argument('--entity', default=None, help='W&B entity')
  397. parser.add_argument('--name', default='exp', help='save to project/name')
  398. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  399. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  400. parser.add_argument('--linear-lr', action='store_true', help='linear LR')
  401. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
  402. parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
  403. parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
  404. parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
  405. parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
  406. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  407. parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
  408. parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
  409. opt = parser.parse_known_args()[0] if known else parser.parse_args()
  410. return opt
  411. def main(opt, callbacks=Callbacks()):
  412. # Checks
  413. set_logging(RANK)
  414. if RANK in [-1, 0]:
  415. print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
  416. check_git_status()
  417. check_requirements(requirements=FILE.parent / 'requirements.txt', exclude=['thop'])
  418. # Resume
  419. if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run
  420. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  421. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  422. with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
  423. opt = argparse.Namespace(**yaml.safe_load(f)) # replace
  424. opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
  425. LOGGER.info(f'Resuming training from {ckpt}')
  426. else:
  427. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp) # check YAMLs
  428. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  429. if opt.evolve:
  430. opt.project = 'runs/evolve'
  431. opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
  432. opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
  433. # DDP mode
  434. device = select_device(opt.device, batch_size=opt.batch_size)
  435. if LOCAL_RANK != -1:
  436. from datetime import timedelta
  437. assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
  438. assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
  439. assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
  440. assert not opt.evolve, '--evolve argument is not compatible with DDP training'
  441. torch.cuda.set_device(LOCAL_RANK)
  442. device = torch.device('cuda', LOCAL_RANK)
  443. dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
  444. # Train
  445. if not opt.evolve:
  446. train(opt.hyp, opt, device, callbacks)
  447. if WORLD_SIZE > 1 and RANK == 0:
  448. _ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')]
  449. # Evolve hyperparameters (optional)
  450. else:
  451. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  452. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  453. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  454. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  455. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  456. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  457. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  458. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  459. 'box': (1, 0.02, 0.2), # box loss gain
  460. 'cls': (1, 0.2, 4.0), # cls loss gain
  461. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  462. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  463. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  464. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  465. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  466. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  467. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  468. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  469. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  470. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  471. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  472. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  473. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  474. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  475. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  476. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  477. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  478. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  479. 'mixup': (1, 0.0, 1.0), # image mixup (probability)
  480. 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
  481. with open(opt.hyp) as f:
  482. hyp = yaml.safe_load(f) # load hyps dict
  483. if 'anchors' not in hyp: # anchors commented in hyp.yaml
  484. hyp['anchors'] = 3
  485. opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
  486. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  487. evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
  488. if opt.bucket:
  489. os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {save_dir}') # download evolve.csv if exists
  490. for _ in range(opt.evolve): # generations to evolve
  491. if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
  492. # Select parent(s)
  493. parent = 'single' # parent selection method: 'single' or 'weighted'
  494. x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
  495. n = min(5, len(x)) # number of previous results to consider
  496. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  497. w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
  498. if parent == 'single' or len(x) == 1:
  499. # x = x[random.randint(0, n - 1)] # random selection
  500. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  501. elif parent == 'weighted':
  502. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  503. # Mutate
  504. mp, s = 0.8, 0.2 # mutation probability, sigma
  505. npr = np.random
  506. npr.seed(int(time.time()))
  507. g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1
  508. ng = len(meta)
  509. v = np.ones(ng)
  510. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  511. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  512. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  513. hyp[k] = float(x[i + 7] * v[i]) # mutate
  514. # Constrain to limits
  515. for k, v in meta.items():
  516. hyp[k] = max(hyp[k], v[1]) # lower limit
  517. hyp[k] = min(hyp[k], v[2]) # upper limit
  518. hyp[k] = round(hyp[k], 5) # significant digits
  519. # Train mutation
  520. results = train(hyp.copy(), opt, device, callbacks)
  521. # Write mutation results
  522. print_mutation(results, hyp.copy(), save_dir, opt.bucket)
  523. # Plot results
  524. plot_evolve(evolve_csv)
  525. print(f'Hyperparameter evolution finished\n'
  526. f"Results saved to {colorstr('bold', save_dir)}\n"
  527. f'Use best hyperparameters example: $ python train.py --hyp {evolve_yaml}')
  528. def run(**kwargs):
  529. # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
  530. opt = parse_opt(True)
  531. for k, v in kwargs.items():
  532. setattr(opt, k, v)
  533. main(opt)
  534. if __name__ == "__main__":
  535. opt = parse_opt()
  536. main(opt)