You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

317 lines
13KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. PyTorch utils
  4. """
  5. import math
  6. import os
  7. import platform
  8. import subprocess
  9. import time
  10. import warnings
  11. from contextlib import contextmanager
  12. from copy import deepcopy
  13. from pathlib import Path
  14. import torch
  15. import torch.distributed as dist
  16. import torch.nn as nn
  17. import torch.nn.functional as F
  18. from utils.general import LOGGER, file_date, git_describe
  19. try:
  20. import thop # for FLOPs computation
  21. except ImportError:
  22. thop = None
  23. # Suppress PyTorch warnings
  24. warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
  25. @contextmanager
  26. def torch_distributed_zero_first(local_rank: int):
  27. # Decorator to make all processes in distributed training wait for each local_master to do something
  28. if local_rank not in [-1, 0]:
  29. dist.barrier(device_ids=[local_rank])
  30. yield
  31. if local_rank == 0:
  32. dist.barrier(device_ids=[0])
  33. def device_count():
  34. # Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Supports Linux and Windows
  35. assert platform.system() in ('Linux', 'Windows'), 'device_count() only supported on Linux or Windows'
  36. try:
  37. cmd = 'nvidia-smi -L | wc -l' if platform.system() == 'Linux' else 'nvidia-smi -L | find /c /v ""' # Windows
  38. return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1])
  39. except Exception:
  40. return 0
  41. def select_device(device='', batch_size=0, newline=True):
  42. # device = None or 'cpu' or 0 or '0' or '0,1,2,3'
  43. s = f'YOLOv5 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
  44. device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
  45. cpu = device == 'cpu'
  46. mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
  47. if cpu or mps:
  48. os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
  49. elif device: # non-cpu device requested
  50. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
  51. assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
  52. f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
  53. if not cpu and torch.cuda.is_available(): # prefer GPU if available
  54. devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
  55. n = len(devices) # device count
  56. if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
  57. assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
  58. space = ' ' * (len(s) + 1)
  59. for i, d in enumerate(devices):
  60. p = torch.cuda.get_device_properties(i)
  61. s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
  62. arg = 'cuda:0'
  63. elif not cpu and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available(): # prefer MPS if available
  64. s += 'MPS\n'
  65. arg = 'mps'
  66. else: # revert to CPU
  67. s += 'CPU\n'
  68. arg = 'cpu'
  69. if not newline:
  70. s = s.rstrip()
  71. LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
  72. return torch.device(arg)
  73. def time_sync():
  74. # PyTorch-accurate time
  75. if torch.cuda.is_available():
  76. torch.cuda.synchronize()
  77. return time.time()
  78. def profile(input, ops, n=10, device=None):
  79. # YOLOv5 speed/memory/FLOPs profiler
  80. #
  81. # Usage:
  82. # input = torch.randn(16, 3, 640, 640)
  83. # m1 = lambda x: x * torch.sigmoid(x)
  84. # m2 = nn.SiLU()
  85. # profile(input, [m1, m2], n=100) # profile over 100 iterations
  86. results = []
  87. if not isinstance(device, torch.device):
  88. device = select_device(device)
  89. print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
  90. f"{'input':>24s}{'output':>24s}")
  91. for x in input if isinstance(input, list) else [input]:
  92. x = x.to(device)
  93. x.requires_grad = True
  94. for m in ops if isinstance(ops, list) else [ops]:
  95. m = m.to(device) if hasattr(m, 'to') else m # device
  96. m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
  97. tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
  98. try:
  99. flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
  100. except Exception:
  101. flops = 0
  102. try:
  103. for _ in range(n):
  104. t[0] = time_sync()
  105. y = m(x)
  106. t[1] = time_sync()
  107. try:
  108. _ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
  109. t[2] = time_sync()
  110. except Exception: # no backward method
  111. # print(e) # for debug
  112. t[2] = float('nan')
  113. tf += (t[1] - t[0]) * 1000 / n # ms per op forward
  114. tb += (t[2] - t[1]) * 1000 / n # ms per op backward
  115. mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
  116. s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
  117. p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
  118. print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
  119. results.append([p, flops, mem, tf, tb, s_in, s_out])
  120. except Exception as e:
  121. print(e)
  122. results.append(None)
  123. torch.cuda.empty_cache()
  124. return results
  125. def is_parallel(model):
  126. # Returns True if model is of type DP or DDP
  127. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  128. def de_parallel(model):
  129. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  130. return model.module if is_parallel(model) else model
  131. def initialize_weights(model):
  132. for m in model.modules():
  133. t = type(m)
  134. if t is nn.Conv2d:
  135. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  136. elif t is nn.BatchNorm2d:
  137. m.eps = 1e-3
  138. m.momentum = 0.03
  139. elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
  140. m.inplace = True
  141. def find_modules(model, mclass=nn.Conv2d):
  142. # Finds layer indices matching module class 'mclass'
  143. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  144. def sparsity(model):
  145. # Return global model sparsity
  146. a, b = 0, 0
  147. for p in model.parameters():
  148. a += p.numel()
  149. b += (p == 0).sum()
  150. return b / a
  151. def prune(model, amount=0.3):
  152. # Prune model to requested global sparsity
  153. import torch.nn.utils.prune as prune
  154. print('Pruning model... ', end='')
  155. for name, m in model.named_modules():
  156. if isinstance(m, nn.Conv2d):
  157. prune.l1_unstructured(m, name='weight', amount=amount) # prune
  158. prune.remove(m, 'weight') # make permanent
  159. print(' %.3g global sparsity' % sparsity(model))
  160. def fuse_conv_and_bn(conv, bn):
  161. # Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  162. fusedconv = nn.Conv2d(conv.in_channels,
  163. conv.out_channels,
  164. kernel_size=conv.kernel_size,
  165. stride=conv.stride,
  166. padding=conv.padding,
  167. groups=conv.groups,
  168. bias=True).requires_grad_(False).to(conv.weight.device)
  169. # Prepare filters
  170. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  171. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  172. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
  173. # Prepare spatial bias
  174. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  175. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  176. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  177. return fusedconv
  178. def model_info(model, verbose=False, img_size=640):
  179. # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
  180. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  181. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  182. if verbose:
  183. print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
  184. for i, (name, p) in enumerate(model.named_parameters()):
  185. name = name.replace('module_list.', '')
  186. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  187. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  188. try: # FLOPs
  189. from thop import profile
  190. stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
  191. img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
  192. flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
  193. img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
  194. fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPs
  195. except Exception:
  196. fs = ''
  197. name = Path(model.yaml_file).stem.replace('yolov5', 'YOLOv5') if hasattr(model, 'yaml_file') else 'Model'
  198. LOGGER.info(f"{name} summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
  199. def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
  200. # Scales img(bs,3,y,x) by ratio constrained to gs-multiple
  201. if ratio == 1.0:
  202. return img
  203. h, w = img.shape[2:]
  204. s = (int(h * ratio), int(w * ratio)) # new size
  205. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  206. if not same_shape: # pad/crop img
  207. h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
  208. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  209. def copy_attr(a, b, include=(), exclude=()):
  210. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  211. for k, v in b.__dict__.items():
  212. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  213. continue
  214. else:
  215. setattr(a, k, v)
  216. class EarlyStopping:
  217. # YOLOv5 simple early stopper
  218. def __init__(self, patience=30):
  219. self.best_fitness = 0.0 # i.e. mAP
  220. self.best_epoch = 0
  221. self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
  222. self.possible_stop = False # possible stop may occur next epoch
  223. def __call__(self, epoch, fitness):
  224. if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
  225. self.best_epoch = epoch
  226. self.best_fitness = fitness
  227. delta = epoch - self.best_epoch # epochs without improvement
  228. self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
  229. stop = delta >= self.patience # stop training if patience exceeded
  230. if stop:
  231. LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
  232. f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
  233. f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
  234. f'i.e. `python train.py --patience 300` or use `--patience 0` to disable EarlyStopping.')
  235. return stop
  236. class ModelEMA:
  237. """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
  238. Keeps a moving average of everything in the model state_dict (parameters and buffers)
  239. For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  240. """
  241. def __init__(self, model, decay=0.9999, tau=2000, updates=0):
  242. # Create EMA
  243. self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
  244. # if next(model.parameters()).device.type != 'cpu':
  245. # self.ema.half() # FP16 EMA
  246. self.updates = updates # number of EMA updates
  247. self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
  248. for p in self.ema.parameters():
  249. p.requires_grad_(False)
  250. def update(self, model):
  251. # Update EMA parameters
  252. with torch.no_grad():
  253. self.updates += 1
  254. d = self.decay(self.updates)
  255. msd = de_parallel(model).state_dict() # model state_dict
  256. for k, v in self.ema.state_dict().items():
  257. if v.dtype.is_floating_point:
  258. v *= d
  259. v += (1 - d) * msd[k].detach()
  260. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  261. # Update EMA attributes
  262. copy_attr(self.ema, model, include, exclude)