You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

316 lines
13KB

  1. # YOLOv5 PyTorch utils
  2. import datetime
  3. import math
  4. import os
  5. import platform
  6. import subprocess
  7. import time
  8. from contextlib import contextmanager
  9. from copy import deepcopy
  10. from pathlib import Path
  11. from loguru import logger
  12. import torch
  13. import torch.backends.cudnn as cudnn
  14. import torch.nn as nn
  15. import torch.nn.functional as F
  16. import torchvision
  17. try:
  18. import thop # for FLOPS computation
  19. except ImportError:
  20. thop = None
  21. @contextmanager
  22. def torch_distributed_zero_first(local_rank: int):
  23. """
  24. Decorator to make all processes in distributed training wait for each local_master to do something.
  25. """
  26. if local_rank not in [-1, 0]:
  27. torch.distributed.barrier()
  28. yield
  29. if local_rank == 0:
  30. torch.distributed.barrier()
  31. def init_torch_seeds(seed=0):
  32. # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
  33. torch.manual_seed(seed)
  34. if seed == 0: # slower, more reproducible
  35. cudnn.benchmark, cudnn.deterministic = False, True
  36. else: # faster, less reproducible
  37. cudnn.benchmark, cudnn.deterministic = True, False
  38. # 文件最后修改时间
  39. def date_modified(path=__file__):
  40. # return human-readable file modification date, i.e. '2021-3-26'
  41. t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
  42. return f'{t.year}-{t.month}-{t.day}'
  43. def git_describe(path=Path(__file__).parent): # path must be a directory
  44. # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
  45. # -C 指定GIT仓库的路径
  46. # describe: 命令名,表示获取最近的Git标签信息。
  47. # tags: 选项,表示只考虑标签。
  48. # long: 选项,表示使用完整的Git SHA-1哈希值来描述提交。
  49. # always: 选项,表示如果没有Git标签,则使用Git哈希值来描述提交。
  50. s = f'git -C {path} describe --tags --long --always'
  51. try:
  52. return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
  53. except subprocess.CalledProcessError as e:
  54. return ''
  55. def select_device(device='0'):
  56. logger.info("当前torch版本: {}", torch.__version__)
  57. # 设置环境变量
  58. os.environ['CUDA_VISIBLE_DEVICES'] = device
  59. assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested'
  60. return torch.device('cuda:%s' % device)
  61. # def select_device(device='', batch_size=None):
  62. # # device = 'cpu' or '0' or '0,1,2,3'
  63. # s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
  64. # cpu = device.lower() == 'cpu'
  65. # if cpu:
  66. # os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
  67. # elif device: # non-cpu device requested
  68. # os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
  69. # assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
  70. #
  71. # cuda = not cpu and torch.cuda.is_available()
  72. # if cuda:
  73. # n = torch.cuda.device_count()
  74. # if n > 1 and batch_size: # check that batch_size is compatible with device_count
  75. # assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
  76. # space = ' ' * len(s)
  77. # for i, d in enumerate(device.split(',') if device else range(n)):
  78. # p = torch.cuda.get_device_properties(i)
  79. # s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
  80. # else:
  81. # s += 'CPU\n'
  82. # logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
  83. # return torch.device('cuda:0' if cuda else 'cpu')
  84. def time_synchronized():
  85. # pytorch-accurate time
  86. if torch.cuda.is_available():
  87. torch.cuda.synchronize()
  88. return time.time()
  89. def profile(x, ops, n=100, device=None):
  90. # profile a pytorch module or list of modules. Example usage:
  91. # x = torch.randn(16, 3, 640, 640) # input
  92. # m1 = lambda x: x * torch.sigmoid(x)
  93. # m2 = nn.SiLU()
  94. # profile(x, [m1, m2], n=100) # profile speed over 100 iterations
  95. device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  96. x = x.to(device)
  97. x.requires_grad = True
  98. print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
  99. print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
  100. for m in ops if isinstance(ops, list) else [ops]:
  101. m = m.to(device) if hasattr(m, 'to') else m # device
  102. m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type
  103. dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
  104. try:
  105. flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS
  106. except:
  107. flops = 0
  108. for _ in range(n):
  109. t[0] = time_synchronized()
  110. y = m(x)
  111. t[1] = time_synchronized()
  112. try:
  113. _ = y.sum().backward()
  114. t[2] = time_synchronized()
  115. except: # no backward method
  116. t[2] = float('nan')
  117. dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
  118. dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
  119. s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
  120. s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
  121. p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
  122. print(f'{p:12}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}')
  123. def is_parallel(model):
  124. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  125. def intersect_dicts(da, db, exclude=()):
  126. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  127. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  128. def initialize_weights(model):
  129. for m in model.modules():
  130. t = type(m)
  131. if t is nn.Conv2d:
  132. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  133. elif t is nn.BatchNorm2d:
  134. m.eps = 1e-3
  135. m.momentum = 0.03
  136. elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
  137. m.inplace = True
  138. def find_modules(model, mclass=nn.Conv2d):
  139. # Finds layer indices matching module class 'mclass'
  140. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  141. def sparsity(model):
  142. # Return global model sparsity
  143. a, b = 0., 0.
  144. for p in model.parameters():
  145. a += p.numel()
  146. b += (p == 0).sum()
  147. return b / a
  148. def prune(model, amount=0.3):
  149. # Prune model to requested global sparsity
  150. import torch.nn.utils.prune as prune
  151. print('Pruning model... ', end='')
  152. for name, m in model.named_modules():
  153. if isinstance(m, nn.Conv2d):
  154. prune.l1_unstructured(m, name='weight', amount=amount) # prune
  155. prune.remove(m, 'weight') # make permanent
  156. print(' %.3g global sparsity' % sparsity(model))
  157. def fuse_conv_and_bn(conv, bn):
  158. # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  159. fusedconv = nn.Conv2d(conv.in_channels,
  160. conv.out_channels,
  161. kernel_size=conv.kernel_size,
  162. stride=conv.stride,
  163. padding=conv.padding,
  164. groups=conv.groups,
  165. bias=True).requires_grad_(False).to(conv.weight.device)
  166. # prepare filters
  167. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  168. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  169. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
  170. # prepare spatial bias
  171. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  172. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  173. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  174. return fusedconv
  175. def model_info(model, verbose=False, img_size=640):
  176. # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
  177. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  178. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  179. if verbose:
  180. print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
  181. for i, (name, p) in enumerate(model.named_parameters()):
  182. name = name.replace('module_list.', '')
  183. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  184. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  185. try: # FLOPS
  186. from thop import profile
  187. stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
  188. img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
  189. flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS
  190. img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
  191. fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPS
  192. except (ImportError, Exception):
  193. fs = ''
  194. logger.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
  195. def load_classifier(name='resnet101', n=2):
  196. # Loads a pretrained model reshaped to n-class output
  197. model = torchvision.models.__dict__[name](pretrained=True)
  198. # ResNet model properties
  199. # input_size = [3, 224, 224]
  200. # input_space = 'RGB'
  201. # input_range = [0, 1]
  202. # mean = [0.485, 0.456, 0.406]
  203. # std = [0.229, 0.224, 0.225]
  204. # Reshape output to n classes
  205. filters = model.fc.weight.shape[1]
  206. model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
  207. model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
  208. model.fc.out_features = n
  209. return model
  210. def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
  211. # scales img(bs,3,y,x) by ratio constrained to gs-multiple
  212. if ratio == 1.0:
  213. return img
  214. else:
  215. h, w = img.shape[2:]
  216. s = (int(h * ratio), int(w * ratio)) # new size
  217. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  218. if not same_shape: # pad/crop img
  219. h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
  220. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  221. def copy_attr(a, b, include=(), exclude=()):
  222. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  223. for k, v in b.__dict__.items():
  224. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  225. continue
  226. else:
  227. setattr(a, k, v)
  228. class ModelEMA:
  229. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  230. Keep a moving average of everything in the model state_dict (parameters and buffers).
  231. This is intended to allow functionality like
  232. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  233. A smoothed version of the weights is necessary for some training schemes to perform well.
  234. This class is sensitive where it is initialized in the sequence of model init,
  235. GPU assignment and distributed training wrappers.
  236. """
  237. def __init__(self, model, decay=0.9999, updates=0):
  238. # Create EMA
  239. self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
  240. # if next(model.parameters()).device.type != 'cpu':
  241. # self.ema.half() # FP16 EMA
  242. self.updates = updates # number of EMA updates
  243. self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
  244. for p in self.ema.parameters():
  245. p.requires_grad_(False)
  246. def update(self, model):
  247. # Update EMA parameters
  248. with torch.no_grad():
  249. self.updates += 1
  250. d = self.decay(self.updates)
  251. msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
  252. for k, v in self.ema.state_dict().items():
  253. if v.dtype.is_floating_point:
  254. v *= d
  255. v += (1. - d) * msd[k].detach()
  256. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  257. # Update EMA attributes
  258. copy_attr(self.ema, model, include, exclude)