Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

321 lines
13KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. PyTorch utils
  4. """
  5. import datetime
  6. import math
  7. import os
  8. import platform
  9. import subprocess
  10. import time
  11. from contextlib import contextmanager
  12. from copy import deepcopy
  13. from pathlib import Path
  14. import torch
  15. import torch.distributed as dist
  16. import torch.nn as nn
  17. import torch.nn.functional as F
  18. from utils.general import LOGGER
  19. try:
  20. import thop # for FLOPs computation
  21. except ImportError:
  22. thop = None
  23. @contextmanager
  24. def torch_distributed_zero_first(local_rank: int):
  25. """
  26. Decorator to make all processes in distributed training wait for each local_master to do something.
  27. """
  28. if local_rank not in [-1, 0]:
  29. dist.barrier(device_ids=[local_rank])
  30. yield
  31. if local_rank == 0:
  32. dist.barrier(device_ids=[0])
  33. def date_modified(path=__file__):
  34. # return human-readable file modification date, i.e. '2021-3-26'
  35. t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
  36. return f'{t.year}-{t.month}-{t.day}'
  37. def git_describe(path=Path(__file__).parent): # path must be a directory
  38. # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
  39. s = f'git -C {path} describe --tags --long --always'
  40. try:
  41. return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
  42. except subprocess.CalledProcessError as e:
  43. return '' # not a git repository
  44. def select_device(device='', batch_size=0, newline=True):
  45. # device = 'cpu' or '0' or '0,1,2,3'
  46. s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
  47. device = str(device).strip().lower().replace('cuda:', '') # to string, 'cuda:0' to '0'
  48. cpu = device == 'cpu'
  49. if cpu:
  50. os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
  51. elif device: # non-cpu device requested
  52. nd = torch.cuda.device_count() # number of CUDA devices
  53. assert torch.cuda.is_available(), 'CUDA is not available, use `--device cpu` or do not pass a --device'
  54. assert nd > int(max(device.split(','))), f'Invalid `--device {device}` request, valid devices are 0 - {nd - 1}'
  55. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable (must be after asserts)
  56. cuda = not cpu and torch.cuda.is_available()
  57. if cuda:
  58. devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
  59. n = len(devices) # device count
  60. if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
  61. assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
  62. space = ' ' * (len(s) + 1)
  63. for i, d in enumerate(devices):
  64. p = torch.cuda.get_device_properties(i)
  65. s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2:.0f}MiB)\n" # bytes to MB
  66. else:
  67. s += 'CPU\n'
  68. if not newline:
  69. s = s.rstrip()
  70. LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
  71. return torch.device('cuda:0' if cuda else 'cpu')
  72. def time_sync():
  73. # pytorch-accurate time
  74. if torch.cuda.is_available():
  75. torch.cuda.synchronize()
  76. return time.time()
  77. def profile(input, ops, n=10, device=None):
  78. # YOLOv5 speed/memory/FLOPs profiler
  79. #
  80. # Usage:
  81. # input = torch.randn(16, 3, 640, 640)
  82. # m1 = lambda x: x * torch.sigmoid(x)
  83. # m2 = nn.SiLU()
  84. # profile(input, [m1, m2], n=100) # profile over 100 iterations
  85. results = []
  86. device = device or select_device()
  87. print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
  88. f"{'input':>24s}{'output':>24s}")
  89. for x in input if isinstance(input, list) else [input]:
  90. x = x.to(device)
  91. x.requires_grad = True
  92. for m in ops if isinstance(ops, list) else [ops]:
  93. m = m.to(device) if hasattr(m, 'to') else m # device
  94. m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
  95. tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
  96. try:
  97. flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
  98. except:
  99. flops = 0
  100. try:
  101. for _ in range(n):
  102. t[0] = time_sync()
  103. y = m(x)
  104. t[1] = time_sync()
  105. try:
  106. _ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
  107. t[2] = time_sync()
  108. except Exception as e: # no backward method
  109. # print(e) # for debug
  110. t[2] = float('nan')
  111. tf += (t[1] - t[0]) * 1000 / n # ms per op forward
  112. tb += (t[2] - t[1]) * 1000 / n # ms per op backward
  113. mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
  114. s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
  115. s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
  116. p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
  117. print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
  118. results.append([p, flops, mem, tf, tb, s_in, s_out])
  119. except Exception as e:
  120. print(e)
  121. results.append(None)
  122. torch.cuda.empty_cache()
  123. return results
  124. def is_parallel(model):
  125. # Returns True if model is of type DP or DDP
  126. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  127. def de_parallel(model):
  128. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  129. return model.module if is_parallel(model) else model
  130. def initialize_weights(model):
  131. for m in model.modules():
  132. t = type(m)
  133. if t is nn.Conv2d:
  134. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  135. elif t is nn.BatchNorm2d:
  136. m.eps = 1e-3
  137. m.momentum = 0.03
  138. elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
  139. m.inplace = True
  140. def find_modules(model, mclass=nn.Conv2d):
  141. # Finds layer indices matching module class 'mclass'
  142. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  143. def sparsity(model):
  144. # Return global model sparsity
  145. a, b = 0, 0
  146. for p in model.parameters():
  147. a += p.numel()
  148. b += (p == 0).sum()
  149. return b / a
  150. def prune(model, amount=0.3):
  151. # Prune model to requested global sparsity
  152. import torch.nn.utils.prune as prune
  153. print('Pruning model... ', end='')
  154. for name, m in model.named_modules():
  155. if isinstance(m, nn.Conv2d):
  156. prune.l1_unstructured(m, name='weight', amount=amount) # prune
  157. prune.remove(m, 'weight') # make permanent
  158. print(' %.3g global sparsity' % sparsity(model))
  159. def fuse_conv_and_bn(conv, bn):
  160. # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  161. fusedconv = nn.Conv2d(conv.in_channels,
  162. conv.out_channels,
  163. kernel_size=conv.kernel_size,
  164. stride=conv.stride,
  165. padding=conv.padding,
  166. groups=conv.groups,
  167. bias=True).requires_grad_(False).to(conv.weight.device)
  168. # prepare filters
  169. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  170. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  171. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
  172. # prepare spatial bias
  173. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  174. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  175. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  176. return fusedconv
  177. def model_info(model, verbose=False, img_size=640):
  178. # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
  179. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  180. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  181. if verbose:
  182. print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
  183. for i, (name, p) in enumerate(model.named_parameters()):
  184. name = name.replace('module_list.', '')
  185. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  186. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  187. try: # FLOPs
  188. from thop import profile
  189. stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
  190. img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
  191. flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
  192. img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
  193. fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPs
  194. except (ImportError, Exception):
  195. fs = ''
  196. LOGGER.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
  197. def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
  198. # scales img(bs,3,y,x) by ratio constrained to gs-multiple
  199. if ratio == 1.0:
  200. return img
  201. else:
  202. h, w = img.shape[2:]
  203. s = (int(h * ratio), int(w * ratio)) # new size
  204. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  205. if not same_shape: # pad/crop img
  206. h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
  207. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  208. def copy_attr(a, b, include=(), exclude=()):
  209. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  210. for k, v in b.__dict__.items():
  211. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  212. continue
  213. else:
  214. setattr(a, k, v)
  215. class EarlyStopping:
  216. # YOLOv5 simple early stopper
  217. def __init__(self, patience=30):
  218. self.best_fitness = 0.0 # i.e. mAP
  219. self.best_epoch = 0
  220. self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
  221. self.possible_stop = False # possible stop may occur next epoch
  222. def __call__(self, epoch, fitness):
  223. if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
  224. self.best_epoch = epoch
  225. self.best_fitness = fitness
  226. delta = epoch - self.best_epoch # epochs without improvement
  227. self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
  228. stop = delta >= self.patience # stop training if patience exceeded
  229. if stop:
  230. LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
  231. f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
  232. f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
  233. f'i.e. `python train.py --patience 300` or use `--patience 0` to disable EarlyStopping.')
  234. return stop
  235. class ModelEMA:
  236. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  237. Keep a moving average of everything in the model state_dict (parameters and buffers).
  238. This is intended to allow functionality like
  239. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  240. A smoothed version of the weights is necessary for some training schemes to perform well.
  241. This class is sensitive where it is initialized in the sequence of model init,
  242. GPU assignment and distributed training wrappers.
  243. """
  244. def __init__(self, model, decay=0.9999, updates=0):
  245. # Create EMA
  246. self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
  247. # if next(model.parameters()).device.type != 'cpu':
  248. # self.ema.half() # FP16 EMA
  249. self.updates = updates # number of EMA updates
  250. self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
  251. for p in self.ema.parameters():
  252. p.requires_grad_(False)
  253. def update(self, model):
  254. # Update EMA parameters
  255. with torch.no_grad():
  256. self.updates += 1
  257. d = self.decay(self.updates)
  258. msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
  259. for k, v in self.ema.state_dict().items():
  260. if v.dtype.is_floating_point:
  261. v *= d
  262. v += (1 - d) * msd[k].detach()
  263. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  264. # Update EMA attributes
  265. copy_attr(self.ema, model, include, exclude)