Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.

637 lines
33KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Train a YOLOv5 model on a custom dataset.
  4. Models and datasets download automatically from the latest YOLOv5 release.
  5. Models: https://github.com/ultralytics/yolov5/tree/master/models
  6. Datasets: https://github.com/ultralytics/yolov5/tree/master/data
  7. Tutorial: https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data
  8. Usage:
  9. $ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640 # from pretrained (RECOMMENDED)
  10. $ python path/to/train.py --data coco128.yaml --weights '' --cfg yolov5s.yaml --img 640 # from scratch
  11. """
  12. import argparse
  13. import math
  14. import os
  15. import random
  16. import sys
  17. import time
  18. from copy import deepcopy
  19. from datetime import datetime
  20. from pathlib import Path
  21. import numpy as np
  22. import torch
  23. import torch.distributed as dist
  24. import torch.nn as nn
  25. import yaml
  26. from torch.cuda import amp
  27. from torch.nn.parallel import DistributedDataParallel as DDP
  28. from torch.optim import SGD, Adam, AdamW, lr_scheduler
  29. from tqdm import tqdm
  30. FILE = Path(__file__).resolve()
  31. ROOT = FILE.parents[0] # YOLOv5 root directory
  32. if str(ROOT) not in sys.path:
  33. sys.path.append(str(ROOT)) # add ROOT to PATH
  34. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  35. import val # for end-of-epoch mAP
  36. from models.experimental import attempt_load
  37. from models.yolo import Model
  38. from utils.autoanchor import check_anchors
  39. from utils.autobatch import check_train_batch_size
  40. from utils.callbacks import Callbacks
  41. from utils.datasets import create_dataloader
  42. from utils.downloads import attempt_download
  43. from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
  44. check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
  45. intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle,
  46. print_args, print_mutation, strip_optimizer)
  47. from utils.loggers import Loggers
  48. from utils.loggers.wandb.wandb_utils import check_wandb_resume
  49. from utils.loss import ComputeLoss
  50. from utils.metrics import fitness
  51. from utils.plots import plot_evolve, plot_labels
  52. from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
  53. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  54. RANK = int(os.getenv('RANK', -1))
  55. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
  56. def train(hyp, # path/to/hyp.yaml or hyp dictionary
  57. opt,
  58. device,
  59. callbacks
  60. ):
  61. save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = \
  62. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
  63. opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
  64. # Directories
  65. w = save_dir / 'weights' # weights dir
  66. (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
  67. last, best = w / 'last.pt', w / 'best.pt'
  68. # Hyperparameters
  69. if isinstance(hyp, str):
  70. with open(hyp, errors='ignore') as f:
  71. hyp = yaml.safe_load(f) # load hyps dict
  72. LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  73. # Save run settings
  74. if not evolve:
  75. with open(save_dir / 'hyp.yaml', 'w') as f:
  76. yaml.safe_dump(hyp, f, sort_keys=False)
  77. with open(save_dir / 'opt.yaml', 'w') as f:
  78. yaml.safe_dump(vars(opt), f, sort_keys=False)
  79. # Loggers
  80. data_dict = None
  81. if RANK in [-1, 0]:
  82. loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
  83. if loggers.wandb:
  84. data_dict = loggers.wandb.data_dict
  85. if resume:
  86. weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp
  87. # Register actions
  88. for k in methods(loggers):
  89. callbacks.register_action(k, callback=getattr(loggers, k))
  90. # Config
  91. plots = not evolve # create plots
  92. cuda = device.type != 'cpu'
  93. init_seeds(1 + RANK)
  94. with torch_distributed_zero_first(LOCAL_RANK):
  95. data_dict = data_dict or check_dataset(data) # check if None
  96. train_path, val_path = data_dict['train'], data_dict['val']
  97. nc = 1 if single_cls else int(data_dict['nc']) # number of classes
  98. names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  99. assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
  100. is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
  101. # Model
  102. check_suffix(weights, '.pt') # check weights
  103. pretrained = weights.endswith('.pt')
  104. if pretrained:
  105. with torch_distributed_zero_first(LOCAL_RANK):
  106. weights = attempt_download(weights) # download if not found locally
  107. ckpt = torch.load(weights, map_location=device) # load checkpoint
  108. model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  109. exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
  110. csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
  111. csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
  112. model.load_state_dict(csd, strict=False) # load
  113. LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
  114. else:
  115. model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  116. # Freeze
  117. freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
  118. for k, v in model.named_parameters():
  119. v.requires_grad = True # train all layers
  120. if any(x in k for x in freeze):
  121. LOGGER.info(f'freezing {k}')
  122. v.requires_grad = False
  123. # Image size
  124. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  125. imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
  126. # Batch size
  127. if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
  128. batch_size = check_train_batch_size(model, imgsz)
  129. loggers.on_params_update({"batch_size": batch_size})
  130. # Optimizer
  131. nbs = 64 # nominal batch size
  132. accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
  133. hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
  134. LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  135. g0, g1, g2 = [], [], [] # optimizer parameter groups
  136. for v in model.modules():
  137. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias
  138. g2.append(v.bias)
  139. if isinstance(v, nn.BatchNorm2d): # weight (no decay)
  140. g0.append(v.weight)
  141. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
  142. g1.append(v.weight)
  143. if opt.optimizer == 'Adam':
  144. optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  145. elif opt.optimizer == 'AdamW':
  146. optimizer = AdamW(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  147. else:
  148. optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  149. optimizer.add_param_group({'params': g1, 'weight_decay': hyp['weight_decay']}) # add g1 with weight_decay
  150. optimizer.add_param_group({'params': g2}) # add g2 (biases)
  151. LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
  152. f"{len(g0)} weight (no decay), {len(g1)} weight, {len(g2)} bias")
  153. del g0, g1, g2
  154. # Scheduler
  155. if opt.linear_lr:
  156. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  157. else:
  158. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  159. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
  160. # EMA
  161. ema = ModelEMA(model) if RANK in [-1, 0] else None
  162. # Resume
  163. start_epoch, best_fitness = 0, 0.0
  164. if pretrained:
  165. # Optimizer
  166. if ckpt['optimizer'] is not None:
  167. optimizer.load_state_dict(ckpt['optimizer'])
  168. best_fitness = ckpt['best_fitness']
  169. # EMA
  170. if ema and ckpt.get('ema'):
  171. ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
  172. ema.updates = ckpt['updates']
  173. # Epochs
  174. start_epoch = ckpt['epoch'] + 1
  175. if resume:
  176. assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
  177. if epochs < start_epoch:
  178. LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
  179. epochs += ckpt['epoch'] # finetune additional epochs
  180. del ckpt, csd
  181. # DP mode
  182. if cuda and RANK == -1 and torch.cuda.device_count() > 1:
  183. LOGGER.warning('WARNING: DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.\n'
  184. 'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.')
  185. model = torch.nn.DataParallel(model)
  186. # SyncBatchNorm
  187. if opt.sync_bn and cuda and RANK != -1:
  188. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  189. LOGGER.info('Using SyncBatchNorm()')
  190. # Trainloader
  191. train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
  192. hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
  193. workers=workers, image_weights=opt.image_weights, quad=opt.quad,
  194. prefix=colorstr('train: '), shuffle=True)
  195. mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
  196. nb = len(train_loader) # number of batches
  197. assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
  198. # Process 0
  199. if RANK in [-1, 0]:
  200. val_loader = create_dataloader(val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls,
  201. hyp=hyp, cache=None if noval else opt.cache, rect=True, rank=-1,
  202. workers=workers, pad=0.5,
  203. prefix=colorstr('val: '))[0]
  204. if not resume:
  205. labels = np.concatenate(dataset.labels, 0)
  206. # c = torch.tensor(labels[:, 0]) # classes
  207. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  208. # model._initialize_biases(cf.to(device))
  209. if plots:
  210. plot_labels(labels, names, save_dir)
  211. # Anchors
  212. if not opt.noautoanchor:
  213. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  214. model.half().float() # pre-reduce anchor precision
  215. callbacks.run('on_pretrain_routine_end')
  216. # DDP mode
  217. if cuda and RANK != -1:
  218. model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
  219. # Model attributes
  220. nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
  221. hyp['box'] *= 3 / nl # scale to layers
  222. hyp['cls'] *= nc / 80 * 3 / nl # scale to classes and layers
  223. hyp['obj'] *= (imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
  224. hyp['label_smoothing'] = opt.label_smoothing
  225. model.nc = nc # attach number of classes to model
  226. model.hyp = hyp # attach hyperparameters to model
  227. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  228. model.names = names
  229. # Start training
  230. t0 = time.time()
  231. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  232. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  233. last_opt_step = -1
  234. maps = np.zeros(nc) # mAP per class
  235. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  236. scheduler.last_epoch = start_epoch - 1 # do not move
  237. scaler = amp.GradScaler(enabled=cuda)
  238. stopper = EarlyStopping(patience=opt.patience)
  239. compute_loss = ComputeLoss(model) # init loss class
  240. LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
  241. f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n'
  242. f"Logging results to {colorstr('bold', save_dir)}\n"
  243. f'Starting training for {epochs} epochs...')
  244. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  245. model.train()
  246. # Update image weights (optional, single-GPU only)
  247. if opt.image_weights:
  248. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  249. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  250. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  251. # Update mosaic border (optional)
  252. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  253. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  254. mloss = torch.zeros(3, device=device) # mean losses
  255. if RANK != -1:
  256. train_loader.sampler.set_epoch(epoch)
  257. pbar = enumerate(train_loader)
  258. LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))
  259. if RANK in [-1, 0]:
  260. pbar = tqdm(pbar, total=nb, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
  261. optimizer.zero_grad()
  262. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  263. ni = i + nb * epoch # number integrated batches (since train start)
  264. imgs = imgs.to(device, non_blocking=True).float() / 255 # uint8 to float32, 0-255 to 0.0-1.0
  265. # Warmup
  266. if ni <= nw:
  267. xi = [0, nw] # x interp
  268. # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  269. accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
  270. for j, x in enumerate(optimizer.param_groups):
  271. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  272. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  273. if 'momentum' in x:
  274. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  275. # Multi-scale
  276. if opt.multi_scale:
  277. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  278. sf = sz / max(imgs.shape[2:]) # scale factor
  279. if sf != 1:
  280. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  281. imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  282. # Forward
  283. with amp.autocast(enabled=cuda):
  284. pred = model(imgs) # forward
  285. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  286. if RANK != -1:
  287. loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
  288. if opt.quad:
  289. loss *= 4.
  290. # Backward
  291. scaler.scale(loss).backward()
  292. # Optimize
  293. if ni - last_opt_step >= accumulate:
  294. scaler.step(optimizer) # optimizer.step
  295. scaler.update()
  296. optimizer.zero_grad()
  297. if ema:
  298. ema.update(model)
  299. last_opt_step = ni
  300. # Log
  301. if RANK in [-1, 0]:
  302. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  303. mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
  304. pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
  305. f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
  306. callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots, opt.sync_bn)
  307. # end batch ------------------------------------------------------------------------------------------------
  308. # Scheduler
  309. lr = [x['lr'] for x in optimizer.param_groups] # for loggers
  310. scheduler.step()
  311. if RANK in [-1, 0]:
  312. # mAP
  313. callbacks.run('on_train_epoch_end', epoch=epoch)
  314. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
  315. final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
  316. if not noval or final_epoch: # Calculate mAP
  317. results, maps, _ = val.run(data_dict,
  318. batch_size=batch_size // WORLD_SIZE * 2,
  319. imgsz=imgsz,
  320. model=ema.ema,
  321. single_cls=single_cls,
  322. dataloader=val_loader,
  323. save_dir=save_dir,
  324. plots=False,
  325. callbacks=callbacks,
  326. compute_loss=compute_loss)
  327. # Update best mAP
  328. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  329. if fi > best_fitness:
  330. best_fitness = fi
  331. log_vals = list(mloss) + list(results) + lr
  332. callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
  333. # Save model
  334. if (not nosave) or (final_epoch and not evolve): # if save
  335. ckpt = {'epoch': epoch,
  336. 'best_fitness': best_fitness,
  337. 'model': deepcopy(de_parallel(model)).half(),
  338. 'ema': deepcopy(ema.ema).half(),
  339. 'updates': ema.updates,
  340. 'optimizer': optimizer.state_dict(),
  341. 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None,
  342. 'date': datetime.now().isoformat()}
  343. # Save last, best and delete
  344. torch.save(ckpt, last)
  345. if best_fitness == fi:
  346. torch.save(ckpt, best)
  347. if (epoch > 0) and (opt.save_period > 0) and (epoch % opt.save_period == 0):
  348. torch.save(ckpt, w / f'epoch{epoch}.pt')
  349. del ckpt
  350. callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
  351. # Stop Single-GPU
  352. if RANK == -1 and stopper(epoch=epoch, fitness=fi):
  353. break
  354. # Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576
  355. # stop = stopper(epoch=epoch, fitness=fi)
  356. # if RANK == 0:
  357. # dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks
  358. # Stop DPP
  359. # with torch_distributed_zero_first(RANK):
  360. # if stop:
  361. # break # must break all DDP ranks
  362. # end epoch ----------------------------------------------------------------------------------------------------
  363. # end training -----------------------------------------------------------------------------------------------------
  364. if RANK in [-1, 0]:
  365. LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
  366. for f in last, best:
  367. if f.exists():
  368. strip_optimizer(f) # strip optimizers
  369. if f is best:
  370. LOGGER.info(f'\nValidating {f}...')
  371. results, _, _ = val.run(data_dict,
  372. batch_size=batch_size // WORLD_SIZE * 2,
  373. imgsz=imgsz,
  374. model=attempt_load(f, device).half(),
  375. iou_thres=0.65 if is_coco else 0.60, # best pycocotools results at 0.65
  376. single_cls=single_cls,
  377. dataloader=val_loader,
  378. save_dir=save_dir,
  379. save_json=is_coco,
  380. verbose=True,
  381. plots=True,
  382. callbacks=callbacks,
  383. compute_loss=compute_loss) # val best model with plots
  384. if is_coco:
  385. callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
  386. callbacks.run('on_train_end', last, best, plots, epoch, results)
  387. LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
  388. torch.cuda.empty_cache()
  389. return results
  390. def parse_opt(known=False):
  391. parser = argparse.ArgumentParser()
  392. parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='initial weights path')
  393. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  394. parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
  395. parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch.yaml', help='hyperparameters path')
  396. parser.add_argument('--epochs', type=int, default=300)
  397. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
  398. parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
  399. parser.add_argument('--rect', action='store_true', help='rectangular training')
  400. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  401. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  402. parser.add_argument('--noval', action='store_true', help='only validate final epoch')
  403. parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor')
  404. parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
  405. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  406. parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
  407. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  408. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  409. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  410. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  411. parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='SGD', help='optimizer')
  412. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  413. parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
  414. parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
  415. parser.add_argument('--name', default='exp', help='save to project/name')
  416. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  417. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  418. parser.add_argument('--linear-lr', action='store_true', help='linear LR')
  419. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
  420. parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
  421. parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
  422. parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
  423. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  424. # Weights & Biases arguments
  425. parser.add_argument('--entity', default=None, help='W&B: Entity')
  426. parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='W&B: Upload data, "val" option')
  427. parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval')
  428. parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')
  429. opt = parser.parse_known_args()[0] if known else parser.parse_args()
  430. return opt
  431. def main(opt, callbacks=Callbacks()):
  432. # Checks
  433. if RANK in [-1, 0]:
  434. print_args(FILE.stem, opt)
  435. check_git_status()
  436. check_requirements(exclude=['thop'])
  437. # Resume
  438. if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run
  439. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  440. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  441. with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f:
  442. opt = argparse.Namespace(**yaml.safe_load(f)) # replace
  443. opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
  444. LOGGER.info(f'Resuming training from {ckpt}')
  445. else:
  446. opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
  447. check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
  448. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  449. if opt.evolve:
  450. opt.project = str(ROOT / 'runs/evolve')
  451. opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
  452. opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
  453. # DDP mode
  454. device = select_device(opt.device, batch_size=opt.batch_size)
  455. if LOCAL_RANK != -1:
  456. assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
  457. assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
  458. assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
  459. assert not opt.evolve, '--evolve argument is not compatible with DDP training'
  460. torch.cuda.set_device(LOCAL_RANK)
  461. device = torch.device('cuda', LOCAL_RANK)
  462. dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
  463. # Train
  464. if not opt.evolve:
  465. train(opt.hyp, opt, device, callbacks)
  466. if WORLD_SIZE > 1 and RANK == 0:
  467. LOGGER.info('Destroying process group... ')
  468. dist.destroy_process_group()
  469. # Evolve hyperparameters (optional)
  470. else:
  471. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  472. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  473. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  474. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  475. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  476. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  477. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  478. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  479. 'box': (1, 0.02, 0.2), # box loss gain
  480. 'cls': (1, 0.2, 4.0), # cls loss gain
  481. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  482. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  483. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  484. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  485. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  486. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  487. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  488. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  489. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  490. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  491. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  492. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  493. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  494. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  495. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  496. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  497. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  498. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  499. 'mixup': (1, 0.0, 1.0), # image mixup (probability)
  500. 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
  501. with open(opt.hyp, errors='ignore') as f:
  502. hyp = yaml.safe_load(f) # load hyps dict
  503. if 'anchors' not in hyp: # anchors commented in hyp.yaml
  504. hyp['anchors'] = 3
  505. opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
  506. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  507. evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
  508. if opt.bucket:
  509. os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {save_dir}') # download evolve.csv if exists
  510. for _ in range(opt.evolve): # generations to evolve
  511. if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
  512. # Select parent(s)
  513. parent = 'single' # parent selection method: 'single' or 'weighted'
  514. x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
  515. n = min(5, len(x)) # number of previous results to consider
  516. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  517. w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
  518. if parent == 'single' or len(x) == 1:
  519. # x = x[random.randint(0, n - 1)] # random selection
  520. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  521. elif parent == 'weighted':
  522. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  523. # Mutate
  524. mp, s = 0.8, 0.2 # mutation probability, sigma
  525. npr = np.random
  526. npr.seed(int(time.time()))
  527. g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1
  528. ng = len(meta)
  529. v = np.ones(ng)
  530. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  531. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  532. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  533. hyp[k] = float(x[i + 7] * v[i]) # mutate
  534. # Constrain to limits
  535. for k, v in meta.items():
  536. hyp[k] = max(hyp[k], v[1]) # lower limit
  537. hyp[k] = min(hyp[k], v[2]) # upper limit
  538. hyp[k] = round(hyp[k], 5) # significant digits
  539. # Train mutation
  540. results = train(hyp.copy(), opt, device, callbacks)
  541. callbacks = Callbacks()
  542. # Write mutation results
  543. print_mutation(results, hyp.copy(), save_dir, opt.bucket)
  544. # Plot results
  545. plot_evolve(evolve_csv)
  546. LOGGER.info(f'Hyperparameter evolution finished\n'
  547. f"Results saved to {colorstr('bold', save_dir)}\n"
  548. f'Use best hyperparameters example: $ python train.py --hyp {evolve_yaml}')
  549. def run(**kwargs):
  550. # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
  551. opt = parse_opt(True)
  552. for k, v in kwargs.items():
  553. setattr(opt, k, v)
  554. main(opt)
  555. if __name__ == "__main__":
  556. opt = parse_opt()
  557. main(opt)