You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

311 line
12KB

  1. # YOLOv5 PyTorch utils
  2. import datetime
  3. import logging
  4. import math
  5. import os
  6. import platform
  7. import subprocess
  8. import time
  9. from contextlib import contextmanager
  10. from copy import deepcopy
  11. from pathlib import Path
  12. import torch
  13. import torch.backends.cudnn as cudnn
  14. import torch.nn as nn
  15. import torch.nn.functional as F
  16. import torchvision
  17. try:
  18. import thop # for FLOPs computation
  19. except ImportError:
  20. thop = None
  21. logger = logging.getLogger(__name__)
  22. @contextmanager
  23. def torch_distributed_zero_first(local_rank: int):
  24. """
  25. Decorator to make all processes in distributed training wait for each local_master to do something.
  26. """
  27. if local_rank not in [-1, 0]:
  28. torch.distributed.barrier()
  29. yield
  30. if local_rank == 0:
  31. torch.distributed.barrier()
  32. def init_torch_seeds(seed=0):
  33. # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
  34. torch.manual_seed(seed)
  35. if seed == 0: # slower, more reproducible
  36. cudnn.benchmark, cudnn.deterministic = False, True
  37. else: # faster, less reproducible
  38. cudnn.benchmark, cudnn.deterministic = True, False
  39. def date_modified(path=__file__):
  40. # return human-readable file modification date, i.e. '2021-3-26'
  41. t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
  42. return f'{t.year}-{t.month}-{t.day}'
  43. def git_describe(path=Path(__file__).parent): # path must be a directory
  44. # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
  45. s = f'git -C {path} describe --tags --long --always'
  46. try:
  47. return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
  48. except subprocess.CalledProcessError as e:
  49. return '' # not a git repository
  50. def select_device(device='', batch_size=None):
  51. # device = 'cpu' or '0' or '0,1,2,3'
  52. s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
  53. cpu = device.lower() == 'cpu'
  54. if cpu:
  55. os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
  56. elif device: # non-cpu device requested
  57. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
  58. assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
  59. cuda = not cpu and torch.cuda.is_available()
  60. if cuda:
  61. devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
  62. n = len(devices) # device count
  63. if n > 1 and batch_size: # check batch_size is divisible by device_count
  64. assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
  65. space = ' ' * (len(s) + 1)
  66. for i, d in enumerate(devices):
  67. p = torch.cuda.get_device_properties(i)
  68. s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
  69. else:
  70. s += 'CPU\n'
  71. logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
  72. return torch.device('cuda:0' if cuda else 'cpu')
  73. def time_synchronized():
  74. # pytorch-accurate time
  75. if torch.cuda.is_available():
  76. torch.cuda.synchronize()
  77. return time.time()
  78. def profile(x, ops, n=100, device=None):
  79. # profile a pytorch module or list of modules. Example usage:
  80. # x = torch.randn(16, 3, 640, 640) # input
  81. # m1 = lambda x: x * torch.sigmoid(x)
  82. # m2 = nn.SiLU()
  83. # profile(x, [m1, m2], n=100) # profile speed over 100 iterations
  84. device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  85. x = x.to(device)
  86. x.requires_grad = True
  87. print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
  88. print(f"\n{'Params':>12s}{'GFLOPs':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
  89. for m in ops if isinstance(ops, list) else [ops]:
  90. m = m.to(device) if hasattr(m, 'to') else m # device
  91. m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type
  92. dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
  93. try:
  94. flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
  95. except:
  96. flops = 0
  97. for _ in range(n):
  98. t[0] = time_synchronized()
  99. y = m(x)
  100. t[1] = time_synchronized()
  101. try:
  102. _ = y.sum().backward()
  103. t[2] = time_synchronized()
  104. except: # no backward method
  105. t[2] = float('nan')
  106. dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
  107. dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
  108. s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
  109. s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
  110. p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
  111. print(f'{p:12}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}')
  112. def is_parallel(model):
  113. # Returns True if model is of type DP or DDP
  114. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  115. def de_parallel(model):
  116. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  117. return model.module if is_parallel(model) else model
  118. def intersect_dicts(da, db, exclude=()):
  119. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  120. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  121. def initialize_weights(model):
  122. for m in model.modules():
  123. t = type(m)
  124. if t is nn.Conv2d:
  125. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  126. elif t is nn.BatchNorm2d:
  127. m.eps = 1e-3
  128. m.momentum = 0.03
  129. elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
  130. m.inplace = True
  131. def find_modules(model, mclass=nn.Conv2d):
  132. # Finds layer indices matching module class 'mclass'
  133. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  134. def sparsity(model):
  135. # Return global model sparsity
  136. a, b = 0., 0.
  137. for p in model.parameters():
  138. a += p.numel()
  139. b += (p == 0).sum()
  140. return b / a
  141. def prune(model, amount=0.3):
  142. # Prune model to requested global sparsity
  143. import torch.nn.utils.prune as prune
  144. print('Pruning model... ', end='')
  145. for name, m in model.named_modules():
  146. if isinstance(m, nn.Conv2d):
  147. prune.l1_unstructured(m, name='weight', amount=amount) # prune
  148. prune.remove(m, 'weight') # make permanent
  149. print(' %.3g global sparsity' % sparsity(model))
  150. def fuse_conv_and_bn(conv, bn):
  151. # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  152. fusedconv = nn.Conv2d(conv.in_channels,
  153. conv.out_channels,
  154. kernel_size=conv.kernel_size,
  155. stride=conv.stride,
  156. padding=conv.padding,
  157. groups=conv.groups,
  158. bias=True).requires_grad_(False).to(conv.weight.device)
  159. # prepare filters
  160. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  161. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  162. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
  163. # prepare spatial bias
  164. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  165. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  166. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  167. return fusedconv
  168. def model_info(model, verbose=False, img_size=640):
  169. # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
  170. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  171. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  172. if verbose:
  173. print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
  174. for i, (name, p) in enumerate(model.named_parameters()):
  175. name = name.replace('module_list.', '')
  176. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  177. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  178. try: # FLOPs
  179. from thop import profile
  180. stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
  181. img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
  182. flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
  183. img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
  184. fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPs
  185. except (ImportError, Exception):
  186. fs = ''
  187. logger.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
  188. def load_classifier(name='resnet101', n=2):
  189. # Loads a pretrained model reshaped to n-class output
  190. model = torchvision.models.__dict__[name](pretrained=True)
  191. # ResNet model properties
  192. # input_size = [3, 224, 224]
  193. # input_space = 'RGB'
  194. # input_range = [0, 1]
  195. # mean = [0.485, 0.456, 0.406]
  196. # std = [0.229, 0.224, 0.225]
  197. # Reshape output to n classes
  198. filters = model.fc.weight.shape[1]
  199. model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
  200. model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
  201. model.fc.out_features = n
  202. return model
  203. def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
  204. # scales img(bs,3,y,x) by ratio constrained to gs-multiple
  205. if ratio == 1.0:
  206. return img
  207. else:
  208. h, w = img.shape[2:]
  209. s = (int(h * ratio), int(w * ratio)) # new size
  210. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  211. if not same_shape: # pad/crop img
  212. h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
  213. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  214. def copy_attr(a, b, include=(), exclude=()):
  215. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  216. for k, v in b.__dict__.items():
  217. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  218. continue
  219. else:
  220. setattr(a, k, v)
  221. class ModelEMA:
  222. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  223. Keep a moving average of everything in the model state_dict (parameters and buffers).
  224. This is intended to allow functionality like
  225. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  226. A smoothed version of the weights is necessary for some training schemes to perform well.
  227. This class is sensitive where it is initialized in the sequence of model init,
  228. GPU assignment and distributed training wrappers.
  229. """
  230. def __init__(self, model, decay=0.9999, updates=0):
  231. # Create EMA
  232. self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
  233. # if next(model.parameters()).device.type != 'cpu':
  234. # self.ema.half() # FP16 EMA
  235. self.updates = updates # number of EMA updates
  236. self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
  237. for p in self.ema.parameters():
  238. p.requires_grad_(False)
  239. def update(self, model):
  240. # Update EMA parameters
  241. with torch.no_grad():
  242. self.updates += 1
  243. d = self.decay(self.updates)
  244. msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
  245. for k, v in self.ema.state_dict().items():
  246. if v.dtype.is_floating_point:
  247. v *= d
  248. v += (1. - d) * msd[k].detach()
  249. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  250. # Update EMA attributes
  251. copy_attr(self.ema, model, include, exclude)