Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.

328 lines
13KB

  1. # YOLOv5 PyTorch utils
  2. import datetime
  3. import logging
  4. import os
  5. import platform
  6. import subprocess
  7. import time
  8. from contextlib import contextmanager
  9. from copy import deepcopy
  10. from pathlib import Path
  11. import math
  12. import torch
  13. import torch.backends.cudnn as cudnn
  14. import torch.distributed as dist
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. import torchvision
  18. try:
  19. import thop # for FLOPs computation
  20. except ImportError:
  21. thop = None
  22. LOGGER = logging.getLogger(__name__)
  23. @contextmanager
  24. def torch_distributed_zero_first(local_rank: int):
  25. """
  26. Decorator to make all processes in distributed training wait for each local_master to do something.
  27. """
  28. if local_rank not in [-1, 0]:
  29. dist.barrier()
  30. yield
  31. if local_rank == 0:
  32. dist.barrier()
  33. def init_torch_seeds(seed=0):
  34. # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
  35. torch.manual_seed(seed)
  36. if seed == 0: # slower, more reproducible
  37. cudnn.benchmark, cudnn.deterministic = False, True
  38. else: # faster, less reproducible
  39. cudnn.benchmark, cudnn.deterministic = True, False
  40. def date_modified(path=__file__):
  41. # return human-readable file modification date, i.e. '2021-3-26'
  42. t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
  43. return f'{t.year}-{t.month}-{t.day}'
  44. def git_describe(path=Path(__file__).parent): # path must be a directory
  45. # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
  46. s = f'git -C {path} describe --tags --long --always'
  47. try:
  48. return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
  49. except subprocess.CalledProcessError as e:
  50. return '' # not a git repository
  51. def select_device(device='', batch_size=None):
  52. # device = 'cpu' or '0' or '0,1,2,3'
  53. s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
  54. device = str(device).strip().lower().replace('cuda:', '') # to string, 'cuda:0' to '0'
  55. cpu = device == 'cpu'
  56. if cpu:
  57. os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
  58. elif device: # non-cpu device requested
  59. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
  60. assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
  61. cuda = not cpu and torch.cuda.is_available()
  62. if cuda:
  63. devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
  64. n = len(devices) # device count
  65. if n > 1 and batch_size: # check batch_size is divisible by device_count
  66. assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
  67. space = ' ' * (len(s) + 1)
  68. for i, d in enumerate(devices):
  69. p = torch.cuda.get_device_properties(i)
  70. s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
  71. else:
  72. s += 'CPU\n'
  73. LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
  74. return torch.device('cuda:0' if cuda else 'cpu')
  75. def time_sync():
  76. # pytorch-accurate time
  77. if torch.cuda.is_available():
  78. torch.cuda.synchronize()
  79. return time.time()
  80. def profile(input, ops, n=10, device=None):
  81. # YOLOv5 speed/memory/FLOPs profiler
  82. #
  83. # Usage:
  84. # input = torch.randn(16, 3, 640, 640)
  85. # m1 = lambda x: x * torch.sigmoid(x)
  86. # m2 = nn.SiLU()
  87. # profile(input, [m1, m2], n=100) # profile over 100 iterations
  88. results = []
  89. logging.basicConfig(format="%(message)s", level=logging.INFO)
  90. device = device or select_device()
  91. print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
  92. f"{'input':>24s}{'output':>24s}")
  93. for x in input if isinstance(input, list) else [input]:
  94. x = x.to(device)
  95. x.requires_grad = True
  96. for m in ops if isinstance(ops, list) else [ops]:
  97. m = m.to(device) if hasattr(m, 'to') else m # device
  98. m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
  99. tf, tb, t = 0., 0., [0., 0., 0.] # dt forward, backward
  100. try:
  101. flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
  102. except:
  103. flops = 0
  104. try:
  105. for _ in range(n):
  106. t[0] = time_sync()
  107. y = m(x)
  108. t[1] = time_sync()
  109. try:
  110. _ = (sum([yi.sum() for yi in y]) if isinstance(y, list) else y).sum().backward()
  111. t[2] = time_sync()
  112. except Exception as e: # no backward method
  113. print(e)
  114. t[2] = float('nan')
  115. tf += (t[1] - t[0]) * 1000 / n # ms per op forward
  116. tb += (t[2] - t[1]) * 1000 / n # ms per op backward
  117. mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
  118. s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
  119. s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
  120. p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
  121. print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
  122. results.append([p, flops, mem, tf, tb, s_in, s_out])
  123. except Exception as e:
  124. print(e)
  125. results.append(None)
  126. torch.cuda.empty_cache()
  127. return results
  128. def is_parallel(model):
  129. # Returns True if model is of type DP or DDP
  130. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  131. def de_parallel(model):
  132. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  133. return model.module if is_parallel(model) else model
  134. def intersect_dicts(da, db, exclude=()):
  135. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  136. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  137. def initialize_weights(model):
  138. for m in model.modules():
  139. t = type(m)
  140. if t is nn.Conv2d:
  141. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  142. elif t is nn.BatchNorm2d:
  143. m.eps = 1e-3
  144. m.momentum = 0.03
  145. elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
  146. m.inplace = True
  147. def find_modules(model, mclass=nn.Conv2d):
  148. # Finds layer indices matching module class 'mclass'
  149. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  150. def sparsity(model):
  151. # Return global model sparsity
  152. a, b = 0., 0.
  153. for p in model.parameters():
  154. a += p.numel()
  155. b += (p == 0).sum()
  156. return b / a
  157. def prune(model, amount=0.3):
  158. # Prune model to requested global sparsity
  159. import torch.nn.utils.prune as prune
  160. print('Pruning model... ', end='')
  161. for name, m in model.named_modules():
  162. if isinstance(m, nn.Conv2d):
  163. prune.l1_unstructured(m, name='weight', amount=amount) # prune
  164. prune.remove(m, 'weight') # make permanent
  165. print(' %.3g global sparsity' % sparsity(model))
  166. def fuse_conv_and_bn(conv, bn):
  167. # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  168. fusedconv = nn.Conv2d(conv.in_channels,
  169. conv.out_channels,
  170. kernel_size=conv.kernel_size,
  171. stride=conv.stride,
  172. padding=conv.padding,
  173. groups=conv.groups,
  174. bias=True).requires_grad_(False).to(conv.weight.device)
  175. # prepare filters
  176. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  177. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  178. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
  179. # prepare spatial bias
  180. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  181. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  182. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  183. return fusedconv
  184. def model_info(model, verbose=False, img_size=640):
  185. # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
  186. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  187. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  188. if verbose:
  189. print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
  190. for i, (name, p) in enumerate(model.named_parameters()):
  191. name = name.replace('module_list.', '')
  192. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  193. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  194. try: # FLOPs
  195. from thop import profile
  196. stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
  197. img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
  198. flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
  199. img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
  200. fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPs
  201. except (ImportError, Exception):
  202. fs = ''
  203. LOGGER.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
  204. def load_classifier(name='resnet101', n=2):
  205. # Loads a pretrained model reshaped to n-class output
  206. model = torchvision.models.__dict__[name](pretrained=True)
  207. # ResNet model properties
  208. # input_size = [3, 224, 224]
  209. # input_space = 'RGB'
  210. # input_range = [0, 1]
  211. # mean = [0.485, 0.456, 0.406]
  212. # std = [0.229, 0.224, 0.225]
  213. # Reshape output to n classes
  214. filters = model.fc.weight.shape[1]
  215. model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
  216. model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
  217. model.fc.out_features = n
  218. return model
  219. def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
  220. # scales img(bs,3,y,x) by ratio constrained to gs-multiple
  221. if ratio == 1.0:
  222. return img
  223. else:
  224. h, w = img.shape[2:]
  225. s = (int(h * ratio), int(w * ratio)) # new size
  226. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  227. if not same_shape: # pad/crop img
  228. h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
  229. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  230. def copy_attr(a, b, include=(), exclude=()):
  231. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  232. for k, v in b.__dict__.items():
  233. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  234. continue
  235. else:
  236. setattr(a, k, v)
  237. class ModelEMA:
  238. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  239. Keep a moving average of everything in the model state_dict (parameters and buffers).
  240. This is intended to allow functionality like
  241. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  242. A smoothed version of the weights is necessary for some training schemes to perform well.
  243. This class is sensitive where it is initialized in the sequence of model init,
  244. GPU assignment and distributed training wrappers.
  245. """
  246. def __init__(self, model, decay=0.9999, updates=0):
  247. # Create EMA
  248. self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
  249. # if next(model.parameters()).device.type != 'cpu':
  250. # self.ema.half() # FP16 EMA
  251. self.updates = updates # number of EMA updates
  252. self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
  253. for p in self.ema.parameters():
  254. p.requires_grad_(False)
  255. def update(self, model):
  256. # Update EMA parameters
  257. with torch.no_grad():
  258. self.updates += 1
  259. d = self.decay(self.updates)
  260. msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
  261. for k, v in self.ema.state_dict().items():
  262. if v.dtype.is_floating_point:
  263. v *= d
  264. v += (1. - d) * msd[k].detach()
  265. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  266. # Update EMA attributes
  267. copy_attr(self.ema, model, include, exclude)