You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

319 lines
13KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. PyTorch utils
  4. """
  5. import datetime
  6. import math
  7. import os
  8. import platform
  9. import subprocess
  10. import time
  11. from contextlib import contextmanager
  12. from copy import deepcopy
  13. from pathlib import Path
  14. import torch
  15. import torch.distributed as dist
  16. import torch.nn as nn
  17. import torch.nn.functional as F
  18. from utils.general import LOGGER
  19. try:
  20. import thop # for FLOPs computation
  21. except ImportError:
  22. thop = None
  23. @contextmanager
  24. def torch_distributed_zero_first(local_rank: int):
  25. """
  26. Decorator to make all processes in distributed training wait for each local_master to do something.
  27. """
  28. if local_rank not in [-1, 0]:
  29. dist.barrier(device_ids=[local_rank])
  30. yield
  31. if local_rank == 0:
  32. dist.barrier(device_ids=[0])
  33. def date_modified(path=__file__):
  34. # return human-readable file modification date, i.e. '2021-3-26'
  35. t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
  36. return f'{t.year}-{t.month}-{t.day}'
  37. def git_describe(path=Path(__file__).parent): # path must be a directory
  38. # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
  39. s = f'git -C {path} describe --tags --long --always'
  40. try:
  41. return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
  42. except subprocess.CalledProcessError as e:
  43. return '' # not a git repository
  44. def select_device(device='', batch_size=None, newline=True):
  45. # device = 'cpu' or '0' or '0,1,2,3'
  46. s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
  47. device = str(device).strip().lower().replace('cuda:', '') # to string, 'cuda:0' to '0'
  48. cpu = device == 'cpu'
  49. if cpu:
  50. os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
  51. elif device: # non-cpu device requested
  52. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
  53. assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
  54. cuda = not cpu and torch.cuda.is_available()
  55. if cuda:
  56. devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
  57. n = len(devices) # device count
  58. if n > 1 and batch_size: # check batch_size is divisible by device_count
  59. assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
  60. space = ' ' * (len(s) + 1)
  61. for i, d in enumerate(devices):
  62. p = torch.cuda.get_device_properties(i)
  63. s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2:.0f}MiB)\n" # bytes to MB
  64. else:
  65. s += 'CPU\n'
  66. if not newline:
  67. s = s.rstrip()
  68. LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
  69. return torch.device('cuda:0' if cuda else 'cpu')
  70. def time_sync():
  71. # pytorch-accurate time
  72. if torch.cuda.is_available():
  73. torch.cuda.synchronize()
  74. return time.time()
  75. def profile(input, ops, n=10, device=None):
  76. # YOLOv5 speed/memory/FLOPs profiler
  77. #
  78. # Usage:
  79. # input = torch.randn(16, 3, 640, 640)
  80. # m1 = lambda x: x * torch.sigmoid(x)
  81. # m2 = nn.SiLU()
  82. # profile(input, [m1, m2], n=100) # profile over 100 iterations
  83. results = []
  84. device = device or select_device()
  85. print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
  86. f"{'input':>24s}{'output':>24s}")
  87. for x in input if isinstance(input, list) else [input]:
  88. x = x.to(device)
  89. x.requires_grad = True
  90. for m in ops if isinstance(ops, list) else [ops]:
  91. m = m.to(device) if hasattr(m, 'to') else m # device
  92. m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
  93. tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
  94. try:
  95. flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
  96. except:
  97. flops = 0
  98. try:
  99. for _ in range(n):
  100. t[0] = time_sync()
  101. y = m(x)
  102. t[1] = time_sync()
  103. try:
  104. _ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
  105. t[2] = time_sync()
  106. except Exception as e: # no backward method
  107. # print(e) # for debug
  108. t[2] = float('nan')
  109. tf += (t[1] - t[0]) * 1000 / n # ms per op forward
  110. tb += (t[2] - t[1]) * 1000 / n # ms per op backward
  111. mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
  112. s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
  113. s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
  114. p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
  115. print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
  116. results.append([p, flops, mem, tf, tb, s_in, s_out])
  117. except Exception as e:
  118. print(e)
  119. results.append(None)
  120. torch.cuda.empty_cache()
  121. return results
  122. def is_parallel(model):
  123. # Returns True if model is of type DP or DDP
  124. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  125. def de_parallel(model):
  126. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  127. return model.module if is_parallel(model) else model
  128. def initialize_weights(model):
  129. for m in model.modules():
  130. t = type(m)
  131. if t is nn.Conv2d:
  132. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  133. elif t is nn.BatchNorm2d:
  134. m.eps = 1e-3
  135. m.momentum = 0.03
  136. elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
  137. m.inplace = True
  138. def find_modules(model, mclass=nn.Conv2d):
  139. # Finds layer indices matching module class 'mclass'
  140. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  141. def sparsity(model):
  142. # Return global model sparsity
  143. a, b = 0, 0
  144. for p in model.parameters():
  145. a += p.numel()
  146. b += (p == 0).sum()
  147. return b / a
  148. def prune(model, amount=0.3):
  149. # Prune model to requested global sparsity
  150. import torch.nn.utils.prune as prune
  151. print('Pruning model... ', end='')
  152. for name, m in model.named_modules():
  153. if isinstance(m, nn.Conv2d):
  154. prune.l1_unstructured(m, name='weight', amount=amount) # prune
  155. prune.remove(m, 'weight') # make permanent
  156. print(' %.3g global sparsity' % sparsity(model))
  157. def fuse_conv_and_bn(conv, bn):
  158. # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  159. fusedconv = nn.Conv2d(conv.in_channels,
  160. conv.out_channels,
  161. kernel_size=conv.kernel_size,
  162. stride=conv.stride,
  163. padding=conv.padding,
  164. groups=conv.groups,
  165. bias=True).requires_grad_(False).to(conv.weight.device)
  166. # prepare filters
  167. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  168. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  169. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
  170. # prepare spatial bias
  171. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  172. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  173. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  174. return fusedconv
  175. def model_info(model, verbose=False, img_size=640):
  176. # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
  177. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  178. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  179. if verbose:
  180. print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
  181. for i, (name, p) in enumerate(model.named_parameters()):
  182. name = name.replace('module_list.', '')
  183. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  184. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  185. try: # FLOPs
  186. from thop import profile
  187. stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
  188. img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
  189. flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
  190. img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
  191. fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPs
  192. except (ImportError, Exception):
  193. fs = ''
  194. LOGGER.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
  195. def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
  196. # scales img(bs,3,y,x) by ratio constrained to gs-multiple
  197. if ratio == 1.0:
  198. return img
  199. else:
  200. h, w = img.shape[2:]
  201. s = (int(h * ratio), int(w * ratio)) # new size
  202. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  203. if not same_shape: # pad/crop img
  204. h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
  205. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  206. def copy_attr(a, b, include=(), exclude=()):
  207. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  208. for k, v in b.__dict__.items():
  209. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  210. continue
  211. else:
  212. setattr(a, k, v)
  213. class EarlyStopping:
  214. # YOLOv5 simple early stopper
  215. def __init__(self, patience=30):
  216. self.best_fitness = 0.0 # i.e. mAP
  217. self.best_epoch = 0
  218. self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
  219. self.possible_stop = False # possible stop may occur next epoch
  220. def __call__(self, epoch, fitness):
  221. if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
  222. self.best_epoch = epoch
  223. self.best_fitness = fitness
  224. delta = epoch - self.best_epoch # epochs without improvement
  225. self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
  226. stop = delta >= self.patience # stop training if patience exceeded
  227. if stop:
  228. LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
  229. f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
  230. f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
  231. f'i.e. `python train.py --patience 300` or use `--patience 0` to disable EarlyStopping.')
  232. return stop
  233. class ModelEMA:
  234. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  235. Keep a moving average of everything in the model state_dict (parameters and buffers).
  236. This is intended to allow functionality like
  237. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  238. A smoothed version of the weights is necessary for some training schemes to perform well.
  239. This class is sensitive where it is initialized in the sequence of model init,
  240. GPU assignment and distributed training wrappers.
  241. """
  242. def __init__(self, model, decay=0.9999, updates=0):
  243. # Create EMA
  244. self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
  245. # if next(model.parameters()).device.type != 'cpu':
  246. # self.ema.half() # FP16 EMA
  247. self.updates = updates # number of EMA updates
  248. self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
  249. for p in self.ema.parameters():
  250. p.requires_grad_(False)
  251. def update(self, model):
  252. # Update EMA parameters
  253. with torch.no_grad():
  254. self.updates += 1
  255. d = self.decay(self.updates)
  256. msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
  257. for k, v in self.ema.state_dict().items():
  258. if v.dtype.is_floating_point:
  259. v *= d
  260. v += (1 - d) * msd[k].detach()
  261. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  262. # Update EMA attributes
  263. copy_attr(self.ema, model, include, exclude)