Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

604 lines
30KB

  1. """Train a YOLOv5 model on a custom dataset
  2. Usage:
  3. $ python path/to/train.py --data coco128.yaml --weights yolov5s.pt --img 640
  4. """
  5. import argparse
  6. import logging
  7. import os
  8. import random
  9. import sys
  10. import time
  11. from copy import deepcopy
  12. from pathlib import Path
  13. from threading import Thread
  14. import math
  15. import numpy as np
  16. import torch
  17. import torch.distributed as dist
  18. import torch.nn as nn
  19. import yaml
  20. from torch.cuda import amp
  21. from torch.nn.parallel import DistributedDataParallel as DDP
  22. from torch.optim import Adam, SGD, lr_scheduler
  23. from tqdm import tqdm
  24. FILE = Path(__file__).absolute()
  25. sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path
  26. import val # for end-of-epoch mAP
  27. from models.experimental import attempt_load
  28. from models.yolo import Model
  29. from utils.autoanchor import check_anchors
  30. from utils.datasets import create_dataloader
  31. from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
  32. strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
  33. check_requirements, print_mutation, set_logging, one_cycle, colorstr
  34. from utils.google_utils import attempt_download
  35. from utils.loss import ComputeLoss
  36. from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
  37. from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
  38. from utils.loggers.wandb.wandb_utils import check_wandb_resume
  39. from utils.metrics import fitness
  40. from utils.loggers import Loggers
  41. LOGGER = logging.getLogger(__name__)
  42. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  43. RANK = int(os.getenv('RANK', -1))
  44. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
  45. def train(hyp, # path/to/hyp.yaml or hyp dictionary
  46. opt,
  47. device,
  48. ):
  49. save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, = \
  50. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
  51. opt.resume, opt.noval, opt.nosave, opt.workers
  52. # Directories
  53. w = save_dir / 'weights' # weights dir
  54. w.mkdir(parents=True, exist_ok=True) # make dir
  55. last, best, results_file = w / 'last.pt', w / 'best.pt', save_dir / 'results.txt'
  56. # Hyperparameters
  57. if isinstance(hyp, str):
  58. with open(hyp) as f:
  59. hyp = yaml.safe_load(f) # load hyps dict
  60. LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  61. # Save run settings
  62. with open(save_dir / 'hyp.yaml', 'w') as f:
  63. yaml.safe_dump(hyp, f, sort_keys=False)
  64. with open(save_dir / 'opt.yaml', 'w') as f:
  65. yaml.safe_dump(vars(opt), f, sort_keys=False)
  66. # Config
  67. plots = not evolve # create plots
  68. cuda = device.type != 'cpu'
  69. init_seeds(1 + RANK)
  70. with open(data) as f:
  71. data_dict = yaml.safe_load(f) # data dict
  72. nc = 1 if single_cls else int(data_dict['nc']) # number of classes
  73. names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  74. assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
  75. is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset
  76. # Loggers
  77. if RANK in [-1, 0]:
  78. loggers = Loggers(save_dir, results_file, weights, opt, hyp, data_dict, LOGGER).start() # loggers dict
  79. if loggers.wandb and resume:
  80. weights, epochs, hyp, data_dict = opt.weights, opt.epochs, opt.hyp, loggers.wandb.data_dict
  81. # Model
  82. pretrained = weights.endswith('.pt')
  83. if pretrained:
  84. with torch_distributed_zero_first(RANK):
  85. weights = attempt_download(weights) # download if not found locally
  86. ckpt = torch.load(weights, map_location=device) # load checkpoint
  87. model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  88. exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
  89. csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
  90. csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
  91. model.load_state_dict(csd, strict=False) # load
  92. LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
  93. else:
  94. model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  95. with torch_distributed_zero_first(RANK):
  96. check_dataset(data_dict) # check
  97. train_path, val_path = data_dict['train'], data_dict['val']
  98. # Freeze
  99. freeze = [] # parameter names to freeze (full or partial)
  100. for k, v in model.named_parameters():
  101. v.requires_grad = True # train all layers
  102. if any(x in k for x in freeze):
  103. print(f'freezing {k}')
  104. v.requires_grad = False
  105. # Optimizer
  106. nbs = 64 # nominal batch size
  107. accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
  108. hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
  109. LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  110. g0, g1, g2 = [], [], [] # optimizer parameter groups
  111. for v in model.modules():
  112. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias
  113. g2.append(v.bias)
  114. if isinstance(v, nn.BatchNorm2d): # weight with decay
  115. g0.append(v.weight)
  116. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight without decay
  117. g1.append(v.weight)
  118. if opt.adam:
  119. optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  120. else:
  121. optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  122. optimizer.add_param_group({'params': g1, 'weight_decay': hyp['weight_decay']}) # add g1 with weight_decay
  123. optimizer.add_param_group({'params': g2}) # add g2 (biases)
  124. LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
  125. f"{len(g0)} weight, {len(g1)} weight (no decay), {len(g2)} bias")
  126. del g0, g1, g2
  127. # Scheduler
  128. if opt.linear_lr:
  129. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  130. else:
  131. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  132. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
  133. # EMA
  134. ema = ModelEMA(model) if RANK in [-1, 0] else None
  135. # Resume
  136. start_epoch, best_fitness = 0, 0.0
  137. if pretrained:
  138. # Optimizer
  139. if ckpt['optimizer'] is not None:
  140. optimizer.load_state_dict(ckpt['optimizer'])
  141. best_fitness = ckpt['best_fitness']
  142. # EMA
  143. if ema and ckpt.get('ema'):
  144. ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
  145. ema.updates = ckpt['updates']
  146. # Results
  147. if ckpt.get('training_results') is not None:
  148. results_file.write_text(ckpt['training_results']) # write results.txt
  149. # Epochs
  150. start_epoch = ckpt['epoch'] + 1
  151. if resume:
  152. assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
  153. if epochs < start_epoch:
  154. LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
  155. epochs += ckpt['epoch'] # finetune additional epochs
  156. del ckpt, csd
  157. # Image sizes
  158. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  159. nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj'])
  160. imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
  161. # DP mode
  162. if cuda and RANK == -1 and torch.cuda.device_count() > 1:
  163. logging.warning('DP not recommended, instead use torch.distributed.run for best DDP Multi-GPU results.\n'
  164. 'See Multi-GPU Tutorial at https://github.com/ultralytics/yolov5/issues/475 to get started.')
  165. model = torch.nn.DataParallel(model)
  166. # SyncBatchNorm
  167. if opt.sync_bn and cuda and RANK != -1:
  168. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  169. LOGGER.info('Using SyncBatchNorm()')
  170. # Trainloader
  171. train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
  172. hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
  173. workers=workers, image_weights=opt.image_weights, quad=opt.quad,
  174. prefix=colorstr('train: '))
  175. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  176. nb = len(train_loader) # number of batches
  177. assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
  178. # Process 0
  179. if RANK in [-1, 0]:
  180. val_loader = create_dataloader(val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls,
  181. hyp=hyp, cache=opt.cache_images and not noval, rect=True, rank=-1,
  182. workers=workers, pad=0.5,
  183. prefix=colorstr('val: '))[0]
  184. if not resume:
  185. labels = np.concatenate(dataset.labels, 0)
  186. # c = torch.tensor(labels[:, 0]) # classes
  187. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  188. # model._initialize_biases(cf.to(device))
  189. if plots:
  190. plot_labels(labels, names, save_dir, loggers)
  191. # Anchors
  192. if not opt.noautoanchor:
  193. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  194. model.half().float() # pre-reduce anchor precision
  195. # DDP mode
  196. if cuda and RANK != -1:
  197. model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
  198. # Model parameters
  199. hyp['box'] *= 3. / nl # scale to layers
  200. hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
  201. hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
  202. hyp['label_smoothing'] = opt.label_smoothing
  203. model.nc = nc # attach number of classes to model
  204. model.hyp = hyp # attach hyperparameters to model
  205. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  206. model.names = names
  207. # Start training
  208. t0 = time.time()
  209. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  210. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  211. last_opt_step = -1
  212. maps = np.zeros(nc) # mAP per class
  213. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  214. scheduler.last_epoch = start_epoch - 1 # do not move
  215. scaler = amp.GradScaler(enabled=cuda)
  216. compute_loss = ComputeLoss(model) # init loss class
  217. LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
  218. f'Using {train_loader.num_workers} dataloader workers\n'
  219. f'Logging results to {save_dir}\n'
  220. f'Starting training for {epochs} epochs...')
  221. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  222. model.train()
  223. # Update image weights (optional)
  224. if opt.image_weights:
  225. # Generate indices
  226. if RANK in [-1, 0]:
  227. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  228. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  229. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  230. # Broadcast if DDP
  231. if RANK != -1:
  232. indices = (torch.tensor(dataset.indices) if RANK == 0 else torch.zeros(dataset.n)).int()
  233. dist.broadcast(indices, 0)
  234. if RANK != 0:
  235. dataset.indices = indices.cpu().numpy()
  236. # Update mosaic border
  237. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  238. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  239. mloss = torch.zeros(4, device=device) # mean losses
  240. if RANK != -1:
  241. train_loader.sampler.set_epoch(epoch)
  242. pbar = enumerate(train_loader)
  243. LOGGER.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size'))
  244. if RANK in [-1, 0]:
  245. pbar = tqdm(pbar, total=nb) # progress bar
  246. optimizer.zero_grad()
  247. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  248. ni = i + nb * epoch # number integrated batches (since train start)
  249. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  250. # Warmup
  251. if ni <= nw:
  252. xi = [0, nw] # x interp
  253. # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  254. accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
  255. for j, x in enumerate(optimizer.param_groups):
  256. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  257. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  258. if 'momentum' in x:
  259. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  260. # Multi-scale
  261. if opt.multi_scale:
  262. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  263. sf = sz / max(imgs.shape[2:]) # scale factor
  264. if sf != 1:
  265. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  266. imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  267. # Forward
  268. with amp.autocast(enabled=cuda):
  269. pred = model(imgs) # forward
  270. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  271. if RANK != -1:
  272. loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
  273. if opt.quad:
  274. loss *= 4.
  275. # Backward
  276. scaler.scale(loss).backward()
  277. # Optimize
  278. if ni - last_opt_step >= accumulate:
  279. scaler.step(optimizer) # optimizer.step
  280. scaler.update()
  281. optimizer.zero_grad()
  282. if ema:
  283. ema.update(model)
  284. last_opt_step = ni
  285. # Print
  286. if RANK in [-1, 0]:
  287. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  288. mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
  289. s = ('%10s' * 2 + '%10.4g' * 6) % (
  290. f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])
  291. pbar.set_description(s)
  292. # Plot
  293. if plots:
  294. if ni < 3:
  295. f = save_dir / f'train_batch{ni}.jpg' # filename
  296. Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
  297. loggers.on_train_batch_end(ni, model, imgs)
  298. # end batch ------------------------------------------------------------------------------------------------
  299. # Scheduler
  300. lr = [x['lr'] for x in optimizer.param_groups] # for loggers
  301. scheduler.step()
  302. if RANK in [-1, 0]:
  303. # mAP
  304. loggers.on_train_epoch_end(epoch)
  305. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
  306. final_epoch = epoch + 1 == epochs
  307. if not noval or final_epoch: # Calculate mAP
  308. results, maps, _ = val.run(data_dict,
  309. batch_size=batch_size // WORLD_SIZE * 2,
  310. imgsz=imgsz,
  311. model=ema.ema,
  312. single_cls=single_cls,
  313. dataloader=val_loader,
  314. save_dir=save_dir,
  315. save_json=is_coco and final_epoch,
  316. verbose=nc < 50 and final_epoch,
  317. plots=plots and final_epoch,
  318. loggers=loggers,
  319. compute_loss=compute_loss)
  320. # Update best mAP
  321. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  322. if fi > best_fitness:
  323. best_fitness = fi
  324. loggers.on_train_val_end(mloss, results, lr, epoch, s, best_fitness, fi)
  325. # Save model
  326. if (not nosave) or (final_epoch and not evolve): # if save
  327. ckpt = {'epoch': epoch,
  328. 'best_fitness': best_fitness,
  329. 'training_results': results_file.read_text(),
  330. 'model': deepcopy(de_parallel(model)).half(),
  331. 'ema': deepcopy(ema.ema).half(),
  332. 'updates': ema.updates,
  333. 'optimizer': optimizer.state_dict(),
  334. 'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None}
  335. # Save last, best and delete
  336. torch.save(ckpt, last)
  337. if best_fitness == fi:
  338. torch.save(ckpt, best)
  339. del ckpt
  340. loggers.on_model_save(last, epoch, final_epoch, best_fitness, fi)
  341. # end epoch ----------------------------------------------------------------------------------------------------
  342. # end training -----------------------------------------------------------------------------------------------------
  343. if RANK in [-1, 0]:
  344. LOGGER.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
  345. if plots:
  346. plot_results(save_dir=save_dir) # save as results.png
  347. if not evolve:
  348. if is_coco: # COCO dataset
  349. for m in [last, best] if best.exists() else [last]: # speed, mAP tests
  350. results, _, _ = val.run(data_dict,
  351. batch_size=batch_size // WORLD_SIZE * 2,
  352. imgsz=imgsz,
  353. model=attempt_load(m, device).half(),
  354. iou_thres=0.7, # NMS IoU threshold for best pycocotools results
  355. single_cls=single_cls,
  356. dataloader=val_loader,
  357. save_dir=save_dir,
  358. save_json=True,
  359. plots=False)
  360. # Strip optimizers
  361. for f in last, best:
  362. if f.exists():
  363. strip_optimizer(f) # strip optimizers
  364. loggers.on_train_end(last, best)
  365. torch.cuda.empty_cache()
  366. return results
  367. def parse_opt(known=False):
  368. parser = argparse.ArgumentParser()
  369. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  370. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  371. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='dataset.yaml path')
  372. parser.add_argument('--hyp', type=str, default='data/hyps/hyp.scratch.yaml', help='hyperparameters path')
  373. parser.add_argument('--epochs', type=int, default=300)
  374. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  375. parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
  376. parser.add_argument('--rect', action='store_true', help='rectangular training')
  377. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  378. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  379. parser.add_argument('--noval', action='store_true', help='only validate final epoch')
  380. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  381. parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
  382. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  383. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  384. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  385. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  386. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  387. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  388. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  389. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  390. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  391. parser.add_argument('--project', default='runs/train', help='save to project/name')
  392. parser.add_argument('--entity', default=None, help='W&B entity')
  393. parser.add_argument('--name', default='exp', help='save to project/name')
  394. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  395. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  396. parser.add_argument('--linear-lr', action='store_true', help='linear LR')
  397. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
  398. parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
  399. parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
  400. parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
  401. parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
  402. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  403. opt = parser.parse_known_args()[0] if known else parser.parse_args()
  404. return opt
  405. def main(opt):
  406. set_logging(RANK)
  407. if RANK in [-1, 0]:
  408. print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
  409. check_git_status()
  410. check_requirements(exclude=['thop'])
  411. # Resume
  412. if opt.resume and not check_wandb_resume(opt): # resume an interrupted run
  413. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  414. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  415. with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
  416. opt = argparse.Namespace(**yaml.safe_load(f)) # replace
  417. opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
  418. LOGGER.info(f'Resuming training from {ckpt}')
  419. else:
  420. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  421. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  422. opt.name = 'evolve' if opt.evolve else opt.name
  423. opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok or opt.evolve))
  424. # DDP mode
  425. device = select_device(opt.device, batch_size=opt.batch_size)
  426. if LOCAL_RANK != -1:
  427. from datetime import timedelta
  428. assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
  429. assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
  430. assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
  431. assert not opt.evolve, '--evolve argument is not compatible with DDP training'
  432. assert not opt.sync_bn, '--sync-bn known training issue, see https://github.com/ultralytics/yolov5/issues/3998'
  433. torch.cuda.set_device(LOCAL_RANK)
  434. device = torch.device('cuda', LOCAL_RANK)
  435. dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo", timeout=timedelta(seconds=60))
  436. # Train
  437. if not opt.evolve:
  438. train(opt.hyp, opt, device)
  439. if WORLD_SIZE > 1 and RANK == 0:
  440. _ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')]
  441. # Evolve hyperparameters (optional)
  442. else:
  443. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  444. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  445. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  446. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  447. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  448. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  449. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  450. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  451. 'box': (1, 0.02, 0.2), # box loss gain
  452. 'cls': (1, 0.2, 4.0), # cls loss gain
  453. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  454. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  455. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  456. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  457. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  458. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  459. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  460. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  461. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  462. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  463. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  464. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  465. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  466. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  467. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  468. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  469. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  470. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  471. 'mixup': (1, 0.0, 1.0), # image mixup (probability)
  472. 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
  473. with open(opt.hyp) as f:
  474. hyp = yaml.safe_load(f) # load hyps dict
  475. if 'anchors' not in hyp: # anchors commented in hyp.yaml
  476. hyp['anchors'] = 3
  477. opt.noval, opt.nosave = True, True # only val/save final epoch
  478. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  479. yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
  480. if opt.bucket:
  481. os.system(f'gsutil cp gs://{opt.bucket}/evolve.txt .') # download evolve.txt if exists
  482. for _ in range(opt.evolve): # generations to evolve
  483. if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate
  484. # Select parent(s)
  485. parent = 'single' # parent selection method: 'single' or 'weighted'
  486. x = np.loadtxt('evolve.txt', ndmin=2)
  487. n = min(5, len(x)) # number of previous results to consider
  488. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  489. w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
  490. if parent == 'single' or len(x) == 1:
  491. # x = x[random.randint(0, n - 1)] # random selection
  492. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  493. elif parent == 'weighted':
  494. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  495. # Mutate
  496. mp, s = 0.8, 0.2 # mutation probability, sigma
  497. npr = np.random
  498. npr.seed(int(time.time()))
  499. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  500. ng = len(meta)
  501. v = np.ones(ng)
  502. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  503. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  504. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  505. hyp[k] = float(x[i + 7] * v[i]) # mutate
  506. # Constrain to limits
  507. for k, v in meta.items():
  508. hyp[k] = max(hyp[k], v[1]) # lower limit
  509. hyp[k] = min(hyp[k], v[2]) # upper limit
  510. hyp[k] = round(hyp[k], 5) # significant digits
  511. # Train mutation
  512. results = train(hyp.copy(), opt, device)
  513. # Write mutation results
  514. print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
  515. # Plot results
  516. plot_evolution(yaml_file)
  517. print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'
  518. f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')
  519. def run(**kwargs):
  520. # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
  521. opt = parse_opt(True)
  522. for k, v in kwargs.items():
  523. setattr(opt, k, v)
  524. main(opt)
  525. if __name__ == "__main__":
  526. opt = parse_opt()
  527. main(opt)