Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

667 lines
34KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Train a YOLOv5 model on a custom dataset.
  4. Models and datasets download automatically from the latest YOLOv5 release.
  5. Models: https://github.com/ultralytics/yolov5/tree/master/models
  6. Datasets: https://github.com/ultralytics/yolov5/tree/master/data
  7. Tutorial: https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data
  8. Usage:
  9. $ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640 # from pretrained (RECOMMENDED)
  10. $ python path/to/train.py --data coco128.yaml --weights '' --cfg yolov5s.yaml --img 640 # from scratch
  11. """
  12. import argparse
  13. import math
  14. import os
  15. import random
  16. import sys
  17. import time
  18. from copy import deepcopy
  19. from datetime import datetime
  20. from pathlib import Path
  21. import numpy as np
  22. import torch
  23. import torch.distributed as dist
  24. import torch.nn as nn
  25. import yaml
  26. from torch.cuda import amp
  27. from torch.nn.parallel import DistributedDataParallel as DDP
  28. from torch.optim import SGD, Adam, AdamW, lr_scheduler
  29. from tqdm.auto import tqdm
  30. FILE = Path(__file__).resolve()
  31. ROOT = FILE.parents[0] # YOLOv5 root directory
  32. if str(ROOT) not in sys.path:
  33. sys.path.append(str(ROOT)) # add ROOT to PATH
  34. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  35. import val # for end-of-epoch mAP
  36. from models.experimental import attempt_load
  37. from models.yolo import Model
  38. from utils.autoanchor import check_anchors
  39. from utils.autobatch import check_train_batch_size
  40. from utils.callbacks import Callbacks
  41. from utils.datasets import create_dataloader
  42. from utils.downloads import attempt_download
  43. from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
  44. check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
  45. intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, one_cycle,
  46. print_args, print_mutation, strip_optimizer)
  47. from utils.loggers import Loggers
  48. from utils.loggers.wandb.wandb_utils import check_wandb_resume
  49. from utils.loss import ComputeLoss
  50. from utils.metrics import fitness
  51. from utils.plots import plot_evolve, plot_labels
  52. from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
  53. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  54. RANK = int(os.getenv('RANK', -1))
  55. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
  56. def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictionary
  57. save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = \
  58. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
  59. opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
  60. callbacks.run('on_pretrain_routine_start')
  61. # Directories
  62. w = save_dir / 'weights' # weights dir
  63. (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
  64. last, best = w / 'last.pt', w / 'best.pt'
  65. # Hyperparameters
  66. if isinstance(hyp, str):
  67. with open(hyp, errors='ignore') as f:
  68. hyp = yaml.safe_load(f) # load hyps dict
  69. LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  70. # Save run settings
  71. if not evolve:
  72. with open(save_dir / 'hyp.yaml', 'w') as f:
  73. yaml.safe_dump(hyp, f, sort_keys=False)
  74. with open(save_dir / 'opt.yaml', 'w') as f:
  75. yaml.safe_dump(vars(opt), f, sort_keys=False)
  76. # Loggers
  77. data_dict = None
  78. if RANK in [-1, 0]:
  79. loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
  80. if loggers.wandb:
  81. data_dict = loggers.wandb.data_dict
  82. if resume:
  83. weights, epochs, hyp, batch_size = opt.weights, opt.epochs, opt.hyp, opt.batch_size
  84. # Register actions
  85. for k in methods(loggers):
  86. callbacks.register_action(k, callback=getattr(loggers, k))
  87. # Config
  88. plots = not evolve # create plots
  89. cuda = device.type != 'cpu'
  90. init_seeds(1 + RANK)
  91. with torch_distributed_zero_first(LOCAL_RANK):
  92. data_dict = data_dict or check_dataset(data) # check if None
  93. train_path, val_path = data_dict['train'], data_dict['val']
  94. nc = 1 if single_cls else int(data_dict['nc']) # number of classes
  95. names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  96. assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
  97. is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
  98. # Model
  99. check_suffix(weights, '.pt') # check weights
  100. pretrained = weights.endswith('.pt')
  101. if pretrained:
  102. with torch_distributed_zero_first(LOCAL_RANK):
  103. weights = attempt_download(weights) # download if not found locally
  104. ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
  105. model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  106. exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
  107. csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
  108. csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
  109. model.load_state_dict(csd, strict=False) # load
  110. LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
  111. else:
  112. model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  113. # Freeze
  114. freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
  115. for k, v in model.named_parameters():
  116. v.requires_grad = True # train all layers
  117. if any(x in k for x in freeze):
  118. LOGGER.info(f'freezing {k}')
  119. v.requires_grad = False
  120. # Image size
  121. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  122. imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
  123. # Batch size
  124. if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
  125. batch_size = check_train_batch_size(model, imgsz)
  126. loggers.on_params_update({"batch_size": batch_size})
  127. # Optimizer
  128. nbs = 64 # nominal batch size
  129. accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
  130. hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
  131. LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  132. g0, g1, g2 = [], [], [] # optimizer parameter groups
  133. for v in model.modules():
  134. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias
  135. g2.append(v.bias)
  136. if isinstance(v, nn.BatchNorm2d): # weight (no decay)
  137. g0.append(v.weight)
  138. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
  139. g1.append(v.weight)
  140. if opt.optimizer == 'Adam':
  141. optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  142. elif opt.optimizer == 'AdamW':
  143. optimizer = AdamW(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  144. else:
  145. optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  146. optimizer.add_param_group({'params': g1, 'weight_decay': hyp['weight_decay']}) # add g1 with weight_decay
  147. optimizer.add_param_group({'params': g2}) # add g2 (biases)
  148. LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
  149. f"{len(g0)} weight (no decay), {len(g1)} weight, {len(g2)} bias")
  150. del g0, g1, g2
  151. # Scheduler
  152. if opt.cos_lr:
  153. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  154. else:
  155. lf = lambda x: (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  156. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
  157. # EMA
  158. ema = ModelEMA(model) if RANK in [-1, 0] else None
  159. # Resume
  160. start_epoch, best_fitness = 0, 0.0
  161. if pretrained:
  162. # Optimizer
  163. if ckpt['optimizer'] is not None:
  164. optimizer.load_state_dict(ckpt['optimizer'])
  165. best_fitness = ckpt['best_fitness']
  166. # EMA
  167. if ema and ckpt.get('ema'):
  168. ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
  169. ema.updates = ckpt['updates']
  170. # Epochs
  171. start_epoch = ckpt['epoch'] + 1
  172. if resume:
  173. assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
  174. if epochs < start_epoch:
  175. LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
  176. epochs += ckpt['epoch'] # finetune additional epochs
  177. del ckpt, csd
  178. # DP mode
  179. if cuda and RANK == -1 and torch.cuda.device_count() > 1:
  180. LOGGER.warning('WARNING: DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.\n'
  181. 'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.')
  182. model = torch.nn.DataParallel(model)
  183. # SyncBatchNorm
  184. if opt.sync_bn and cuda and RANK != -1:
  185. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  186. LOGGER.info('Using SyncBatchNorm()')
  187. # Trainloader
  188. train_loader, dataset = create_dataloader(train_path,
  189. imgsz,
  190. batch_size // WORLD_SIZE,
  191. gs,
  192. single_cls,
  193. hyp=hyp,
  194. augment=True,
  195. cache=None if opt.cache == 'val' else opt.cache,
  196. rect=opt.rect,
  197. rank=LOCAL_RANK,
  198. workers=workers,
  199. image_weights=opt.image_weights,
  200. quad=opt.quad,
  201. prefix=colorstr('train: '),
  202. shuffle=True)
  203. mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
  204. nb = len(train_loader) # number of batches
  205. assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
  206. # Process 0
  207. if RANK in [-1, 0]:
  208. val_loader = create_dataloader(val_path,
  209. imgsz,
  210. batch_size // WORLD_SIZE * 2,
  211. gs,
  212. single_cls,
  213. hyp=hyp,
  214. cache=None if noval else opt.cache,
  215. rect=True,
  216. rank=-1,
  217. workers=workers * 2,
  218. pad=0.5,
  219. prefix=colorstr('val: '))[0]
  220. if not resume:
  221. labels = np.concatenate(dataset.labels, 0)
  222. # c = torch.tensor(labels[:, 0]) # classes
  223. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  224. # model._initialize_biases(cf.to(device))
  225. if plots:
  226. plot_labels(labels, names, save_dir)
  227. # Anchors
  228. if not opt.noautoanchor:
  229. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  230. model.half().float() # pre-reduce anchor precision
  231. callbacks.run('on_pretrain_routine_end')
  232. # DDP mode
  233. if cuda and RANK != -1:
  234. model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
  235. # Model attributes
  236. nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
  237. hyp['box'] *= 3 / nl # scale to layers
  238. hyp['cls'] *= nc / 80 * 3 / nl # scale to classes and layers
  239. hyp['obj'] *= (imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
  240. hyp['label_smoothing'] = opt.label_smoothing
  241. model.nc = nc # attach number of classes to model
  242. model.hyp = hyp # attach hyperparameters to model
  243. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  244. model.names = names
  245. # Start training
  246. t0 = time.time()
  247. nw = max(round(hyp['warmup_epochs'] * nb), 100) # number of warmup iterations, max(3 epochs, 100 iterations)
  248. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  249. last_opt_step = -1
  250. maps = np.zeros(nc) # mAP per class
  251. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  252. scheduler.last_epoch = start_epoch - 1 # do not move
  253. scaler = amp.GradScaler(enabled=cuda)
  254. stopper = EarlyStopping(patience=opt.patience)
  255. compute_loss = ComputeLoss(model) # init loss class
  256. callbacks.run('on_train_start')
  257. LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
  258. f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n'
  259. f"Logging results to {colorstr('bold', save_dir)}\n"
  260. f'Starting training for {epochs} epochs...')
  261. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  262. callbacks.run('on_train_epoch_start')
  263. model.train()
  264. # Update image weights (optional, single-GPU only)
  265. if opt.image_weights:
  266. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  267. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  268. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  269. # Update mosaic border (optional)
  270. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  271. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  272. mloss = torch.zeros(3, device=device) # mean losses
  273. if RANK != -1:
  274. train_loader.sampler.set_epoch(epoch)
  275. pbar = enumerate(train_loader)
  276. LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))
  277. if RANK in [-1, 0]:
  278. pbar = tqdm(pbar, total=nb, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
  279. optimizer.zero_grad()
  280. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  281. callbacks.run('on_train_batch_start')
  282. ni = i + nb * epoch # number integrated batches (since train start)
  283. imgs = imgs.to(device, non_blocking=True).float() / 255 # uint8 to float32, 0-255 to 0.0-1.0
  284. # Warmup
  285. if ni <= nw:
  286. xi = [0, nw] # x interp
  287. # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  288. accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
  289. for j, x in enumerate(optimizer.param_groups):
  290. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  291. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  292. if 'momentum' in x:
  293. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  294. # Multi-scale
  295. if opt.multi_scale:
  296. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  297. sf = sz / max(imgs.shape[2:]) # scale factor
  298. if sf != 1:
  299. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  300. imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  301. # Forward
  302. with amp.autocast(enabled=cuda):
  303. pred = model(imgs) # forward
  304. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  305. if RANK != -1:
  306. loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
  307. if opt.quad:
  308. loss *= 4.
  309. # Backward
  310. scaler.scale(loss).backward()
  311. # Optimize
  312. if ni - last_opt_step >= accumulate:
  313. scaler.step(optimizer) # optimizer.step
  314. scaler.update()
  315. optimizer.zero_grad()
  316. if ema:
  317. ema.update(model)
  318. last_opt_step = ni
  319. # Log
  320. if RANK in [-1, 0]:
  321. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  322. mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
  323. pbar.set_description(('%10s' * 2 + '%10.4g' * 5) %
  324. (f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
  325. callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots, opt.sync_bn)
  326. if callbacks.stop_training:
  327. return
  328. # end batch ------------------------------------------------------------------------------------------------
  329. # Scheduler
  330. lr = [x['lr'] for x in optimizer.param_groups] # for loggers
  331. scheduler.step()
  332. if RANK in [-1, 0]:
  333. # mAP
  334. callbacks.run('on_train_epoch_end', epoch=epoch)
  335. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
  336. final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
  337. if not noval or final_epoch: # Calculate mAP
  338. results, maps, _ = val.run(data_dict,
  339. batch_size=batch_size // WORLD_SIZE * 2,
  340. imgsz=imgsz,
  341. model=ema.ema,
  342. single_cls=single_cls,
  343. dataloader=val_loader,
  344. save_dir=save_dir,
  345. plots=False,
  346. callbacks=callbacks,
  347. compute_loss=compute_loss)
  348. # Update best mAP
  349. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  350. if fi > best_fitness:
  351. best_fitness = fi
  352. log_vals = list(mloss) + list(results) + lr
  353. callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
  354. # Save model
  355. if (not nosave) or (final_epoch and not evolve): # if save
  356. ckpt = {
  357. 'epoch': epoch,
  358. 'best_fitness': best_fitness,
  359. 'model': deepcopy(de_parallel(model)).half(),
  360. 'ema': deepcopy(ema.ema).half(),
  361. 'updates': ema.updates,
  362. 'optimizer': optimizer.state_dict(),
  363. 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None,
  364. 'date': datetime.now().isoformat()}
  365. # Save last, best and delete
  366. torch.save(ckpt, last)
  367. if best_fitness == fi:
  368. torch.save(ckpt, best)
  369. if (epoch > 0) and (opt.save_period > 0) and (epoch % opt.save_period == 0):
  370. torch.save(ckpt, w / f'epoch{epoch}.pt')
  371. del ckpt
  372. callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
  373. # Stop Single-GPU
  374. if RANK == -1 and stopper(epoch=epoch, fitness=fi):
  375. break
  376. # Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576
  377. # stop = stopper(epoch=epoch, fitness=fi)
  378. # if RANK == 0:
  379. # dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks
  380. # Stop DPP
  381. # with torch_distributed_zero_first(RANK):
  382. # if stop:
  383. # break # must break all DDP ranks
  384. # end epoch ----------------------------------------------------------------------------------------------------
  385. # end training -----------------------------------------------------------------------------------------------------
  386. if RANK in [-1, 0]:
  387. LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
  388. for f in last, best:
  389. if f.exists():
  390. strip_optimizer(f) # strip optimizers
  391. if f is best:
  392. LOGGER.info(f'\nValidating {f}...')
  393. results, _, _ = val.run(
  394. data_dict,
  395. batch_size=batch_size // WORLD_SIZE * 2,
  396. imgsz=imgsz,
  397. model=attempt_load(f, device).half(),
  398. iou_thres=0.65 if is_coco else 0.60, # best pycocotools results at 0.65
  399. single_cls=single_cls,
  400. dataloader=val_loader,
  401. save_dir=save_dir,
  402. save_json=is_coco,
  403. verbose=True,
  404. plots=True,
  405. callbacks=callbacks,
  406. compute_loss=compute_loss) # val best model with plots
  407. if is_coco:
  408. callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
  409. callbacks.run('on_train_end', last, best, plots, epoch, results)
  410. LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
  411. torch.cuda.empty_cache()
  412. return results
  413. def parse_opt(known=False):
  414. parser = argparse.ArgumentParser()
  415. parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='initial weights path')
  416. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  417. parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
  418. parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-low.yaml', help='hyperparameters path')
  419. parser.add_argument('--epochs', type=int, default=300)
  420. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
  421. parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
  422. parser.add_argument('--rect', action='store_true', help='rectangular training')
  423. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  424. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  425. parser.add_argument('--noval', action='store_true', help='only validate final epoch')
  426. parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor')
  427. parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
  428. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  429. parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
  430. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  431. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  432. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  433. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  434. parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='SGD', help='optimizer')
  435. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  436. parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
  437. parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
  438. parser.add_argument('--name', default='exp', help='save to project/name')
  439. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  440. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  441. parser.add_argument('--cos-lr', action='store_true', help='cosine LR scheduler')
  442. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
  443. parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
  444. parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
  445. parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
  446. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  447. # Weights & Biases arguments
  448. parser.add_argument('--entity', default=None, help='W&B: Entity')
  449. parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='W&B: Upload data, "val" option')
  450. parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval')
  451. parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')
  452. opt = parser.parse_known_args()[0] if known else parser.parse_args()
  453. return opt
  454. def main(opt, callbacks=Callbacks()):
  455. # Checks
  456. if RANK in [-1, 0]:
  457. print_args(vars(opt))
  458. check_git_status()
  459. check_requirements(exclude=['thop'])
  460. # Resume
  461. if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run
  462. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  463. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  464. with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f:
  465. opt = argparse.Namespace(**yaml.safe_load(f)) # replace
  466. opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
  467. LOGGER.info(f'Resuming training from {ckpt}')
  468. else:
  469. opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
  470. check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
  471. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  472. if opt.evolve:
  473. if opt.project == str(ROOT / 'runs/train'): # if default project name, rename to runs/evolve
  474. opt.project = str(ROOT / 'runs/evolve')
  475. opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
  476. if opt.name == 'cfg':
  477. opt.name = Path(opt.cfg).stem # use model.yaml as name
  478. opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
  479. # DDP mode
  480. device = select_device(opt.device, batch_size=opt.batch_size)
  481. if LOCAL_RANK != -1:
  482. msg = 'is not compatible with YOLOv5 Multi-GPU DDP training'
  483. assert not opt.image_weights, f'--image-weights {msg}'
  484. assert not opt.evolve, f'--evolve {msg}'
  485. assert opt.batch_size != -1, f'AutoBatch with --batch-size -1 {msg}, please pass a valid --batch-size'
  486. assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multiple of WORLD_SIZE'
  487. assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
  488. torch.cuda.set_device(LOCAL_RANK)
  489. device = torch.device('cuda', LOCAL_RANK)
  490. dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
  491. # Train
  492. if not opt.evolve:
  493. train(opt.hyp, opt, device, callbacks)
  494. if WORLD_SIZE > 1 and RANK == 0:
  495. LOGGER.info('Destroying process group... ')
  496. dist.destroy_process_group()
  497. # Evolve hyperparameters (optional)
  498. else:
  499. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  500. meta = {
  501. 'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  502. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  503. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  504. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  505. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  506. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  507. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  508. 'box': (1, 0.02, 0.2), # box loss gain
  509. 'cls': (1, 0.2, 4.0), # cls loss gain
  510. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  511. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  512. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  513. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  514. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  515. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  516. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  517. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  518. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  519. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  520. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  521. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  522. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  523. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  524. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  525. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  526. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  527. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  528. 'mixup': (1, 0.0, 1.0), # image mixup (probability)
  529. 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
  530. with open(opt.hyp, errors='ignore') as f:
  531. hyp = yaml.safe_load(f) # load hyps dict
  532. if 'anchors' not in hyp: # anchors commented in hyp.yaml
  533. hyp['anchors'] = 3
  534. opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
  535. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  536. evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
  537. if opt.bucket:
  538. os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {evolve_csv}') # download evolve.csv if exists
  539. for _ in range(opt.evolve): # generations to evolve
  540. if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
  541. # Select parent(s)
  542. parent = 'single' # parent selection method: 'single' or 'weighted'
  543. x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
  544. n = min(5, len(x)) # number of previous results to consider
  545. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  546. w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
  547. if parent == 'single' or len(x) == 1:
  548. # x = x[random.randint(0, n - 1)] # random selection
  549. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  550. elif parent == 'weighted':
  551. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  552. # Mutate
  553. mp, s = 0.8, 0.2 # mutation probability, sigma
  554. npr = np.random
  555. npr.seed(int(time.time()))
  556. g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1
  557. ng = len(meta)
  558. v = np.ones(ng)
  559. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  560. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  561. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  562. hyp[k] = float(x[i + 7] * v[i]) # mutate
  563. # Constrain to limits
  564. for k, v in meta.items():
  565. hyp[k] = max(hyp[k], v[1]) # lower limit
  566. hyp[k] = min(hyp[k], v[2]) # upper limit
  567. hyp[k] = round(hyp[k], 5) # significant digits
  568. # Train mutation
  569. results = train(hyp.copy(), opt, device, callbacks)
  570. callbacks = Callbacks()
  571. # Write mutation results
  572. print_mutation(results, hyp.copy(), save_dir, opt.bucket)
  573. # Plot results
  574. plot_evolve(evolve_csv)
  575. LOGGER.info(f'Hyperparameter evolution finished {opt.evolve} generations\n'
  576. f"Results saved to {colorstr('bold', save_dir)}\n"
  577. f'Usage example: $ python train.py --hyp {evolve_yaml}')
  578. def run(**kwargs):
  579. # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
  580. opt = parse_opt(True)
  581. for k, v in kwargs.items():
  582. setattr(opt, k, v)
  583. main(opt)
  584. return opt
  585. if __name__ == "__main__":
  586. opt = parse_opt()
  587. main(opt)