Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

604 lines
31KB

  1. import argparse
  2. import logging
  3. import math
  4. import os
  5. import random
  6. import time
  7. from pathlib import Path
  8. from threading import Thread
  9. import numpy as np
  10. import torch.distributed as dist
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. import torch.optim as optim
  14. import torch.optim.lr_scheduler as lr_scheduler
  15. import torch.utils.data
  16. import yaml
  17. from torch.cuda import amp
  18. from torch.nn.parallel import DistributedDataParallel as DDP
  19. from torch.utils.tensorboard import SummaryWriter
  20. from tqdm import tqdm
  21. import test # import test.py to get mAP after each epoch
  22. from models.experimental import attempt_load
  23. from models.yolo import Model
  24. from utils.autoanchor import check_anchors
  25. from utils.datasets import create_dataloader
  26. from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
  27. fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
  28. check_requirements, print_mutation, set_logging, one_cycle, colorstr
  29. from utils.google_utils import attempt_download
  30. from utils.loss import ComputeLoss
  31. from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
  32. from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first
  33. logger = logging.getLogger(__name__)
  34. def train(hyp, opt, device, tb_writer=None, wandb=None):
  35. logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  36. save_dir, epochs, batch_size, total_batch_size, weights, rank = \
  37. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
  38. # Directories
  39. wdir = save_dir / 'weights'
  40. wdir.mkdir(parents=True, exist_ok=True) # make dir
  41. last = wdir / 'last.pt'
  42. best = wdir / 'best.pt'
  43. results_file = save_dir / 'results.txt'
  44. # Save run settings
  45. with open(save_dir / 'hyp.yaml', 'w') as f:
  46. yaml.dump(hyp, f, sort_keys=False)
  47. with open(save_dir / 'opt.yaml', 'w') as f:
  48. yaml.dump(vars(opt), f, sort_keys=False)
  49. # Configure
  50. plots = not opt.evolve # create plots
  51. cuda = device.type != 'cpu'
  52. init_seeds(2 + rank)
  53. with open(opt.data) as f:
  54. data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
  55. with torch_distributed_zero_first(rank):
  56. check_dataset(data_dict) # check
  57. train_path = data_dict['train']
  58. test_path = data_dict['val']
  59. nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
  60. names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  61. assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
  62. # Model
  63. pretrained = weights.endswith('.pt')
  64. if pretrained:
  65. with torch_distributed_zero_first(rank):
  66. attempt_download(weights) # download if not found locally
  67. ckpt = torch.load(weights, map_location=device) # load checkpoint
  68. if hyp.get('anchors'):
  69. ckpt['model'].yaml['anchors'] = round(hyp['anchors']) # force autoanchor
  70. model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create
  71. exclude = ['anchor'] if opt.cfg or hyp.get('anchors') else [] # exclude keys
  72. state_dict = ckpt['model'].float().state_dict() # to FP32
  73. state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
  74. model.load_state_dict(state_dict, strict=False) # load
  75. logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
  76. else:
  77. model = Model(opt.cfg, ch=3, nc=nc).to(device) # create
  78. # Freeze
  79. freeze = [] # parameter names to freeze (full or partial)
  80. for k, v in model.named_parameters():
  81. v.requires_grad = True # train all layers
  82. if any(x in k for x in freeze):
  83. print('freezing %s' % k)
  84. v.requires_grad = False
  85. # Optimizer
  86. nbs = 64 # nominal batch size
  87. accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
  88. hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
  89. logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  90. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  91. for k, v in model.named_modules():
  92. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
  93. pg2.append(v.bias) # biases
  94. if isinstance(v, nn.BatchNorm2d):
  95. pg0.append(v.weight) # no decay
  96. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
  97. pg1.append(v.weight) # apply decay
  98. if opt.adam:
  99. optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  100. else:
  101. optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  102. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  103. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  104. logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  105. del pg0, pg1, pg2
  106. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  107. # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  108. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  109. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  110. # plot_lr_scheduler(optimizer, scheduler, epochs)
  111. # Logging
  112. if rank in [-1, 0] and wandb and wandb.run is None:
  113. opt.hyp = hyp # add hyperparameters
  114. wandb_run = wandb.init(config=opt, resume="allow",
  115. project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
  116. name=save_dir.stem,
  117. id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
  118. loggers = {'wandb': wandb} # loggers dict
  119. # Resume
  120. start_epoch, best_fitness = 0, 0.0
  121. if pretrained:
  122. # Optimizer
  123. if ckpt['optimizer'] is not None:
  124. optimizer.load_state_dict(ckpt['optimizer'])
  125. best_fitness = ckpt['best_fitness']
  126. # Results
  127. if ckpt.get('training_results') is not None:
  128. with open(results_file, 'w') as file:
  129. file.write(ckpt['training_results']) # write results.txt
  130. # Epochs
  131. start_epoch = ckpt['epoch'] + 1
  132. if opt.resume:
  133. assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
  134. if epochs < start_epoch:
  135. logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
  136. (weights, ckpt['epoch'], epochs))
  137. epochs += ckpt['epoch'] # finetune additional epochs
  138. del ckpt, state_dict
  139. # Image sizes
  140. gs = int(model.stride.max()) # grid size (max stride)
  141. nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj'])
  142. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  143. # DP mode
  144. if cuda and rank == -1 and torch.cuda.device_count() > 1:
  145. model = torch.nn.DataParallel(model)
  146. # SyncBatchNorm
  147. if opt.sync_bn and cuda and rank != -1:
  148. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  149. logger.info('Using SyncBatchNorm()')
  150. # EMA
  151. ema = ModelEMA(model) if rank in [-1, 0] else None
  152. # DDP mode
  153. if cuda and rank != -1:
  154. model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank)
  155. # Trainloader
  156. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
  157. hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
  158. world_size=opt.world_size, workers=opt.workers,
  159. image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
  160. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  161. nb = len(dataloader) # number of batches
  162. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
  163. # Process 0
  164. if rank in [-1, 0]:
  165. ema.updates = start_epoch * nb // accumulate # set EMA updates
  166. testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, # testloader
  167. hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
  168. world_size=opt.world_size, workers=opt.workers,
  169. pad=0.5, prefix=colorstr('val: '))[0]
  170. if not opt.resume:
  171. labels = np.concatenate(dataset.labels, 0)
  172. c = torch.tensor(labels[:, 0]) # classes
  173. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  174. # model._initialize_biases(cf.to(device))
  175. if plots:
  176. plot_labels(labels, save_dir, loggers)
  177. if tb_writer:
  178. tb_writer.add_histogram('classes', c, 0)
  179. # Anchors
  180. if not opt.noautoanchor:
  181. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  182. # Model parameters
  183. hyp['box'] *= 3. / nl # scale to layers
  184. hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
  185. hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
  186. model.nc = nc # attach number of classes to model
  187. model.hyp = hyp # attach hyperparameters to model
  188. model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
  189. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  190. model.names = names
  191. # Start training
  192. t0 = time.time()
  193. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  194. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  195. maps = np.zeros(nc) # mAP per class
  196. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  197. scheduler.last_epoch = start_epoch - 1 # do not move
  198. scaler = amp.GradScaler(enabled=cuda)
  199. compute_loss = ComputeLoss(model) # init loss class
  200. logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
  201. f'Using {dataloader.num_workers} dataloader workers\n'
  202. f'Logging results to {save_dir}\n'
  203. f'Starting training for {epochs} epochs...')
  204. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  205. model.train()
  206. # Update image weights (optional)
  207. if opt.image_weights:
  208. # Generate indices
  209. if rank in [-1, 0]:
  210. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  211. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  212. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  213. # Broadcast if DDP
  214. if rank != -1:
  215. indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
  216. dist.broadcast(indices, 0)
  217. if rank != 0:
  218. dataset.indices = indices.cpu().numpy()
  219. # Update mosaic border
  220. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  221. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  222. mloss = torch.zeros(4, device=device) # mean losses
  223. if rank != -1:
  224. dataloader.sampler.set_epoch(epoch)
  225. pbar = enumerate(dataloader)
  226. logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'targets', 'img_size'))
  227. if rank in [-1, 0]:
  228. pbar = tqdm(pbar, total=nb) # progress bar
  229. optimizer.zero_grad()
  230. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  231. ni = i + nb * epoch # number integrated batches (since train start)
  232. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  233. # Warmup
  234. if ni <= nw:
  235. xi = [0, nw] # x interp
  236. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  237. accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
  238. for j, x in enumerate(optimizer.param_groups):
  239. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  240. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  241. if 'momentum' in x:
  242. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  243. # Multi-scale
  244. if opt.multi_scale:
  245. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  246. sf = sz / max(imgs.shape[2:]) # scale factor
  247. if sf != 1:
  248. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  249. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  250. # Forward
  251. with amp.autocast(enabled=cuda):
  252. pred = model(imgs) # forward
  253. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  254. if rank != -1:
  255. loss *= opt.world_size # gradient averaged between devices in DDP mode
  256. if opt.quad:
  257. loss *= 4.
  258. # Backward
  259. scaler.scale(loss).backward()
  260. # Optimize
  261. if ni % accumulate == 0:
  262. scaler.step(optimizer) # optimizer.step
  263. scaler.update()
  264. optimizer.zero_grad()
  265. if ema:
  266. ema.update(model)
  267. # Print
  268. if rank in [-1, 0]:
  269. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  270. mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  271. s = ('%10s' * 2 + '%10.4g' * 6) % (
  272. '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
  273. pbar.set_description(s)
  274. # Plot
  275. if plots and ni < 3:
  276. f = save_dir / f'train_batch{ni}.jpg' # filename
  277. Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
  278. # if tb_writer:
  279. # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
  280. # tb_writer.add_graph(model, imgs) # add model to tensorboard
  281. elif plots and ni == 10 and wandb:
  282. wandb.log({"Mosaics": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('train*.jpg')
  283. if x.exists()]})
  284. # end batch ------------------------------------------------------------------------------------------------
  285. # end epoch ----------------------------------------------------------------------------------------------------
  286. # Scheduler
  287. lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
  288. scheduler.step()
  289. # DDP process 0 or single-GPU
  290. if rank in [-1, 0]:
  291. # mAP
  292. if ema:
  293. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
  294. final_epoch = epoch + 1 == epochs
  295. if not opt.notest or final_epoch: # Calculate mAP
  296. results, maps, times = test.test(opt.data,
  297. batch_size=total_batch_size,
  298. imgsz=imgsz_test,
  299. model=ema.ema,
  300. single_cls=opt.single_cls,
  301. dataloader=testloader,
  302. save_dir=save_dir,
  303. plots=plots and final_epoch,
  304. log_imgs=opt.log_imgs if wandb else 0,
  305. compute_loss=compute_loss)
  306. # Write
  307. with open(results_file, 'a') as f:
  308. f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  309. if len(opt.name) and opt.bucket:
  310. os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
  311. # Log
  312. tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
  313. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  314. 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
  315. 'x/lr0', 'x/lr1', 'x/lr2'] # params
  316. for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
  317. if tb_writer:
  318. tb_writer.add_scalar(tag, x, epoch) # tensorboard
  319. if wandb:
  320. wandb.log({tag: x}) # W&B
  321. # Update best mAP
  322. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  323. if fi > best_fitness:
  324. best_fitness = fi
  325. # Save model
  326. save = (not opt.nosave) or (final_epoch and not opt.evolve)
  327. if save:
  328. with open(results_file, 'r') as f: # create checkpoint
  329. ckpt = {'epoch': epoch,
  330. 'best_fitness': best_fitness,
  331. 'training_results': f.read(),
  332. 'model': ema.ema,
  333. 'optimizer': None if final_epoch else optimizer.state_dict(),
  334. 'wandb_id': wandb_run.id if wandb else None}
  335. # Save last, best and delete
  336. torch.save(ckpt, last)
  337. if best_fitness == fi:
  338. torch.save(ckpt, best)
  339. del ckpt
  340. # end epoch ----------------------------------------------------------------------------------------------------
  341. # end training
  342. if rank in [-1, 0]:
  343. # Strip optimizers
  344. final = best if best.exists() else last # final model
  345. for f in [last, best]:
  346. if f.exists():
  347. strip_optimizer(f) # strip optimizers
  348. if opt.bucket:
  349. os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
  350. # Plots
  351. if plots:
  352. plot_results(save_dir=save_dir) # save as results.png
  353. if wandb:
  354. files = ['results.png', 'precision_recall_curve.png', 'confusion_matrix.png']
  355. wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
  356. if (save_dir / f).exists()]})
  357. if opt.log_artifacts:
  358. wandb.log_artifact(artifact_or_path=str(final), type='model', name=save_dir.stem)
  359. # Test best.pt
  360. logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
  361. if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
  362. for conf, iou, save_json in ([0.25, 0.45, False], [0.001, 0.65, True]): # speed, mAP tests
  363. results, _, _ = test.test(opt.data,
  364. batch_size=total_batch_size,
  365. imgsz=imgsz_test,
  366. conf_thres=conf,
  367. iou_thres=iou,
  368. model=attempt_load(final, device).half(),
  369. single_cls=opt.single_cls,
  370. dataloader=testloader,
  371. save_dir=save_dir,
  372. save_json=save_json,
  373. plots=False)
  374. else:
  375. dist.destroy_process_group()
  376. wandb.run.finish() if wandb and wandb.run else None
  377. torch.cuda.empty_cache()
  378. return results
  379. if __name__ == '__main__':
  380. parser = argparse.ArgumentParser()
  381. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  382. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  383. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
  384. parser.add_argument('--hyp', type=str, default='data/hyp.scratch.yaml', help='hyperparameters path')
  385. parser.add_argument('--epochs', type=int, default=300)
  386. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  387. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
  388. parser.add_argument('--rect', action='store_true', help='rectangular training')
  389. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  390. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  391. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  392. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  393. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
  394. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  395. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  396. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  397. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  398. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  399. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  400. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  401. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  402. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  403. parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100')
  404. parser.add_argument('--log-artifacts', action='store_true', help='log artifacts, i.e. final trained model')
  405. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  406. parser.add_argument('--project', default='runs/train', help='save to project/name')
  407. parser.add_argument('--name', default='exp', help='save to project/name')
  408. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  409. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  410. opt = parser.parse_args()
  411. # Set DDP variables
  412. opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
  413. opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
  414. set_logging(opt.global_rank)
  415. if opt.global_rank in [-1, 0]:
  416. check_git_status()
  417. check_requirements()
  418. # Resume
  419. if opt.resume: # resume an interrupted run
  420. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  421. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  422. apriori = opt.global_rank, opt.local_rank
  423. with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
  424. opt = argparse.Namespace(**yaml.load(f, Loader=yaml.FullLoader)) # replace
  425. opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = '', ckpt, True, opt.total_batch_size, *apriori # reinstate
  426. logger.info('Resuming training from %s' % ckpt)
  427. else:
  428. # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
  429. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  430. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  431. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  432. opt.name = 'evolve' if opt.evolve else opt.name
  433. opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run
  434. # DDP mode
  435. opt.total_batch_size = opt.batch_size
  436. device = select_device(opt.device, batch_size=opt.batch_size)
  437. if opt.local_rank != -1:
  438. assert torch.cuda.device_count() > opt.local_rank
  439. torch.cuda.set_device(opt.local_rank)
  440. device = torch.device('cuda', opt.local_rank)
  441. dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
  442. assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
  443. opt.batch_size = opt.total_batch_size // opt.world_size
  444. # Hyperparameters
  445. with open(opt.hyp) as f:
  446. hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
  447. # Train
  448. logger.info(opt)
  449. try:
  450. import wandb
  451. except ImportError:
  452. wandb = None
  453. prefix = colorstr('wandb: ')
  454. logger.info(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
  455. if not opt.evolve:
  456. tb_writer = None # init loggers
  457. if opt.global_rank in [-1, 0]:
  458. logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/')
  459. tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
  460. train(hyp, opt, device, tb_writer, wandb)
  461. # Evolve hyperparameters (optional)
  462. else:
  463. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  464. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  465. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  466. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  467. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  468. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  469. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  470. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  471. 'box': (1, 0.02, 0.2), # box loss gain
  472. 'cls': (1, 0.2, 4.0), # cls loss gain
  473. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  474. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  475. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  476. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  477. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  478. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  479. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  480. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  481. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  482. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  483. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  484. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  485. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  486. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  487. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  488. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  489. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  490. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  491. 'mixup': (1, 0.0, 1.0)} # image mixup (probability)
  492. assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
  493. opt.notest, opt.nosave = True, True # only test/save final epoch
  494. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  495. yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
  496. if opt.bucket:
  497. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  498. for _ in range(300): # generations to evolve
  499. if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate
  500. # Select parent(s)
  501. parent = 'single' # parent selection method: 'single' or 'weighted'
  502. x = np.loadtxt('evolve.txt', ndmin=2)
  503. n = min(5, len(x)) # number of previous results to consider
  504. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  505. w = fitness(x) - fitness(x).min() # weights
  506. if parent == 'single' or len(x) == 1:
  507. # x = x[random.randint(0, n - 1)] # random selection
  508. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  509. elif parent == 'weighted':
  510. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  511. # Mutate
  512. mp, s = 0.8, 0.2 # mutation probability, sigma
  513. npr = np.random
  514. npr.seed(int(time.time()))
  515. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  516. ng = len(meta)
  517. v = np.ones(ng)
  518. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  519. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  520. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  521. hyp[k] = float(x[i + 7] * v[i]) # mutate
  522. # Constrain to limits
  523. for k, v in meta.items():
  524. hyp[k] = max(hyp[k], v[1]) # lower limit
  525. hyp[k] = min(hyp[k], v[2]) # upper limit
  526. hyp[k] = round(hyp[k], 5) # significant digits
  527. # Train mutation
  528. results = train(hyp.copy(), opt, device, wandb=wandb)
  529. # Write mutation results
  530. print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
  531. # Plot results
  532. plot_evolution(yaml_file)
  533. print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'
  534. f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')