You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

443 satır
21KB

  1. import argparse
  2. import torch.distributed as dist
  3. import torch.nn.functional as F
  4. import torch.optim as optim
  5. import torch.optim.lr_scheduler as lr_scheduler
  6. import torch.utils.data
  7. from torch.utils.tensorboard import SummaryWriter
  8. import test # import test.py to get mAP after each epoch
  9. from models.yolo import Model
  10. from utils import google_utils
  11. from utils.datasets import *
  12. from utils.utils import *
  13. mixed_precision = True
  14. try: # Mixed precision training https://github.com/NVIDIA/apex
  15. from apex import amp
  16. except:
  17. print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex')
  18. mixed_precision = False # not installed
  19. wdir = 'weights' + os.sep # weights dir
  20. os.makedirs(wdir, exist_ok=True)
  21. last = wdir + 'last.pt'
  22. best = wdir + 'best.pt'
  23. results_file = 'results.txt'
  24. # Hyperparameters
  25. hyp = {'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
  26. 'momentum': 0.937, # SGD momentum
  27. 'weight_decay': 5e-4, # optimizer weight decay
  28. 'giou': 0.05, # giou loss gain
  29. 'cls': 0.58, # cls loss gain
  30. 'cls_pw': 1.0, # cls BCELoss positive_weight
  31. 'obj': 1.0, # obj loss gain (*=img_size/320 if img_size != 320)
  32. 'obj_pw': 1.0, # obj BCELoss positive_weight
  33. 'iou_t': 0.20, # iou training threshold
  34. 'anchor_t': 4.0, # anchor-multiple threshold
  35. 'fl_gamma': 0.0, # focal loss gamma (efficientDet default is gamma=1.5)
  36. 'hsv_h': 0.014, # image HSV-Hue augmentation (fraction)
  37. 'hsv_s': 0.68, # image HSV-Saturation augmentation (fraction)
  38. 'hsv_v': 0.36, # image HSV-Value augmentation (fraction)
  39. 'degrees': 0.0, # image rotation (+/- deg)
  40. 'translate': 0.0, # image translation (+/- fraction)
  41. 'scale': 0.5, # image scale (+/- gain)
  42. 'shear': 0.0} # image shear (+/- deg)
  43. print(hyp)
  44. # Overwrite hyp with hyp*.txt (optional)
  45. f = glob.glob('hyp*.txt')
  46. if f:
  47. print('Using %s' % f[0])
  48. for k, v in zip(hyp.keys(), np.loadtxt(f[0])):
  49. hyp[k] = v
  50. # Print focal loss if gamma > 0
  51. if hyp['fl_gamma']:
  52. print('Using FocalLoss(gamma=%g)' % hyp['fl_gamma'])
  53. def train(hyp):
  54. epochs = opt.epochs # 300
  55. batch_size = opt.batch_size # 64
  56. weights = opt.weights # initial training weights
  57. # Configure
  58. init_seeds(1)
  59. with open(opt.data) as f:
  60. data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
  61. train_path = data_dict['train']
  62. test_path = data_dict['val']
  63. nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
  64. # Remove previous results
  65. for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
  66. os.remove(f)
  67. # Create model
  68. model = Model(opt.cfg).to(device)
  69. assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
  70. model.names = data_dict['names']
  71. # Image sizes
  72. gs = int(max(model.stride)) # grid size (max stride)
  73. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  74. # Optimizer
  75. nbs = 64 # nominal batch size
  76. accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
  77. hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
  78. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  79. for k, v in model.named_parameters():
  80. if v.requires_grad:
  81. if '.bias' in k:
  82. pg2.append(v) # biases
  83. elif '.weight' in k and '.bn' not in k:
  84. pg1.append(v) # apply weight decay
  85. else:
  86. pg0.append(v) # all else
  87. optimizer = optim.Adam(pg0, lr=hyp['lr0']) if opt.adam else \
  88. optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  89. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  90. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  91. print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  92. del pg0, pg1, pg2
  93. # Load Model
  94. google_utils.attempt_download(weights)
  95. start_epoch, best_fitness = 0, 0.0
  96. if weights.endswith('.pt'): # pytorch format
  97. ckpt = torch.load(weights, map_location=device) # load checkpoint
  98. # load model
  99. try:
  100. ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items()
  101. if model.state_dict()[k].shape == v.shape} # to FP32, filter
  102. model.load_state_dict(ckpt['model'], strict=False)
  103. except KeyError as e:
  104. s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \
  105. "Please delete or update %s and try again, or use --weights '' to train from scatch." \
  106. % (opt.weights, opt.cfg, opt.weights, opt.weights)
  107. raise KeyError(s) from e
  108. # load optimizer
  109. if ckpt['optimizer'] is not None:
  110. optimizer.load_state_dict(ckpt['optimizer'])
  111. best_fitness = ckpt['best_fitness']
  112. # load results
  113. if ckpt.get('training_results') is not None:
  114. with open(results_file, 'w') as file:
  115. file.write(ckpt['training_results']) # write results.txt
  116. # epochs
  117. start_epoch = ckpt['epoch'] + 1
  118. if epochs < start_epoch:
  119. print('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
  120. (opt.weights, ckpt['epoch'], epochs))
  121. epochs += ckpt['epoch'] # finetune additional epochs
  122. del ckpt
  123. # Mixed precision training https://github.com/NVIDIA/apex
  124. if mixed_precision:
  125. model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
  126. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  127. lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine
  128. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  129. scheduler.last_epoch = start_epoch - 1 # do not move
  130. # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
  131. # plot_lr_scheduler(optimizer, scheduler, epochs)
  132. # Initialize distributed training
  133. if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available():
  134. dist.init_process_group(backend='nccl', # distributed backend
  135. init_method='tcp://127.0.0.1:9999', # init method
  136. world_size=1, # number of nodes
  137. rank=0) # node rank
  138. model = torch.nn.parallel.DistributedDataParallel(model)
  139. # pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html
  140. # Trainloader
  141. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
  142. hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect)
  143. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  144. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
  145. # Testloader
  146. testloader = create_dataloader(test_path, imgsz_test, batch_size, gs, opt,
  147. hyp=hyp, augment=False, cache=opt.cache_images, rect=True)[0]
  148. # Model parameters
  149. hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
  150. model.nc = nc # attach number of classes to model
  151. model.hyp = hyp # attach hyperparameters to model
  152. model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
  153. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
  154. # Class frequency
  155. labels = np.concatenate(dataset.labels, 0)
  156. c = torch.tensor(labels[:, 0]) # classes
  157. # cf = torch.bincount(c.long(), minlength=nc) + 1.
  158. # model._initialize_biases(cf.to(device))
  159. if tb_writer:
  160. plot_labels(labels)
  161. tb_writer.add_histogram('classes', c, 0)
  162. # Check anchors
  163. if not opt.noautoanchor:
  164. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  165. # Exponential moving average
  166. ema = torch_utils.ModelEMA(model)
  167. # Start training
  168. t0 = time.time()
  169. nb = len(dataloader) # number of batches
  170. n_burn = max(3 * nb, 1e3) # burn-in iterations, max(3 epochs, 1k iterations)
  171. maps = np.zeros(nc) # mAP per class
  172. results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
  173. print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
  174. print('Using %g dataloader workers' % dataloader.num_workers)
  175. print('Starting training for %g epochs...' % epochs)
  176. # torch.autograd.set_detect_anomaly(True)
  177. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  178. model.train()
  179. # Update image weights (optional)
  180. if dataset.image_weights:
  181. w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
  182. image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
  183. dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx
  184. # Update mosaic border
  185. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  186. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  187. mloss = torch.zeros(4, device=device) # mean losses
  188. print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
  189. pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
  190. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  191. ni = i + nb * epoch # number integrated batches (since train start)
  192. imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
  193. # Burn-in
  194. if ni <= n_burn:
  195. xi = [0, n_burn] # x interp
  196. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou)
  197. accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
  198. for j, x in enumerate(optimizer.param_groups):
  199. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  200. x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  201. if 'momentum' in x:
  202. x['momentum'] = np.interp(ni, xi, [0.9, hyp['momentum']])
  203. # Multi-scale
  204. if opt.multi_scale:
  205. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  206. sf = sz / max(imgs.shape[2:]) # scale factor
  207. if sf != 1:
  208. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  209. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  210. # Forward
  211. pred = model(imgs)
  212. # Loss
  213. loss, loss_items = compute_loss(pred, targets.to(device), model)
  214. if not torch.isfinite(loss):
  215. print('WARNING: non-finite loss, ending training ', loss_items)
  216. return results
  217. # Backward
  218. if mixed_precision:
  219. with amp.scale_loss(loss, optimizer) as scaled_loss:
  220. scaled_loss.backward()
  221. else:
  222. loss.backward()
  223. # Optimize
  224. if ni % accumulate == 0:
  225. optimizer.step()
  226. optimizer.zero_grad()
  227. ema.update(model)
  228. # Print
  229. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  230. mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  231. s = ('%10s' * 2 + '%10.4g' * 6) % (
  232. '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
  233. pbar.set_description(s)
  234. # Plot
  235. if ni < 3:
  236. f = 'train_batch%g.jpg' % ni # filename
  237. result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
  238. if tb_writer and result is not None:
  239. tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
  240. # tb_writer.add_graph(model, imgs) # add model to tensorboard
  241. # end batch ------------------------------------------------------------------------------------------------
  242. # Scheduler
  243. scheduler.step()
  244. # mAP
  245. ema.update_attr(model)
  246. final_epoch = epoch + 1 == epochs
  247. if not opt.notest or final_epoch: # Calculate mAP
  248. results, maps, times = test.test(opt.data,
  249. batch_size=batch_size,
  250. imgsz=imgsz_test,
  251. save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
  252. model=ema.ema,
  253. single_cls=opt.single_cls,
  254. dataloader=testloader)
  255. # Write
  256. with open(results_file, 'a') as f:
  257. f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
  258. if len(opt.name) and opt.bucket:
  259. os.system('gsutil cp results.txt gs://%s/results/results%s.txt' % (opt.bucket, opt.name))
  260. # Tensorboard
  261. if tb_writer:
  262. tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss',
  263. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/F1',
  264. 'val/giou_loss', 'val/obj_loss', 'val/cls_loss']
  265. for x, tag in zip(list(mloss[:-1]) + list(results), tags):
  266. tb_writer.add_scalar(tag, x, epoch)
  267. # Update best mAP
  268. fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1]
  269. if fi > best_fitness:
  270. best_fitness = fi
  271. # Save model
  272. save = (not opt.nosave) or (final_epoch and not opt.evolve)
  273. if save:
  274. with open(results_file, 'r') as f: # create checkpoint
  275. ckpt = {'epoch': epoch,
  276. 'best_fitness': best_fitness,
  277. 'training_results': f.read(),
  278. 'model': ema.ema.module if hasattr(model, 'module') else ema.ema,
  279. 'optimizer': None if final_epoch else optimizer.state_dict()}
  280. # Save last, best and delete
  281. torch.save(ckpt, last)
  282. if (best_fitness == fi) and not final_epoch:
  283. torch.save(ckpt, best)
  284. del ckpt
  285. # end epoch ----------------------------------------------------------------------------------------------------
  286. # end training
  287. n = opt.name
  288. if len(n):
  289. n = '_' + n if not n.isnumeric() else n
  290. fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
  291. for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'], [flast, fbest, fresults]):
  292. if os.path.exists(f1):
  293. os.rename(f1, f2) # rename
  294. ispt = f2.endswith('.pt') # is *.pt
  295. strip_optimizer(f2) if ispt else None # strip optimizer
  296. os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload
  297. if not opt.evolve:
  298. plot_results() # save as results.png
  299. print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
  300. dist.destroy_process_group() if device.type != 'cpu' and torch.cuda.device_count() > 1 else None
  301. torch.cuda.empty_cache()
  302. return results
  303. if __name__ == '__main__':
  304. check_git_status()
  305. parser = argparse.ArgumentParser()
  306. parser.add_argument('--epochs', type=int, default=300)
  307. parser.add_argument('--batch-size', type=int, default=16)
  308. parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml', help='*.cfg path')
  309. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path')
  310. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
  311. parser.add_argument('--rect', action='store_true', help='rectangular training')
  312. parser.add_argument('--resume', action='store_true', help='resume training from last.pt')
  313. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  314. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  315. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  316. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
  317. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  318. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  319. parser.add_argument('--weights', type=str, default='', help='initial weights path')
  320. parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied')
  321. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  322. parser.add_argument('--adam', action='store_true', help='use adam optimizer')
  323. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%')
  324. parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
  325. opt = parser.parse_args()
  326. opt.weights = last if opt.resume and not opt.weights else opt.weights
  327. opt.cfg = check_file(opt.cfg) # check file
  328. opt.data = check_file(opt.data) # check file
  329. print(opt)
  330. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  331. device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
  332. if device.type == 'cpu':
  333. mixed_precision = False
  334. # Train
  335. if not opt.evolve:
  336. tb_writer = SummaryWriter(comment=opt.name)
  337. print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
  338. train(hyp)
  339. # Evolve hyperparameters (optional)
  340. else:
  341. tb_writer = None
  342. opt.notest, opt.nosave = True, True # only test/save final epoch
  343. if opt.bucket:
  344. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  345. for _ in range(10): # generations to evolve
  346. if os.path.exists('evolve.txt'): # if evolve.txt exists: select best hyps and mutate
  347. # Select parent(s)
  348. parent = 'single' # parent selection method: 'single' or 'weighted'
  349. x = np.loadtxt('evolve.txt', ndmin=2)
  350. n = min(5, len(x)) # number of previous results to consider
  351. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  352. w = fitness(x) - fitness(x).min() # weights
  353. if parent == 'single' or len(x) == 1:
  354. # x = x[random.randint(0, n - 1)] # random selection
  355. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  356. elif parent == 'weighted':
  357. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  358. # Mutate
  359. mp, s = 0.9, 0.2 # mutation probability, sigma
  360. npr = np.random
  361. npr.seed(int(time.time()))
  362. g = np.array([1, 1, 1, 1, 1, 1, 1, 0, .1, 1, 0, 1, 1, 1, 1, 1, 1, 1]) # gains
  363. ng = len(g)
  364. v = np.ones(ng)
  365. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  366. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  367. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  368. hyp[k] = x[i + 7] * v[i] # mutate
  369. # Clip to limits
  370. keys = ['lr0', 'iou_t', 'momentum', 'weight_decay', 'hsv_s', 'hsv_v', 'translate', 'scale', 'fl_gamma']
  371. limits = [(1e-5, 1e-2), (0.00, 0.70), (0.60, 0.98), (0, 0.001), (0, .9), (0, .9), (0, .9), (0, .9), (0, 3)]
  372. for k, v in zip(keys, limits):
  373. hyp[k] = np.clip(hyp[k], v[0], v[1])
  374. # Train mutation
  375. results = train(hyp.copy())
  376. # Write mutation results
  377. print_mutation(hyp, results, opt.bucket)
  378. # Plot results
  379. # plot_evolution_results(hyp)