You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

614 lines
32KB

  1. import argparse
  2. import logging
  3. import math
  4. import os
  5. import random
  6. import time
  7. from copy import deepcopy
  8. from pathlib import Path
  9. from threading import Thread
  10. import numpy as np
  11. import torch.distributed as dist
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. import torch.optim as optim
  15. import torch.optim.lr_scheduler as lr_scheduler
  16. import torch.utils.data
  17. import yaml
  18. from torch.cuda import amp
  19. from torch.nn.parallel import DistributedDataParallel as DDP
  20. from torch.utils.tensorboard import SummaryWriter
  21. from tqdm import tqdm
  22. import test # import test.py to get mAP after each epoch
  23. from models.experimental import attempt_load
  24. from models.yolo import Model
  25. from utils.autoanchor import check_anchors
  26. from utils.datasets import create_dataloader
  27. from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
  28. fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
  29. check_requirements, print_mutation, set_logging, one_cycle, colorstr
  30. from utils.google_utils import attempt_download
  31. from utils.loss import ComputeLoss
  32. from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
  33. from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
  34. logger = logging.getLogger(__name__)
  35. def train(hyp, opt, device, tb_writer=None, wandb=None):
  36. logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  37. save_dir, epochs, batch_size, total_batch_size, weights, rank = \
  38. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
  39. # Directories
  40. wdir = save_dir / 'weights'
  41. wdir.mkdir(parents=True, exist_ok=True) # make dir
  42. last = wdir / 'last.pt'
  43. best = wdir / 'best.pt'
  44. results_file = save_dir / 'results.txt'
  45. # Save run settings
  46. with open(save_dir / 'hyp.yaml', 'w') as f:
  47. yaml.dump(hyp, f, sort_keys=False)
  48. with open(save_dir / 'opt.yaml', 'w') as f:
  49. yaml.dump(vars(opt), f, sort_keys=False)
  50. # Configure
  51. plots = not opt.evolve # create plots
  52. cuda = device.type != 'cpu'
  53. init_seeds(2 + rank)
  54. with open(opt.data) as f:
  55. data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
  56. with torch_distributed_zero_first(rank):
  57. check_dataset(data_dict) # check
  58. train_path = data_dict['train']
  59. test_path = data_dict['val']
  60. nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
  61. names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  62. assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
  63. # Model
  64. pretrained = weights.endswith('.pt')
  65. if pretrained:
  66. with torch_distributed_zero_first(rank):
  67. attempt_download(weights) # download if not found locally
  68. ckpt = torch.load(weights, map_location=device) # load checkpoint
  69. model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  70. exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
  71. state_dict = ckpt['model'].float().state_dict() # to FP32
  72. state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
  73. model.load_state_dict(state_dict, strict=False) # load
  74. logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
  75. else:
  76. model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  77. # Freeze
  78. freeze = [] # parameter names to freeze (full or partial)
  79. for k, v in model.named_parameters():
  80. v.requires_grad = True # train all layers
  81. if any(x in k for x in freeze):
  82. print('freezing %s' % k)
  83. v.requires_grad = False
  84. # Optimizer
  85. nbs = 64 # nominal batch size
  86. accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
  87. hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
  88. logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  89. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  90. for k, v in model.named_modules():
  91. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
  92. pg2.append(v.bias) # biases
  93. if isinstance(v, nn.BatchNorm2d):
  94. pg0.append(v.weight) # no decay
  95. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
  96. pg1.append(v.weight) # apply decay
  97. if opt.adam:
  98. optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  99. else:
  100. optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  101. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  102. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  103. logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  104. del pg0, pg1, pg2
  105. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  106. # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  107. if opt.linear_lr:
  108. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  109. else:
  110. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  111. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  112. # plot_lr_scheduler(optimizer, scheduler, epochs)
  113. # Logging
  114. if rank in [-1, 0] and wandb and wandb.run is None:
  115. opt.hyp = hyp # add hyperparameters
  116. wandb_run = wandb.init(config=opt, resume="allow",
  117. project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
  118. name=save_dir.stem,
  119. entity=opt.entity,
  120. id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
  121. loggers = {'wandb': wandb} # loggers dict
  122. # EMA
  123. ema = ModelEMA(model) if rank in [-1, 0] else None
  124. # Resume
  125. start_epoch, best_fitness = 0, 0.0
  126. if pretrained:
  127. # Optimizer
  128. if ckpt['optimizer'] is not None:
  129. optimizer.load_state_dict(ckpt['optimizer'])
  130. best_fitness = ckpt['best_fitness']
  131. # EMA
  132. if ema and ckpt.get('ema'):
  133. ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
  134. ema.updates = ckpt['updates']
  135. # Results
  136. if ckpt.get('training_results') is not None:
  137. results_file.write_text(ckpt['training_results']) # write results.txt
  138. # Epochs
  139. start_epoch = ckpt['epoch'] + 1
  140. if opt.resume:
  141. assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
  142. if epochs < start_epoch:
  143. logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
  144. (weights, ckpt['epoch'], epochs))
  145. epochs += ckpt['epoch'] # finetune additional epochs
  146. del ckpt, state_dict
  147. # Image sizes
  148. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  149. nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj'])
  150. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  151. # DP mode
  152. if cuda and rank == -1 and torch.cuda.device_count() > 1:
  153. model = torch.nn.DataParallel(model)
  154. # SyncBatchNorm
  155. if opt.sync_bn and cuda and rank != -1:
  156. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  157. logger.info('Using SyncBatchNorm()')
  158. # Trainloader
  159. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
  160. hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
  161. world_size=opt.world_size, workers=opt.workers,
  162. image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
  163. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  164. nb = len(dataloader) # number of batches
  165. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
  166. # Process 0
  167. if rank in [-1, 0]:
  168. testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
  169. hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
  170. world_size=opt.world_size, workers=opt.workers,
  171. pad=0.5, prefix=colorstr('val: '))[0]
  172. if not opt.resume:
  173. labels = np.concatenate(dataset.labels, 0)
  174. c = torch.tensor(labels[:, 0]) # classes
  175. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  176. # model._initialize_biases(cf.to(device))
  177. if plots:
  178. plot_labels(labels, save_dir, loggers)
  179. if tb_writer:
  180. tb_writer.add_histogram('classes', c, 0)
  181. # Anchors
  182. if not opt.noautoanchor:
  183. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  184. model.half().float() # pre-reduce anchor precision
  185. # DDP mode
  186. if cuda and rank != -1:
  187. model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank)
  188. # Model parameters
  189. hyp['box'] *= 3. / nl # scale to layers
  190. hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
  191. hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
  192. model.nc = nc # attach number of classes to model
  193. model.hyp = hyp # attach hyperparameters to model
  194. model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
  195. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  196. model.names = names
  197. # Start training
  198. t0 = time.time()
  199. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  200. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  201. maps = np.zeros(nc) # mAP per class
  202. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  203. scheduler.last_epoch = start_epoch - 1 # do not move
  204. scaler = amp.GradScaler(enabled=cuda)
  205. compute_loss = ComputeLoss(model) # init loss class
  206. logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
  207. f'Using {dataloader.num_workers} dataloader workers\n'
  208. f'Logging results to {save_dir}\n'
  209. f'Starting training for {epochs} epochs...')
  210. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  211. model.train()
  212. # Update image weights (optional)
  213. if opt.image_weights:
  214. # Generate indices
  215. if rank in [-1, 0]:
  216. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  217. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  218. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  219. # Broadcast if DDP
  220. if rank != -1:
  221. indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
  222. dist.broadcast(indices, 0)
  223. if rank != 0:
  224. dataset.indices = indices.cpu().numpy()
  225. # Update mosaic border
  226. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  227. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  228. mloss = torch.zeros(4, device=device) # mean losses
  229. if rank != -1:
  230. dataloader.sampler.set_epoch(epoch)
  231. pbar = enumerate(dataloader)
  232. logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size'))
  233. if rank in [-1, 0]:
  234. pbar = tqdm(pbar, total=nb) # progress bar
  235. optimizer.zero_grad()
  236. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  237. ni = i + nb * epoch # number integrated batches (since train start)
  238. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  239. # Warmup
  240. if ni <= nw:
  241. xi = [0, nw] # x interp
  242. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  243. accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
  244. for j, x in enumerate(optimizer.param_groups):
  245. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  246. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  247. if 'momentum' in x:
  248. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  249. # Multi-scale
  250. if opt.multi_scale:
  251. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  252. sf = sz / max(imgs.shape[2:]) # scale factor
  253. if sf != 1:
  254. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  255. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  256. # Forward
  257. with amp.autocast(enabled=cuda):
  258. pred = model(imgs) # forward
  259. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  260. if rank != -1:
  261. loss *= opt.world_size # gradient averaged between devices in DDP mode
  262. if opt.quad:
  263. loss *= 4.
  264. # Backward
  265. scaler.scale(loss).backward()
  266. # Optimize
  267. if ni % accumulate == 0:
  268. scaler.step(optimizer) # optimizer.step
  269. scaler.update()
  270. optimizer.zero_grad()
  271. if ema:
  272. ema.update(model)
  273. # Print
  274. if rank in [-1, 0]:
  275. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  276. mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  277. s = ('%10s' * 2 + '%10.4g' * 6) % (
  278. '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
  279. pbar.set_description(s)
  280. # Plot
  281. if plots and ni < 3:
  282. f = save_dir / f'train_batch{ni}.jpg' # filename
  283. Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
  284. # if tb_writer:
  285. # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
  286. # tb_writer.add_graph(model, imgs) # add model to tensorboard
  287. elif plots and ni == 10 and wandb:
  288. wandb.log({"Mosaics": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('train*.jpg')
  289. if x.exists()]}, commit=False)
  290. # end batch ------------------------------------------------------------------------------------------------
  291. # end epoch ----------------------------------------------------------------------------------------------------
  292. # Scheduler
  293. lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
  294. scheduler.step()
  295. # DDP process 0 or single-GPU
  296. if rank in [-1, 0]:
  297. # mAP
  298. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
  299. final_epoch = epoch + 1 == epochs
  300. if not opt.notest or final_epoch: # Calculate mAP
  301. results, maps, times = test.test(opt.data,
  302. batch_size=batch_size * 2,
  303. imgsz=imgsz_test,
  304. model=ema.ema,
  305. single_cls=opt.single_cls,
  306. dataloader=testloader,
  307. save_dir=save_dir,
  308. verbose=nc < 50 and final_epoch,
  309. plots=plots and final_epoch,
  310. log_imgs=opt.log_imgs if wandb else 0,
  311. compute_loss=compute_loss)
  312. # Write
  313. with open(results_file, 'a') as f:
  314. f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
  315. if len(opt.name) and opt.bucket:
  316. os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
  317. # Log
  318. tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
  319. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  320. 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
  321. 'x/lr0', 'x/lr1', 'x/lr2'] # params
  322. for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
  323. if tb_writer:
  324. tb_writer.add_scalar(tag, x, epoch) # tensorboard
  325. if wandb:
  326. wandb.log({tag: x}, step=epoch, commit=tag == tags[-1]) # W&B
  327. # Update best mAP
  328. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  329. if fi > best_fitness:
  330. best_fitness = fi
  331. # Save model
  332. if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
  333. ckpt = {'epoch': epoch,
  334. 'best_fitness': best_fitness,
  335. 'training_results': results_file.read_text(),
  336. 'model': deepcopy(model.module if is_parallel(model) else model).half(),
  337. 'ema': deepcopy(ema.ema).half(),
  338. 'updates': ema.updates,
  339. 'optimizer': optimizer.state_dict(),
  340. 'wandb_id': wandb_run.id if wandb else None}
  341. # Save last, best and delete
  342. torch.save(ckpt, last)
  343. if best_fitness == fi:
  344. torch.save(ckpt, best)
  345. del ckpt
  346. # end epoch ----------------------------------------------------------------------------------------------------
  347. # end training
  348. if rank in [-1, 0]:
  349. # Strip optimizers
  350. final = best if best.exists() else last # final model
  351. for f in last, best:
  352. if f.exists():
  353. strip_optimizer(f)
  354. if opt.bucket:
  355. os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
  356. # Plots
  357. if plots:
  358. plot_results(save_dir=save_dir) # save as results.png
  359. if wandb:
  360. files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
  361. wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
  362. if (save_dir / f).exists()]})
  363. if opt.log_artifacts:
  364. wandb.log_artifact(artifact_or_path=str(final), type='model', name=save_dir.stem)
  365. # Test best.pt
  366. logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
  367. if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
  368. for m in (last, best) if best.exists() else (last): # speed, mAP tests
  369. results, _, _ = test.test(opt.data,
  370. batch_size=batch_size * 2,
  371. imgsz=imgsz_test,
  372. conf_thres=0.001,
  373. iou_thres=0.7,
  374. model=attempt_load(m, device).half(),
  375. single_cls=opt.single_cls,
  376. dataloader=testloader,
  377. save_dir=save_dir,
  378. save_json=True,
  379. plots=False)
  380. else:
  381. dist.destroy_process_group()
  382. wandb.run.finish() if wandb and wandb.run else None
  383. torch.cuda.empty_cache()
  384. return results
  385. if __name__ == '__main__':
  386. parser = argparse.ArgumentParser()
  387. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  388. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  389. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
  390. parser.add_argument('--hyp', type=str, default='data/hyp.scratch.yaml', help='hyperparameters path')
  391. parser.add_argument('--epochs', type=int, default=300)
  392. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  393. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
  394. parser.add_argument('--rect', action='store_true', help='rectangular training')
  395. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  396. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  397. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  398. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  399. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
  400. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  401. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  402. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  403. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  404. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  405. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  406. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  407. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  408. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  409. parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100')
  410. parser.add_argument('--log-artifacts', action='store_true', help='log artifacts, i.e. final trained model')
  411. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  412. parser.add_argument('--project', default='runs/train', help='save to project/name')
  413. parser.add_argument('--entity', default=None, help='W&B entity')
  414. parser.add_argument('--name', default='exp', help='save to project/name')
  415. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  416. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  417. parser.add_argument('--linear-lr', action='store_true', help='linear LR')
  418. opt = parser.parse_args()
  419. # Set DDP variables
  420. opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
  421. opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
  422. set_logging(opt.global_rank)
  423. if opt.global_rank in [-1, 0]:
  424. check_git_status()
  425. check_requirements()
  426. # Resume
  427. if opt.resume: # resume an interrupted run
  428. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  429. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  430. apriori = opt.global_rank, opt.local_rank
  431. with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
  432. opt = argparse.Namespace(**yaml.load(f, Loader=yaml.SafeLoader)) # replace
  433. opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = '', ckpt, True, opt.total_batch_size, *apriori # reinstate
  434. logger.info('Resuming training from %s' % ckpt)
  435. else:
  436. # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
  437. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  438. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  439. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  440. opt.name = 'evolve' if opt.evolve else opt.name
  441. opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run
  442. # DDP mode
  443. opt.total_batch_size = opt.batch_size
  444. device = select_device(opt.device, batch_size=opt.batch_size)
  445. if opt.local_rank != -1:
  446. assert torch.cuda.device_count() > opt.local_rank
  447. torch.cuda.set_device(opt.local_rank)
  448. device = torch.device('cuda', opt.local_rank)
  449. dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
  450. assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
  451. opt.batch_size = opt.total_batch_size // opt.world_size
  452. # Hyperparameters
  453. with open(opt.hyp) as f:
  454. hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps
  455. # Train
  456. logger.info(opt)
  457. try:
  458. import wandb
  459. except ImportError:
  460. wandb = None
  461. prefix = colorstr('wandb: ')
  462. logger.info(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
  463. if not opt.evolve:
  464. tb_writer = None # init loggers
  465. if opt.global_rank in [-1, 0]:
  466. logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/')
  467. tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
  468. train(hyp, opt, device, tb_writer, wandb)
  469. # Evolve hyperparameters (optional)
  470. else:
  471. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  472. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  473. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  474. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  475. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  476. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  477. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  478. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  479. 'box': (1, 0.02, 0.2), # box loss gain
  480. 'cls': (1, 0.2, 4.0), # cls loss gain
  481. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  482. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  483. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  484. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  485. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  486. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  487. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  488. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  489. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  490. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  491. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  492. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  493. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  494. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  495. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  496. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  497. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  498. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  499. 'mixup': (1, 0.0, 1.0)} # image mixup (probability)
  500. assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
  501. opt.notest, opt.nosave = True, True # only test/save final epoch
  502. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  503. yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
  504. if opt.bucket:
  505. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  506. for _ in range(300): # generations to evolve
  507. if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate
  508. # Select parent(s)
  509. parent = 'single' # parent selection method: 'single' or 'weighted'
  510. x = np.loadtxt('evolve.txt', ndmin=2)
  511. n = min(5, len(x)) # number of previous results to consider
  512. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  513. w = fitness(x) - fitness(x).min() # weights
  514. if parent == 'single' or len(x) == 1:
  515. # x = x[random.randint(0, n - 1)] # random selection
  516. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  517. elif parent == 'weighted':
  518. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  519. # Mutate
  520. mp, s = 0.8, 0.2 # mutation probability, sigma
  521. npr = np.random
  522. npr.seed(int(time.time()))
  523. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  524. ng = len(meta)
  525. v = np.ones(ng)
  526. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  527. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  528. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  529. hyp[k] = float(x[i + 7] * v[i]) # mutate
  530. # Constrain to limits
  531. for k, v in meta.items():
  532. hyp[k] = max(hyp[k], v[1]) # lower limit
  533. hyp[k] = min(hyp[k], v[2]) # upper limit
  534. hyp[k] = round(hyp[k], 5) # significant digits
  535. # Train mutation
  536. results = train(hyp.copy(), opt, device, wandb=wandb)
  537. # Write mutation results
  538. print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
  539. # Plot results
  540. plot_evolution(yaml_file)
  541. print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'
  542. f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')