Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

447 lines
21KB

  1. import argparse
  2. import torch.distributed as dist
  3. import torch.nn.functional as F
  4. import torch.optim as optim
  5. import torch.optim.lr_scheduler as lr_scheduler
  6. import yaml
  7. from torch.utils.tensorboard import SummaryWriter
  8. import test # import test.py to get mAP after each epoch
  9. from models.yolo import Model
  10. from utils.datasets import *
  11. from utils.utils import *
  12. mixed_precision = True
  13. try: # Mixed precision training https://github.com/NVIDIA/apex
  14. from apex import amp
  15. except:
  16. print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex')
  17. mixed_precision = False # not installed
  18. wdir = 'weights' + os.sep # weights dir
  19. last = wdir + 'last.pt'
  20. best = wdir + 'best.pt'
  21. results_file = 'results.txt'
  22. # Hyperparameters
  23. hyp = {'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
  24. 'momentum': 0.937, # SGD momentum
  25. 'weight_decay': 5e-4, # optimizer weight decay
  26. 'giou': 0.05, # giou loss gain
  27. 'cls': 0.58, # cls loss gain
  28. 'cls_pw': 1.0, # cls BCELoss positive_weight
  29. 'obj': 1.0, # obj loss gain (*=img_size/320 if img_size != 320)
  30. 'obj_pw': 1.0, # obj BCELoss positive_weight
  31. 'iou_t': 0.20, # iou training threshold
  32. 'anchor_t': 4.0, # anchor-multiple threshold
  33. 'fl_gamma': 0.0, # focal loss gamma (efficientDet default is gamma=1.5)
  34. 'hsv_h': 0.014, # image HSV-Hue augmentation (fraction)
  35. 'hsv_s': 0.68, # image HSV-Saturation augmentation (fraction)
  36. 'hsv_v': 0.36, # image HSV-Value augmentation (fraction)
  37. 'degrees': 0.0, # image rotation (+/- deg)
  38. 'translate': 0.0, # image translation (+/- fraction)
  39. 'scale': 0.5, # image scale (+/- gain)
  40. 'shear': 0.0} # image shear (+/- deg)
  41. print(hyp)
  42. # Overwrite hyp with hyp*.txt (optional)
  43. f = glob.glob('hyp*.txt')
  44. if f:
  45. print('Using %s' % f[0])
  46. for k, v in zip(hyp.keys(), np.loadtxt(f[0])):
  47. hyp[k] = v
  48. # Print focal loss if gamma > 0
  49. if hyp['fl_gamma']:
  50. print('Using FocalLoss(gamma=%g)' % hyp['fl_gamma'])
  51. def train(hyp):
  52. epochs = opt.epochs # 300
  53. batch_size = opt.batch_size # 64
  54. weights = opt.weights # initial training weights
  55. # Configure
  56. init_seeds(1)
  57. with open(opt.data) as f:
  58. data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
  59. train_path = data_dict['train']
  60. test_path = data_dict['val']
  61. nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
  62. # Remove previous results
  63. for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
  64. os.remove(f)
  65. # Create model
  66. model = Model(opt.cfg).to(device)
  67. assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
  68. # Image sizes
  69. gs = int(max(model.stride)) # grid size (max stride)
  70. if any(x % gs != 0 for x in opt.img_size):
  71. print('WARNING: --img-size %g,%g must be multiple of %s max stride %g' % (*opt.img_size, opt.cfg, gs))
  72. imgsz, imgsz_test = [make_divisible(x, gs) for x in opt.img_size] # image sizes (train, test)
  73. # Optimizer
  74. nbs = 64 # nominal batch size
  75. accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
  76. hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
  77. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  78. for k, v in model.named_parameters():
  79. if v.requires_grad:
  80. if '.bias' in k:
  81. pg2.append(v) # biases
  82. elif '.weight' in k and '.bn' not in k:
  83. pg1.append(v) # apply weight decay
  84. else:
  85. pg0.append(v) # all else
  86. optimizer = optim.Adam(pg0, lr=hyp['lr0']) if opt.adam else \
  87. optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  88. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  89. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  90. print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  91. del pg0, pg1, pg2
  92. # Load Model
  93. google_utils.attempt_download(weights)
  94. start_epoch, best_fitness = 0, 0.0
  95. if weights.endswith('.pt'): # pytorch format
  96. ckpt = torch.load(weights, map_location=device) # load checkpoint
  97. # load model
  98. try:
  99. ckpt['model'] = \
  100. {k: v for k, v in ckpt['model'].state_dict().items() if model.state_dict()[k].numel() == v.numel()}
  101. model.load_state_dict(ckpt['model'], strict=False)
  102. except KeyError as e:
  103. s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \
  104. % (opt.weights, opt.cfg, opt.weights)
  105. raise KeyError(s) from e
  106. # load optimizer
  107. if ckpt['optimizer'] is not None:
  108. optimizer.load_state_dict(ckpt['optimizer'])
  109. best_fitness = ckpt['best_fitness']
  110. # load results
  111. if ckpt.get('training_results') is not None:
  112. with open(results_file, 'w') as file:
  113. file.write(ckpt['training_results']) # write results.txt
  114. start_epoch = ckpt['epoch'] + 1
  115. del ckpt
  116. # Mixed precision training https://github.com/NVIDIA/apex
  117. if mixed_precision:
  118. model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
  119. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  120. lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine
  121. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  122. scheduler.last_epoch = start_epoch - 1 # do not move
  123. # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
  124. # plot_lr_scheduler(optimizer, scheduler, epochs)
  125. # Initialize distributed training
  126. if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available():
  127. dist.init_process_group(backend='nccl', # distributed backend
  128. init_method='tcp://127.0.0.1:9999', # init method
  129. world_size=1, # number of nodes
  130. rank=0) # node rank
  131. model = torch.nn.parallel.DistributedDataParallel(model)
  132. # Dataset
  133. dataset = LoadImagesAndLabels(train_path, imgsz, batch_size,
  134. augment=True,
  135. hyp=hyp, # augmentation hyperparameters
  136. rect=opt.rect, # rectangular training
  137. cache_images=opt.cache_images,
  138. single_cls=opt.single_cls)
  139. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  140. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
  141. # Dataloader
  142. batch_size = min(batch_size, len(dataset))
  143. nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
  144. dataloader = torch.utils.data.DataLoader(dataset,
  145. batch_size=batch_size,
  146. num_workers=nw,
  147. shuffle=not opt.rect, # Shuffle=True unless rectangular training is used
  148. pin_memory=True,
  149. collate_fn=dataset.collate_fn)
  150. # Testloader
  151. testloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path, imgsz_test, batch_size,
  152. hyp=hyp,
  153. rect=True,
  154. cache_images=opt.cache_images,
  155. single_cls=opt.single_cls),
  156. batch_size=batch_size,
  157. num_workers=nw,
  158. pin_memory=True,
  159. collate_fn=dataset.collate_fn)
  160. # Model parameters
  161. hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
  162. model.nc = nc # attach number of classes to model
  163. model.hyp = hyp # attach hyperparameters to model
  164. model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
  165. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
  166. model.names = data_dict['names']
  167. # class frequency
  168. labels = np.concatenate(dataset.labels, 0)
  169. c = torch.tensor(labels[:, 0]) # classes
  170. # cf = torch.bincount(c.long(), minlength=nc) + 1.
  171. # model._initialize_biases(cf.to(device))
  172. plot_labels(labels)
  173. tb_writer.add_histogram('classes', c, 0)
  174. # Exponential moving average
  175. ema = torch_utils.ModelEMA(model)
  176. # Start training
  177. t0 = time.time()
  178. nb = len(dataloader) # number of batches
  179. n_burn = max(3 * nb, 1e3) # burn-in iterations, max(3 epochs, 1k iterations)
  180. maps = np.zeros(nc) # mAP per class
  181. results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
  182. print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
  183. print('Using %g dataloader workers' % nw)
  184. print('Starting training for %g epochs...' % epochs)
  185. # torch.autograd.set_detect_anomaly(True)
  186. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  187. model.train()
  188. # Update image weights (optional)
  189. if dataset.image_weights:
  190. w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
  191. image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
  192. dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx
  193. mloss = torch.zeros(4, device=device) # mean losses
  194. print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
  195. pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
  196. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  197. ni = i + nb * epoch # number integrated batches (since train start)
  198. imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
  199. # Burn-in
  200. if ni <= n_burn:
  201. xi = [0, n_burn] # x interp
  202. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou)
  203. accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
  204. for j, x in enumerate(optimizer.param_groups):
  205. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  206. x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  207. if 'momentum' in x:
  208. x['momentum'] = np.interp(ni, xi, [0.9, hyp['momentum']])
  209. # Multi-scale
  210. if opt.multi_scale:
  211. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  212. sf = sz / max(imgs.shape[2:]) # scale factor
  213. if sf != 1:
  214. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  215. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  216. # Forward
  217. pred = model(imgs)
  218. # Loss
  219. loss, loss_items = compute_loss(pred, targets.to(device), model)
  220. if not torch.isfinite(loss):
  221. print('WARNING: non-finite loss, ending training ', loss_items)
  222. return results
  223. # Backward
  224. if mixed_precision:
  225. with amp.scale_loss(loss, optimizer) as scaled_loss:
  226. scaled_loss.backward()
  227. else:
  228. loss.backward()
  229. # Optimize
  230. if ni % accumulate == 0:
  231. optimizer.step()
  232. optimizer.zero_grad()
  233. ema.update(model)
  234. # Print
  235. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  236. mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  237. s = ('%10s' * 2 + '%10.4g' * 6) % (
  238. '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
  239. pbar.set_description(s)
  240. # Plot
  241. if ni < 3:
  242. f = 'train_batch%g.jpg' % i # filename
  243. res = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
  244. if tb_writer:
  245. tb_writer.add_image(f, res, dataformats='HWC', global_step=epoch)
  246. # tb_writer.add_graph(model, imgs) # add model to tensorboard
  247. # end batch ------------------------------------------------------------------------------------------------
  248. # Scheduler
  249. scheduler.step()
  250. # mAP
  251. ema.update_attr(model)
  252. final_epoch = epoch + 1 == epochs
  253. if not opt.notest or final_epoch: # Calculate mAP
  254. results, maps, times = test.test(opt.data,
  255. batch_size=batch_size,
  256. imgsz=imgsz_test,
  257. save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
  258. model=ema.ema,
  259. single_cls=opt.single_cls,
  260. dataloader=testloader,
  261. fast=ni < n_burn)
  262. # Write
  263. with open(results_file, 'a') as f:
  264. f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
  265. if len(opt.name) and opt.bucket:
  266. os.system('gsutil cp results.txt gs://%s/results/results%s.txt' % (opt.bucket, opt.name))
  267. # Tensorboard
  268. if tb_writer:
  269. tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss',
  270. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/F1',
  271. 'val/giou_loss', 'val/obj_loss', 'val/cls_loss']
  272. for x, tag in zip(list(mloss[:-1]) + list(results), tags):
  273. tb_writer.add_scalar(tag, x, epoch)
  274. # Update best mAP
  275. fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1]
  276. if fi > best_fitness:
  277. best_fitness = fi
  278. # Save model
  279. save = (not opt.nosave) or (final_epoch and not opt.evolve)
  280. if save:
  281. with open(results_file, 'r') as f: # create checkpoint
  282. ckpt = {'epoch': epoch,
  283. 'best_fitness': best_fitness,
  284. 'training_results': f.read(),
  285. 'model': ema.ema.module if hasattr(model, 'module') else ema.ema,
  286. 'optimizer': None if final_epoch else optimizer.state_dict()}
  287. # Save last, best and delete
  288. torch.save(ckpt, last)
  289. if (best_fitness == fi) and not final_epoch:
  290. torch.save(ckpt, best)
  291. del ckpt
  292. # end epoch ----------------------------------------------------------------------------------------------------
  293. # end training
  294. n = opt.name
  295. if len(n):
  296. n = '_' + n if not n.isnumeric() else n
  297. fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
  298. for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'], [flast, fbest, fresults]):
  299. if os.path.exists(f1):
  300. os.rename(f1, f2) # rename
  301. ispt = f2.endswith('.pt') # is *.pt
  302. strip_optimizer(f2) if ispt else None # strip optimizer
  303. os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload
  304. if not opt.evolve:
  305. plot_results() # save as results.png
  306. print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
  307. dist.destroy_process_group() if torch.cuda.device_count() > 1 else None
  308. torch.cuda.empty_cache()
  309. return results
  310. if __name__ == '__main__':
  311. parser = argparse.ArgumentParser()
  312. parser.add_argument('--epochs', type=int, default=300)
  313. parser.add_argument('--batch-size', type=int, default=16)
  314. parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml', help='*.cfg path')
  315. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path')
  316. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
  317. parser.add_argument('--rect', action='store_true', help='rectangular training')
  318. parser.add_argument('--resume', action='store_true', help='resume training from last.pt')
  319. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  320. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  321. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
  322. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  323. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  324. parser.add_argument('--weights', type=str, default='', help='initial weights path')
  325. parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied')
  326. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  327. parser.add_argument('--adam', action='store_true', help='use adam optimizer')
  328. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%')
  329. parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
  330. opt = parser.parse_args()
  331. opt.weights = last if opt.resume else opt.weights
  332. opt.cfg = glob.glob('./**/' + opt.cfg, recursive=True)[0] # find file
  333. opt.data = glob.glob('./**/' + opt.data, recursive=True)[0] # find file
  334. print(opt)
  335. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  336. device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
  337. # check_git_status()
  338. if device.type == 'cpu':
  339. mixed_precision = False
  340. # Train
  341. if not opt.evolve:
  342. tb_writer = SummaryWriter(comment=opt.name)
  343. print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
  344. train(hyp)
  345. # Evolve hyperparameters (optional)
  346. else:
  347. tb_writer = None
  348. opt.notest, opt.nosave = True, True # only test/save final epoch
  349. if opt.bucket:
  350. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  351. for _ in range(10): # generations to evolve
  352. if os.path.exists('evolve.txt'): # if evolve.txt exists: select best hyps and mutate
  353. # Select parent(s)
  354. parent = 'single' # parent selection method: 'single' or 'weighted'
  355. x = np.loadtxt('evolve.txt', ndmin=2)
  356. n = min(5, len(x)) # number of previous results to consider
  357. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  358. w = fitness(x) - fitness(x).min() # weights
  359. if parent == 'single' or len(x) == 1:
  360. # x = x[random.randint(0, n - 1)] # random selection
  361. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  362. elif parent == 'weighted':
  363. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  364. # Mutate
  365. mp, s = 0.9, 0.2 # mutation probability, sigma
  366. npr = np.random
  367. npr.seed(int(time.time()))
  368. g = np.array([1, 1, 1, 1, 1, 1, 1, 0, .1, 1, 0, 1, 1, 1, 1, 1, 1, 1]) # gains
  369. ng = len(g)
  370. v = np.ones(ng)
  371. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  372. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  373. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  374. hyp[k] = x[i + 7] * v[i] # mutate
  375. # Clip to limits
  376. keys = ['lr0', 'iou_t', 'momentum', 'weight_decay', 'hsv_s', 'hsv_v', 'translate', 'scale', 'fl_gamma']
  377. limits = [(1e-5, 1e-2), (0.00, 0.70), (0.60, 0.98), (0, 0.001), (0, .9), (0, .9), (0, .9), (0, .9), (0, 3)]
  378. for k, v in zip(keys, limits):
  379. hyp[k] = np.clip(hyp[k], v[0], v[1])
  380. # Train mutation
  381. results = train(hyp.copy())
  382. # Write mutation results
  383. print_mutation(hyp, results, opt.bucket)
  384. # Plot results
  385. # plot_evolution_results(hyp)