Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

624 lignes
33KB

  1. import argparse
  2. import logging
  3. import math
  4. import os
  5. import random
  6. import time
  7. from copy import deepcopy
  8. from pathlib import Path
  9. from threading import Thread
  10. import numpy as np
  11. import torch.distributed as dist
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. import torch.optim as optim
  15. import torch.optim.lr_scheduler as lr_scheduler
  16. import torch.utils.data
  17. import yaml
  18. from torch.cuda import amp
  19. from torch.nn.parallel import DistributedDataParallel as DDP
  20. from torch.utils.tensorboard import SummaryWriter
  21. from tqdm import tqdm
  22. import test # import test.py to get mAP after each epoch
  23. from models.experimental import attempt_load
  24. from models.yolo import Model
  25. from utils.autoanchor import check_anchors
  26. from utils.datasets import create_dataloader
  27. from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
  28. fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
  29. check_requirements, print_mutation, set_logging, one_cycle, colorstr
  30. from utils.google_utils import attempt_download
  31. from utils.loss import ComputeLoss
  32. from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
  33. from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
  34. from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
  35. logger = logging.getLogger(__name__)
  36. def train(hyp, opt, device, tb_writer=None):
  37. logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  38. save_dir, epochs, batch_size, total_batch_size, weights, rank = \
  39. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
  40. # Directories
  41. wdir = save_dir / 'weights'
  42. wdir.mkdir(parents=True, exist_ok=True) # make dir
  43. last = wdir / 'last.pt'
  44. best = wdir / 'best.pt'
  45. results_file = save_dir / 'results.txt'
  46. # Save run settings
  47. with open(save_dir / 'hyp.yaml', 'w') as f:
  48. yaml.dump(hyp, f, sort_keys=False)
  49. with open(save_dir / 'opt.yaml', 'w') as f:
  50. yaml.dump(vars(opt), f, sort_keys=False)
  51. # Configure
  52. plots = not opt.evolve # create plots
  53. cuda = device.type != 'cpu'
  54. init_seeds(2 + rank)
  55. with open(opt.data) as f:
  56. data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
  57. is_coco = opt.data.endswith('coco.yaml')
  58. # Logging- Doing this before checking the dataset. Might update data_dict
  59. loggers = {'wandb': None} # loggers dict
  60. if rank in [-1, 0]:
  61. opt.hyp = hyp # add hyperparameters
  62. run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
  63. wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict)
  64. loggers['wandb'] = wandb_logger.wandb
  65. data_dict = wandb_logger.data_dict
  66. if wandb_logger.wandb:
  67. weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
  68. nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
  69. names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  70. assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
  71. # Model
  72. pretrained = weights.endswith('.pt')
  73. if pretrained:
  74. with torch_distributed_zero_first(rank):
  75. attempt_download(weights) # download if not found locally
  76. ckpt = torch.load(weights, map_location=device) # load checkpoint
  77. model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  78. exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
  79. state_dict = ckpt['model'].float().state_dict() # to FP32
  80. state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
  81. model.load_state_dict(state_dict, strict=False) # load
  82. logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
  83. else:
  84. model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  85. with torch_distributed_zero_first(rank):
  86. check_dataset(data_dict) # check
  87. train_path = data_dict['train']
  88. test_path = data_dict['val']
  89. # Freeze
  90. freeze = [] # parameter names to freeze (full or partial)
  91. for k, v in model.named_parameters():
  92. v.requires_grad = True # train all layers
  93. if any(x in k for x in freeze):
  94. print('freezing %s' % k)
  95. v.requires_grad = False
  96. # Optimizer
  97. nbs = 64 # nominal batch size
  98. accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
  99. hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
  100. logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  101. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  102. for k, v in model.named_modules():
  103. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
  104. pg2.append(v.bias) # biases
  105. if isinstance(v, nn.BatchNorm2d):
  106. pg0.append(v.weight) # no decay
  107. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
  108. pg1.append(v.weight) # apply decay
  109. if opt.adam:
  110. optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  111. else:
  112. optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  113. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  114. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  115. logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  116. del pg0, pg1, pg2
  117. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  118. # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  119. if opt.linear_lr:
  120. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  121. else:
  122. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  123. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  124. # plot_lr_scheduler(optimizer, scheduler, epochs)
  125. # EMA
  126. ema = ModelEMA(model) if rank in [-1, 0] else None
  127. # Resume
  128. start_epoch, best_fitness = 0, 0.0
  129. if pretrained:
  130. # Optimizer
  131. if ckpt['optimizer'] is not None:
  132. optimizer.load_state_dict(ckpt['optimizer'])
  133. best_fitness = ckpt['best_fitness']
  134. # EMA
  135. if ema and ckpt.get('ema'):
  136. ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
  137. ema.updates = ckpt['updates']
  138. # Results
  139. if ckpt.get('training_results') is not None:
  140. results_file.write_text(ckpt['training_results']) # write results.txt
  141. # Epochs
  142. start_epoch = ckpt['epoch'] + 1
  143. if opt.resume:
  144. assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
  145. if epochs < start_epoch:
  146. logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
  147. (weights, ckpt['epoch'], epochs))
  148. epochs += ckpt['epoch'] # finetune additional epochs
  149. del ckpt, state_dict
  150. # Image sizes
  151. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  152. nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj'])
  153. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  154. # DP mode
  155. if cuda and rank == -1 and torch.cuda.device_count() > 1:
  156. model = torch.nn.DataParallel(model)
  157. # SyncBatchNorm
  158. if opt.sync_bn and cuda and rank != -1:
  159. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  160. logger.info('Using SyncBatchNorm()')
  161. # Trainloader
  162. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
  163. hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
  164. world_size=opt.world_size, workers=opt.workers,
  165. image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
  166. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  167. nb = len(dataloader) # number of batches
  168. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
  169. # Process 0
  170. if rank in [-1, 0]:
  171. testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
  172. hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
  173. world_size=opt.world_size, workers=opt.workers,
  174. pad=0.5, prefix=colorstr('val: '))[0]
  175. if not opt.resume:
  176. labels = np.concatenate(dataset.labels, 0)
  177. c = torch.tensor(labels[:, 0]) # classes
  178. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  179. # model._initialize_biases(cf.to(device))
  180. if plots:
  181. plot_labels(labels, names, save_dir, loggers)
  182. if tb_writer:
  183. tb_writer.add_histogram('classes', c, 0)
  184. # Anchors
  185. if not opt.noautoanchor:
  186. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  187. model.half().float() # pre-reduce anchor precision
  188. # DDP mode
  189. if cuda and rank != -1:
  190. model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank)
  191. # Model parameters
  192. hyp['box'] *= 3. / nl # scale to layers
  193. hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
  194. hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
  195. hyp['label_smoothing'] = opt.label_smoothing
  196. model.nc = nc # attach number of classes to model
  197. model.hyp = hyp # attach hyperparameters to model
  198. model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
  199. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  200. model.names = names
  201. # Start training
  202. t0 = time.time()
  203. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  204. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  205. maps = np.zeros(nc) # mAP per class
  206. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  207. scheduler.last_epoch = start_epoch - 1 # do not move
  208. scaler = amp.GradScaler(enabled=cuda)
  209. compute_loss = ComputeLoss(model) # init loss class
  210. logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
  211. f'Using {dataloader.num_workers} dataloader workers\n'
  212. f'Logging results to {save_dir}\n'
  213. f'Starting training for {epochs} epochs...')
  214. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  215. model.train()
  216. # Update image weights (optional)
  217. if opt.image_weights:
  218. # Generate indices
  219. if rank in [-1, 0]:
  220. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  221. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  222. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  223. # Broadcast if DDP
  224. if rank != -1:
  225. indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
  226. dist.broadcast(indices, 0)
  227. if rank != 0:
  228. dataset.indices = indices.cpu().numpy()
  229. # Update mosaic border
  230. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  231. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  232. mloss = torch.zeros(4, device=device) # mean losses
  233. if rank != -1:
  234. dataloader.sampler.set_epoch(epoch)
  235. pbar = enumerate(dataloader)
  236. logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size'))
  237. if rank in [-1, 0]:
  238. pbar = tqdm(pbar, total=nb) # progress bar
  239. optimizer.zero_grad()
  240. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  241. ni = i + nb * epoch # number integrated batches (since train start)
  242. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  243. # Warmup
  244. if ni <= nw:
  245. xi = [0, nw] # x interp
  246. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  247. accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
  248. for j, x in enumerate(optimizer.param_groups):
  249. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  250. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  251. if 'momentum' in x:
  252. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  253. # Multi-scale
  254. if opt.multi_scale:
  255. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  256. sf = sz / max(imgs.shape[2:]) # scale factor
  257. if sf != 1:
  258. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  259. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  260. # Forward
  261. with amp.autocast(enabled=cuda):
  262. pred = model(imgs) # forward
  263. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  264. if rank != -1:
  265. loss *= opt.world_size # gradient averaged between devices in DDP mode
  266. if opt.quad:
  267. loss *= 4.
  268. # Backward
  269. scaler.scale(loss).backward()
  270. # Optimize
  271. if ni % accumulate == 0:
  272. scaler.step(optimizer) # optimizer.step
  273. scaler.update()
  274. optimizer.zero_grad()
  275. if ema:
  276. ema.update(model)
  277. # Print
  278. if rank in [-1, 0]:
  279. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  280. mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  281. s = ('%10s' * 2 + '%10.4g' * 6) % (
  282. '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
  283. pbar.set_description(s)
  284. # Plot
  285. if plots and ni < 3:
  286. f = save_dir / f'train_batch{ni}.jpg' # filename
  287. Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
  288. # if tb_writer:
  289. # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
  290. # tb_writer.add_graph(model, imgs) # add model to tensorboard
  291. elif plots and ni == 10 and wandb_logger.wandb:
  292. wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
  293. save_dir.glob('train*.jpg') if x.exists()]})
  294. # end batch ------------------------------------------------------------------------------------------------
  295. # end epoch ----------------------------------------------------------------------------------------------------
  296. # Scheduler
  297. lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
  298. scheduler.step()
  299. # DDP process 0 or single-GPU
  300. if rank in [-1, 0]:
  301. # mAP
  302. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
  303. final_epoch = epoch + 1 == epochs
  304. if not opt.notest or final_epoch: # Calculate mAP
  305. wandb_logger.current_epoch = epoch + 1
  306. results, maps, times = test.test(data_dict,
  307. batch_size=batch_size * 2,
  308. imgsz=imgsz_test,
  309. model=ema.ema,
  310. single_cls=opt.single_cls,
  311. dataloader=testloader,
  312. save_dir=save_dir,
  313. verbose=nc < 50 and final_epoch,
  314. plots=plots and final_epoch,
  315. wandb_logger=wandb_logger,
  316. compute_loss=compute_loss,
  317. is_coco=is_coco)
  318. # Write
  319. with open(results_file, 'a') as f:
  320. f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss
  321. if len(opt.name) and opt.bucket:
  322. os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
  323. # Log
  324. tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
  325. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  326. 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
  327. 'x/lr0', 'x/lr1', 'x/lr2'] # params
  328. for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
  329. if tb_writer:
  330. tb_writer.add_scalar(tag, x, epoch) # tensorboard
  331. if wandb_logger.wandb:
  332. wandb_logger.log({tag: x}) # W&B
  333. # Update best mAP
  334. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  335. if fi > best_fitness:
  336. best_fitness = fi
  337. wandb_logger.end_epoch(best_result=best_fitness == fi)
  338. # Save model
  339. if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
  340. ckpt = {'epoch': epoch,
  341. 'best_fitness': best_fitness,
  342. 'training_results': results_file.read_text(),
  343. 'model': deepcopy(model.module if is_parallel(model) else model).half(),
  344. 'ema': deepcopy(ema.ema).half(),
  345. 'updates': ema.updates,
  346. 'optimizer': optimizer.state_dict(),
  347. 'wandb_id': wandb_logger.wandb_run.id if wandb_logger.wandb else None}
  348. # Save last, best and delete
  349. torch.save(ckpt, last)
  350. if best_fitness == fi:
  351. torch.save(ckpt, best)
  352. if wandb_logger.wandb:
  353. if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
  354. wandb_logger.log_model(
  355. last.parent, opt, epoch, fi, best_model=best_fitness == fi)
  356. del ckpt
  357. # end epoch ----------------------------------------------------------------------------------------------------
  358. # end training
  359. if rank in [-1, 0]:
  360. # Plots
  361. if plots:
  362. plot_results(save_dir=save_dir) # save as results.png
  363. if wandb_logger.wandb:
  364. files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
  365. wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files
  366. if (save_dir / f).exists()]})
  367. # Test best.pt
  368. logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
  369. if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
  370. for m in (last, best) if best.exists() else (last): # speed, mAP tests
  371. results, _, _ = test.test(opt.data,
  372. batch_size=batch_size * 2,
  373. imgsz=imgsz_test,
  374. conf_thres=0.001,
  375. iou_thres=0.7,
  376. model=attempt_load(m, device).half(),
  377. single_cls=opt.single_cls,
  378. dataloader=testloader,
  379. save_dir=save_dir,
  380. save_json=True,
  381. plots=False,
  382. is_coco=is_coco)
  383. # Strip optimizers
  384. final = best if best.exists() else last # final model
  385. for f in last, best:
  386. if f.exists():
  387. strip_optimizer(f) # strip optimizers
  388. if opt.bucket:
  389. os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
  390. if wandb_logger.wandb and not opt.evolve: # Log the stripped model
  391. wandb_logger.wandb.log_artifact(str(final), type='model',
  392. name='run_' + wandb_logger.wandb_run.id + '_model',
  393. aliases=['last', 'best', 'stripped'])
  394. wandb_logger.finish_run()
  395. else:
  396. dist.destroy_process_group()
  397. torch.cuda.empty_cache()
  398. return results
  399. if __name__ == '__main__':
  400. parser = argparse.ArgumentParser()
  401. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  402. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  403. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
  404. parser.add_argument('--hyp', type=str, default='data/hyp.scratch.yaml', help='hyperparameters path')
  405. parser.add_argument('--epochs', type=int, default=300)
  406. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  407. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
  408. parser.add_argument('--rect', action='store_true', help='rectangular training')
  409. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  410. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  411. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  412. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  413. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
  414. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  415. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  416. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  417. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  418. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  419. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  420. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  421. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  422. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  423. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  424. parser.add_argument('--project', default='runs/train', help='save to project/name')
  425. parser.add_argument('--entity', default=None, help='W&B entity')
  426. parser.add_argument('--name', default='exp', help='save to project/name')
  427. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  428. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  429. parser.add_argument('--linear-lr', action='store_true', help='linear LR')
  430. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
  431. parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
  432. parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
  433. parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
  434. parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
  435. opt = parser.parse_args()
  436. # Set DDP variables
  437. opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
  438. opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
  439. set_logging(opt.global_rank)
  440. if opt.global_rank in [-1, 0]:
  441. check_git_status()
  442. check_requirements()
  443. # Resume
  444. wandb_run = check_wandb_resume(opt)
  445. if opt.resume and not wandb_run: # resume an interrupted run
  446. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  447. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  448. apriori = opt.global_rank, opt.local_rank
  449. with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
  450. opt = argparse.Namespace(**yaml.load(f, Loader=yaml.SafeLoader)) # replace
  451. opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = '', ckpt, True, opt.total_batch_size, *apriori # reinstate
  452. logger.info('Resuming training from %s' % ckpt)
  453. else:
  454. # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
  455. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  456. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  457. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  458. opt.name = 'evolve' if opt.evolve else opt.name
  459. opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run
  460. # DDP mode
  461. opt.total_batch_size = opt.batch_size
  462. device = select_device(opt.device, batch_size=opt.batch_size)
  463. if opt.local_rank != -1:
  464. assert torch.cuda.device_count() > opt.local_rank
  465. torch.cuda.set_device(opt.local_rank)
  466. device = torch.device('cuda', opt.local_rank)
  467. dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
  468. assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
  469. opt.batch_size = opt.total_batch_size // opt.world_size
  470. # Hyperparameters
  471. with open(opt.hyp) as f:
  472. hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps
  473. # Train
  474. logger.info(opt)
  475. if not opt.evolve:
  476. tb_writer = None # init loggers
  477. if opt.global_rank in [-1, 0]:
  478. prefix = colorstr('tensorboard: ')
  479. logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
  480. tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
  481. train(hyp, opt, device, tb_writer)
  482. # Evolve hyperparameters (optional)
  483. else:
  484. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  485. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  486. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  487. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  488. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  489. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  490. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  491. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  492. 'box': (1, 0.02, 0.2), # box loss gain
  493. 'cls': (1, 0.2, 4.0), # cls loss gain
  494. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  495. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  496. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  497. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  498. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  499. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  500. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  501. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  502. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  503. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  504. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  505. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  506. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  507. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  508. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  509. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  510. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  511. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  512. 'mixup': (1, 0.0, 1.0)} # image mixup (probability)
  513. assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
  514. opt.notest, opt.nosave = True, True # only test/save final epoch
  515. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  516. yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
  517. if opt.bucket:
  518. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  519. for _ in range(300): # generations to evolve
  520. if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate
  521. # Select parent(s)
  522. parent = 'single' # parent selection method: 'single' or 'weighted'
  523. x = np.loadtxt('evolve.txt', ndmin=2)
  524. n = min(5, len(x)) # number of previous results to consider
  525. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  526. w = fitness(x) - fitness(x).min() # weights
  527. if parent == 'single' or len(x) == 1:
  528. # x = x[random.randint(0, n - 1)] # random selection
  529. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  530. elif parent == 'weighted':
  531. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  532. # Mutate
  533. mp, s = 0.8, 0.2 # mutation probability, sigma
  534. npr = np.random
  535. npr.seed(int(time.time()))
  536. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  537. ng = len(meta)
  538. v = np.ones(ng)
  539. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  540. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  541. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  542. hyp[k] = float(x[i + 7] * v[i]) # mutate
  543. # Constrain to limits
  544. for k, v in meta.items():
  545. hyp[k] = max(hyp[k], v[1]) # lower limit
  546. hyp[k] = min(hyp[k], v[2]) # upper limit
  547. hyp[k] = round(hyp[k], 5) # significant digits
  548. # Train mutation
  549. results = train(hyp.copy(), opt, device)
  550. # Write mutation results
  551. print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
  552. # Plot results
  553. plot_evolution(yaml_file)
  554. print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'
  555. f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')