Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

313 Zeilen
13KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. PyTorch utils
  4. """
  5. import math
  6. import os
  7. import platform
  8. import subprocess
  9. import time
  10. import warnings
  11. from contextlib import contextmanager
  12. from copy import deepcopy
  13. from pathlib import Path
  14. import torch
  15. import torch.distributed as dist
  16. import torch.nn as nn
  17. import torch.nn.functional as F
  18. from utils.general import LOGGER, file_update_date, git_describe
  19. try:
  20. import thop # for FLOPs computation
  21. except ImportError:
  22. thop = None
  23. # Suppress PyTorch warnings
  24. warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
  25. @contextmanager
  26. def torch_distributed_zero_first(local_rank: int):
  27. # Decorator to make all processes in distributed training wait for each local_master to do something
  28. if local_rank not in [-1, 0]:
  29. dist.barrier(device_ids=[local_rank])
  30. yield
  31. if local_rank == 0:
  32. dist.barrier(device_ids=[0])
  33. def device_count():
  34. # Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Only works on Linux.
  35. assert platform.system() == 'Linux', 'device_count() function only works on Linux'
  36. try:
  37. cmd = 'nvidia-smi -L | wc -l'
  38. return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1])
  39. except Exception:
  40. return 0
  41. def select_device(device='', batch_size=0, newline=True):
  42. # device = 'cpu' or '0' or '0,1,2,3'
  43. s = f'YOLOv5 🚀 {git_describe() or file_update_date()} torch {torch.__version__} ' # string
  44. device = str(device).strip().lower().replace('cuda:', '') # to string, 'cuda:0' to '0'
  45. cpu = device == 'cpu'
  46. if cpu:
  47. os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
  48. elif device: # non-cpu device requested
  49. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
  50. assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
  51. f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
  52. cuda = not cpu and torch.cuda.is_available()
  53. if cuda:
  54. devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
  55. n = len(devices) # device count
  56. if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
  57. assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
  58. space = ' ' * (len(s) + 1)
  59. for i, d in enumerate(devices):
  60. p = torch.cuda.get_device_properties(i)
  61. s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
  62. else:
  63. s += 'CPU\n'
  64. if not newline:
  65. s = s.rstrip()
  66. LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
  67. return torch.device('cuda:0' if cuda else 'cpu')
  68. def time_sync():
  69. # PyTorch-accurate time
  70. if torch.cuda.is_available():
  71. torch.cuda.synchronize()
  72. return time.time()
  73. def profile(input, ops, n=10, device=None):
  74. # YOLOv5 speed/memory/FLOPs profiler
  75. #
  76. # Usage:
  77. # input = torch.randn(16, 3, 640, 640)
  78. # m1 = lambda x: x * torch.sigmoid(x)
  79. # m2 = nn.SiLU()
  80. # profile(input, [m1, m2], n=100) # profile over 100 iterations
  81. results = []
  82. device = device or select_device()
  83. print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
  84. f"{'input':>24s}{'output':>24s}")
  85. for x in input if isinstance(input, list) else [input]:
  86. x = x.to(device)
  87. x.requires_grad = True
  88. for m in ops if isinstance(ops, list) else [ops]:
  89. m = m.to(device) if hasattr(m, 'to') else m # device
  90. m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
  91. tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
  92. try:
  93. flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
  94. except Exception:
  95. flops = 0
  96. try:
  97. for _ in range(n):
  98. t[0] = time_sync()
  99. y = m(x)
  100. t[1] = time_sync()
  101. try:
  102. _ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
  103. t[2] = time_sync()
  104. except Exception: # no backward method
  105. # print(e) # for debug
  106. t[2] = float('nan')
  107. tf += (t[1] - t[0]) * 1000 / n # ms per op forward
  108. tb += (t[2] - t[1]) * 1000 / n # ms per op backward
  109. mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
  110. s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
  111. s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
  112. p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
  113. print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
  114. results.append([p, flops, mem, tf, tb, s_in, s_out])
  115. except Exception as e:
  116. print(e)
  117. results.append(None)
  118. torch.cuda.empty_cache()
  119. return results
  120. def is_parallel(model):
  121. # Returns True if model is of type DP or DDP
  122. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  123. def de_parallel(model):
  124. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  125. return model.module if is_parallel(model) else model
  126. def initialize_weights(model):
  127. for m in model.modules():
  128. t = type(m)
  129. if t is nn.Conv2d:
  130. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  131. elif t is nn.BatchNorm2d:
  132. m.eps = 1e-3
  133. m.momentum = 0.03
  134. elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
  135. m.inplace = True
  136. def find_modules(model, mclass=nn.Conv2d):
  137. # Finds layer indices matching module class 'mclass'
  138. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  139. def sparsity(model):
  140. # Return global model sparsity
  141. a, b = 0, 0
  142. for p in model.parameters():
  143. a += p.numel()
  144. b += (p == 0).sum()
  145. return b / a
  146. def prune(model, amount=0.3):
  147. # Prune model to requested global sparsity
  148. import torch.nn.utils.prune as prune
  149. print('Pruning model... ', end='')
  150. for name, m in model.named_modules():
  151. if isinstance(m, nn.Conv2d):
  152. prune.l1_unstructured(m, name='weight', amount=amount) # prune
  153. prune.remove(m, 'weight') # make permanent
  154. print(' %.3g global sparsity' % sparsity(model))
  155. def fuse_conv_and_bn(conv, bn):
  156. # Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  157. fusedconv = nn.Conv2d(conv.in_channels,
  158. conv.out_channels,
  159. kernel_size=conv.kernel_size,
  160. stride=conv.stride,
  161. padding=conv.padding,
  162. groups=conv.groups,
  163. bias=True).requires_grad_(False).to(conv.weight.device)
  164. # Prepare filters
  165. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  166. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  167. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
  168. # Prepare spatial bias
  169. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  170. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  171. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  172. return fusedconv
  173. def model_info(model, verbose=False, img_size=640):
  174. # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
  175. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  176. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  177. if verbose:
  178. print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
  179. for i, (name, p) in enumerate(model.named_parameters()):
  180. name = name.replace('module_list.', '')
  181. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  182. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  183. try: # FLOPs
  184. from thop import profile
  185. stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
  186. img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
  187. flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
  188. img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
  189. fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPs
  190. except (ImportError, Exception):
  191. fs = ''
  192. name = Path(model.yaml_file).stem.replace('yolov5', 'YOLOv5') if hasattr(model, 'yaml_file') else 'Model'
  193. LOGGER.info(f"{name} summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
  194. def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
  195. # Scales img(bs,3,y,x) by ratio constrained to gs-multiple
  196. if ratio == 1.0:
  197. return img
  198. else:
  199. h, w = img.shape[2:]
  200. s = (int(h * ratio), int(w * ratio)) # new size
  201. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  202. if not same_shape: # pad/crop img
  203. h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
  204. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  205. def copy_attr(a, b, include=(), exclude=()):
  206. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  207. for k, v in b.__dict__.items():
  208. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  209. continue
  210. else:
  211. setattr(a, k, v)
  212. class EarlyStopping:
  213. # YOLOv5 simple early stopper
  214. def __init__(self, patience=30):
  215. self.best_fitness = 0.0 # i.e. mAP
  216. self.best_epoch = 0
  217. self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
  218. self.possible_stop = False # possible stop may occur next epoch
  219. def __call__(self, epoch, fitness):
  220. if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
  221. self.best_epoch = epoch
  222. self.best_fitness = fitness
  223. delta = epoch - self.best_epoch # epochs without improvement
  224. self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
  225. stop = delta >= self.patience # stop training if patience exceeded
  226. if stop:
  227. LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
  228. f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
  229. f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
  230. f'i.e. `python train.py --patience 300` or use `--patience 0` to disable EarlyStopping.')
  231. return stop
  232. class ModelEMA:
  233. """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
  234. Keeps a moving average of everything in the model state_dict (parameters and buffers)
  235. For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  236. """
  237. def __init__(self, model, decay=0.9999, tau=2000, updates=0):
  238. # Create EMA
  239. self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
  240. # if next(model.parameters()).device.type != 'cpu':
  241. # self.ema.half() # FP16 EMA
  242. self.updates = updates # number of EMA updates
  243. self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
  244. for p in self.ema.parameters():
  245. p.requires_grad_(False)
  246. def update(self, model):
  247. # Update EMA parameters
  248. with torch.no_grad():
  249. self.updates += 1
  250. d = self.decay(self.updates)
  251. msd = de_parallel(model).state_dict() # model state_dict
  252. for k, v in self.ema.state_dict().items():
  253. if v.dtype.is_floating_point:
  254. v *= d
  255. v += (1 - d) * msd[k].detach()
  256. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  257. # Update EMA attributes
  258. copy_attr(self.ema, model, include, exclude)