You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

657 line
33KB

  1. """Train a YOLOv5 model on a custom dataset
  2. Usage:
  3. $ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640
  4. """
  5. import argparse
  6. import logging
  7. import math
  8. import os
  9. import random
  10. import sys
  11. import time
  12. import warnings
  13. from copy import deepcopy
  14. from pathlib import Path
  15. from threading import Thread
  16. import numpy as np
  17. import torch.distributed as dist
  18. import torch.nn as nn
  19. import torch.nn.functional as F
  20. import torch.optim as optim
  21. import torch.optim.lr_scheduler as lr_scheduler
  22. import torch.utils.data
  23. import yaml
  24. from torch.cuda import amp
  25. from torch.nn.parallel import DistributedDataParallel as DDP
  26. from torch.utils.tensorboard import SummaryWriter
  27. from tqdm import tqdm
  28. FILE = Path(__file__).absolute()
  29. sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path
  30. import test # for end-of-epoch mAP
  31. from models.experimental import attempt_load
  32. from models.yolo import Model
  33. from utils.autoanchor import check_anchors
  34. from utils.datasets import create_dataloader
  35. from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
  36. fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
  37. check_requirements, print_mutation, set_logging, one_cycle, colorstr
  38. from utils.google_utils import attempt_download
  39. from utils.loss import ComputeLoss
  40. from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
  41. from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
  42. from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
  43. logger = logging.getLogger(__name__)
  44. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  45. RANK = int(os.getenv('RANK', -1))
  46. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
  47. def train(hyp, # path/to/hyp.yaml or hyp dictionary
  48. opt,
  49. device,
  50. ):
  51. save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, notest, nosave, workers, = \
  52. opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
  53. opt.resume, opt.notest, opt.nosave, opt.workers
  54. # Directories
  55. save_dir = Path(save_dir)
  56. wdir = save_dir / 'weights'
  57. wdir.mkdir(parents=True, exist_ok=True) # make dir
  58. last = wdir / 'last.pt'
  59. best = wdir / 'best.pt'
  60. results_file = save_dir / 'results.txt'
  61. # Hyperparameters
  62. if isinstance(hyp, str):
  63. with open(hyp) as f:
  64. hyp = yaml.safe_load(f) # load hyps dict
  65. logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  66. # Save run settings
  67. with open(save_dir / 'hyp.yaml', 'w') as f:
  68. yaml.safe_dump(hyp, f, sort_keys=False)
  69. with open(save_dir / 'opt.yaml', 'w') as f:
  70. yaml.safe_dump(vars(opt), f, sort_keys=False)
  71. # Configure
  72. plots = not evolve # create plots
  73. cuda = device.type != 'cpu'
  74. init_seeds(2 + RANK)
  75. with open(data) as f:
  76. data_dict = yaml.safe_load(f) # data dict
  77. # Loggers
  78. loggers = {'wandb': None, 'tb': None} # loggers dict
  79. if RANK in [-1, 0]:
  80. # TensorBoard
  81. if not evolve:
  82. prefix = colorstr('tensorboard: ')
  83. logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
  84. loggers['tb'] = SummaryWriter(str(save_dir))
  85. # W&B
  86. opt.hyp = hyp # add hyperparameters
  87. run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
  88. run_id = run_id if opt.resume else None # start fresh run if transfer learning
  89. wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
  90. loggers['wandb'] = wandb_logger.wandb
  91. if loggers['wandb']:
  92. data_dict = wandb_logger.data_dict
  93. weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update weights, epochs if resuming
  94. nc = 1 if single_cls else int(data_dict['nc']) # number of classes
  95. names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  96. assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data) # check
  97. is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset
  98. # Model
  99. pretrained = weights.endswith('.pt')
  100. if pretrained:
  101. with torch_distributed_zero_first(RANK):
  102. weights = attempt_download(weights) # download if not found locally
  103. ckpt = torch.load(weights, map_location=device) # load checkpoint
  104. model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  105. exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
  106. state_dict = ckpt['model'].float().state_dict() # to FP32
  107. state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
  108. model.load_state_dict(state_dict, strict=False) # load
  109. logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
  110. else:
  111. model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  112. with torch_distributed_zero_first(RANK):
  113. check_dataset(data_dict) # check
  114. train_path = data_dict['train']
  115. test_path = data_dict['val']
  116. # Freeze
  117. freeze = [] # parameter names to freeze (full or partial)
  118. for k, v in model.named_parameters():
  119. v.requires_grad = True # train all layers
  120. if any(x in k for x in freeze):
  121. print('freezing %s' % k)
  122. v.requires_grad = False
  123. # Optimizer
  124. nbs = 64 # nominal batch size
  125. accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
  126. hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
  127. logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  128. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  129. for k, v in model.named_modules():
  130. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
  131. pg2.append(v.bias) # biases
  132. if isinstance(v, nn.BatchNorm2d):
  133. pg0.append(v.weight) # no decay
  134. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
  135. pg1.append(v.weight) # apply decay
  136. if opt.adam:
  137. optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  138. else:
  139. optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  140. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  141. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  142. logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  143. del pg0, pg1, pg2
  144. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  145. # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  146. if opt.linear_lr:
  147. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  148. else:
  149. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  150. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  151. # plot_lr_scheduler(optimizer, scheduler, epochs)
  152. # EMA
  153. ema = ModelEMA(model) if RANK in [-1, 0] else None
  154. # Resume
  155. start_epoch, best_fitness = 0, 0.0
  156. if pretrained:
  157. # Optimizer
  158. if ckpt['optimizer'] is not None:
  159. optimizer.load_state_dict(ckpt['optimizer'])
  160. best_fitness = ckpt['best_fitness']
  161. # EMA
  162. if ema and ckpt.get('ema'):
  163. ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
  164. ema.updates = ckpt['updates']
  165. # Results
  166. if ckpt.get('training_results') is not None:
  167. results_file.write_text(ckpt['training_results']) # write results.txt
  168. # Epochs
  169. start_epoch = ckpt['epoch'] + 1
  170. if resume:
  171. assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
  172. if epochs < start_epoch:
  173. logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
  174. (weights, ckpt['epoch'], epochs))
  175. epochs += ckpt['epoch'] # finetune additional epochs
  176. del ckpt, state_dict
  177. # Image sizes
  178. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  179. nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj'])
  180. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  181. # DP mode
  182. if cuda and RANK == -1 and torch.cuda.device_count() > 1:
  183. logging.warning('DP not recommended, instead use torch.distributed.run for best DDP Multi-GPU results.\n'
  184. 'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.')
  185. model = torch.nn.DataParallel(model)
  186. # SyncBatchNorm
  187. if opt.sync_bn and cuda and RANK != -1:
  188. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  189. logger.info('Using SyncBatchNorm()')
  190. # Trainloader
  191. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
  192. hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
  193. workers=workers,
  194. image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
  195. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  196. nb = len(dataloader) # number of batches
  197. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, data, nc - 1)
  198. # Process 0
  199. if RANK in [-1, 0]:
  200. testloader = create_dataloader(test_path, imgsz_test, batch_size // WORLD_SIZE * 2, gs, single_cls,
  201. hyp=hyp, cache=opt.cache_images and not notest, rect=True, rank=-1,
  202. workers=workers,
  203. pad=0.5, prefix=colorstr('val: '))[0]
  204. if not resume:
  205. labels = np.concatenate(dataset.labels, 0)
  206. c = torch.tensor(labels[:, 0]) # classes
  207. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  208. # model._initialize_biases(cf.to(device))
  209. if plots:
  210. plot_labels(labels, names, save_dir, loggers)
  211. if loggers['tb']:
  212. loggers['tb'].add_histogram('classes', c, 0) # TensorBoard
  213. # Anchors
  214. if not opt.noautoanchor:
  215. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  216. model.half().float() # pre-reduce anchor precision
  217. # DDP mode
  218. if cuda and RANK != -1:
  219. model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
  220. # Model parameters
  221. hyp['box'] *= 3. / nl # scale to layers
  222. hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
  223. hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
  224. hyp['label_smoothing'] = opt.label_smoothing
  225. model.nc = nc # attach number of classes to model
  226. model.hyp = hyp # attach hyperparameters to model
  227. model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
  228. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  229. model.names = names
  230. # Start training
  231. t0 = time.time()
  232. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  233. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  234. maps = np.zeros(nc) # mAP per class
  235. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  236. scheduler.last_epoch = start_epoch - 1 # do not move
  237. scaler = amp.GradScaler(enabled=cuda)
  238. compute_loss = ComputeLoss(model) # init loss class
  239. logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
  240. f'Using {dataloader.num_workers} dataloader workers\n'
  241. f'Logging results to {save_dir}\n'
  242. f'Starting training for {epochs} epochs...')
  243. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  244. model.train()
  245. # Update image weights (optional)
  246. if opt.image_weights:
  247. # Generate indices
  248. if RANK in [-1, 0]:
  249. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  250. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  251. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  252. # Broadcast if DDP
  253. if RANK != -1:
  254. indices = (torch.tensor(dataset.indices) if RANK == 0 else torch.zeros(dataset.n)).int()
  255. dist.broadcast(indices, 0)
  256. if RANK != 0:
  257. dataset.indices = indices.cpu().numpy()
  258. # Update mosaic border
  259. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  260. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  261. mloss = torch.zeros(4, device=device) # mean losses
  262. if RANK != -1:
  263. dataloader.sampler.set_epoch(epoch)
  264. pbar = enumerate(dataloader)
  265. logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size'))
  266. if RANK in [-1, 0]:
  267. pbar = tqdm(pbar, total=nb) # progress bar
  268. optimizer.zero_grad()
  269. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  270. ni = i + nb * epoch # number integrated batches (since train start)
  271. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  272. # Warmup
  273. if ni <= nw:
  274. xi = [0, nw] # x interp
  275. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  276. accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
  277. for j, x in enumerate(optimizer.param_groups):
  278. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  279. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  280. if 'momentum' in x:
  281. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  282. # Multi-scale
  283. if opt.multi_scale:
  284. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  285. sf = sz / max(imgs.shape[2:]) # scale factor
  286. if sf != 1:
  287. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  288. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  289. # Forward
  290. with amp.autocast(enabled=cuda):
  291. pred = model(imgs) # forward
  292. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  293. if RANK != -1:
  294. loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
  295. if opt.quad:
  296. loss *= 4.
  297. # Backward
  298. scaler.scale(loss).backward()
  299. # Optimize
  300. if ni % accumulate == 0:
  301. scaler.step(optimizer) # optimizer.step
  302. scaler.update()
  303. optimizer.zero_grad()
  304. if ema:
  305. ema.update(model)
  306. # Print
  307. if RANK in [-1, 0]:
  308. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  309. mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  310. s = ('%10s' * 2 + '%10.4g' * 6) % (
  311. f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])
  312. pbar.set_description(s)
  313. # Plot
  314. if plots and ni < 3:
  315. f = save_dir / f'train_batch{ni}.jpg' # filename
  316. Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
  317. if loggers['tb'] and ni == 0: # TensorBoard
  318. with warnings.catch_warnings():
  319. warnings.simplefilter('ignore') # suppress jit trace warning
  320. loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
  321. elif plots and ni == 10 and loggers['wandb']:
  322. wandb_logger.log({'Mosaics': [loggers['wandb'].Image(str(x), caption=x.name) for x in
  323. save_dir.glob('train*.jpg') if x.exists()]})
  324. # end batch ------------------------------------------------------------------------------------------------
  325. # Scheduler
  326. lr = [x['lr'] for x in optimizer.param_groups] # for loggers
  327. scheduler.step()
  328. # DDP process 0 or single-GPU
  329. if RANK in [-1, 0]:
  330. # mAP
  331. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
  332. final_epoch = epoch + 1 == epochs
  333. if not notest or final_epoch: # Calculate mAP
  334. wandb_logger.current_epoch = epoch + 1
  335. results, maps, _ = test.run(data_dict,
  336. batch_size=batch_size // WORLD_SIZE * 2,
  337. imgsz=imgsz_test,
  338. model=ema.ema,
  339. single_cls=single_cls,
  340. dataloader=testloader,
  341. save_dir=save_dir,
  342. save_json=is_coco and final_epoch,
  343. verbose=nc < 50 and final_epoch,
  344. plots=plots and final_epoch,
  345. wandb_logger=wandb_logger,
  346. compute_loss=compute_loss)
  347. # Write
  348. with open(results_file, 'a') as f:
  349. f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
  350. # Log
  351. tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
  352. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  353. 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
  354. 'x/lr0', 'x/lr1', 'x/lr2'] # params
  355. for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
  356. if loggers['tb']:
  357. loggers['tb'].add_scalar(tag, x, epoch) # TensorBoard
  358. if loggers['wandb']:
  359. wandb_logger.log({tag: x}) # W&B
  360. # Update best mAP
  361. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  362. if fi > best_fitness:
  363. best_fitness = fi
  364. wandb_logger.end_epoch(best_result=best_fitness == fi)
  365. # Save model
  366. if (not nosave) or (final_epoch and not evolve): # if save
  367. ckpt = {'epoch': epoch,
  368. 'best_fitness': best_fitness,
  369. 'training_results': results_file.read_text(),
  370. 'model': deepcopy(de_parallel(model)).half(),
  371. 'ema': deepcopy(ema.ema).half(),
  372. 'updates': ema.updates,
  373. 'optimizer': optimizer.state_dict(),
  374. 'wandb_id': wandb_logger.wandb_run.id if loggers['wandb'] else None}
  375. # Save last, best and delete
  376. torch.save(ckpt, last)
  377. if best_fitness == fi:
  378. torch.save(ckpt, best)
  379. if loggers['wandb']:
  380. if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
  381. wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi)
  382. del ckpt
  383. # end epoch ----------------------------------------------------------------------------------------------------
  384. # end training -----------------------------------------------------------------------------------------------------
  385. if RANK in [-1, 0]:
  386. logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
  387. if plots:
  388. plot_results(save_dir=save_dir) # save as results.png
  389. if loggers['wandb']:
  390. files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
  391. wandb_logger.log({"Results": [loggers['wandb'].Image(str(save_dir / f), caption=f) for f in files
  392. if (save_dir / f).exists()]})
  393. if not evolve:
  394. if is_coco: # COCO dataset
  395. for m in [last, best] if best.exists() else [last]: # speed, mAP tests
  396. results, _, _ = test.run(data_dict,
  397. batch_size=batch_size // WORLD_SIZE * 2,
  398. imgsz=imgsz_test,
  399. conf_thres=0.001,
  400. iou_thres=0.7,
  401. model=attempt_load(m, device).half(),
  402. single_cls=single_cls,
  403. dataloader=testloader,
  404. save_dir=save_dir,
  405. save_json=True,
  406. plots=False)
  407. # Strip optimizers
  408. for f in last, best:
  409. if f.exists():
  410. strip_optimizer(f) # strip optimizers
  411. if loggers['wandb']: # Log the stripped model
  412. loggers['wandb'].log_artifact(str(best if best.exists() else last), type='model',
  413. name='run_' + wandb_logger.wandb_run.id + '_model',
  414. aliases=['latest', 'best', 'stripped'])
  415. wandb_logger.finish_run()
  416. torch.cuda.empty_cache()
  417. return results
  418. def parse_opt(known=False):
  419. parser = argparse.ArgumentParser()
  420. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  421. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  422. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='dataset.yaml path')
  423. parser.add_argument('--hyp', type=str, default='data/hyps/hyp.scratch.yaml', help='hyperparameters path')
  424. parser.add_argument('--epochs', type=int, default=300)
  425. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  426. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
  427. parser.add_argument('--rect', action='store_true', help='rectangular training')
  428. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  429. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  430. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  431. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  432. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
  433. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  434. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  435. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  436. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  437. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  438. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  439. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  440. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  441. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  442. parser.add_argument('--project', default='runs/train', help='save to project/name')
  443. parser.add_argument('--entity', default=None, help='W&B entity')
  444. parser.add_argument('--name', default='exp', help='save to project/name')
  445. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  446. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  447. parser.add_argument('--linear-lr', action='store_true', help='linear LR')
  448. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
  449. parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
  450. parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
  451. parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
  452. parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
  453. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  454. opt = parser.parse_known_args()[0] if known else parser.parse_args()
  455. return opt
  456. def main(opt):
  457. set_logging(RANK)
  458. if RANK in [-1, 0]:
  459. print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
  460. check_git_status()
  461. check_requirements(exclude=['thop'])
  462. # Resume
  463. wandb_run = check_wandb_resume(opt)
  464. if opt.resume and not wandb_run: # resume an interrupted run
  465. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  466. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  467. with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
  468. opt = argparse.Namespace(**yaml.safe_load(f)) # replace
  469. opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
  470. logger.info('Resuming training from %s' % ckpt)
  471. else:
  472. # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
  473. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  474. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  475. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  476. opt.name = 'evolve' if opt.evolve else opt.name
  477. opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve))
  478. # DDP mode
  479. device = select_device(opt.device, batch_size=opt.batch_size)
  480. if LOCAL_RANK != -1:
  481. from datetime import timedelta
  482. assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
  483. torch.cuda.set_device(LOCAL_RANK)
  484. device = torch.device('cuda', LOCAL_RANK)
  485. dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo", timeout=timedelta(seconds=60))
  486. assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
  487. assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
  488. # Train
  489. if not opt.evolve:
  490. train(opt.hyp, opt, device)
  491. if WORLD_SIZE > 1 and RANK == 0:
  492. _ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')]
  493. # Evolve hyperparameters (optional)
  494. else:
  495. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  496. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  497. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  498. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  499. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  500. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  501. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  502. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  503. 'box': (1, 0.02, 0.2), # box loss gain
  504. 'cls': (1, 0.2, 4.0), # cls loss gain
  505. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  506. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  507. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  508. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  509. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  510. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  511. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  512. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  513. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  514. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  515. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  516. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  517. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  518. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  519. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  520. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  521. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  522. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  523. 'mixup': (1, 0.0, 1.0)} # image mixup (probability)
  524. with open(opt.hyp) as f:
  525. hyp = yaml.safe_load(f) # load hyps dict
  526. assert LOCAL_RANK == -1, 'DDP mode not implemented for --evolve'
  527. opt.notest, opt.nosave = True, True # only test/save final epoch
  528. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  529. yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
  530. if opt.bucket:
  531. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  532. for _ in range(300): # generations to evolve
  533. if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate
  534. # Select parent(s)
  535. parent = 'single' # parent selection method: 'single' or 'weighted'
  536. x = np.loadtxt('evolve.txt', ndmin=2)
  537. n = min(5, len(x)) # number of previous results to consider
  538. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  539. w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
  540. if parent == 'single' or len(x) == 1:
  541. # x = x[random.randint(0, n - 1)] # random selection
  542. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  543. elif parent == 'weighted':
  544. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  545. # Mutate
  546. mp, s = 0.8, 0.2 # mutation probability, sigma
  547. npr = np.random
  548. npr.seed(int(time.time()))
  549. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  550. ng = len(meta)
  551. v = np.ones(ng)
  552. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  553. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  554. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  555. hyp[k] = float(x[i + 7] * v[i]) # mutate
  556. # Constrain to limits
  557. for k, v in meta.items():
  558. hyp[k] = max(hyp[k], v[1]) # lower limit
  559. hyp[k] = min(hyp[k], v[2]) # upper limit
  560. hyp[k] = round(hyp[k], 5) # significant digits
  561. # Train mutation
  562. results = train(hyp.copy(), opt, device)
  563. # Write mutation results
  564. print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
  565. # Plot results
  566. plot_evolution(yaml_file)
  567. print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'
  568. f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')
  569. def run(**kwargs):
  570. # Usage: import train; train.run(imgsz=320, weights='yolov5m.pt')
  571. opt = parse_opt(True)
  572. for k, v in kwargs.items():
  573. setattr(opt, k, v)
  574. main(opt)
  575. if __name__ == "__main__":
  576. opt = parse_opt()
  577. main(opt)