Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

311 lines
13KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. PyTorch utils
  4. """
  5. import math
  6. import os
  7. import platform
  8. import subprocess
  9. import time
  10. import warnings
  11. from contextlib import contextmanager
  12. from copy import deepcopy
  13. import torch
  14. import torch.distributed as dist
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. from utils.general import LOGGER, file_update_date, git_describe
  18. try:
  19. import thop # for FLOPs computation
  20. except ImportError:
  21. thop = None
  22. # Suppress PyTorch warnings
  23. warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
  24. @contextmanager
  25. def torch_distributed_zero_first(local_rank: int):
  26. # Decorator to make all processes in distributed training wait for each local_master to do something
  27. if local_rank not in [-1, 0]:
  28. dist.barrier(device_ids=[local_rank])
  29. yield
  30. if local_rank == 0:
  31. dist.barrier(device_ids=[0])
  32. def device_count():
  33. # Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Only works on Linux.
  34. assert platform.system() == 'Linux', 'device_count() function only works on Linux'
  35. try:
  36. cmd = 'nvidia-smi -L | wc -l'
  37. return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1])
  38. except Exception:
  39. return 0
  40. def select_device(device='', batch_size=0, newline=True):
  41. # device = 'cpu' or '0' or '0,1,2,3'
  42. s = f'YOLOv5 🚀 {git_describe() or file_update_date()} torch {torch.__version__} ' # string
  43. device = str(device).strip().lower().replace('cuda:', '') # to string, 'cuda:0' to '0'
  44. cpu = device == 'cpu'
  45. if cpu:
  46. os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
  47. elif device: # non-cpu device requested
  48. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
  49. assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
  50. f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
  51. cuda = not cpu and torch.cuda.is_available()
  52. if cuda:
  53. devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
  54. n = len(devices) # device count
  55. if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
  56. assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
  57. space = ' ' * (len(s) + 1)
  58. for i, d in enumerate(devices):
  59. p = torch.cuda.get_device_properties(i)
  60. s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
  61. else:
  62. s += 'CPU\n'
  63. if not newline:
  64. s = s.rstrip()
  65. LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
  66. return torch.device('cuda:0' if cuda else 'cpu')
  67. def time_sync():
  68. # PyTorch-accurate time
  69. if torch.cuda.is_available():
  70. torch.cuda.synchronize()
  71. return time.time()
  72. def profile(input, ops, n=10, device=None):
  73. # YOLOv5 speed/memory/FLOPs profiler
  74. #
  75. # Usage:
  76. # input = torch.randn(16, 3, 640, 640)
  77. # m1 = lambda x: x * torch.sigmoid(x)
  78. # m2 = nn.SiLU()
  79. # profile(input, [m1, m2], n=100) # profile over 100 iterations
  80. results = []
  81. device = device or select_device()
  82. print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
  83. f"{'input':>24s}{'output':>24s}")
  84. for x in input if isinstance(input, list) else [input]:
  85. x = x.to(device)
  86. x.requires_grad = True
  87. for m in ops if isinstance(ops, list) else [ops]:
  88. m = m.to(device) if hasattr(m, 'to') else m # device
  89. m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
  90. tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
  91. try:
  92. flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
  93. except Exception:
  94. flops = 0
  95. try:
  96. for _ in range(n):
  97. t[0] = time_sync()
  98. y = m(x)
  99. t[1] = time_sync()
  100. try:
  101. _ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
  102. t[2] = time_sync()
  103. except Exception: # no backward method
  104. # print(e) # for debug
  105. t[2] = float('nan')
  106. tf += (t[1] - t[0]) * 1000 / n # ms per op forward
  107. tb += (t[2] - t[1]) * 1000 / n # ms per op backward
  108. mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
  109. s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
  110. s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
  111. p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
  112. print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
  113. results.append([p, flops, mem, tf, tb, s_in, s_out])
  114. except Exception as e:
  115. print(e)
  116. results.append(None)
  117. torch.cuda.empty_cache()
  118. return results
  119. def is_parallel(model):
  120. # Returns True if model is of type DP or DDP
  121. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  122. def de_parallel(model):
  123. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  124. return model.module if is_parallel(model) else model
  125. def initialize_weights(model):
  126. for m in model.modules():
  127. t = type(m)
  128. if t is nn.Conv2d:
  129. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  130. elif t is nn.BatchNorm2d:
  131. m.eps = 1e-3
  132. m.momentum = 0.03
  133. elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
  134. m.inplace = True
  135. def find_modules(model, mclass=nn.Conv2d):
  136. # Finds layer indices matching module class 'mclass'
  137. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  138. def sparsity(model):
  139. # Return global model sparsity
  140. a, b = 0, 0
  141. for p in model.parameters():
  142. a += p.numel()
  143. b += (p == 0).sum()
  144. return b / a
  145. def prune(model, amount=0.3):
  146. # Prune model to requested global sparsity
  147. import torch.nn.utils.prune as prune
  148. print('Pruning model... ', end='')
  149. for name, m in model.named_modules():
  150. if isinstance(m, nn.Conv2d):
  151. prune.l1_unstructured(m, name='weight', amount=amount) # prune
  152. prune.remove(m, 'weight') # make permanent
  153. print(' %.3g global sparsity' % sparsity(model))
  154. def fuse_conv_and_bn(conv, bn):
  155. # Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  156. fusedconv = nn.Conv2d(conv.in_channels,
  157. conv.out_channels,
  158. kernel_size=conv.kernel_size,
  159. stride=conv.stride,
  160. padding=conv.padding,
  161. groups=conv.groups,
  162. bias=True).requires_grad_(False).to(conv.weight.device)
  163. # Prepare filters
  164. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  165. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  166. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
  167. # Prepare spatial bias
  168. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  169. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  170. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  171. return fusedconv
  172. def model_info(model, verbose=False, img_size=640):
  173. # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
  174. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  175. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  176. if verbose:
  177. print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
  178. for i, (name, p) in enumerate(model.named_parameters()):
  179. name = name.replace('module_list.', '')
  180. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  181. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  182. try: # FLOPs
  183. from thop import profile
  184. stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
  185. img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
  186. flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
  187. img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
  188. fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPs
  189. except (ImportError, Exception):
  190. fs = ''
  191. LOGGER.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
  192. def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
  193. # Scales img(bs,3,y,x) by ratio constrained to gs-multiple
  194. if ratio == 1.0:
  195. return img
  196. else:
  197. h, w = img.shape[2:]
  198. s = (int(h * ratio), int(w * ratio)) # new size
  199. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  200. if not same_shape: # pad/crop img
  201. h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
  202. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  203. def copy_attr(a, b, include=(), exclude=()):
  204. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  205. for k, v in b.__dict__.items():
  206. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  207. continue
  208. else:
  209. setattr(a, k, v)
  210. class EarlyStopping:
  211. # YOLOv5 simple early stopper
  212. def __init__(self, patience=30):
  213. self.best_fitness = 0.0 # i.e. mAP
  214. self.best_epoch = 0
  215. self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
  216. self.possible_stop = False # possible stop may occur next epoch
  217. def __call__(self, epoch, fitness):
  218. if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
  219. self.best_epoch = epoch
  220. self.best_fitness = fitness
  221. delta = epoch - self.best_epoch # epochs without improvement
  222. self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
  223. stop = delta >= self.patience # stop training if patience exceeded
  224. if stop:
  225. LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
  226. f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
  227. f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
  228. f'i.e. `python train.py --patience 300` or use `--patience 0` to disable EarlyStopping.')
  229. return stop
  230. class ModelEMA:
  231. """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
  232. Keeps a moving average of everything in the model state_dict (parameters and buffers)
  233. For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  234. """
  235. def __init__(self, model, decay=0.9999, tau=2000, updates=0):
  236. # Create EMA
  237. self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
  238. # if next(model.parameters()).device.type != 'cpu':
  239. # self.ema.half() # FP16 EMA
  240. self.updates = updates # number of EMA updates
  241. self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
  242. for p in self.ema.parameters():
  243. p.requires_grad_(False)
  244. def update(self, model):
  245. # Update EMA parameters
  246. with torch.no_grad():
  247. self.updates += 1
  248. d = self.decay(self.updates)
  249. msd = de_parallel(model).state_dict() # model state_dict
  250. for k, v in self.ema.state_dict().items():
  251. if v.dtype.is_floating_point:
  252. v *= d
  253. v += (1 - d) * msd[k].detach()
  254. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  255. # Update EMA attributes
  256. copy_attr(self.ema, model, include, exclude)