Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

663 lines
34KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Train a YOLOv5 model on a custom dataset.
  4. Models and datasets download automatically from the latest YOLOv5 release.
  5. Models: https://github.com/ultralytics/yolov5/tree/master/models
  6. Datasets: https://github.com/ultralytics/yolov5/tree/master/data
  7. Tutorial: https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data
  8. Usage:
  9. $ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640 # from pretrained (RECOMMENDED)
  10. $ python path/to/train.py --data coco128.yaml --weights '' --cfg yolov5s.yaml --img 640 # from scratch
  11. """
  12. import argparse
  13. import math
  14. import os
  15. import random
  16. import sys
  17. import time
  18. from copy import deepcopy
  19. from datetime import datetime
  20. from pathlib import Path
  21. import numpy as np
  22. import torch
  23. import torch.distributed as dist
  24. import torch.nn as nn
  25. import yaml
  26. from torch.cuda import amp
  27. from torch.nn.parallel import DistributedDataParallel as DDP
  28. from torch.optim import SGD, Adam, AdamW, lr_scheduler
  29. from tqdm import tqdm
  30. FILE = Path(__file__).resolve()
  31. ROOT = FILE.parents[0] # YOLOv5 root directory
  32. if str(ROOT) not in sys.path:
  33. sys.path.append(str(ROOT)) # add ROOT to PATH
  34. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  35. import val # for end-of-epoch mAP
  36. from models.experimental import attempt_load
  37. from models.yolo import Model
  38. from utils.autoanchor import check_anchors
  39. from utils.autobatch import check_train_batch_size
  40. from utils.callbacks import Callbacks
  41. from utils.datasets import create_dataloader
  42. from utils.downloads import attempt_download
  43. from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
  44. check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
  45. intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle,
  46. print_args, print_mutation, strip_optimizer)
  47. from utils.loggers import Loggers
  48. from utils.loggers.wandb.wandb_utils import check_wandb_resume
  49. from utils.loss import ComputeLoss
  50. from utils.metrics import fitness
  51. from utils.plots import plot_evolve, plot_labels
  52. from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
  53. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  54. RANK = int(os.getenv('RANK', -1))
  55. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
  56. def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictionary
  57. save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = \
  58. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
  59. opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
  60. # Directories
  61. w = save_dir / 'weights' # weights dir
  62. (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
  63. last, best = w / 'last.pt', w / 'best.pt'
  64. # Hyperparameters
  65. if isinstance(hyp, str):
  66. with open(hyp, errors='ignore') as f:
  67. hyp = yaml.safe_load(f) # load hyps dict
  68. LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  69. # Save run settings
  70. if not evolve:
  71. with open(save_dir / 'hyp.yaml', 'w') as f:
  72. yaml.safe_dump(hyp, f, sort_keys=False)
  73. with open(save_dir / 'opt.yaml', 'w') as f:
  74. yaml.safe_dump(vars(opt), f, sort_keys=False)
  75. # Loggers
  76. data_dict = None
  77. if RANK in [-1, 0]:
  78. loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
  79. if loggers.wandb:
  80. data_dict = loggers.wandb.data_dict
  81. if resume:
  82. weights, epochs, hyp, batch_size = opt.weights, opt.epochs, opt.hyp, opt.batch_size
  83. # Register actions
  84. for k in methods(loggers):
  85. callbacks.register_action(k, callback=getattr(loggers, k))
  86. # Config
  87. plots = not evolve # create plots
  88. cuda = device.type != 'cpu'
  89. init_seeds(1 + RANK)
  90. with torch_distributed_zero_first(LOCAL_RANK):
  91. data_dict = data_dict or check_dataset(data) # check if None
  92. train_path, val_path = data_dict['train'], data_dict['val']
  93. nc = 1 if single_cls else int(data_dict['nc']) # number of classes
  94. names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  95. assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
  96. is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
  97. # Model
  98. check_suffix(weights, '.pt') # check weights
  99. pretrained = weights.endswith('.pt')
  100. if pretrained:
  101. with torch_distributed_zero_first(LOCAL_RANK):
  102. weights = attempt_download(weights) # download if not found locally
  103. ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
  104. model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  105. exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
  106. csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
  107. csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
  108. model.load_state_dict(csd, strict=False) # load
  109. LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
  110. else:
  111. model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  112. # Freeze
  113. freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
  114. for k, v in model.named_parameters():
  115. v.requires_grad = True # train all layers
  116. if any(x in k for x in freeze):
  117. LOGGER.info(f'freezing {k}')
  118. v.requires_grad = False
  119. # Image size
  120. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  121. imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
  122. # Batch size
  123. if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
  124. batch_size = check_train_batch_size(model, imgsz)
  125. loggers.on_params_update({"batch_size": batch_size})
  126. # Optimizer
  127. nbs = 64 # nominal batch size
  128. accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
  129. hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
  130. LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  131. g0, g1, g2 = [], [], [] # optimizer parameter groups
  132. for v in model.modules():
  133. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias
  134. g2.append(v.bias)
  135. if isinstance(v, nn.BatchNorm2d): # weight (no decay)
  136. g0.append(v.weight)
  137. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
  138. g1.append(v.weight)
  139. if opt.optimizer == 'Adam':
  140. optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  141. elif opt.optimizer == 'AdamW':
  142. optimizer = AdamW(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  143. else:
  144. optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  145. optimizer.add_param_group({'params': g1, 'weight_decay': hyp['weight_decay']}) # add g1 with weight_decay
  146. optimizer.add_param_group({'params': g2}) # add g2 (biases)
  147. LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
  148. f"{len(g0)} weight (no decay), {len(g1)} weight, {len(g2)} bias")
  149. del g0, g1, g2
  150. # Scheduler
  151. if opt.cos_lr:
  152. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  153. else:
  154. lf = lambda x: (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  155. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
  156. # EMA
  157. ema = ModelEMA(model) if RANK in [-1, 0] else None
  158. # Resume
  159. start_epoch, best_fitness = 0, 0.0
  160. if pretrained:
  161. # Optimizer
  162. if ckpt['optimizer'] is not None:
  163. optimizer.load_state_dict(ckpt['optimizer'])
  164. best_fitness = ckpt['best_fitness']
  165. # EMA
  166. if ema and ckpt.get('ema'):
  167. ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
  168. ema.updates = ckpt['updates']
  169. # Epochs
  170. start_epoch = ckpt['epoch'] + 1
  171. if resume:
  172. assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
  173. if epochs < start_epoch:
  174. LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
  175. epochs += ckpt['epoch'] # finetune additional epochs
  176. del ckpt, csd
  177. # DP mode
  178. if cuda and RANK == -1 and torch.cuda.device_count() > 1:
  179. LOGGER.warning('WARNING: DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.\n'
  180. 'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.')
  181. model = torch.nn.DataParallel(model)
  182. # SyncBatchNorm
  183. if opt.sync_bn and cuda and RANK != -1:
  184. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  185. LOGGER.info('Using SyncBatchNorm()')
  186. # Trainloader
  187. train_loader, dataset = create_dataloader(train_path,
  188. imgsz,
  189. batch_size // WORLD_SIZE,
  190. gs,
  191. single_cls,
  192. hyp=hyp,
  193. augment=True,
  194. cache=None if opt.cache == 'val' else opt.cache,
  195. rect=opt.rect,
  196. rank=LOCAL_RANK,
  197. workers=workers,
  198. image_weights=opt.image_weights,
  199. quad=opt.quad,
  200. prefix=colorstr('train: '),
  201. shuffle=True)
  202. mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
  203. nb = len(train_loader) # number of batches
  204. assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
  205. # Process 0
  206. if RANK in [-1, 0]:
  207. val_loader = create_dataloader(val_path,
  208. imgsz,
  209. batch_size // WORLD_SIZE * 2,
  210. gs,
  211. single_cls,
  212. hyp=hyp,
  213. cache=None if noval else opt.cache,
  214. rect=True,
  215. rank=-1,
  216. workers=workers * 2,
  217. pad=0.5,
  218. prefix=colorstr('val: '))[0]
  219. if not resume:
  220. labels = np.concatenate(dataset.labels, 0)
  221. # c = torch.tensor(labels[:, 0]) # classes
  222. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  223. # model._initialize_biases(cf.to(device))
  224. if plots:
  225. plot_labels(labels, names, save_dir)
  226. # Anchors
  227. if not opt.noautoanchor:
  228. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  229. model.half().float() # pre-reduce anchor precision
  230. callbacks.run('on_pretrain_routine_end')
  231. # DDP mode
  232. if cuda and RANK != -1:
  233. model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
  234. # Model attributes
  235. nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
  236. hyp['box'] *= 3 / nl # scale to layers
  237. hyp['cls'] *= nc / 80 * 3 / nl # scale to classes and layers
  238. hyp['obj'] *= (imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
  239. hyp['label_smoothing'] = opt.label_smoothing
  240. model.nc = nc # attach number of classes to model
  241. model.hyp = hyp # attach hyperparameters to model
  242. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  243. model.names = names
  244. # Start training
  245. t0 = time.time()
  246. nw = max(round(hyp['warmup_epochs'] * nb), 100) # number of warmup iterations, max(3 epochs, 100 iterations)
  247. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  248. last_opt_step = -1
  249. maps = np.zeros(nc) # mAP per class
  250. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  251. scheduler.last_epoch = start_epoch - 1 # do not move
  252. scaler = amp.GradScaler(enabled=cuda)
  253. stopper = EarlyStopping(patience=opt.patience)
  254. compute_loss = ComputeLoss(model) # init loss class
  255. LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
  256. f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n'
  257. f"Logging results to {colorstr('bold', save_dir)}\n"
  258. f'Starting training for {epochs} epochs...')
  259. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  260. model.train()
  261. # Update image weights (optional, single-GPU only)
  262. if opt.image_weights:
  263. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  264. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  265. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  266. # Update mosaic border (optional)
  267. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  268. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  269. mloss = torch.zeros(3, device=device) # mean losses
  270. if RANK != -1:
  271. train_loader.sampler.set_epoch(epoch)
  272. pbar = enumerate(train_loader)
  273. LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))
  274. if RANK in [-1, 0]:
  275. pbar = tqdm(pbar, total=nb, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
  276. optimizer.zero_grad()
  277. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  278. ni = i + nb * epoch # number integrated batches (since train start)
  279. imgs = imgs.to(device, non_blocking=True).float() / 255 # uint8 to float32, 0-255 to 0.0-1.0
  280. # Warmup
  281. if ni <= nw:
  282. xi = [0, nw] # x interp
  283. # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  284. accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
  285. for j, x in enumerate(optimizer.param_groups):
  286. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  287. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  288. if 'momentum' in x:
  289. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  290. # Multi-scale
  291. if opt.multi_scale:
  292. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  293. sf = sz / max(imgs.shape[2:]) # scale factor
  294. if sf != 1:
  295. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  296. imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  297. # Forward
  298. with amp.autocast(enabled=cuda):
  299. pred = model(imgs) # forward
  300. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  301. if RANK != -1:
  302. loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
  303. if opt.quad:
  304. loss *= 4.
  305. # Backward
  306. scaler.scale(loss).backward()
  307. # Optimize
  308. if ni - last_opt_step >= accumulate:
  309. scaler.step(optimizer) # optimizer.step
  310. scaler.update()
  311. optimizer.zero_grad()
  312. if ema:
  313. ema.update(model)
  314. last_opt_step = ni
  315. # Log
  316. if RANK in [-1, 0]:
  317. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  318. mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
  319. pbar.set_description(('%10s' * 2 + '%10.4g' * 5) %
  320. (f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
  321. callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots, opt.sync_bn)
  322. if callbacks.stop_training:
  323. return
  324. # end batch ------------------------------------------------------------------------------------------------
  325. # Scheduler
  326. lr = [x['lr'] for x in optimizer.param_groups] # for loggers
  327. scheduler.step()
  328. if RANK in [-1, 0]:
  329. # mAP
  330. callbacks.run('on_train_epoch_end', epoch=epoch)
  331. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
  332. final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
  333. if not noval or final_epoch: # Calculate mAP
  334. results, maps, _ = val.run(data_dict,
  335. batch_size=batch_size // WORLD_SIZE * 2,
  336. imgsz=imgsz,
  337. model=ema.ema,
  338. single_cls=single_cls,
  339. dataloader=val_loader,
  340. save_dir=save_dir,
  341. plots=False,
  342. callbacks=callbacks,
  343. compute_loss=compute_loss)
  344. # Update best mAP
  345. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  346. if fi > best_fitness:
  347. best_fitness = fi
  348. log_vals = list(mloss) + list(results) + lr
  349. callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
  350. # Save model
  351. if (not nosave) or (final_epoch and not evolve): # if save
  352. ckpt = {
  353. 'epoch': epoch,
  354. 'best_fitness': best_fitness,
  355. 'model': deepcopy(de_parallel(model)).half(),
  356. 'ema': deepcopy(ema.ema).half(),
  357. 'updates': ema.updates,
  358. 'optimizer': optimizer.state_dict(),
  359. 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None,
  360. 'date': datetime.now().isoformat()}
  361. # Save last, best and delete
  362. torch.save(ckpt, last)
  363. if best_fitness == fi:
  364. torch.save(ckpt, best)
  365. if (epoch > 0) and (opt.save_period > 0) and (epoch % opt.save_period == 0):
  366. torch.save(ckpt, w / f'epoch{epoch}.pt')
  367. del ckpt
  368. callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
  369. # Stop Single-GPU
  370. if RANK == -1 and stopper(epoch=epoch, fitness=fi):
  371. break
  372. # Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576
  373. # stop = stopper(epoch=epoch, fitness=fi)
  374. # if RANK == 0:
  375. # dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks
  376. # Stop DPP
  377. # with torch_distributed_zero_first(RANK):
  378. # if stop:
  379. # break # must break all DDP ranks
  380. # end epoch ----------------------------------------------------------------------------------------------------
  381. # end training -----------------------------------------------------------------------------------------------------
  382. if RANK in [-1, 0]:
  383. LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
  384. for f in last, best:
  385. if f.exists():
  386. strip_optimizer(f) # strip optimizers
  387. if f is best:
  388. LOGGER.info(f'\nValidating {f}...')
  389. results, _, _ = val.run(
  390. data_dict,
  391. batch_size=batch_size // WORLD_SIZE * 2,
  392. imgsz=imgsz,
  393. model=attempt_load(f, device).half(),
  394. iou_thres=0.65 if is_coco else 0.60, # best pycocotools results at 0.65
  395. single_cls=single_cls,
  396. dataloader=val_loader,
  397. save_dir=save_dir,
  398. save_json=is_coco,
  399. verbose=True,
  400. plots=True,
  401. callbacks=callbacks,
  402. compute_loss=compute_loss) # val best model with plots
  403. if is_coco:
  404. callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
  405. callbacks.run('on_train_end', last, best, plots, epoch, results)
  406. LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
  407. torch.cuda.empty_cache()
  408. return results
  409. def parse_opt(known=False):
  410. parser = argparse.ArgumentParser()
  411. parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='initial weights path')
  412. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  413. parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
  414. parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-low.yaml', help='hyperparameters path')
  415. parser.add_argument('--epochs', type=int, default=300)
  416. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
  417. parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
  418. parser.add_argument('--rect', action='store_true', help='rectangular training')
  419. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  420. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  421. parser.add_argument('--noval', action='store_true', help='only validate final epoch')
  422. parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor')
  423. parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
  424. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  425. parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
  426. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  427. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  428. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  429. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  430. parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='SGD', help='optimizer')
  431. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  432. parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
  433. parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
  434. parser.add_argument('--name', default='exp', help='save to project/name')
  435. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  436. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  437. parser.add_argument('--cos-lr', action='store_true', help='cosine LR scheduler')
  438. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
  439. parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
  440. parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
  441. parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
  442. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  443. # Weights & Biases arguments
  444. parser.add_argument('--entity', default=None, help='W&B: Entity')
  445. parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='W&B: Upload data, "val" option')
  446. parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval')
  447. parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')
  448. opt = parser.parse_known_args()[0] if known else parser.parse_args()
  449. return opt
  450. def main(opt, callbacks=Callbacks()):
  451. # Checks
  452. if RANK in [-1, 0]:
  453. print_args(vars(opt))
  454. check_git_status()
  455. check_requirements(exclude=['thop'])
  456. # Resume
  457. if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run
  458. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  459. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  460. with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f:
  461. opt = argparse.Namespace(**yaml.safe_load(f)) # replace
  462. opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
  463. LOGGER.info(f'Resuming training from {ckpt}')
  464. else:
  465. opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
  466. check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
  467. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  468. if opt.evolve:
  469. if opt.project == str(ROOT / 'runs/train'): # if default project name, rename to runs/evolve
  470. opt.project = str(ROOT / 'runs/evolve')
  471. opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
  472. if opt.name == 'cfg':
  473. opt.name = Path(opt.cfg).stem # use model.yaml as name
  474. opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
  475. # DDP mode
  476. device = select_device(opt.device, batch_size=opt.batch_size)
  477. if LOCAL_RANK != -1:
  478. msg = 'is not compatible with YOLOv5 Multi-GPU DDP training'
  479. assert not opt.image_weights, f'--image-weights {msg}'
  480. assert not opt.evolve, f'--evolve {msg}'
  481. assert opt.batch_size != -1, f'AutoBatch with --batch-size -1 {msg}, please pass a valid --batch-size'
  482. assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multiple of WORLD_SIZE'
  483. assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
  484. torch.cuda.set_device(LOCAL_RANK)
  485. device = torch.device('cuda', LOCAL_RANK)
  486. dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
  487. # Train
  488. if not opt.evolve:
  489. train(opt.hyp, opt, device, callbacks)
  490. if WORLD_SIZE > 1 and RANK == 0:
  491. LOGGER.info('Destroying process group... ')
  492. dist.destroy_process_group()
  493. # Evolve hyperparameters (optional)
  494. else:
  495. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  496. meta = {
  497. 'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  498. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  499. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  500. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  501. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  502. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  503. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  504. 'box': (1, 0.02, 0.2), # box loss gain
  505. 'cls': (1, 0.2, 4.0), # cls loss gain
  506. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  507. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  508. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  509. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  510. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  511. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  512. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  513. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  514. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  515. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  516. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  517. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  518. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  519. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  520. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  521. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  522. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  523. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  524. 'mixup': (1, 0.0, 1.0), # image mixup (probability)
  525. 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
  526. with open(opt.hyp, errors='ignore') as f:
  527. hyp = yaml.safe_load(f) # load hyps dict
  528. if 'anchors' not in hyp: # anchors commented in hyp.yaml
  529. hyp['anchors'] = 3
  530. opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
  531. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  532. evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
  533. if opt.bucket:
  534. os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {evolve_csv}') # download evolve.csv if exists
  535. for _ in range(opt.evolve): # generations to evolve
  536. if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
  537. # Select parent(s)
  538. parent = 'single' # parent selection method: 'single' or 'weighted'
  539. x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
  540. n = min(5, len(x)) # number of previous results to consider
  541. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  542. w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
  543. if parent == 'single' or len(x) == 1:
  544. # x = x[random.randint(0, n - 1)] # random selection
  545. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  546. elif parent == 'weighted':
  547. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  548. # Mutate
  549. mp, s = 0.8, 0.2 # mutation probability, sigma
  550. npr = np.random
  551. npr.seed(int(time.time()))
  552. g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1
  553. ng = len(meta)
  554. v = np.ones(ng)
  555. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  556. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  557. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  558. hyp[k] = float(x[i + 7] * v[i]) # mutate
  559. # Constrain to limits
  560. for k, v in meta.items():
  561. hyp[k] = max(hyp[k], v[1]) # lower limit
  562. hyp[k] = min(hyp[k], v[2]) # upper limit
  563. hyp[k] = round(hyp[k], 5) # significant digits
  564. # Train mutation
  565. results = train(hyp.copy(), opt, device, callbacks)
  566. callbacks = Callbacks()
  567. # Write mutation results
  568. print_mutation(results, hyp.copy(), save_dir, opt.bucket)
  569. # Plot results
  570. plot_evolve(evolve_csv)
  571. LOGGER.info(f'Hyperparameter evolution finished {opt.evolve} generations\n'
  572. f"Results saved to {colorstr('bold', save_dir)}\n"
  573. f'Usage example: $ python train.py --hyp {evolve_yaml}')
  574. def run(**kwargs):
  575. # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
  576. opt = parse_opt(True)
  577. for k, v in kwargs.items():
  578. setattr(opt, k, v)
  579. main(opt)
  580. return opt
  581. if __name__ == "__main__":
  582. opt = parse_opt()
  583. main(opt)