Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

626 lines
33KB

  1. import argparse
  2. import logging
  3. import math
  4. import os
  5. import random
  6. import time
  7. import warnings
  8. from copy import deepcopy
  9. from pathlib import Path
  10. from threading import Thread
  11. import numpy as np
  12. import torch.distributed as dist
  13. import torch.nn as nn
  14. import torch.nn.functional as F
  15. import torch.optim as optim
  16. import torch.optim.lr_scheduler as lr_scheduler
  17. import torch.utils.data
  18. import yaml
  19. from torch.cuda import amp
  20. from torch.nn.parallel import DistributedDataParallel as DDP
  21. from torch.utils.tensorboard import SummaryWriter
  22. from tqdm import tqdm
  23. import test # import test.py to get mAP after each epoch
  24. from models.experimental import attempt_load
  25. from models.yolo import Model
  26. from utils.autoanchor import check_anchors
  27. from utils.datasets import create_dataloader
  28. from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
  29. fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
  30. check_requirements, print_mutation, set_logging, one_cycle, colorstr
  31. from utils.google_utils import attempt_download
  32. from utils.loss import ComputeLoss
  33. from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
  34. from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
  35. from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
  36. logger = logging.getLogger(__name__)
  37. def train(hyp, opt, device, tb_writer=None):
  38. logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  39. save_dir, epochs, batch_size, total_batch_size, weights, rank = \
  40. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
  41. # Directories
  42. wdir = save_dir / 'weights'
  43. wdir.mkdir(parents=True, exist_ok=True) # make dir
  44. last = wdir / 'last.pt'
  45. best = wdir / 'best.pt'
  46. results_file = save_dir / 'results.txt'
  47. # Save run settings
  48. with open(save_dir / 'hyp.yaml', 'w') as f:
  49. yaml.safe_dump(hyp, f, sort_keys=False)
  50. with open(save_dir / 'opt.yaml', 'w') as f:
  51. yaml.safe_dump(vars(opt), f, sort_keys=False)
  52. # Configure
  53. plots = not opt.evolve # create plots
  54. cuda = device.type != 'cpu'
  55. init_seeds(2 + rank)
  56. with open(opt.data) as f:
  57. data_dict = yaml.safe_load(f) # data dict
  58. # Logging- Doing this before checking the dataset. Might update data_dict
  59. loggers = {'wandb': None} # loggers dict
  60. if rank in [-1, 0]:
  61. opt.hyp = hyp # add hyperparameters
  62. run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
  63. wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
  64. loggers['wandb'] = wandb_logger.wandb
  65. data_dict = wandb_logger.data_dict
  66. if wandb_logger.wandb:
  67. weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
  68. nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
  69. names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  70. assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
  71. is_coco = opt.data.endswith('coco.yaml') and nc == 80 # COCO dataset
  72. # Model
  73. pretrained = weights.endswith('.pt')
  74. if pretrained:
  75. with torch_distributed_zero_first(rank):
  76. weights = attempt_download(weights) # download if not found locally
  77. ckpt = torch.load(weights, map_location=device) # load checkpoint
  78. model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  79. exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
  80. state_dict = ckpt['model'].float().state_dict() # to FP32
  81. state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
  82. model.load_state_dict(state_dict, strict=False) # load
  83. logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
  84. else:
  85. model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  86. with torch_distributed_zero_first(rank):
  87. check_dataset(data_dict) # check
  88. train_path = data_dict['train']
  89. test_path = data_dict['val']
  90. # Freeze
  91. freeze = [] # parameter names to freeze (full or partial)
  92. for k, v in model.named_parameters():
  93. v.requires_grad = True # train all layers
  94. if any(x in k for x in freeze):
  95. print('freezing %s' % k)
  96. v.requires_grad = False
  97. # Optimizer
  98. nbs = 64 # nominal batch size
  99. accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
  100. hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
  101. logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  102. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  103. for k, v in model.named_modules():
  104. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
  105. pg2.append(v.bias) # biases
  106. if isinstance(v, nn.BatchNorm2d):
  107. pg0.append(v.weight) # no decay
  108. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
  109. pg1.append(v.weight) # apply decay
  110. if opt.adam:
  111. optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  112. else:
  113. optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  114. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  115. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  116. logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  117. del pg0, pg1, pg2
  118. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  119. # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  120. if opt.linear_lr:
  121. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  122. else:
  123. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  124. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  125. # plot_lr_scheduler(optimizer, scheduler, epochs)
  126. # EMA
  127. ema = ModelEMA(model) if rank in [-1, 0] else None
  128. # Resume
  129. start_epoch, best_fitness = 0, 0.0
  130. if pretrained:
  131. # Optimizer
  132. if ckpt['optimizer'] is not None:
  133. optimizer.load_state_dict(ckpt['optimizer'])
  134. best_fitness = ckpt['best_fitness']
  135. # EMA
  136. if ema and ckpt.get('ema'):
  137. ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
  138. ema.updates = ckpt['updates']
  139. # Results
  140. if ckpt.get('training_results') is not None:
  141. results_file.write_text(ckpt['training_results']) # write results.txt
  142. # Epochs
  143. start_epoch = ckpt['epoch'] + 1
  144. if opt.resume:
  145. assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
  146. if epochs < start_epoch:
  147. logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
  148. (weights, ckpt['epoch'], epochs))
  149. epochs += ckpt['epoch'] # finetune additional epochs
  150. del ckpt, state_dict
  151. # Image sizes
  152. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  153. nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj'])
  154. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  155. # DP mode
  156. if cuda and rank == -1 and torch.cuda.device_count() > 1:
  157. model = torch.nn.DataParallel(model)
  158. # SyncBatchNorm
  159. if opt.sync_bn and cuda and rank != -1:
  160. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  161. logger.info('Using SyncBatchNorm()')
  162. # Trainloader
  163. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
  164. hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
  165. world_size=opt.world_size, workers=opt.workers,
  166. image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
  167. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  168. nb = len(dataloader) # number of batches
  169. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
  170. # Process 0
  171. if rank in [-1, 0]:
  172. testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
  173. hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
  174. world_size=opt.world_size, workers=opt.workers,
  175. pad=0.5, prefix=colorstr('val: '))[0]
  176. if not opt.resume:
  177. labels = np.concatenate(dataset.labels, 0)
  178. c = torch.tensor(labels[:, 0]) # classes
  179. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  180. # model._initialize_biases(cf.to(device))
  181. if plots:
  182. plot_labels(labels, names, save_dir, loggers)
  183. if tb_writer:
  184. tb_writer.add_histogram('classes', c, 0)
  185. # Anchors
  186. if not opt.noautoanchor:
  187. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  188. model.half().float() # pre-reduce anchor precision
  189. # DDP mode
  190. if cuda and rank != -1:
  191. model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank,
  192. # nn.MultiheadAttention incompatibility with DDP https://github.com/pytorch/pytorch/issues/26698
  193. find_unused_parameters=any(isinstance(layer, nn.MultiheadAttention) for layer in model.modules()))
  194. # Model parameters
  195. hyp['box'] *= 3. / nl # scale to layers
  196. hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
  197. hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
  198. hyp['label_smoothing'] = opt.label_smoothing
  199. model.nc = nc # attach number of classes to model
  200. model.hyp = hyp # attach hyperparameters to model
  201. model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
  202. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  203. model.names = names
  204. # Start training
  205. t0 = time.time()
  206. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  207. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  208. maps = np.zeros(nc) # mAP per class
  209. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  210. scheduler.last_epoch = start_epoch - 1 # do not move
  211. scaler = amp.GradScaler(enabled=cuda)
  212. compute_loss = ComputeLoss(model) # init loss class
  213. logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
  214. f'Using {dataloader.num_workers} dataloader workers\n'
  215. f'Logging results to {save_dir}\n'
  216. f'Starting training for {epochs} epochs...')
  217. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  218. model.train()
  219. # Update image weights (optional)
  220. if opt.image_weights:
  221. # Generate indices
  222. if rank in [-1, 0]:
  223. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  224. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  225. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  226. # Broadcast if DDP
  227. if rank != -1:
  228. indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
  229. dist.broadcast(indices, 0)
  230. if rank != 0:
  231. dataset.indices = indices.cpu().numpy()
  232. # Update mosaic border
  233. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  234. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  235. mloss = torch.zeros(4, device=device) # mean losses
  236. if rank != -1:
  237. dataloader.sampler.set_epoch(epoch)
  238. pbar = enumerate(dataloader)
  239. logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size'))
  240. if rank in [-1, 0]:
  241. pbar = tqdm(pbar, total=nb) # progress bar
  242. optimizer.zero_grad()
  243. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  244. ni = i + nb * epoch # number integrated batches (since train start)
  245. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  246. # Warmup
  247. if ni <= nw:
  248. xi = [0, nw] # x interp
  249. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  250. accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
  251. for j, x in enumerate(optimizer.param_groups):
  252. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  253. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  254. if 'momentum' in x:
  255. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  256. # Multi-scale
  257. if opt.multi_scale:
  258. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  259. sf = sz / max(imgs.shape[2:]) # scale factor
  260. if sf != 1:
  261. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  262. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  263. # Forward
  264. with amp.autocast(enabled=cuda):
  265. pred = model(imgs) # forward
  266. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  267. if rank != -1:
  268. loss *= opt.world_size # gradient averaged between devices in DDP mode
  269. if opt.quad:
  270. loss *= 4.
  271. # Backward
  272. scaler.scale(loss).backward()
  273. # Optimize
  274. if ni % accumulate == 0:
  275. scaler.step(optimizer) # optimizer.step
  276. scaler.update()
  277. optimizer.zero_grad()
  278. if ema:
  279. ema.update(model)
  280. # Print
  281. if rank in [-1, 0]:
  282. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  283. mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  284. s = ('%10s' * 2 + '%10.4g' * 6) % (
  285. f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])
  286. pbar.set_description(s)
  287. # Plot
  288. if plots and ni < 3:
  289. f = save_dir / f'train_batch{ni}.jpg' # filename
  290. Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
  291. if tb_writer and ni == 0:
  292. with warnings.catch_warnings():
  293. warnings.simplefilter('ignore') # suppress jit trace warning
  294. tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), []) # graph
  295. elif plots and ni == 10 and wandb_logger.wandb:
  296. wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
  297. save_dir.glob('train*.jpg') if x.exists()]})
  298. # end batch ------------------------------------------------------------------------------------------------
  299. # end epoch ----------------------------------------------------------------------------------------------------
  300. # Scheduler
  301. lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
  302. scheduler.step()
  303. # DDP process 0 or single-GPU
  304. if rank in [-1, 0]:
  305. # mAP
  306. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
  307. final_epoch = epoch + 1 == epochs
  308. if not opt.notest or final_epoch: # Calculate mAP
  309. wandb_logger.current_epoch = epoch + 1
  310. results, maps, times = test.test(data_dict,
  311. batch_size=batch_size * 2,
  312. imgsz=imgsz_test,
  313. model=ema.ema,
  314. single_cls=opt.single_cls,
  315. dataloader=testloader,
  316. save_dir=save_dir,
  317. save_json=is_coco and final_epoch,
  318. verbose=nc < 50 and final_epoch,
  319. plots=plots and final_epoch,
  320. wandb_logger=wandb_logger,
  321. compute_loss=compute_loss,
  322. is_coco=is_coco)
  323. # Write
  324. with open(results_file, 'a') as f:
  325. f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
  326. # Log
  327. tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
  328. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  329. 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
  330. 'x/lr0', 'x/lr1', 'x/lr2'] # params
  331. for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
  332. if tb_writer:
  333. tb_writer.add_scalar(tag, x, epoch) # tensorboard
  334. if wandb_logger.wandb:
  335. wandb_logger.log({tag: x}) # W&B
  336. # Update best mAP
  337. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  338. if fi > best_fitness:
  339. best_fitness = fi
  340. wandb_logger.end_epoch(best_result=best_fitness == fi)
  341. # Save model
  342. if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
  343. ckpt = {'epoch': epoch,
  344. 'best_fitness': best_fitness,
  345. 'training_results': results_file.read_text(),
  346. 'model': deepcopy(de_parallel(model)).half(),
  347. 'ema': deepcopy(ema.ema).half(),
  348. 'updates': ema.updates,
  349. 'optimizer': optimizer.state_dict(),
  350. 'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None}
  351. # Save last, best and delete
  352. torch.save(ckpt, last)
  353. if best_fitness == fi:
  354. torch.save(ckpt, best)
  355. if wandb_logger.wandb:
  356. if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
  357. wandb_logger.log_model(
  358. last.parent, opt, epoch, fi, best_model=best_fitness == fi)
  359. del ckpt
  360. # end epoch ----------------------------------------------------------------------------------------------------
  361. # end training
  362. if rank in [-1, 0]:
  363. logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
  364. if plots:
  365. plot_results(save_dir=save_dir) # save as results.png
  366. if wandb_logger.wandb:
  367. files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
  368. wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files
  369. if (save_dir / f).exists()]})
  370. if not opt.evolve:
  371. if is_coco: # COCO dataset
  372. for m in [last, best] if best.exists() else [last]: # speed, mAP tests
  373. results, _, _ = test.test(opt.data,
  374. batch_size=batch_size * 2,
  375. imgsz=imgsz_test,
  376. conf_thres=0.001,
  377. iou_thres=0.7,
  378. model=attempt_load(m, device).half(),
  379. single_cls=opt.single_cls,
  380. dataloader=testloader,
  381. save_dir=save_dir,
  382. save_json=True,
  383. plots=False,
  384. is_coco=is_coco)
  385. # Strip optimizers
  386. for f in last, best:
  387. if f.exists():
  388. strip_optimizer(f) # strip optimizers
  389. if wandb_logger.wandb: # Log the stripped model
  390. wandb_logger.wandb.log_artifact(str(best if best.exists() else last), type='model',
  391. name='run_' + wandb_logger.wandb_run.id + '_model',
  392. aliases=['latest', 'best', 'stripped'])
  393. wandb_logger.finish_run()
  394. else:
  395. dist.destroy_process_group()
  396. torch.cuda.empty_cache()
  397. return results
  398. if __name__ == '__main__':
  399. parser = argparse.ArgumentParser()
  400. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  401. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  402. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
  403. parser.add_argument('--hyp', type=str, default='data/hyp.scratch.yaml', help='hyperparameters path')
  404. parser.add_argument('--epochs', type=int, default=300)
  405. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  406. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
  407. parser.add_argument('--rect', action='store_true', help='rectangular training')
  408. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  409. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  410. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  411. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  412. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
  413. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  414. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  415. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  416. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  417. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  418. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  419. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  420. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  421. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  422. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  423. parser.add_argument('--project', default='runs/train', help='save to project/name')
  424. parser.add_argument('--entity', default=None, help='W&B entity')
  425. parser.add_argument('--name', default='exp', help='save to project/name')
  426. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  427. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  428. parser.add_argument('--linear-lr', action='store_true', help='linear LR')
  429. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
  430. parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
  431. parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
  432. parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
  433. parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
  434. opt = parser.parse_args()
  435. # Set DDP variables
  436. opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
  437. opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
  438. set_logging(opt.global_rank)
  439. if opt.global_rank in [-1, 0]:
  440. check_git_status()
  441. check_requirements(exclude=('pycocotools', 'thop'))
  442. # Resume
  443. wandb_run = check_wandb_resume(opt)
  444. if opt.resume and not wandb_run: # resume an interrupted run
  445. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  446. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  447. apriori = opt.global_rank, opt.local_rank
  448. with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
  449. opt = argparse.Namespace(**yaml.safe_load(f)) # replace
  450. opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = \
  451. '', ckpt, True, opt.total_batch_size, *apriori # reinstate
  452. logger.info('Resuming training from %s' % ckpt)
  453. else:
  454. # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
  455. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  456. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  457. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  458. opt.name = 'evolve' if opt.evolve else opt.name
  459. opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve))
  460. # DDP mode
  461. opt.total_batch_size = opt.batch_size
  462. device = select_device(opt.device, batch_size=opt.batch_size)
  463. if opt.local_rank != -1:
  464. assert torch.cuda.device_count() > opt.local_rank
  465. torch.cuda.set_device(opt.local_rank)
  466. device = torch.device('cuda', opt.local_rank)
  467. dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
  468. assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
  469. assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
  470. opt.batch_size = opt.total_batch_size // opt.world_size
  471. # Hyperparameters
  472. with open(opt.hyp) as f:
  473. hyp = yaml.safe_load(f) # load hyps
  474. # Train
  475. logger.info(opt)
  476. if not opt.evolve:
  477. tb_writer = None # init loggers
  478. if opt.global_rank in [-1, 0]:
  479. prefix = colorstr('tensorboard: ')
  480. logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
  481. tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
  482. train(hyp, opt, device, tb_writer)
  483. # Evolve hyperparameters (optional)
  484. else:
  485. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  486. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  487. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  488. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  489. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  490. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  491. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  492. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  493. 'box': (1, 0.02, 0.2), # box loss gain
  494. 'cls': (1, 0.2, 4.0), # cls loss gain
  495. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  496. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  497. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  498. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  499. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  500. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  501. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  502. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  503. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  504. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  505. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  506. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  507. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  508. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  509. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  510. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  511. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  512. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  513. 'mixup': (1, 0.0, 1.0)} # image mixup (probability)
  514. assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
  515. opt.notest, opt.nosave = True, True # only test/save final epoch
  516. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  517. yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
  518. if opt.bucket:
  519. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  520. for _ in range(300): # generations to evolve
  521. if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate
  522. # Select parent(s)
  523. parent = 'single' # parent selection method: 'single' or 'weighted'
  524. x = np.loadtxt('evolve.txt', ndmin=2)
  525. n = min(5, len(x)) # number of previous results to consider
  526. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  527. w = fitness(x) - fitness(x).min() # weights
  528. if parent == 'single' or len(x) == 1:
  529. # x = x[random.randint(0, n - 1)] # random selection
  530. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  531. elif parent == 'weighted':
  532. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  533. # Mutate
  534. mp, s = 0.8, 0.2 # mutation probability, sigma
  535. npr = np.random
  536. npr.seed(int(time.time()))
  537. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  538. ng = len(meta)
  539. v = np.ones(ng)
  540. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  541. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  542. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  543. hyp[k] = float(x[i + 7] * v[i]) # mutate
  544. # Constrain to limits
  545. for k, v in meta.items():
  546. hyp[k] = max(hyp[k], v[1]) # lower limit
  547. hyp[k] = min(hyp[k], v[2]) # upper limit
  548. hyp[k] = round(hyp[k], 5) # significant digits
  549. # Train mutation
  550. results = train(hyp.copy(), opt, device)
  551. # Write mutation results
  552. print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
  553. # Plot results
  554. plot_evolution(yaml_file)
  555. print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'
  556. f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')