You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

312 line
12KB

  1. # YOLOv5 PyTorch utils
  2. import datetime
  3. import logging
  4. import math
  5. import os
  6. import platform
  7. import subprocess
  8. import time
  9. from contextlib import contextmanager
  10. from copy import deepcopy
  11. from pathlib import Path
  12. import torch
  13. import torch.backends.cudnn as cudnn
  14. import torch.distributed as dist
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. import torchvision
  18. try:
  19. import thop # for FLOPs computation
  20. except ImportError:
  21. thop = None
  22. logger = logging.getLogger(__name__)
  23. @contextmanager
  24. def torch_distributed_zero_first(local_rank: int):
  25. """
  26. Decorator to make all processes in distributed training wait for each local_master to do something.
  27. """
  28. if local_rank not in [-1, 0]:
  29. dist.barrier()
  30. yield
  31. if local_rank == 0:
  32. dist.barrier()
  33. def init_torch_seeds(seed=0):
  34. # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
  35. torch.manual_seed(seed)
  36. if seed == 0: # slower, more reproducible
  37. cudnn.benchmark, cudnn.deterministic = False, True
  38. else: # faster, less reproducible
  39. cudnn.benchmark, cudnn.deterministic = True, False
  40. def date_modified(path=__file__):
  41. # return human-readable file modification date, i.e. '2021-3-26'
  42. t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
  43. return f'{t.year}-{t.month}-{t.day}'
  44. def git_describe(path=Path(__file__).parent): # path must be a directory
  45. # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
  46. s = f'git -C {path} describe --tags --long --always'
  47. try:
  48. return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
  49. except subprocess.CalledProcessError as e:
  50. return '' # not a git repository
  51. def select_device(device='', batch_size=None):
  52. # device = 'cpu' or '0' or '0,1,2,3'
  53. s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
  54. cpu = device.lower() == 'cpu'
  55. if cpu:
  56. os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
  57. elif device: # non-cpu device requested
  58. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
  59. assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
  60. cuda = not cpu and torch.cuda.is_available()
  61. if cuda:
  62. devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
  63. n = len(devices) # device count
  64. if n > 1 and batch_size: # check batch_size is divisible by device_count
  65. assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
  66. space = ' ' * (len(s) + 1)
  67. for i, d in enumerate(devices):
  68. p = torch.cuda.get_device_properties(i)
  69. s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
  70. else:
  71. s += 'CPU\n'
  72. logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
  73. return torch.device('cuda:0' if cuda else 'cpu')
  74. def time_synchronized():
  75. # pytorch-accurate time
  76. if torch.cuda.is_available():
  77. torch.cuda.synchronize()
  78. return time.time()
  79. def profile(x, ops, n=100, device=None):
  80. # profile a pytorch module or list of modules. Example usage:
  81. # x = torch.randn(16, 3, 640, 640) # input
  82. # m1 = lambda x: x * torch.sigmoid(x)
  83. # m2 = nn.SiLU()
  84. # profile(x, [m1, m2], n=100) # profile speed over 100 iterations
  85. device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  86. x = x.to(device)
  87. x.requires_grad = True
  88. print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
  89. print(f"\n{'Params':>12s}{'GFLOPs':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
  90. for m in ops if isinstance(ops, list) else [ops]:
  91. m = m.to(device) if hasattr(m, 'to') else m # device
  92. m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type
  93. dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
  94. try:
  95. flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
  96. except:
  97. flops = 0
  98. for _ in range(n):
  99. t[0] = time_synchronized()
  100. y = m(x)
  101. t[1] = time_synchronized()
  102. try:
  103. _ = y.sum().backward()
  104. t[2] = time_synchronized()
  105. except: # no backward method
  106. t[2] = float('nan')
  107. dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
  108. dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
  109. s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
  110. s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
  111. p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
  112. print(f'{p:12}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}')
  113. def is_parallel(model):
  114. # Returns True if model is of type DP or DDP
  115. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  116. def de_parallel(model):
  117. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  118. return model.module if is_parallel(model) else model
  119. def intersect_dicts(da, db, exclude=()):
  120. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  121. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  122. def initialize_weights(model):
  123. for m in model.modules():
  124. t = type(m)
  125. if t is nn.Conv2d:
  126. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  127. elif t is nn.BatchNorm2d:
  128. m.eps = 1e-3
  129. m.momentum = 0.03
  130. elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
  131. m.inplace = True
  132. def find_modules(model, mclass=nn.Conv2d):
  133. # Finds layer indices matching module class 'mclass'
  134. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  135. def sparsity(model):
  136. # Return global model sparsity
  137. a, b = 0., 0.
  138. for p in model.parameters():
  139. a += p.numel()
  140. b += (p == 0).sum()
  141. return b / a
  142. def prune(model, amount=0.3):
  143. # Prune model to requested global sparsity
  144. import torch.nn.utils.prune as prune
  145. print('Pruning model... ', end='')
  146. for name, m in model.named_modules():
  147. if isinstance(m, nn.Conv2d):
  148. prune.l1_unstructured(m, name='weight', amount=amount) # prune
  149. prune.remove(m, 'weight') # make permanent
  150. print(' %.3g global sparsity' % sparsity(model))
  151. def fuse_conv_and_bn(conv, bn):
  152. # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  153. fusedconv = nn.Conv2d(conv.in_channels,
  154. conv.out_channels,
  155. kernel_size=conv.kernel_size,
  156. stride=conv.stride,
  157. padding=conv.padding,
  158. groups=conv.groups,
  159. bias=True).requires_grad_(False).to(conv.weight.device)
  160. # prepare filters
  161. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  162. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  163. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
  164. # prepare spatial bias
  165. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  166. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  167. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  168. return fusedconv
  169. def model_info(model, verbose=False, img_size=640):
  170. # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
  171. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  172. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  173. if verbose:
  174. print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
  175. for i, (name, p) in enumerate(model.named_parameters()):
  176. name = name.replace('module_list.', '')
  177. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  178. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  179. try: # FLOPs
  180. from thop import profile
  181. stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
  182. img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
  183. flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
  184. img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
  185. fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPs
  186. except (ImportError, Exception):
  187. fs = ''
  188. logger.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
  189. def load_classifier(name='resnet101', n=2):
  190. # Loads a pretrained model reshaped to n-class output
  191. model = torchvision.models.__dict__[name](pretrained=True)
  192. # ResNet model properties
  193. # input_size = [3, 224, 224]
  194. # input_space = 'RGB'
  195. # input_range = [0, 1]
  196. # mean = [0.485, 0.456, 0.406]
  197. # std = [0.229, 0.224, 0.225]
  198. # Reshape output to n classes
  199. filters = model.fc.weight.shape[1]
  200. model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
  201. model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
  202. model.fc.out_features = n
  203. return model
  204. def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
  205. # scales img(bs,3,y,x) by ratio constrained to gs-multiple
  206. if ratio == 1.0:
  207. return img
  208. else:
  209. h, w = img.shape[2:]
  210. s = (int(h * ratio), int(w * ratio)) # new size
  211. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  212. if not same_shape: # pad/crop img
  213. h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
  214. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  215. def copy_attr(a, b, include=(), exclude=()):
  216. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  217. for k, v in b.__dict__.items():
  218. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  219. continue
  220. else:
  221. setattr(a, k, v)
  222. class ModelEMA:
  223. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  224. Keep a moving average of everything in the model state_dict (parameters and buffers).
  225. This is intended to allow functionality like
  226. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  227. A smoothed version of the weights is necessary for some training schemes to perform well.
  228. This class is sensitive where it is initialized in the sequence of model init,
  229. GPU assignment and distributed training wrappers.
  230. """
  231. def __init__(self, model, decay=0.9999, updates=0):
  232. # Create EMA
  233. self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
  234. # if next(model.parameters()).device.type != 'cpu':
  235. # self.ema.half() # FP16 EMA
  236. self.updates = updates # number of EMA updates
  237. self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
  238. for p in self.ema.parameters():
  239. p.requires_grad_(False)
  240. def update(self, model):
  241. # Update EMA parameters
  242. with torch.no_grad():
  243. self.updates += 1
  244. d = self.decay(self.updates)
  245. msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
  246. for k, v in self.ema.state_dict().items():
  247. if v.dtype.is_floating_point:
  248. v *= d
  249. v += (1. - d) * msd[k].detach()
  250. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  251. # Update EMA attributes
  252. copy_attr(self.ema, model, include, exclude)