Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

661 rinda
34KB

  1. """Train a YOLOv5 model on a custom dataset
  2. Usage:
  3. $ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640
  4. """
  5. import argparse
  6. import logging
  7. import os
  8. import random
  9. import sys
  10. import time
  11. import warnings
  12. from copy import deepcopy
  13. from pathlib import Path
  14. from threading import Thread
  15. import math
  16. import numpy as np
  17. import torch.distributed as dist
  18. import torch.nn as nn
  19. import torch.nn.functional as F
  20. import torch.optim as optim
  21. import torch.optim.lr_scheduler as lr_scheduler
  22. import torch.utils.data
  23. import yaml
  24. from torch.cuda import amp
  25. from torch.nn.parallel import DistributedDataParallel as DDP
  26. from torch.utils.tensorboard import SummaryWriter
  27. from tqdm import tqdm
  28. FILE = Path(__file__).absolute()
  29. sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path
  30. import test # for end-of-epoch mAP
  31. from models.experimental import attempt_load
  32. from models.yolo import Model
  33. from utils.autoanchor import check_anchors
  34. from utils.datasets import create_dataloader
  35. from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
  36. strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
  37. check_requirements, print_mutation, set_logging, one_cycle, colorstr
  38. from utils.google_utils import attempt_download
  39. from utils.loss import ComputeLoss
  40. from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
  41. from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
  42. from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
  43. from utils.metrics import fitness
  44. logger = logging.getLogger(__name__)
  45. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  46. RANK = int(os.getenv('RANK', -1))
  47. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
  48. def train(hyp, # path/to/hyp.yaml or hyp dictionary
  49. opt,
  50. device,
  51. ):
  52. save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, notest, nosave, workers, = \
  53. opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
  54. opt.resume, opt.notest, opt.nosave, opt.workers
  55. # Directories
  56. save_dir = Path(save_dir)
  57. wdir = save_dir / 'weights'
  58. wdir.mkdir(parents=True, exist_ok=True) # make dir
  59. last = wdir / 'last.pt'
  60. best = wdir / 'best.pt'
  61. results_file = save_dir / 'results.txt'
  62. # Hyperparameters
  63. if isinstance(hyp, str):
  64. with open(hyp) as f:
  65. hyp = yaml.safe_load(f) # load hyps dict
  66. logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  67. # Save run settings
  68. with open(save_dir / 'hyp.yaml', 'w') as f:
  69. yaml.safe_dump(hyp, f, sort_keys=False)
  70. with open(save_dir / 'opt.yaml', 'w') as f:
  71. yaml.safe_dump(vars(opt), f, sort_keys=False)
  72. # Configure
  73. plots = not evolve # create plots
  74. cuda = device.type != 'cpu'
  75. init_seeds(1 + RANK)
  76. with open(data) as f:
  77. data_dict = yaml.safe_load(f) # data dict
  78. # Loggers
  79. loggers = {'wandb': None, 'tb': None} # loggers dict
  80. if RANK in [-1, 0]:
  81. # TensorBoard
  82. if not evolve:
  83. prefix = colorstr('tensorboard: ')
  84. logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
  85. loggers['tb'] = SummaryWriter(str(save_dir))
  86. # W&B
  87. opt.hyp = hyp # add hyperparameters
  88. run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
  89. run_id = run_id if opt.resume else None # start fresh run if transfer learning
  90. wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
  91. loggers['wandb'] = wandb_logger.wandb
  92. if loggers['wandb']:
  93. data_dict = wandb_logger.data_dict
  94. weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update weights, epochs if resuming
  95. nc = 1 if single_cls else int(data_dict['nc']) # number of classes
  96. names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  97. assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data) # check
  98. is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset
  99. # Model
  100. pretrained = weights.endswith('.pt')
  101. if pretrained:
  102. with torch_distributed_zero_first(RANK):
  103. weights = attempt_download(weights) # download if not found locally
  104. ckpt = torch.load(weights, map_location=device) # load checkpoint
  105. model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  106. exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
  107. state_dict = ckpt['model'].float().state_dict() # to FP32
  108. state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
  109. model.load_state_dict(state_dict, strict=False) # load
  110. logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
  111. else:
  112. model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  113. with torch_distributed_zero_first(RANK):
  114. check_dataset(data_dict) # check
  115. train_path = data_dict['train']
  116. test_path = data_dict['val']
  117. # Freeze
  118. freeze = [] # parameter names to freeze (full or partial)
  119. for k, v in model.named_parameters():
  120. v.requires_grad = True # train all layers
  121. if any(x in k for x in freeze):
  122. print('freezing %s' % k)
  123. v.requires_grad = False
  124. # Optimizer
  125. nbs = 64 # nominal batch size
  126. accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
  127. hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
  128. logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  129. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  130. for k, v in model.named_modules():
  131. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
  132. pg2.append(v.bias) # biases
  133. if isinstance(v, nn.BatchNorm2d):
  134. pg0.append(v.weight) # no decay
  135. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
  136. pg1.append(v.weight) # apply decay
  137. if opt.adam:
  138. optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  139. else:
  140. optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  141. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  142. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  143. logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  144. del pg0, pg1, pg2
  145. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  146. # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  147. if opt.linear_lr:
  148. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  149. else:
  150. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  151. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  152. # plot_lr_scheduler(optimizer, scheduler, epochs)
  153. # EMA
  154. ema = ModelEMA(model) if RANK in [-1, 0] else None
  155. # Resume
  156. start_epoch, best_fitness = 0, 0.0
  157. if pretrained:
  158. # Optimizer
  159. if ckpt['optimizer'] is not None:
  160. optimizer.load_state_dict(ckpt['optimizer'])
  161. best_fitness = ckpt['best_fitness']
  162. # EMA
  163. if ema and ckpt.get('ema'):
  164. ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
  165. ema.updates = ckpt['updates']
  166. # Results
  167. if ckpt.get('training_results') is not None:
  168. results_file.write_text(ckpt['training_results']) # write results.txt
  169. # Epochs
  170. start_epoch = ckpt['epoch'] + 1
  171. if resume:
  172. assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
  173. if epochs < start_epoch:
  174. logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
  175. (weights, ckpt['epoch'], epochs))
  176. epochs += ckpt['epoch'] # finetune additional epochs
  177. del ckpt, state_dict
  178. # Image sizes
  179. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  180. nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj'])
  181. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  182. # DP mode
  183. if cuda and RANK == -1 and torch.cuda.device_count() > 1:
  184. logging.warning('DP not recommended, instead use torch.distributed.run for best DDP Multi-GPU results.\n'
  185. 'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.')
  186. model = torch.nn.DataParallel(model)
  187. # SyncBatchNorm
  188. if opt.sync_bn and cuda and RANK != -1:
  189. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  190. logger.info('Using SyncBatchNorm()')
  191. # Trainloader
  192. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
  193. hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
  194. workers=workers,
  195. image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
  196. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  197. nb = len(dataloader) # number of batches
  198. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, data, nc - 1)
  199. # Process 0
  200. if RANK in [-1, 0]:
  201. testloader = create_dataloader(test_path, imgsz_test, batch_size // WORLD_SIZE * 2, gs, single_cls,
  202. hyp=hyp, cache=opt.cache_images and not notest, rect=True, rank=-1,
  203. workers=workers,
  204. pad=0.5, prefix=colorstr('val: '))[0]
  205. if not resume:
  206. labels = np.concatenate(dataset.labels, 0)
  207. c = torch.tensor(labels[:, 0]) # classes
  208. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  209. # model._initialize_biases(cf.to(device))
  210. if plots:
  211. plot_labels(labels, names, save_dir, loggers)
  212. if loggers['tb']:
  213. loggers['tb'].add_histogram('classes', c, 0) # TensorBoard
  214. # Anchors
  215. if not opt.noautoanchor:
  216. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  217. model.half().float() # pre-reduce anchor precision
  218. # DDP mode
  219. if cuda and RANK != -1:
  220. model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
  221. # Model parameters
  222. hyp['box'] *= 3. / nl # scale to layers
  223. hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
  224. hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
  225. hyp['label_smoothing'] = opt.label_smoothing
  226. model.nc = nc # attach number of classes to model
  227. model.hyp = hyp # attach hyperparameters to model
  228. model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
  229. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  230. model.names = names
  231. # Start training
  232. t0 = time.time()
  233. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  234. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  235. last_opt_step = -1
  236. maps = np.zeros(nc) # mAP per class
  237. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  238. scheduler.last_epoch = start_epoch - 1 # do not move
  239. scaler = amp.GradScaler(enabled=cuda)
  240. compute_loss = ComputeLoss(model) # init loss class
  241. logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
  242. f'Using {dataloader.num_workers} dataloader workers\n'
  243. f'Logging results to {save_dir}\n'
  244. f'Starting training for {epochs} epochs...')
  245. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  246. model.train()
  247. # Update image weights (optional)
  248. if opt.image_weights:
  249. # Generate indices
  250. if RANK in [-1, 0]:
  251. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  252. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  253. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  254. # Broadcast if DDP
  255. if RANK != -1:
  256. indices = (torch.tensor(dataset.indices) if RANK == 0 else torch.zeros(dataset.n)).int()
  257. dist.broadcast(indices, 0)
  258. if RANK != 0:
  259. dataset.indices = indices.cpu().numpy()
  260. # Update mosaic border
  261. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  262. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  263. mloss = torch.zeros(4, device=device) # mean losses
  264. if RANK != -1:
  265. dataloader.sampler.set_epoch(epoch)
  266. pbar = enumerate(dataloader)
  267. logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size'))
  268. if RANK in [-1, 0]:
  269. pbar = tqdm(pbar, total=nb) # progress bar
  270. optimizer.zero_grad()
  271. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  272. ni = i + nb * epoch # number integrated batches (since train start)
  273. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  274. # Warmup
  275. if ni <= nw:
  276. xi = [0, nw] # x interp
  277. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  278. accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
  279. for j, x in enumerate(optimizer.param_groups):
  280. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  281. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  282. if 'momentum' in x:
  283. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  284. # Multi-scale
  285. if opt.multi_scale:
  286. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  287. sf = sz / max(imgs.shape[2:]) # scale factor
  288. if sf != 1:
  289. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  290. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  291. # Forward
  292. with amp.autocast(enabled=cuda):
  293. pred = model(imgs) # forward
  294. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  295. if RANK != -1:
  296. loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
  297. if opt.quad:
  298. loss *= 4.
  299. # Backward
  300. scaler.scale(loss).backward()
  301. # Optimize
  302. if ni - last_opt_step >= accumulate:
  303. scaler.step(optimizer) # optimizer.step
  304. scaler.update()
  305. optimizer.zero_grad()
  306. if ema:
  307. ema.update(model)
  308. last_opt_step = ni
  309. # Print
  310. if RANK in [-1, 0]:
  311. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  312. mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  313. s = ('%10s' * 2 + '%10.4g' * 6) % (
  314. f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])
  315. pbar.set_description(s)
  316. # Plot
  317. if plots and ni < 3:
  318. f = save_dir / f'train_batch{ni}.jpg' # filename
  319. Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
  320. if loggers['tb'] and ni == 0: # TensorBoard
  321. with warnings.catch_warnings():
  322. warnings.simplefilter('ignore') # suppress jit trace warning
  323. loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
  324. elif plots and ni == 10 and loggers['wandb']:
  325. wandb_logger.log({'Mosaics': [loggers['wandb'].Image(str(x), caption=x.name) for x in
  326. save_dir.glob('train*.jpg') if x.exists()]})
  327. # end batch ------------------------------------------------------------------------------------------------
  328. # Scheduler
  329. lr = [x['lr'] for x in optimizer.param_groups] # for loggers
  330. scheduler.step()
  331. # DDP process 0 or single-GPU
  332. if RANK in [-1, 0]:
  333. # mAP
  334. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
  335. final_epoch = epoch + 1 == epochs
  336. if not notest or final_epoch: # Calculate mAP
  337. wandb_logger.current_epoch = epoch + 1
  338. results, maps, _ = test.run(data_dict,
  339. batch_size=batch_size // WORLD_SIZE * 2,
  340. imgsz=imgsz_test,
  341. model=ema.ema,
  342. single_cls=single_cls,
  343. dataloader=testloader,
  344. save_dir=save_dir,
  345. save_json=is_coco and final_epoch,
  346. verbose=nc < 50 and final_epoch,
  347. plots=plots and final_epoch,
  348. wandb_logger=wandb_logger,
  349. compute_loss=compute_loss)
  350. # Write
  351. with open(results_file, 'a') as f:
  352. f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
  353. # Log
  354. tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
  355. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  356. 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
  357. 'x/lr0', 'x/lr1', 'x/lr2'] # params
  358. for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
  359. if loggers['tb']:
  360. loggers['tb'].add_scalar(tag, x, epoch) # TensorBoard
  361. if loggers['wandb']:
  362. wandb_logger.log({tag: x}) # W&B
  363. # Update best mAP
  364. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  365. if fi > best_fitness:
  366. best_fitness = fi
  367. wandb_logger.end_epoch(best_result=best_fitness == fi)
  368. # Save model
  369. if (not nosave) or (final_epoch and not evolve): # if save
  370. ckpt = {'epoch': epoch,
  371. 'best_fitness': best_fitness,
  372. 'training_results': results_file.read_text(),
  373. 'model': deepcopy(de_parallel(model)).half(),
  374. 'ema': deepcopy(ema.ema).half(),
  375. 'updates': ema.updates,
  376. 'optimizer': optimizer.state_dict(),
  377. 'wandb_id': wandb_logger.wandb_run.id if loggers['wandb'] else None}
  378. # Save last, best and delete
  379. torch.save(ckpt, last)
  380. if best_fitness == fi:
  381. torch.save(ckpt, best)
  382. if loggers['wandb']:
  383. if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
  384. wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi)
  385. del ckpt
  386. # end epoch ----------------------------------------------------------------------------------------------------
  387. # end training -----------------------------------------------------------------------------------------------------
  388. if RANK in [-1, 0]:
  389. logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
  390. if plots:
  391. plot_results(save_dir=save_dir) # save as results.png
  392. if loggers['wandb']:
  393. files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
  394. wandb_logger.log({"Results": [loggers['wandb'].Image(str(save_dir / f), caption=f) for f in files
  395. if (save_dir / f).exists()]})
  396. if not evolve:
  397. if is_coco: # COCO dataset
  398. for m in [last, best] if best.exists() else [last]: # speed, mAP tests
  399. results, _, _ = test.run(data_dict,
  400. batch_size=batch_size // WORLD_SIZE * 2,
  401. imgsz=imgsz_test,
  402. model=attempt_load(m, device).half(),
  403. single_cls=single_cls,
  404. dataloader=testloader,
  405. save_dir=save_dir,
  406. save_json=True,
  407. plots=False)
  408. # Strip optimizers
  409. for f in last, best:
  410. if f.exists():
  411. strip_optimizer(f) # strip optimizers
  412. if loggers['wandb']: # Log the stripped model
  413. loggers['wandb'].log_artifact(str(best if best.exists() else last), type='model',
  414. name='run_' + wandb_logger.wandb_run.id + '_model',
  415. aliases=['latest', 'best', 'stripped'])
  416. wandb_logger.finish_run()
  417. torch.cuda.empty_cache()
  418. return results
  419. def parse_opt(known=False):
  420. parser = argparse.ArgumentParser()
  421. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  422. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  423. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='dataset.yaml path')
  424. parser.add_argument('--hyp', type=str, default='data/hyps/hyp.scratch.yaml', help='hyperparameters path')
  425. parser.add_argument('--epochs', type=int, default=300)
  426. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  427. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
  428. parser.add_argument('--rect', action='store_true', help='rectangular training')
  429. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  430. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  431. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  432. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  433. parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
  434. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  435. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  436. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  437. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  438. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  439. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  440. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  441. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  442. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  443. parser.add_argument('--project', default='runs/train', help='save to project/name')
  444. parser.add_argument('--entity', default=None, help='W&B entity')
  445. parser.add_argument('--name', default='exp', help='save to project/name')
  446. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  447. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  448. parser.add_argument('--linear-lr', action='store_true', help='linear LR')
  449. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
  450. parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
  451. parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
  452. parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
  453. parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
  454. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  455. opt = parser.parse_known_args()[0] if known else parser.parse_args()
  456. return opt
  457. def main(opt):
  458. set_logging(RANK)
  459. if RANK in [-1, 0]:
  460. print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
  461. check_git_status()
  462. check_requirements(exclude=['thop'])
  463. # Resume
  464. wandb_run = check_wandb_resume(opt)
  465. if opt.resume and not wandb_run: # resume an interrupted run
  466. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  467. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  468. with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
  469. opt = argparse.Namespace(**yaml.safe_load(f)) # replace
  470. opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
  471. logger.info('Resuming training from %s' % ckpt)
  472. else:
  473. # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
  474. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  475. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  476. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  477. opt.name = 'evolve' if opt.evolve else opt.name
  478. opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok or opt.evolve))
  479. # DDP mode
  480. device = select_device(opt.device, batch_size=opt.batch_size)
  481. if LOCAL_RANK != -1:
  482. from datetime import timedelta
  483. assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
  484. torch.cuda.set_device(LOCAL_RANK)
  485. device = torch.device('cuda', LOCAL_RANK)
  486. dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo", timeout=timedelta(seconds=60))
  487. assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
  488. assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
  489. # Train
  490. if not opt.evolve:
  491. train(opt.hyp, opt, device)
  492. if WORLD_SIZE > 1 and RANK == 0:
  493. _ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')]
  494. # Evolve hyperparameters (optional)
  495. else:
  496. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  497. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  498. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  499. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  500. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  501. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  502. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  503. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  504. 'box': (1, 0.02, 0.2), # box loss gain
  505. 'cls': (1, 0.2, 4.0), # cls loss gain
  506. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  507. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  508. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  509. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  510. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  511. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  512. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  513. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  514. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  515. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  516. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  517. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  518. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  519. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  520. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  521. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  522. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  523. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  524. 'mixup': (1, 0.0, 1.0), # image mixup (probability)
  525. 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
  526. with open(opt.hyp) as f:
  527. hyp = yaml.safe_load(f) # load hyps dict
  528. if 'anchors' not in hyp: # anchors commented in hyp.yaml
  529. hyp['anchors'] = 3
  530. assert LOCAL_RANK == -1, 'DDP mode not implemented for --evolve'
  531. opt.notest, opt.nosave = True, True # only test/save final epoch
  532. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  533. yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
  534. if opt.bucket:
  535. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  536. for _ in range(opt.evolve): # generations to evolve
  537. if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate
  538. # Select parent(s)
  539. parent = 'single' # parent selection method: 'single' or 'weighted'
  540. x = np.loadtxt('evolve.txt', ndmin=2)
  541. n = min(5, len(x)) # number of previous results to consider
  542. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  543. w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
  544. if parent == 'single' or len(x) == 1:
  545. # x = x[random.randint(0, n - 1)] # random selection
  546. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  547. elif parent == 'weighted':
  548. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  549. # Mutate
  550. mp, s = 0.8, 0.2 # mutation probability, sigma
  551. npr = np.random
  552. npr.seed(int(time.time()))
  553. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  554. ng = len(meta)
  555. v = np.ones(ng)
  556. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  557. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  558. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  559. hyp[k] = float(x[i + 7] * v[i]) # mutate
  560. # Constrain to limits
  561. for k, v in meta.items():
  562. hyp[k] = max(hyp[k], v[1]) # lower limit
  563. hyp[k] = min(hyp[k], v[2]) # upper limit
  564. hyp[k] = round(hyp[k], 5) # significant digits
  565. # Train mutation
  566. results = train(hyp.copy(), opt, device)
  567. # Write mutation results
  568. print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
  569. # Plot results
  570. plot_evolution(yaml_file)
  571. print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'
  572. f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')
  573. def run(**kwargs):
  574. # Usage: import train; train.run(imgsz=320, weights='yolov5m.pt')
  575. opt = parse_opt(True)
  576. for k, v in kwargs.items():
  577. setattr(opt, k, v)
  578. main(opt)
  579. if __name__ == "__main__":
  580. opt = parse_opt()
  581. main(opt)