无人机视角的行人小目标检测
Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

631 lines
32KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Train a YOLOv5 model on a custom dataset
  4. Usage:
  5. $ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640
  6. """
  7. import argparse
  8. import logging
  9. import math
  10. import os
  11. import random
  12. import sys
  13. import time
  14. from copy import deepcopy
  15. from datetime import datetime
  16. from pathlib import Path
  17. import numpy as np
  18. import torch
  19. import torch.distributed as dist
  20. import torch.nn as nn
  21. import yaml
  22. from torch.cuda import amp
  23. from torch.nn.parallel import DistributedDataParallel as DDP
  24. from torch.optim import SGD, Adam, lr_scheduler
  25. from tqdm import tqdm
  26. FILE = Path(__file__).resolve()
  27. ROOT = FILE.parents[0] # YOLOv5 root directory
  28. if str(ROOT) not in sys.path:
  29. sys.path.append(str(ROOT)) # add ROOT to PATH
  30. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  31. import val # for end-of-epoch mAP
  32. from models.experimental import attempt_load
  33. from models.yolo import Model
  34. from utils.autoanchor import check_anchors
  35. from utils.autobatch import check_train_batch_size
  36. from utils.callbacks import Callbacks
  37. from utils.datasets import create_dataloader
  38. from utils.downloads import attempt_download
  39. from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
  40. check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
  41. labels_to_class_weights, labels_to_image_weights, methods, one_cycle, print_args,
  42. print_mutation, strip_optimizer)
  43. from utils.loggers import Loggers
  44. from utils.loggers.wandb.wandb_utils import check_wandb_resume
  45. from utils.loss import ComputeLoss
  46. from utils.metrics import fitness
  47. from utils.plots import plot_evolve, plot_labels
  48. from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, intersect_dicts, select_device,
  49. torch_distributed_zero_first)
  50. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  51. RANK = int(os.getenv('RANK', -1))
  52. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
  53. def train(hyp, # path/to/hyp.yaml or hyp dictionary
  54. opt,
  55. device,
  56. callbacks
  57. ):
  58. save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
  59. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
  60. opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
  61. # Directories
  62. w = save_dir / 'weights' # weights dir
  63. (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
  64. last, best = w / 'last.pt', w / 'best.pt'
  65. # Hyperparameters
  66. if isinstance(hyp, str):
  67. with open(hyp, errors='ignore') as f:
  68. hyp = yaml.safe_load(f) # load hyps dict
  69. LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  70. # Save run settings
  71. with open(save_dir / 'hyp.yaml', 'w') as f:
  72. yaml.safe_dump(hyp, f, sort_keys=False)
  73. with open(save_dir / 'opt.yaml', 'w') as f:
  74. yaml.safe_dump(vars(opt), f, sort_keys=False)
  75. data_dict = None
  76. # Loggers
  77. if RANK in [-1, 0]:
  78. loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
  79. if loggers.wandb:
  80. data_dict = loggers.wandb.data_dict
  81. if resume:
  82. weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp
  83. # Register actions
  84. for k in methods(loggers):
  85. callbacks.register_action(k, callback=getattr(loggers, k))
  86. # Config
  87. plots = not evolve # create plots
  88. cuda = device.type != 'cpu'
  89. init_seeds(1 + RANK)
  90. with torch_distributed_zero_first(LOCAL_RANK):
  91. data_dict = data_dict or check_dataset(data) # check if None
  92. train_path, val_path = data_dict['train'], data_dict['val']
  93. nc = 1 if single_cls else int(data_dict['nc']) # number of classes
  94. names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  95. assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
  96. is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
  97. # Model
  98. check_suffix(weights, '.pt') # check weights
  99. pretrained = weights.endswith('.pt')
  100. if pretrained:
  101. with torch_distributed_zero_first(LOCAL_RANK):
  102. weights = attempt_download(weights) # download if not found locally
  103. ckpt = torch.load(weights, map_location=device) # load checkpoint
  104. model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  105. exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
  106. csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
  107. csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
  108. model.load_state_dict(csd, strict=False) # load
  109. LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
  110. else:
  111. model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  112. # Freeze
  113. freeze = [f'model.{x}.' for x in range(freeze)] # layers to freeze
  114. for k, v in model.named_parameters():
  115. v.requires_grad = True # train all layers
  116. if any(x in k for x in freeze):
  117. LOGGER.info(f'freezing {k}')
  118. v.requires_grad = False
  119. # Image size
  120. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  121. imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
  122. # Batch size
  123. if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
  124. batch_size = check_train_batch_size(model, imgsz)
  125. # Optimizer
  126. nbs = 64 # nominal batch size
  127. accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
  128. hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
  129. LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  130. g0, g1, g2 = [], [], [] # optimizer parameter groups
  131. for v in model.modules():
  132. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias
  133. g2.append(v.bias)
  134. if isinstance(v, nn.BatchNorm2d): # weight (no decay)
  135. g0.append(v.weight)
  136. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
  137. g1.append(v.weight)
  138. if opt.adam:
  139. optimizer = Adam(g0, lr=3e-4, betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  140. else:
  141. optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  142. optimizer.add_param_group({'params': g1, 'weight_decay': hyp['weight_decay']}) # add g1 with weight_decay
  143. optimizer.add_param_group({'params': g2}) # add g2 (biases)
  144. LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
  145. f"{len(g0)} weight, {len(g1)} weight (no decay), {len(g2)} bias")
  146. del g0, g1, g2
  147. # Scheduler
  148. if opt.linear_lr:
  149. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  150. else:
  151. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  152. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
  153. # EMA
  154. ema = ModelEMA(model) if RANK in [-1, 0] else None
  155. # Resume
  156. start_epoch, best_fitness = 0, 0.0
  157. if pretrained:
  158. # Optimizer
  159. if ckpt['optimizer'] is not None:
  160. optimizer.load_state_dict(ckpt['optimizer'])
  161. best_fitness = ckpt['best_fitness']
  162. # EMA
  163. if ema and ckpt.get('ema'):
  164. ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
  165. ema.updates = ckpt['updates']
  166. # Epochs
  167. start_epoch = ckpt['epoch'] + 1
  168. if resume:
  169. assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
  170. if epochs < start_epoch:
  171. LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
  172. epochs += ckpt['epoch'] # finetune additional epochs
  173. del ckpt, csd
  174. # DP mode
  175. if cuda and RANK == -1 and torch.cuda.device_count() > 1:
  176. logging.warning('DP not recommended, instead use torch.distributed.run for best DDP Multi-GPU results.\n'
  177. 'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.')
  178. model = torch.nn.DataParallel(model)
  179. # SyncBatchNorm
  180. if opt.sync_bn and cuda and RANK != -1:
  181. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  182. LOGGER.info('Using SyncBatchNorm()')
  183. # Trainloader
  184. train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
  185. hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK,
  186. workers=workers, image_weights=opt.image_weights, quad=opt.quad,
  187. prefix=colorstr('train: '))
  188. mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class
  189. nb = len(train_loader) # number of batches
  190. assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
  191. # Process 0
  192. if RANK in [-1, 0]:
  193. val_loader = create_dataloader(val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls,
  194. hyp=hyp, cache=None if noval else opt.cache, rect=True, rank=-1,
  195. workers=workers, pad=0.5,
  196. prefix=colorstr('val: '))[0]
  197. if not resume:
  198. labels = np.concatenate(dataset.labels, 0)
  199. # c = torch.tensor(labels[:, 0]) # classes
  200. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  201. # model._initialize_biases(cf.to(device))
  202. if plots:
  203. plot_labels(labels, names, save_dir)
  204. # Anchors
  205. if not opt.noautoanchor:
  206. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  207. model.half().float() # pre-reduce anchor precision
  208. callbacks.run('on_pretrain_routine_end')
  209. # DDP mode
  210. if cuda and RANK != -1:
  211. # model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
  212. model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK,find_unused_parameters=any(isinstance(layer, nn.MultiheadAttention) for layer in model.modules()))
  213. # Model parameters
  214. nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
  215. hyp['box'] *= 3 / nl # scale to layers
  216. hyp['cls'] *= nc / 80 * 3 / nl # scale to classes and layers
  217. hyp['obj'] *= (imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
  218. hyp['label_smoothing'] = opt.label_smoothing
  219. model.nc = nc # attach number of classes to model
  220. model.hyp = hyp # attach hyperparameters to model
  221. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  222. model.names = names
  223. # Start training
  224. t0 = time.time()
  225. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  226. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  227. last_opt_step = -1
  228. maps = np.zeros(nc) # mAP per class
  229. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  230. scheduler.last_epoch = start_epoch - 1 # do not move
  231. scaler = amp.GradScaler(enabled=cuda)
  232. stopper = EarlyStopping(patience=opt.patience)
  233. compute_loss = ComputeLoss(model) # init loss class
  234. LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
  235. f'Using {train_loader.num_workers} dataloader workers\n'
  236. f"Logging results to {colorstr('bold', save_dir)}\n"
  237. f'Starting training for {epochs} epochs...')
  238. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  239. model.train()
  240. # Update image weights (optional, single-GPU only)
  241. if opt.image_weights:
  242. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  243. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  244. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  245. # Update mosaic border (optional)
  246. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  247. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  248. mloss = torch.zeros(3, device=device) # mean losses
  249. if RANK != -1:
  250. train_loader.sampler.set_epoch(epoch)
  251. pbar = enumerate(train_loader)
  252. LOGGER.info(('\n' + '%10s' * 7) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'labels', 'img_size'))
  253. if RANK in [-1, 0]:
  254. pbar = tqdm(pbar, total=nb) # progress bar
  255. optimizer.zero_grad()
  256. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  257. ni = i + nb * epoch # number integrated batches (since train start)
  258. imgs = imgs.to(device, non_blocking=True).float() / 255 # uint8 to float32, 0-255 to 0.0-1.0
  259. # Warmup
  260. if ni <= nw:
  261. xi = [0, nw] # x interp
  262. # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  263. accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
  264. for j, x in enumerate(optimizer.param_groups):
  265. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  266. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  267. if 'momentum' in x:
  268. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  269. # Multi-scale
  270. if opt.multi_scale:
  271. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  272. sf = sz / max(imgs.shape[2:]) # scale factor
  273. if sf != 1:
  274. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  275. imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  276. # Forward
  277. with amp.autocast(enabled=cuda):
  278. pred = model(imgs) # forward
  279. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  280. if RANK != -1:
  281. loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
  282. if opt.quad:
  283. loss *= 4.
  284. # Backward
  285. scaler.scale(loss).backward()
  286. # Optimize
  287. if ni - last_opt_step >= accumulate:
  288. scaler.step(optimizer) # optimizer.step
  289. scaler.update()
  290. optimizer.zero_grad()
  291. if ema:
  292. ema.update(model)
  293. last_opt_step = ni
  294. # Log
  295. if RANK in [-1, 0]:
  296. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  297. mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
  298. pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
  299. f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
  300. callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots, opt.sync_bn)
  301. # end batch ------------------------------------------------------------------------------------------------
  302. # Scheduler
  303. lr = [x['lr'] for x in optimizer.param_groups] # for loggers
  304. scheduler.step()
  305. if RANK in [-1, 0]:
  306. # mAP
  307. callbacks.run('on_train_epoch_end', epoch=epoch)
  308. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
  309. final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
  310. if not noval or final_epoch: # Calculate mAP
  311. results, maps, _ = val.run(data_dict,
  312. batch_size=batch_size // WORLD_SIZE * 2,
  313. imgsz=imgsz,
  314. model=ema.ema,
  315. single_cls=single_cls,
  316. dataloader=val_loader,
  317. save_dir=save_dir,
  318. plots=False,
  319. callbacks=callbacks,
  320. compute_loss=compute_loss)
  321. # Update best mAP
  322. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  323. if fi > best_fitness:
  324. best_fitness = fi
  325. log_vals = list(mloss) + list(results) + lr
  326. callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
  327. # Save model
  328. if (not nosave) or (final_epoch and not evolve): # if save
  329. ckpt = {'epoch': epoch,
  330. 'best_fitness': best_fitness,
  331. 'model': deepcopy(de_parallel(model)).half(),
  332. 'ema': deepcopy(ema.ema).half(),
  333. 'updates': ema.updates,
  334. 'optimizer': optimizer.state_dict(),
  335. 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None,
  336. 'date': datetime.now().isoformat()}
  337. # Save last, best and delete
  338. torch.save(ckpt, last)
  339. if best_fitness == fi:
  340. torch.save(ckpt, best)
  341. if (epoch > 0) and (opt.save_period > 0) and (epoch % opt.save_period == 0):
  342. torch.save(ckpt, w / f'epoch{epoch}.pt')
  343. del ckpt
  344. callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
  345. # Stop Single-GPU
  346. if RANK == -1 and stopper(epoch=epoch, fitness=fi):
  347. break
  348. # Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576
  349. # stop = stopper(epoch=epoch, fitness=fi)
  350. # if RANK == 0:
  351. # dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks
  352. # Stop DPP
  353. # with torch_distributed_zero_first(RANK):
  354. # if stop:
  355. # break # must break all DDP ranks
  356. # end epoch ----------------------------------------------------------------------------------------------------
  357. # end training -----------------------------------------------------------------------------------------------------
  358. if RANK in [-1, 0]:
  359. LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
  360. for f in last, best:
  361. if f.exists():
  362. strip_optimizer(f) # strip optimizers
  363. if f is best:
  364. LOGGER.info(f'\nValidating {f}...')
  365. results, _, _ = val.run(data_dict,
  366. batch_size=batch_size // WORLD_SIZE * 2,
  367. imgsz=imgsz,
  368. model=attempt_load(f, device).half(),
  369. iou_thres=0.65 if is_coco else 0.60, # best pycocotools results at 0.65
  370. single_cls=single_cls,
  371. dataloader=val_loader,
  372. save_dir=save_dir,
  373. save_json=is_coco,
  374. verbose=True,
  375. plots=True,
  376. callbacks=callbacks,
  377. compute_loss=compute_loss) # val best model with plots
  378. if is_coco:
  379. callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
  380. callbacks.run('on_train_end', last, best, plots, epoch, results)
  381. LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
  382. torch.cuda.empty_cache()
  383. return results
  384. def parse_opt(known=False):
  385. parser = argparse.ArgumentParser()
  386. parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='initial weights path')
  387. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  388. parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
  389. parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch.yaml', help='hyperparameters path')
  390. parser.add_argument('--epochs', type=int, default=300)
  391. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
  392. parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
  393. parser.add_argument('--rect', action='store_true', help='rectangular training')
  394. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  395. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  396. parser.add_argument('--noval', action='store_true', help='only validate final epoch')
  397. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  398. parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
  399. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  400. parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
  401. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  402. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  403. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  404. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  405. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  406. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  407. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  408. parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
  409. parser.add_argument('--name', default='exp', help='save to project/name')
  410. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  411. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  412. parser.add_argument('--linear-lr', action='store_true', help='linear LR')
  413. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
  414. parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
  415. parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
  416. parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
  417. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  418. # Weights & Biases arguments
  419. parser.add_argument('--entity', default=None, help='W&B: Entity')
  420. parser.add_argument('--upload_dataset', action='store_true', help='W&B: Upload dataset as artifact table')
  421. parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval')
  422. parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')
  423. opt = parser.parse_known_args()[0] if known else parser.parse_args()
  424. return opt
  425. def main(opt, callbacks=Callbacks()):
  426. # Checks
  427. if RANK in [-1, 0]:
  428. print_args(FILE.stem, opt)
  429. check_git_status()
  430. check_requirements(exclude=['thop'])
  431. # Resume
  432. if opt.resume and not check_wandb_resume(opt) and not opt.evolve: # resume an interrupted run
  433. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  434. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  435. with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f:
  436. opt = argparse.Namespace(**yaml.safe_load(f)) # replace
  437. opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
  438. LOGGER.info(f'Resuming training from {ckpt}')
  439. else:
  440. opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
  441. check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
  442. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  443. if opt.evolve:
  444. opt.project = str(ROOT / 'runs/evolve')
  445. opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
  446. opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
  447. # DDP mode
  448. device = select_device(opt.device, batch_size=opt.batch_size)
  449. if LOCAL_RANK != -1:
  450. assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
  451. assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
  452. assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
  453. assert not opt.evolve, '--evolve argument is not compatible with DDP training'
  454. torch.cuda.set_device(LOCAL_RANK)
  455. device = torch.device('cuda', LOCAL_RANK)
  456. dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
  457. # Train
  458. if not opt.evolve:
  459. train(opt.hyp, opt, device, callbacks)
  460. if WORLD_SIZE > 1 and RANK == 0:
  461. LOGGER.info('Destroying process group... ')
  462. dist.destroy_process_group()
  463. # Evolve hyperparameters (optional)
  464. else:
  465. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  466. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  467. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  468. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  469. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  470. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  471. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  472. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  473. 'box': (1, 0.02, 0.2), # box loss gain
  474. 'cls': (1, 0.2, 4.0), # cls loss gain
  475. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  476. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  477. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  478. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  479. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  480. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  481. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  482. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  483. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  484. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  485. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  486. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  487. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  488. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  489. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  490. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  491. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  492. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  493. 'mixup': (1, 0.0, 1.0), # image mixup (probability)
  494. 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
  495. with open(opt.hyp, errors='ignore') as f:
  496. hyp = yaml.safe_load(f) # load hyps dict
  497. if 'anchors' not in hyp: # anchors commented in hyp.yaml
  498. hyp['anchors'] = 3
  499. opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
  500. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  501. evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
  502. if opt.bucket:
  503. os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {save_dir}') # download evolve.csv if exists
  504. for _ in range(opt.evolve): # generations to evolve
  505. if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
  506. # Select parent(s)
  507. parent = 'single' # parent selection method: 'single' or 'weighted'
  508. x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
  509. n = min(5, len(x)) # number of previous results to consider
  510. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  511. w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
  512. if parent == 'single' or len(x) == 1:
  513. # x = x[random.randint(0, n - 1)] # random selection
  514. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  515. elif parent == 'weighted':
  516. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  517. # Mutate
  518. mp, s = 0.8, 0.2 # mutation probability, sigma
  519. npr = np.random
  520. npr.seed(int(time.time()))
  521. g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1
  522. ng = len(meta)
  523. v = np.ones(ng)
  524. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  525. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  526. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  527. hyp[k] = float(x[i + 7] * v[i]) # mutate
  528. # Constrain to limits
  529. for k, v in meta.items():
  530. hyp[k] = max(hyp[k], v[1]) # lower limit
  531. hyp[k] = min(hyp[k], v[2]) # upper limit
  532. hyp[k] = round(hyp[k], 5) # significant digits
  533. # Train mutation
  534. results = train(hyp.copy(), opt, device, callbacks)
  535. # Write mutation results
  536. print_mutation(results, hyp.copy(), save_dir, opt.bucket)
  537. # Plot results
  538. plot_evolve(evolve_csv)
  539. LOGGER.info(f'Hyperparameter evolution finished\n'
  540. f"Results saved to {colorstr('bold', save_dir)}\n"
  541. f'Use best hyperparameters example: $ python train.py --hyp {evolve_yaml}')
  542. def run(**kwargs):
  543. # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
  544. opt = parse_opt(True)
  545. for k, v in kwargs.items():
  546. setattr(opt, k, v)
  547. main(opt)
  548. if __name__ == "__main__":
  549. opt = parse_opt()
  550. main(opt)