You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

295 line
12KB

  1. # PyTorch utils
  2. import logging
  3. import math
  4. import os
  5. import subprocess
  6. import time
  7. from contextlib import contextmanager
  8. from copy import deepcopy
  9. from pathlib import Path
  10. import torch
  11. import torch.backends.cudnn as cudnn
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. import torchvision
  15. try:
  16. import thop # for FLOPS computation
  17. except ImportError:
  18. thop = None
  19. logger = logging.getLogger(__name__)
  20. @contextmanager
  21. def torch_distributed_zero_first(local_rank: int):
  22. """
  23. Decorator to make all processes in distributed training wait for each local_master to do something.
  24. """
  25. if local_rank not in [-1, 0]:
  26. torch.distributed.barrier()
  27. yield
  28. if local_rank == 0:
  29. torch.distributed.barrier()
  30. def init_torch_seeds(seed=0):
  31. # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
  32. torch.manual_seed(seed)
  33. if seed == 0: # slower, more reproducible
  34. cudnn.benchmark, cudnn.deterministic = False, True
  35. else: # faster, less reproducible
  36. cudnn.benchmark, cudnn.deterministic = True, False
  37. def git_describe():
  38. # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
  39. if Path('.git').exists():
  40. return subprocess.check_output('git describe --tags --long --always', shell=True).decode('utf-8')[:-1]
  41. else:
  42. return ''
  43. def select_device(device='', batch_size=None):
  44. # device = 'cpu' or '0' or '0,1,2,3'
  45. s = f'YOLOv5 {git_describe()} torch {torch.__version__} ' # string
  46. cpu = device.lower() == 'cpu'
  47. if cpu:
  48. os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
  49. elif device: # non-cpu device requested
  50. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
  51. assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
  52. cuda = not cpu and torch.cuda.is_available()
  53. if cuda:
  54. n = torch.cuda.device_count()
  55. if n > 1 and batch_size: # check that batch_size is compatible with device_count
  56. assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
  57. space = ' ' * len(s)
  58. for i, d in enumerate(device.split(',') if device else range(n)):
  59. p = torch.cuda.get_device_properties(i)
  60. s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
  61. else:
  62. s += 'CPU\n'
  63. logger.info(s) # skip a line
  64. return torch.device('cuda:0' if cuda else 'cpu')
  65. def time_synchronized():
  66. # pytorch-accurate time
  67. if torch.cuda.is_available():
  68. torch.cuda.synchronize()
  69. return time.time()
  70. def profile(x, ops, n=100, device=None):
  71. # profile a pytorch module or list of modules. Example usage:
  72. # x = torch.randn(16, 3, 640, 640) # input
  73. # m1 = lambda x: x * torch.sigmoid(x)
  74. # m2 = nn.SiLU()
  75. # profile(x, [m1, m2], n=100) # profile speed over 100 iterations
  76. device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  77. x = x.to(device)
  78. x.requires_grad = True
  79. print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
  80. print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
  81. for m in ops if isinstance(ops, list) else [ops]:
  82. m = m.to(device) if hasattr(m, 'to') else m # device
  83. m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type
  84. dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
  85. try:
  86. flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS
  87. except:
  88. flops = 0
  89. for _ in range(n):
  90. t[0] = time_synchronized()
  91. y = m(x)
  92. t[1] = time_synchronized()
  93. try:
  94. _ = y.sum().backward()
  95. t[2] = time_synchronized()
  96. except: # no backward method
  97. t[2] = float('nan')
  98. dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
  99. dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
  100. s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
  101. s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
  102. p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
  103. print(f'{p:12}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}')
  104. def is_parallel(model):
  105. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  106. def intersect_dicts(da, db, exclude=()):
  107. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  108. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  109. def initialize_weights(model):
  110. for m in model.modules():
  111. t = type(m)
  112. if t is nn.Conv2d:
  113. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  114. elif t is nn.BatchNorm2d:
  115. m.eps = 1e-3
  116. m.momentum = 0.03
  117. elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
  118. m.inplace = True
  119. def find_modules(model, mclass=nn.Conv2d):
  120. # Finds layer indices matching module class 'mclass'
  121. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  122. def sparsity(model):
  123. # Return global model sparsity
  124. a, b = 0., 0.
  125. for p in model.parameters():
  126. a += p.numel()
  127. b += (p == 0).sum()
  128. return b / a
  129. def prune(model, amount=0.3):
  130. # Prune model to requested global sparsity
  131. import torch.nn.utils.prune as prune
  132. print('Pruning model... ', end='')
  133. for name, m in model.named_modules():
  134. if isinstance(m, nn.Conv2d):
  135. prune.l1_unstructured(m, name='weight', amount=amount) # prune
  136. prune.remove(m, 'weight') # make permanent
  137. print(' %.3g global sparsity' % sparsity(model))
  138. def fuse_conv_and_bn(conv, bn):
  139. # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  140. fusedconv = nn.Conv2d(conv.in_channels,
  141. conv.out_channels,
  142. kernel_size=conv.kernel_size,
  143. stride=conv.stride,
  144. padding=conv.padding,
  145. groups=conv.groups,
  146. bias=True).requires_grad_(False).to(conv.weight.device)
  147. # prepare filters
  148. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  149. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  150. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
  151. # prepare spatial bias
  152. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  153. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  154. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  155. return fusedconv
  156. def model_info(model, verbose=False, img_size=640):
  157. # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
  158. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  159. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  160. if verbose:
  161. print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
  162. for i, (name, p) in enumerate(model.named_parameters()):
  163. name = name.replace('module_list.', '')
  164. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  165. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  166. try: # FLOPS
  167. from thop import profile
  168. stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
  169. img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
  170. flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS
  171. img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
  172. fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPS
  173. except (ImportError, Exception):
  174. fs = ''
  175. logger.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
  176. def load_classifier(name='resnet101', n=2):
  177. # Loads a pretrained model reshaped to n-class output
  178. model = torchvision.models.__dict__[name](pretrained=True)
  179. # ResNet model properties
  180. # input_size = [3, 224, 224]
  181. # input_space = 'RGB'
  182. # input_range = [0, 1]
  183. # mean = [0.485, 0.456, 0.406]
  184. # std = [0.229, 0.224, 0.225]
  185. # Reshape output to n classes
  186. filters = model.fc.weight.shape[1]
  187. model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
  188. model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
  189. model.fc.out_features = n
  190. return model
  191. def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
  192. # scales img(bs,3,y,x) by ratio constrained to gs-multiple
  193. if ratio == 1.0:
  194. return img
  195. else:
  196. h, w = img.shape[2:]
  197. s = (int(h * ratio), int(w * ratio)) # new size
  198. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  199. if not same_shape: # pad/crop img
  200. h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
  201. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  202. def copy_attr(a, b, include=(), exclude=()):
  203. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  204. for k, v in b.__dict__.items():
  205. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  206. continue
  207. else:
  208. setattr(a, k, v)
  209. class ModelEMA:
  210. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  211. Keep a moving average of everything in the model state_dict (parameters and buffers).
  212. This is intended to allow functionality like
  213. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  214. A smoothed version of the weights is necessary for some training schemes to perform well.
  215. This class is sensitive where it is initialized in the sequence of model init,
  216. GPU assignment and distributed training wrappers.
  217. """
  218. def __init__(self, model, decay=0.9999, updates=0):
  219. # Create EMA
  220. self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
  221. # if next(model.parameters()).device.type != 'cpu':
  222. # self.ema.half() # FP16 EMA
  223. self.updates = updates # number of EMA updates
  224. self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
  225. for p in self.ema.parameters():
  226. p.requires_grad_(False)
  227. def update(self, model):
  228. # Update EMA parameters
  229. with torch.no_grad():
  230. self.updates += 1
  231. d = self.decay(self.updates)
  232. msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
  233. for k, v in self.ema.state_dict().items():
  234. if v.dtype.is_floating_point:
  235. v *= d
  236. v += (1. - d) * msd[k].detach()
  237. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  238. # Update EMA attributes
  239. copy_attr(self.ema, model, include, exclude)