You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

615 lines
32KB

  1. import argparse
  2. import logging
  3. import math
  4. import os
  5. import random
  6. import time
  7. from copy import deepcopy
  8. from pathlib import Path
  9. from threading import Thread
  10. import numpy as np
  11. import torch.distributed as dist
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. import torch.optim as optim
  15. import torch.optim.lr_scheduler as lr_scheduler
  16. import torch.utils.data
  17. import yaml
  18. from torch.cuda import amp
  19. from torch.nn.parallel import DistributedDataParallel as DDP
  20. from torch.utils.tensorboard import SummaryWriter
  21. from tqdm import tqdm
  22. import test # import test.py to get mAP after each epoch
  23. from models.experimental import attempt_load
  24. from models.yolo import Model
  25. from utils.autoanchor import check_anchors
  26. from utils.datasets import create_dataloader
  27. from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
  28. fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
  29. check_requirements, print_mutation, set_logging, one_cycle, colorstr
  30. from utils.google_utils import attempt_download
  31. from utils.loss import ComputeLoss
  32. from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
  33. from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
  34. logger = logging.getLogger(__name__)
  35. def train(hyp, opt, device, tb_writer=None, wandb=None):
  36. logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  37. save_dir, epochs, batch_size, total_batch_size, weights, rank = \
  38. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
  39. # Directories
  40. wdir = save_dir / 'weights'
  41. wdir.mkdir(parents=True, exist_ok=True) # make dir
  42. last = wdir / 'last.pt'
  43. best = wdir / 'best.pt'
  44. results_file = save_dir / 'results.txt'
  45. # Save run settings
  46. with open(save_dir / 'hyp.yaml', 'w') as f:
  47. yaml.dump(hyp, f, sort_keys=False)
  48. with open(save_dir / 'opt.yaml', 'w') as f:
  49. yaml.dump(vars(opt), f, sort_keys=False)
  50. # Configure
  51. plots = not opt.evolve # create plots
  52. cuda = device.type != 'cpu'
  53. init_seeds(2 + rank)
  54. with open(opt.data) as f:
  55. data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
  56. with torch_distributed_zero_first(rank):
  57. check_dataset(data_dict) # check
  58. train_path = data_dict['train']
  59. test_path = data_dict['val']
  60. nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
  61. names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  62. assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
  63. # Model
  64. pretrained = weights.endswith('.pt')
  65. if pretrained:
  66. with torch_distributed_zero_first(rank):
  67. attempt_download(weights) # download if not found locally
  68. ckpt = torch.load(weights, map_location=device) # load checkpoint
  69. if hyp.get('anchors'):
  70. ckpt['model'].yaml['anchors'] = round(hyp['anchors']) # force autoanchor
  71. model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create
  72. exclude = ['anchor'] if opt.cfg or hyp.get('anchors') else [] # exclude keys
  73. state_dict = ckpt['model'].float().state_dict() # to FP32
  74. state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
  75. model.load_state_dict(state_dict, strict=False) # load
  76. logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
  77. else:
  78. model = Model(opt.cfg, ch=3, nc=nc).to(device) # create
  79. # Freeze
  80. freeze = [] # parameter names to freeze (full or partial)
  81. for k, v in model.named_parameters():
  82. v.requires_grad = True # train all layers
  83. if any(x in k for x in freeze):
  84. print('freezing %s' % k)
  85. v.requires_grad = False
  86. # Optimizer
  87. nbs = 64 # nominal batch size
  88. accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
  89. hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
  90. logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  91. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  92. for k, v in model.named_modules():
  93. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
  94. pg2.append(v.bias) # biases
  95. if isinstance(v, nn.BatchNorm2d):
  96. pg0.append(v.weight) # no decay
  97. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
  98. pg1.append(v.weight) # apply decay
  99. if opt.adam:
  100. optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  101. else:
  102. optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  103. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  104. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  105. logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  106. del pg0, pg1, pg2
  107. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  108. # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  109. if opt.linear_lr:
  110. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  111. else:
  112. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  113. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  114. # plot_lr_scheduler(optimizer, scheduler, epochs)
  115. # Logging
  116. if rank in [-1, 0] and wandb and wandb.run is None:
  117. opt.hyp = hyp # add hyperparameters
  118. wandb_run = wandb.init(config=opt, resume="allow",
  119. project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
  120. name=save_dir.stem,
  121. entity=opt.entity,
  122. id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
  123. loggers = {'wandb': wandb} # loggers dict
  124. # EMA
  125. ema = ModelEMA(model) if rank in [-1, 0] else None
  126. # Resume
  127. start_epoch, best_fitness = 0, 0.0
  128. if pretrained:
  129. # Optimizer
  130. if ckpt['optimizer'] is not None:
  131. optimizer.load_state_dict(ckpt['optimizer'])
  132. best_fitness = ckpt['best_fitness']
  133. # EMA
  134. if ema and ckpt.get('ema'):
  135. ema.ema.load_state_dict(ckpt['ema'][0].float().state_dict())
  136. ema.updates = ckpt['ema'][1]
  137. # Results
  138. if ckpt.get('training_results') is not None:
  139. results_file.write_text(ckpt['training_results']) # write results.txt
  140. # Epochs
  141. start_epoch = ckpt['epoch'] + 1
  142. if opt.resume:
  143. assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
  144. if epochs < start_epoch:
  145. logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
  146. (weights, ckpt['epoch'], epochs))
  147. epochs += ckpt['epoch'] # finetune additional epochs
  148. del ckpt, state_dict
  149. # Image sizes
  150. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  151. nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj'])
  152. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  153. # DP mode
  154. if cuda and rank == -1 and torch.cuda.device_count() > 1:
  155. model = torch.nn.DataParallel(model)
  156. # SyncBatchNorm
  157. if opt.sync_bn and cuda and rank != -1:
  158. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  159. logger.info('Using SyncBatchNorm()')
  160. # DDP mode
  161. if cuda and rank != -1:
  162. model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank)
  163. # Trainloader
  164. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
  165. hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
  166. world_size=opt.world_size, workers=opt.workers,
  167. image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
  168. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  169. nb = len(dataloader) # number of batches
  170. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
  171. # Process 0
  172. if rank in [-1, 0]:
  173. testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
  174. hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
  175. world_size=opt.world_size, workers=opt.workers,
  176. pad=0.5, prefix=colorstr('val: '))[0]
  177. if not opt.resume:
  178. labels = np.concatenate(dataset.labels, 0)
  179. c = torch.tensor(labels[:, 0]) # classes
  180. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  181. # model._initialize_biases(cf.to(device))
  182. if plots:
  183. plot_labels(labels, save_dir, loggers)
  184. if tb_writer:
  185. tb_writer.add_histogram('classes', c, 0)
  186. # Anchors
  187. if not opt.noautoanchor:
  188. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  189. # Model parameters
  190. hyp['box'] *= 3. / nl # scale to layers
  191. hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
  192. hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
  193. model.nc = nc # attach number of classes to model
  194. model.hyp = hyp # attach hyperparameters to model
  195. model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
  196. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  197. model.names = names
  198. # Start training
  199. t0 = time.time()
  200. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  201. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  202. maps = np.zeros(nc) # mAP per class
  203. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  204. scheduler.last_epoch = start_epoch - 1 # do not move
  205. scaler = amp.GradScaler(enabled=cuda)
  206. compute_loss = ComputeLoss(model) # init loss class
  207. logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
  208. f'Using {dataloader.num_workers} dataloader workers\n'
  209. f'Logging results to {save_dir}\n'
  210. f'Starting training for {epochs} epochs...')
  211. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  212. model.train()
  213. # Update image weights (optional)
  214. if opt.image_weights:
  215. # Generate indices
  216. if rank in [-1, 0]:
  217. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  218. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  219. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  220. # Broadcast if DDP
  221. if rank != -1:
  222. indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
  223. dist.broadcast(indices, 0)
  224. if rank != 0:
  225. dataset.indices = indices.cpu().numpy()
  226. # Update mosaic border
  227. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  228. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  229. mloss = torch.zeros(4, device=device) # mean losses
  230. if rank != -1:
  231. dataloader.sampler.set_epoch(epoch)
  232. pbar = enumerate(dataloader)
  233. logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'targets', 'img_size'))
  234. if rank in [-1, 0]:
  235. pbar = tqdm(pbar, total=nb) # progress bar
  236. optimizer.zero_grad()
  237. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  238. ni = i + nb * epoch # number integrated batches (since train start)
  239. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  240. # Warmup
  241. if ni <= nw:
  242. xi = [0, nw] # x interp
  243. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  244. accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
  245. for j, x in enumerate(optimizer.param_groups):
  246. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  247. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  248. if 'momentum' in x:
  249. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  250. # Multi-scale
  251. if opt.multi_scale:
  252. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  253. sf = sz / max(imgs.shape[2:]) # scale factor
  254. if sf != 1:
  255. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  256. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  257. # Forward
  258. with amp.autocast(enabled=cuda):
  259. pred = model(imgs) # forward
  260. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  261. if rank != -1:
  262. loss *= opt.world_size # gradient averaged between devices in DDP mode
  263. if opt.quad:
  264. loss *= 4.
  265. # Backward
  266. scaler.scale(loss).backward()
  267. # Optimize
  268. if ni % accumulate == 0:
  269. scaler.step(optimizer) # optimizer.step
  270. scaler.update()
  271. optimizer.zero_grad()
  272. if ema:
  273. ema.update(model)
  274. # Print
  275. if rank in [-1, 0]:
  276. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  277. mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  278. s = ('%10s' * 2 + '%10.4g' * 6) % (
  279. '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
  280. pbar.set_description(s)
  281. # Plot
  282. if plots and ni < 3:
  283. f = save_dir / f'train_batch{ni}.jpg' # filename
  284. Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
  285. # if tb_writer:
  286. # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
  287. # tb_writer.add_graph(model, imgs) # add model to tensorboard
  288. elif plots and ni == 10 and wandb:
  289. wandb.log({"Mosaics": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('train*.jpg')
  290. if x.exists()]}, commit=False)
  291. # end batch ------------------------------------------------------------------------------------------------
  292. # end epoch ----------------------------------------------------------------------------------------------------
  293. # Scheduler
  294. lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
  295. scheduler.step()
  296. # DDP process 0 or single-GPU
  297. if rank in [-1, 0]:
  298. # mAP
  299. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
  300. final_epoch = epoch + 1 == epochs
  301. if not opt.notest or final_epoch: # Calculate mAP
  302. results, maps, times = test.test(opt.data,
  303. batch_size=batch_size * 2,
  304. imgsz=imgsz_test,
  305. model=ema.ema,
  306. single_cls=opt.single_cls,
  307. dataloader=testloader,
  308. save_dir=save_dir,
  309. verbose=nc < 50 and final_epoch,
  310. plots=plots and final_epoch,
  311. log_imgs=opt.log_imgs if wandb else 0,
  312. compute_loss=compute_loss)
  313. # Write
  314. with open(results_file, 'a') as f:
  315. f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
  316. if len(opt.name) and opt.bucket:
  317. os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
  318. # Log
  319. tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
  320. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  321. 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
  322. 'x/lr0', 'x/lr1', 'x/lr2'] # params
  323. for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
  324. if tb_writer:
  325. tb_writer.add_scalar(tag, x, epoch) # tensorboard
  326. if wandb:
  327. wandb.log({tag: x}, step=epoch, commit=tag == tags[-1]) # W&B
  328. # Update best mAP
  329. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  330. if fi > best_fitness:
  331. best_fitness = fi
  332. # Save model
  333. if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
  334. ckpt = {'epoch': epoch,
  335. 'best_fitness': best_fitness,
  336. 'training_results': results_file.read_text(),
  337. 'model': ema.ema if final_epoch else deepcopy(
  338. model.module if is_parallel(model) else model).half(),
  339. 'ema': (deepcopy(ema.ema).half(), ema.updates),
  340. 'optimizer': optimizer.state_dict(),
  341. 'wandb_id': wandb_run.id if wandb else None}
  342. # Save last, best and delete
  343. torch.save(ckpt, last)
  344. if best_fitness == fi:
  345. torch.save(ckpt, best)
  346. del ckpt
  347. # end epoch ----------------------------------------------------------------------------------------------------
  348. # end training
  349. if rank in [-1, 0]:
  350. # Strip optimizers
  351. final = best if best.exists() else last # final model
  352. for f in last, best:
  353. if f.exists():
  354. strip_optimizer(f)
  355. if opt.bucket:
  356. os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
  357. # Plots
  358. if plots:
  359. plot_results(save_dir=save_dir) # save as results.png
  360. if wandb:
  361. files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
  362. wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
  363. if (save_dir / f).exists()]})
  364. if opt.log_artifacts:
  365. wandb.log_artifact(artifact_or_path=str(final), type='model', name=save_dir.stem)
  366. # Test best.pt
  367. logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
  368. if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
  369. for m in (last, best) if best.exists() else (last): # speed, mAP tests
  370. results, _, _ = test.test(opt.data,
  371. batch_size=batch_size * 2,
  372. imgsz=imgsz_test,
  373. conf_thres=0.001,
  374. iou_thres=0.7,
  375. model=attempt_load(m, device).half(),
  376. single_cls=opt.single_cls,
  377. dataloader=testloader,
  378. save_dir=save_dir,
  379. save_json=True,
  380. plots=False)
  381. else:
  382. dist.destroy_process_group()
  383. wandb.run.finish() if wandb and wandb.run else None
  384. torch.cuda.empty_cache()
  385. return results
  386. if __name__ == '__main__':
  387. parser = argparse.ArgumentParser()
  388. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  389. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  390. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
  391. parser.add_argument('--hyp', type=str, default='data/hyp.scratch.yaml', help='hyperparameters path')
  392. parser.add_argument('--epochs', type=int, default=300)
  393. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  394. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
  395. parser.add_argument('--rect', action='store_true', help='rectangular training')
  396. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  397. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  398. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  399. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  400. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
  401. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  402. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  403. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  404. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  405. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  406. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  407. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  408. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  409. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  410. parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100')
  411. parser.add_argument('--log-artifacts', action='store_true', help='log artifacts, i.e. final trained model')
  412. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  413. parser.add_argument('--project', default='runs/train', help='save to project/name')
  414. parser.add_argument('--entity', default=None, help='W&B entity')
  415. parser.add_argument('--name', default='exp', help='save to project/name')
  416. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  417. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  418. parser.add_argument('--linear-lr', action='store_true', help='linear LR')
  419. opt = parser.parse_args()
  420. # Set DDP variables
  421. opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
  422. opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
  423. set_logging(opt.global_rank)
  424. if opt.global_rank in [-1, 0]:
  425. check_git_status()
  426. check_requirements()
  427. # Resume
  428. if opt.resume: # resume an interrupted run
  429. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  430. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  431. apriori = opt.global_rank, opt.local_rank
  432. with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
  433. opt = argparse.Namespace(**yaml.load(f, Loader=yaml.SafeLoader)) # replace
  434. opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = '', ckpt, True, opt.total_batch_size, *apriori # reinstate
  435. logger.info('Resuming training from %s' % ckpt)
  436. else:
  437. # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
  438. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  439. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  440. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  441. opt.name = 'evolve' if opt.evolve else opt.name
  442. opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run
  443. # DDP mode
  444. opt.total_batch_size = opt.batch_size
  445. device = select_device(opt.device, batch_size=opt.batch_size)
  446. if opt.local_rank != -1:
  447. assert torch.cuda.device_count() > opt.local_rank
  448. torch.cuda.set_device(opt.local_rank)
  449. device = torch.device('cuda', opt.local_rank)
  450. dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
  451. assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
  452. opt.batch_size = opt.total_batch_size // opt.world_size
  453. # Hyperparameters
  454. with open(opt.hyp) as f:
  455. hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps
  456. # Train
  457. logger.info(opt)
  458. try:
  459. import wandb
  460. except ImportError:
  461. wandb = None
  462. prefix = colorstr('wandb: ')
  463. logger.info(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
  464. if not opt.evolve:
  465. tb_writer = None # init loggers
  466. if opt.global_rank in [-1, 0]:
  467. logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/')
  468. tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
  469. train(hyp, opt, device, tb_writer, wandb)
  470. # Evolve hyperparameters (optional)
  471. else:
  472. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  473. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  474. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  475. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  476. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  477. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  478. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  479. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  480. 'box': (1, 0.02, 0.2), # box loss gain
  481. 'cls': (1, 0.2, 4.0), # cls loss gain
  482. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  483. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  484. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  485. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  486. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  487. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  488. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  489. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  490. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  491. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  492. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  493. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  494. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  495. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  496. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  497. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  498. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  499. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  500. 'mixup': (1, 0.0, 1.0)} # image mixup (probability)
  501. assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
  502. opt.notest, opt.nosave = True, True # only test/save final epoch
  503. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  504. yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
  505. if opt.bucket:
  506. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  507. for _ in range(300): # generations to evolve
  508. if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate
  509. # Select parent(s)
  510. parent = 'single' # parent selection method: 'single' or 'weighted'
  511. x = np.loadtxt('evolve.txt', ndmin=2)
  512. n = min(5, len(x)) # number of previous results to consider
  513. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  514. w = fitness(x) - fitness(x).min() # weights
  515. if parent == 'single' or len(x) == 1:
  516. # x = x[random.randint(0, n - 1)] # random selection
  517. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  518. elif parent == 'weighted':
  519. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  520. # Mutate
  521. mp, s = 0.8, 0.2 # mutation probability, sigma
  522. npr = np.random
  523. npr.seed(int(time.time()))
  524. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  525. ng = len(meta)
  526. v = np.ones(ng)
  527. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  528. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  529. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  530. hyp[k] = float(x[i + 7] * v[i]) # mutate
  531. # Constrain to limits
  532. for k, v in meta.items():
  533. hyp[k] = max(hyp[k], v[1]) # lower limit
  534. hyp[k] = min(hyp[k], v[2]) # upper limit
  535. hyp[k] = round(hyp[k], 5) # significant digits
  536. # Train mutation
  537. results = train(hyp.copy(), opt, device, wandb=wandb)
  538. # Write mutation results
  539. print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
  540. # Plot results
  541. plot_evolution(yaml_file)
  542. print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'
  543. f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')