Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

667 lines
34KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Train a YOLOv5 model on a custom dataset.
  4. Models and datasets download automatically from the latest YOLOv5 release.
  5. Models: https://github.com/ultralytics/yolov5/tree/master/models
  6. Datasets: https://github.com/ultralytics/yolov5/tree/master/data
  7. Tutorial: https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data
  8. Usage:
  9. $ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640 # from pretrained (RECOMMENDED)
  10. $ python path/to/train.py --data coco128.yaml --weights '' --cfg yolov5s.yaml --img 640 # from scratch
  11. """
  12. import argparse
  13. import math
  14. import os
  15. import random
  16. import sys
  17. import time
  18. from copy import deepcopy
  19. from datetime import datetime
  20. from pathlib import Path
  21. import numpy as np
  22. import torch
  23. import torch.distributed as dist
  24. import torch.nn as nn
  25. import yaml
  26. from torch.nn.parallel import DistributedDataParallel as DDP
  27. from torch.optim import SGD, Adam, AdamW, lr_scheduler
  28. from tqdm import tqdm
  29. FILE = Path(__file__).resolve()
  30. ROOT = FILE.parents[0] # YOLOv5 root directory
  31. if str(ROOT) not in sys.path:
  32. sys.path.append(str(ROOT)) # add ROOT to PATH
  33. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  34. import val # for end-of-epoch mAP
  35. from models.experimental import attempt_load
  36. from models.yolo import Model
  37. from utils.autoanchor import check_anchors
  38. from utils.autobatch import check_train_batch_size
  39. from utils.callbacks import Callbacks
  40. from utils.dataloaders import create_dataloader
  41. from utils.downloads import attempt_download
  42. from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size,
  43. check_requirements, check_suffix, check_version, check_yaml, colorstr, get_latest_run,
  44. increment_path, init_seeds, intersect_dicts, labels_to_class_weights,
  45. labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer)
  46. from utils.loggers import Loggers
  47. from utils.loggers.wandb.wandb_utils import check_wandb_resume
  48. from utils.loss import ComputeLoss
  49. from utils.metrics import fitness
  50. from utils.plots import plot_evolve, plot_labels
  51. from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
  52. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  53. RANK = int(os.getenv('RANK', -1))
  54. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
  55. def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictionary
  56. save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = \
  57. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
  58. opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
  59. callbacks.run('on_pretrain_routine_start')
  60. # Directories
  61. w = save_dir / 'weights' # weights dir
  62. (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
  63. last, best = w / 'last.pt', w / 'best.pt'
  64. # Hyperparameters
  65. if isinstance(hyp, str):
  66. with open(hyp, errors='ignore') as f:
  67. hyp = yaml.safe_load(f) # load hyps dict
  68. LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  69. # Save run settings
  70. if not evolve:
  71. with open(save_dir / 'hyp.yaml', 'w') as f:
  72. yaml.safe_dump(hyp, f, sort_keys=False)
  73. with open(save_dir / 'opt.yaml', 'w') as f:
  74. yaml.safe_dump(vars(opt), f, sort_keys=False)
  75. # Loggers
  76. data_dict = None
  77. if RANK in {-1, 0}:
  78. loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
  79. if loggers.wandb:
  80. data_dict = loggers.wandb.data_dict
  81. if resume:
  82. weights, epochs, hyp, batch_size = opt.weights, opt.epochs, opt.hyp, opt.batch_size
  83. # Register actions
  84. for k in methods(loggers):
  85. callbacks.register_action(k, callback=getattr(loggers, k))
  86. # Config
  87. plots = not evolve and not opt.noplots # create plots
  88. cuda = device.type != 'cpu'
  89. init_seeds(1 + RANK)
  90. with torch_distributed_zero_first(LOCAL_RANK):
  91. data_dict = data_dict or check_dataset(data) # check if None
  92. train_path, val_path = data_dict['train'], data_dict['val']
  93. nc = 1 if single_cls else int(data_dict['nc']) # number of classes
  94. names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  95. assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
  96. is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
  97. # Model
  98. check_suffix(weights, '.pt') # check weights
  99. pretrained = weights.endswith('.pt')
  100. if pretrained:
  101. with torch_distributed_zero_first(LOCAL_RANK):
  102. weights = attempt_download(weights) # download if not found locally
  103. ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
  104. model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  105. exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
  106. csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
  107. csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
  108. model.load_state_dict(csd, strict=False) # load
  109. LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
  110. else:
  111. model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  112. amp = check_amp(model) # check AMP
  113. # Freeze
  114. freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
  115. for k, v in model.named_parameters():
  116. v.requires_grad = True # train all layers
  117. if any(x in k for x in freeze):
  118. LOGGER.info(f'freezing {k}')
  119. v.requires_grad = False
  120. # Image size
  121. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  122. imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
  123. # Batch size
  124. if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
  125. batch_size = check_train_batch_size(model, imgsz, amp)
  126. loggers.on_params_update({"batch_size": batch_size})
  127. # Optimizer
  128. nbs = 64 # nominal batch size
  129. accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
  130. hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
  131. LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  132. g = [], [], [] # optimizer parameter groups
  133. bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
  134. for v in model.modules():
  135. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias
  136. g[2].append(v.bias)
  137. if isinstance(v, bn): # weight (no decay)
  138. g[1].append(v.weight)
  139. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
  140. g[0].append(v.weight)
  141. if opt.optimizer == 'Adam':
  142. optimizer = Adam(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  143. elif opt.optimizer == 'AdamW':
  144. optimizer = AdamW(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  145. else:
  146. optimizer = SGD(g[2], lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  147. optimizer.add_param_group({'params': g[0], 'weight_decay': hyp['weight_decay']}) # add g0 with weight_decay
  148. optimizer.add_param_group({'params': g[1]}) # add g1 (BatchNorm2d weights)
  149. LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
  150. f"{len(g[1])} weight (no decay), {len(g[0])} weight, {len(g[2])} bias")
  151. del g
  152. # Scheduler
  153. if opt.cos_lr:
  154. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  155. else:
  156. lf = lambda x: (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  157. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
  158. # EMA
  159. ema = ModelEMA(model) if RANK in {-1, 0} else None
  160. # Resume
  161. start_epoch, best_fitness = 0, 0.0
  162. if pretrained:
  163. # Optimizer
  164. if ckpt['optimizer'] is not None:
  165. optimizer.load_state_dict(ckpt['optimizer'])
  166. best_fitness = ckpt['best_fitness']
  167. # EMA
  168. if ema and ckpt.get('ema'):
  169. ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
  170. ema.updates = ckpt['updates']
  171. # Epochs
  172. start_epoch = ckpt['epoch'] + 1
  173. if resume:
  174. assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
  175. if epochs < start_epoch:
  176. LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
  177. epochs += ckpt['epoch'] # finetune additional epochs
  178. del ckpt, csd
  179. # DP mode
  180. if cuda and RANK == -1 and torch.cuda.device_count() > 1:
  181. LOGGER.warning('WARNING: DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.\n'
  182. 'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.')
  183. model = torch.nn.DataParallel(model)
  184. # SyncBatchNorm
  185. if opt.sync_bn and cuda and RANK != -1:
  186. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  187. LOGGER.info('Using SyncBatchNorm()')
  188. # Trainloader
  189. train_loader, dataset = create_dataloader(train_path,
  190. imgsz,
  191. batch_size // WORLD_SIZE,
  192. gs,
  193. single_cls,
  194. hyp=hyp,
  195. augment=True,
  196. cache=None if opt.cache == 'val' else opt.cache,
  197. rect=opt.rect,
  198. rank=LOCAL_RANK,
  199. workers=workers,
  200. image_weights=opt.image_weights,
  201. quad=opt.quad,
  202. prefix=colorstr('train: '),
  203. shuffle=True)
  204. mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
  205. nb = len(train_loader) # number of batches
  206. assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
  207. # Process 0
  208. if RANK in {-1, 0}:
  209. val_loader = create_dataloader(val_path,
  210. imgsz,
  211. batch_size // WORLD_SIZE * 2,
  212. gs,
  213. single_cls,
  214. hyp=hyp,
  215. cache=None if noval else opt.cache,
  216. rect=True,
  217. rank=-1,
  218. workers=workers * 2,
  219. pad=0.5,
  220. prefix=colorstr('val: '))[0]
  221. if not resume:
  222. labels = np.concatenate(dataset.labels, 0)
  223. # c = torch.tensor(labels[:, 0]) # classes
  224. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  225. # model._initialize_biases(cf.to(device))
  226. if plots:
  227. plot_labels(labels, names, save_dir)
  228. # Anchors
  229. if not opt.noautoanchor:
  230. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  231. model.half().float() # pre-reduce anchor precision
  232. callbacks.run('on_pretrain_routine_end')
  233. # DDP mode
  234. if cuda and RANK != -1:
  235. if check_version(torch.__version__, '1.11.0'):
  236. model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
  237. else:
  238. model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
  239. # Model attributes
  240. nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
  241. hyp['box'] *= 3 / nl # scale to layers
  242. hyp['cls'] *= nc / 80 * 3 / nl # scale to classes and layers
  243. hyp['obj'] *= (imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
  244. hyp['label_smoothing'] = opt.label_smoothing
  245. model.nc = nc # attach number of classes to model
  246. model.hyp = hyp # attach hyperparameters to model
  247. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  248. model.names = names
  249. # Start training
  250. t0 = time.time()
  251. nw = max(round(hyp['warmup_epochs'] * nb), 100) # number of warmup iterations, max(3 epochs, 100 iterations)
  252. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  253. last_opt_step = -1
  254. maps = np.zeros(nc) # mAP per class
  255. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  256. scheduler.last_epoch = start_epoch - 1 # do not move
  257. scaler = torch.cuda.amp.GradScaler(enabled=amp)
  258. stopper, stop = EarlyStopping(patience=opt.patience), False
  259. compute_loss = ComputeLoss(model) # init loss class
  260. callbacks.run('on_train_start')
  261. LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
  262. f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n'
  263. f"Logging results to {colorstr('bold', save_dir)}\n"
  264. f'Starting training for {epochs} epochs...')
  265. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  266. callbacks.run('on_train_epoch_start')
  267. model.train()
  268. # Update image weights (optional, single-GPU only)
  269. if opt.image_weights:
  270. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  271. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  272. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  273. # Update mosaic border (optional)
  274. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  275. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  276. mloss = torch.zeros(3, device=device) # mean losses
  277. if RANK != -1:
  278. train_loader.sampler.set_epoch(epoch)
  279. pbar = enumerate(train_loader)
  280. LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))
  281. if RANK in {-1, 0}:
  282. pbar = tqdm(pbar, total=nb, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
  283. optimizer.zero_grad()
  284. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  285. callbacks.run('on_train_batch_start')
  286. ni = i + nb * epoch # number integrated batches (since train start)
  287. imgs = imgs.to(device, non_blocking=True).float() / 255 # uint8 to float32, 0-255 to 0.0-1.0
  288. # Warmup
  289. if ni <= nw:
  290. xi = [0, nw] # x interp
  291. # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  292. accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
  293. for j, x in enumerate(optimizer.param_groups):
  294. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  295. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * lf(epoch)])
  296. if 'momentum' in x:
  297. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  298. # Multi-scale
  299. if opt.multi_scale:
  300. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  301. sf = sz / max(imgs.shape[2:]) # scale factor
  302. if sf != 1:
  303. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  304. imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  305. # Forward
  306. with torch.cuda.amp.autocast(amp):
  307. pred = model(imgs) # forward
  308. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  309. if RANK != -1:
  310. loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
  311. if opt.quad:
  312. loss *= 4.
  313. # Backward
  314. scaler.scale(loss).backward()
  315. # Optimize
  316. if ni - last_opt_step >= accumulate:
  317. scaler.step(optimizer) # optimizer.step
  318. scaler.update()
  319. optimizer.zero_grad()
  320. if ema:
  321. ema.update(model)
  322. last_opt_step = ni
  323. # Log
  324. if RANK in {-1, 0}:
  325. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  326. mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
  327. pbar.set_description(('%10s' * 2 + '%10.4g' * 5) %
  328. (f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
  329. callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots)
  330. if callbacks.stop_training:
  331. return
  332. # end batch ------------------------------------------------------------------------------------------------
  333. # Scheduler
  334. lr = [x['lr'] for x in optimizer.param_groups] # for loggers
  335. scheduler.step()
  336. if RANK in {-1, 0}:
  337. # mAP
  338. callbacks.run('on_train_epoch_end', epoch=epoch)
  339. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
  340. final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
  341. if not noval or final_epoch: # Calculate mAP
  342. results, maps, _ = val.run(data_dict,
  343. batch_size=batch_size // WORLD_SIZE * 2,
  344. imgsz=imgsz,
  345. model=ema.ema,
  346. single_cls=single_cls,
  347. dataloader=val_loader,
  348. save_dir=save_dir,
  349. plots=False,
  350. callbacks=callbacks,
  351. compute_loss=compute_loss)
  352. # Update best mAP
  353. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  354. stop = stopper(epoch=epoch, fitness=fi) # early stop check
  355. if fi > best_fitness:
  356. best_fitness = fi
  357. log_vals = list(mloss) + list(results) + lr
  358. callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
  359. # Save model
  360. if (not nosave) or (final_epoch and not evolve): # if save
  361. ckpt = {
  362. 'epoch': epoch,
  363. 'best_fitness': best_fitness,
  364. 'model': deepcopy(de_parallel(model)).half(),
  365. 'ema': deepcopy(ema.ema).half(),
  366. 'updates': ema.updates,
  367. 'optimizer': optimizer.state_dict(),
  368. 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None,
  369. 'date': datetime.now().isoformat()}
  370. # Save last, best and delete
  371. torch.save(ckpt, last)
  372. if best_fitness == fi:
  373. torch.save(ckpt, best)
  374. if opt.save_period > 0 and epoch % opt.save_period == 0:
  375. torch.save(ckpt, w / f'epoch{epoch}.pt')
  376. del ckpt
  377. callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
  378. # EarlyStopping
  379. if RANK != -1: # if DDP training
  380. broadcast_list = [stop if RANK == 0 else None]
  381. dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
  382. if RANK != 0:
  383. stop = broadcast_list[0]
  384. if stop:
  385. break # must break all DDP ranks
  386. # end epoch ----------------------------------------------------------------------------------------------------
  387. # end training -----------------------------------------------------------------------------------------------------
  388. if RANK in {-1, 0}:
  389. LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
  390. for f in last, best:
  391. if f.exists():
  392. strip_optimizer(f) # strip optimizers
  393. if f is best:
  394. LOGGER.info(f'\nValidating {f}...')
  395. results, _, _ = val.run(
  396. data_dict,
  397. batch_size=batch_size // WORLD_SIZE * 2,
  398. imgsz=imgsz,
  399. model=attempt_load(f, device).half(),
  400. iou_thres=0.65 if is_coco else 0.60, # best pycocotools results at 0.65
  401. single_cls=single_cls,
  402. dataloader=val_loader,
  403. save_dir=save_dir,
  404. save_json=is_coco,
  405. verbose=True,
  406. plots=plots,
  407. callbacks=callbacks,
  408. compute_loss=compute_loss) # val best model with plots
  409. if is_coco:
  410. callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
  411. callbacks.run('on_train_end', last, best, plots, epoch, results)
  412. torch.cuda.empty_cache()
  413. return results
  414. def parse_opt(known=False):
  415. parser = argparse.ArgumentParser()
  416. parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='initial weights path')
  417. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  418. parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
  419. parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-low.yaml', help='hyperparameters path')
  420. parser.add_argument('--epochs', type=int, default=300)
  421. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
  422. parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
  423. parser.add_argument('--rect', action='store_true', help='rectangular training')
  424. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  425. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  426. parser.add_argument('--noval', action='store_true', help='only validate final epoch')
  427. parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor')
  428. parser.add_argument('--noplots', action='store_true', help='save no plot files')
  429. parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
  430. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  431. parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
  432. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  433. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  434. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  435. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  436. parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='SGD', help='optimizer')
  437. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  438. parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
  439. parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
  440. parser.add_argument('--name', default='exp', help='save to project/name')
  441. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  442. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  443. parser.add_argument('--cos-lr', action='store_true', help='cosine LR scheduler')
  444. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
  445. parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
  446. parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
  447. parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
  448. parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
  449. # Weights & Biases arguments
  450. parser.add_argument('--entity', default=None, help='W&B: Entity')
  451. parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='W&B: Upload data, "val" option')
  452. parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval')
  453. parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')
  454. opt = parser.parse_known_args()[0] if known else parser.parse_args()
  455. return opt
  456. def main(opt, callbacks=Callbacks()):
  457. # Checks
  458. if RANK in {-1, 0}:
  459. print_args(vars(opt))
  460. check_git_status()
  461. check_requirements(exclude=['thop'])
  462. # Resume
  463. if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run
  464. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  465. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  466. with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f:
  467. opt = argparse.Namespace(**yaml.safe_load(f)) # replace
  468. opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
  469. LOGGER.info(f'Resuming training from {ckpt}')
  470. else:
  471. opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
  472. check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
  473. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  474. if opt.evolve:
  475. if opt.project == str(ROOT / 'runs/train'): # if default project name, rename to runs/evolve
  476. opt.project = str(ROOT / 'runs/evolve')
  477. opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
  478. if opt.name == 'cfg':
  479. opt.name = Path(opt.cfg).stem # use model.yaml as name
  480. opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
  481. # DDP mode
  482. device = select_device(opt.device, batch_size=opt.batch_size)
  483. if LOCAL_RANK != -1:
  484. msg = 'is not compatible with YOLOv5 Multi-GPU DDP training'
  485. assert not opt.image_weights, f'--image-weights {msg}'
  486. assert not opt.evolve, f'--evolve {msg}'
  487. assert opt.batch_size != -1, f'AutoBatch with --batch-size -1 {msg}, please pass a valid --batch-size'
  488. assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multiple of WORLD_SIZE'
  489. assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
  490. torch.cuda.set_device(LOCAL_RANK)
  491. device = torch.device('cuda', LOCAL_RANK)
  492. dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
  493. # Train
  494. if not opt.evolve:
  495. train(opt.hyp, opt, device, callbacks)
  496. if WORLD_SIZE > 1 and RANK == 0:
  497. LOGGER.info('Destroying process group... ')
  498. dist.destroy_process_group()
  499. # Evolve hyperparameters (optional)
  500. else:
  501. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  502. meta = {
  503. 'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  504. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  505. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  506. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  507. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  508. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  509. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  510. 'box': (1, 0.02, 0.2), # box loss gain
  511. 'cls': (1, 0.2, 4.0), # cls loss gain
  512. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  513. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  514. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  515. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  516. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  517. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  518. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  519. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  520. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  521. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  522. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  523. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  524. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  525. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  526. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  527. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  528. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  529. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  530. 'mixup': (1, 0.0, 1.0), # image mixup (probability)
  531. 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
  532. with open(opt.hyp, errors='ignore') as f:
  533. hyp = yaml.safe_load(f) # load hyps dict
  534. if 'anchors' not in hyp: # anchors commented in hyp.yaml
  535. hyp['anchors'] = 3
  536. opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
  537. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  538. evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
  539. if opt.bucket:
  540. os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {evolve_csv}') # download evolve.csv if exists
  541. for _ in range(opt.evolve): # generations to evolve
  542. if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
  543. # Select parent(s)
  544. parent = 'single' # parent selection method: 'single' or 'weighted'
  545. x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
  546. n = min(5, len(x)) # number of previous results to consider
  547. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  548. w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
  549. if parent == 'single' or len(x) == 1:
  550. # x = x[random.randint(0, n - 1)] # random selection
  551. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  552. elif parent == 'weighted':
  553. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  554. # Mutate
  555. mp, s = 0.8, 0.2 # mutation probability, sigma
  556. npr = np.random
  557. npr.seed(int(time.time()))
  558. g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1
  559. ng = len(meta)
  560. v = np.ones(ng)
  561. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  562. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  563. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  564. hyp[k] = float(x[i + 7] * v[i]) # mutate
  565. # Constrain to limits
  566. for k, v in meta.items():
  567. hyp[k] = max(hyp[k], v[1]) # lower limit
  568. hyp[k] = min(hyp[k], v[2]) # upper limit
  569. hyp[k] = round(hyp[k], 5) # significant digits
  570. # Train mutation
  571. results = train(hyp.copy(), opt, device, callbacks)
  572. callbacks = Callbacks()
  573. # Write mutation results
  574. print_mutation(results, hyp.copy(), save_dir, opt.bucket)
  575. # Plot results
  576. plot_evolve(evolve_csv)
  577. LOGGER.info(f'Hyperparameter evolution finished {opt.evolve} generations\n'
  578. f"Results saved to {colorstr('bold', save_dir)}\n"
  579. f'Usage example: $ python train.py --hyp {evolve_yaml}')
  580. def run(**kwargs):
  581. # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
  582. opt = parse_opt(True)
  583. for k, v in kwargs.items():
  584. setattr(opt, k, v)
  585. main(opt)
  586. return opt
  587. if __name__ == "__main__":
  588. opt = parse_opt()
  589. main(opt)