Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Auto-batch utils
  4. """
  5. from copy import deepcopy
  6. import numpy as np
  7. import torch
  8. from torch.cuda import amp
  9. from utils.general import LOGGER, colorstr
  10. from utils.torch_utils import profile
  11. def check_train_batch_size(model, imgsz=640):
  12. # Check YOLOv5 training batch size
  13. with amp.autocast():
  14. return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
  15. def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
  16. # Automatically estimate best batch size to use `fraction` of available CUDA memory
  17. # Usage:
  18. # import torch
  19. # from utils.autobatch import autobatch
  20. # model = torch.hub.load('ultralytics/yolov5', 'yolov5s', autoshape=False)
  21. # print(autobatch(model))
  22. prefix = colorstr('AutoBatch: ')
  23. LOGGER.info(f'{prefix}Computing optimal batch size for --imgsz {imgsz}')
  24. device = next(model.parameters()).device # get model device
  25. if device.type == 'cpu':
  26. LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
  27. return batch_size
  28. d = str(device).upper() # 'CUDA:0'
  29. properties = torch.cuda.get_device_properties(device) # device properties
  30. t = properties.total_memory / 1024 ** 3 # (GiB)
  31. r = torch.cuda.memory_reserved(device) / 1024 ** 3 # (GiB)
  32. a = torch.cuda.memory_allocated(device) / 1024 ** 3 # (GiB)
  33. f = t - (r + a) # free inside reserved
  34. LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
  35. batch_sizes = [1, 2, 4, 8, 16]
  36. try:
  37. img = [torch.zeros(b, 3, imgsz, imgsz) for b in batch_sizes]
  38. y = profile(img, model, n=3, device=device)
  39. except Exception as e:
  40. LOGGER.warning(f'{prefix}{e}')
  41. y = [x[2] for x in y if x] # memory [2]
  42. batch_sizes = batch_sizes[:len(y)]
  43. p = np.polyfit(batch_sizes, y, deg=1) # first degree polynomial fit
  44. b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
  45. LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%)')
  46. return b