Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

313 lines
12KB

  1. # YOLOv5 PyTorch utils
  2. import datetime
  3. import logging
  4. import os
  5. import platform
  6. import subprocess
  7. import time
  8. from contextlib import contextmanager
  9. from copy import deepcopy
  10. from pathlib import Path
  11. import math
  12. import torch
  13. import torch.backends.cudnn as cudnn
  14. import torch.distributed as dist
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. import torchvision
  18. try:
  19. import thop # for FLOPs computation
  20. except ImportError:
  21. thop = None
  22. LOGGER = logging.getLogger(__name__)
  23. @contextmanager
  24. def torch_distributed_zero_first(local_rank: int):
  25. """
  26. Decorator to make all processes in distributed training wait for each local_master to do something.
  27. """
  28. if local_rank not in [-1, 0]:
  29. dist.barrier()
  30. yield
  31. if local_rank == 0:
  32. dist.barrier()
  33. def init_torch_seeds(seed=0):
  34. # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
  35. torch.manual_seed(seed)
  36. if seed == 0: # slower, more reproducible
  37. cudnn.benchmark, cudnn.deterministic = False, True
  38. else: # faster, less reproducible
  39. cudnn.benchmark, cudnn.deterministic = True, False
  40. def date_modified(path=__file__):
  41. # return human-readable file modification date, i.e. '2021-3-26'
  42. t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
  43. return f'{t.year}-{t.month}-{t.day}'
  44. def git_describe(path=Path(__file__).parent): # path must be a directory
  45. # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
  46. s = f'git -C {path} describe --tags --long --always'
  47. try:
  48. return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
  49. except subprocess.CalledProcessError as e:
  50. return '' # not a git repository
  51. def select_device(device='', batch_size=None):
  52. # device = 'cpu' or '0' or '0,1,2,3'
  53. s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
  54. device = str(device).strip().lower().replace('cuda:', '') # to string, 'cuda:0' to '0'
  55. cpu = device == 'cpu'
  56. if cpu:
  57. os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
  58. elif device: # non-cpu device requested
  59. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
  60. assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
  61. cuda = not cpu and torch.cuda.is_available()
  62. if cuda:
  63. devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
  64. n = len(devices) # device count
  65. if n > 1 and batch_size: # check batch_size is divisible by device_count
  66. assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
  67. space = ' ' * (len(s) + 1)
  68. for i, d in enumerate(devices):
  69. p = torch.cuda.get_device_properties(i)
  70. s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
  71. else:
  72. s += 'CPU\n'
  73. LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
  74. return torch.device('cuda:0' if cuda else 'cpu')
  75. def time_sync():
  76. # pytorch-accurate time
  77. if torch.cuda.is_available():
  78. torch.cuda.synchronize()
  79. return time.time()
  80. def profile(x, ops, n=100, device=None):
  81. # profile a pytorch module or list of modules. Example usage:
  82. # x = torch.randn(16, 3, 640, 640) # input
  83. # m1 = lambda x: x * torch.sigmoid(x)
  84. # m2 = nn.SiLU()
  85. # profile(x, [m1, m2], n=100) # profile speed over 100 iterations
  86. device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  87. x = x.to(device)
  88. x.requires_grad = True
  89. print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
  90. print(f"\n{'Params':>12s}{'GFLOPs':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
  91. for m in ops if isinstance(ops, list) else [ops]:
  92. m = m.to(device) if hasattr(m, 'to') else m # device
  93. m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type
  94. dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
  95. try:
  96. flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
  97. except:
  98. flops = 0
  99. for _ in range(n):
  100. t[0] = time_sync()
  101. y = m(x)
  102. t[1] = time_sync()
  103. try:
  104. _ = y.sum().backward()
  105. t[2] = time_sync()
  106. except: # no backward method
  107. t[2] = float('nan')
  108. dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
  109. dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
  110. s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
  111. s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
  112. p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
  113. print(f'{p:12}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}')
  114. def is_parallel(model):
  115. # Returns True if model is of type DP or DDP
  116. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  117. def de_parallel(model):
  118. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  119. return model.module if is_parallel(model) else model
  120. def intersect_dicts(da, db, exclude=()):
  121. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  122. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  123. def initialize_weights(model):
  124. for m in model.modules():
  125. t = type(m)
  126. if t is nn.Conv2d:
  127. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  128. elif t is nn.BatchNorm2d:
  129. m.eps = 1e-3
  130. m.momentum = 0.03
  131. elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
  132. m.inplace = True
  133. def find_modules(model, mclass=nn.Conv2d):
  134. # Finds layer indices matching module class 'mclass'
  135. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  136. def sparsity(model):
  137. # Return global model sparsity
  138. a, b = 0., 0.
  139. for p in model.parameters():
  140. a += p.numel()
  141. b += (p == 0).sum()
  142. return b / a
  143. def prune(model, amount=0.3):
  144. # Prune model to requested global sparsity
  145. import torch.nn.utils.prune as prune
  146. print('Pruning model... ', end='')
  147. for name, m in model.named_modules():
  148. if isinstance(m, nn.Conv2d):
  149. prune.l1_unstructured(m, name='weight', amount=amount) # prune
  150. prune.remove(m, 'weight') # make permanent
  151. print(' %.3g global sparsity' % sparsity(model))
  152. def fuse_conv_and_bn(conv, bn):
  153. # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  154. fusedconv = nn.Conv2d(conv.in_channels,
  155. conv.out_channels,
  156. kernel_size=conv.kernel_size,
  157. stride=conv.stride,
  158. padding=conv.padding,
  159. groups=conv.groups,
  160. bias=True).requires_grad_(False).to(conv.weight.device)
  161. # prepare filters
  162. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  163. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  164. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
  165. # prepare spatial bias
  166. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  167. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  168. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  169. return fusedconv
  170. def model_info(model, verbose=False, img_size=640):
  171. # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
  172. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  173. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  174. if verbose:
  175. print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
  176. for i, (name, p) in enumerate(model.named_parameters()):
  177. name = name.replace('module_list.', '')
  178. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  179. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  180. try: # FLOPs
  181. from thop import profile
  182. stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
  183. img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
  184. flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
  185. img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
  186. fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPs
  187. except (ImportError, Exception):
  188. fs = ''
  189. LOGGER.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
  190. def load_classifier(name='resnet101', n=2):
  191. # Loads a pretrained model reshaped to n-class output
  192. model = torchvision.models.__dict__[name](pretrained=True)
  193. # ResNet model properties
  194. # input_size = [3, 224, 224]
  195. # input_space = 'RGB'
  196. # input_range = [0, 1]
  197. # mean = [0.485, 0.456, 0.406]
  198. # std = [0.229, 0.224, 0.225]
  199. # Reshape output to n classes
  200. filters = model.fc.weight.shape[1]
  201. model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
  202. model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
  203. model.fc.out_features = n
  204. return model
  205. def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
  206. # scales img(bs,3,y,x) by ratio constrained to gs-multiple
  207. if ratio == 1.0:
  208. return img
  209. else:
  210. h, w = img.shape[2:]
  211. s = (int(h * ratio), int(w * ratio)) # new size
  212. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  213. if not same_shape: # pad/crop img
  214. h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
  215. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  216. def copy_attr(a, b, include=(), exclude=()):
  217. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  218. for k, v in b.__dict__.items():
  219. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  220. continue
  221. else:
  222. setattr(a, k, v)
  223. class ModelEMA:
  224. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  225. Keep a moving average of everything in the model state_dict (parameters and buffers).
  226. This is intended to allow functionality like
  227. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  228. A smoothed version of the weights is necessary for some training schemes to perform well.
  229. This class is sensitive where it is initialized in the sequence of model init,
  230. GPU assignment and distributed training wrappers.
  231. """
  232. def __init__(self, model, decay=0.9999, updates=0):
  233. # Create EMA
  234. self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
  235. # if next(model.parameters()).device.type != 'cpu':
  236. # self.ema.half() # FP16 EMA
  237. self.updates = updates # number of EMA updates
  238. self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
  239. for p in self.ema.parameters():
  240. p.requires_grad_(False)
  241. def update(self, model):
  242. # Update EMA parameters
  243. with torch.no_grad():
  244. self.updates += 1
  245. d = self.decay(self.updates)
  246. msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
  247. for k, v in self.ema.state_dict().items():
  248. if v.dtype.is_floating_point:
  249. v *= d
  250. v += (1. - d) * msd[k].detach()
  251. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  252. # Update EMA attributes
  253. copy_attr(self.ema, model, include, exclude)