Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

torch_utils.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. # YOLOv5 PyTorch utils
  2. import datetime # 时间模块 基于time进行了封装 更高级
  3. import logging # 日志功能生成模块
  4. import math # 数学函数模块
  5. import os # 与操作系统进行交互的模块
  6. import platform # 提供获取操作系统相关信息的模块
  7. import subprocess # 子进程定义及操作的模块
  8. import time # 时间模块 更底层
  9. from contextlib import contextmanager # 用于进行上下文管理的模块
  10. from copy import deepcopy # 实现深度复制的模块
  11. from pathlib import Path # Path将str转换为Path对象 使字符串路径易于操作的模块
  12. import torch
  13. import torch.backends.cudnn as cudnn
  14. import torch.nn as nn
  15. import torch.nn.functional as F
  16. import torchvision
  17. try:
  18. import thop # 用于Pytorch模型的FLOPS计算工具模块
  19. except ImportError:
  20. thop = None
  21. logger = logging.getLogger(__name__)
  22. @contextmanager
  23. def torch_distributed_zero_first(local_rank: int):
  24. """
  25. Decorator to make all processes in distributed training wait for each local_master to do something.
  26. """
  27. if local_rank not in [-1, 0]:
  28. torch.distributed.barrier()
  29. yield
  30. if local_rank == 0:
  31. torch.distributed.barrier()
  32. def init_torch_seeds(seed=0):
  33. # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
  34. torch.manual_seed(seed)
  35. if seed == 0: # slower, more reproducible
  36. cudnn.benchmark, cudnn.deterministic = False, True
  37. else: # faster, less reproducible
  38. cudnn.benchmark, cudnn.deterministic = True, False
  39. def date_modified(path=__file__):
  40. # return human-readable file modification date, i.e. '2021-3-26'
  41. t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
  42. return f'{t.year}-{t.month}-{t.day}'
  43. def git_describe(path=Path(__file__).parent): # path must be a directory
  44. # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
  45. s = f'git -C {path} describe --tags --long --always'
  46. try:
  47. return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
  48. except subprocess.CalledProcessError as e:
  49. return '' # not a git repository
  50. # 这个函数才是主角,用于自动选择本机模型训练的设备,并输出日志信息。
  51. def select_device(device='', batch_size=None):
  52. # device = 'cpu' or '0' or '0,1,2,3'
  53. """广泛用于train.py、test.py、detect.py等文件中
  54. 用于选择模型训练的设备 并输出日志信息
  55. :params device: 输入的设备 device = 'cpu' or '0' or '0,1,2,3'
  56. :params batch_size: 一个批次的图片个数
  57. """
  58. # git_describe(): 返回当前文件父文件的描述信息(yolov5) date_modified(): 返回当前文件的修改日期
  59. # s: 之后要加入logger日志的显示信息
  60. s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
  61. # 如果device输入为cpu cpu=True device.lower(): 将device字符串全部转为小写字母
  62. cpu = device.lower() == 'cpu'
  63. if cpu:
  64. # 如果cpu=True 就强制(force)使用cpu 令torch.cuda.is_available() = False
  65. os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
  66. elif device: # non-cpu device requested
  67. # 如果输入device不为空 device=GPU 直接设置 CUDA environment variable = device 加入CUDA可用设备
  68. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
  69. # 检查cuda的可用性 如果不可用则终止程序
  70. assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
  71. # 输入device为空 自行根据计算机情况选择相应设备 先看GPU 没有就CPU
  72. # 如果cuda可用 且 输入device != cpu 则 cuda=True 反正cuda=False
  73. cuda = not cpu and torch.cuda.is_available()
  74. if cuda:
  75. n = torch.cuda.device_count()
  76. if n > 1 and batch_size: # check that batch_size is compatible with device_count
  77. assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
  78. space = ' ' * len(s)
  79. for i, d in enumerate(device.split(',') if device else range(n)):
  80. # p: 每个可用显卡的相关属性
  81. p = torch.cuda.get_device_properties(i)
  82. # 显示信息s加上每张显卡的属性信息
  83. s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
  84. else:
  85. s += 'CPU\n'
  86. logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
  87. # 如果cuda可用就返回第一张显卡的的名称 如: GeForce RTX 2060 反之返回CPU对应的名称
  88. return torch.device('cuda:0' if cuda else 'cpu')
  89. def time_synchronized():
  90. # pytorch-accurate time
  91. if torch.cuda.is_available():
  92. torch.cuda.synchronize()
  93. return time.time()
  94. def profile(x, ops, n=100, device=None):
  95. # profile a pytorch module or list of modules. Example usage:
  96. # x = torch.randn(16, 3, 640, 640) # input
  97. # m1 = lambda x: x * torch.sigmoid(x)
  98. # m2 = nn.SiLU()
  99. # profile(x, [m1, m2], n=100) # profile speed over 100 iterations
  100. device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  101. x = x.to(device)
  102. x.requires_grad = True
  103. print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
  104. print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
  105. for m in ops if isinstance(ops, list) else [ops]:
  106. m = m.to(device) if hasattr(m, 'to') else m # device
  107. m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type
  108. dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
  109. try:
  110. flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS
  111. except:
  112. flops = 0
  113. for _ in range(n):
  114. t[0] = time_synchronized()
  115. y = m(x)
  116. t[1] = time_synchronized()
  117. try:
  118. _ = y.sum().backward()
  119. t[2] = time_synchronized()
  120. except: # no backward method
  121. t[2] = float('nan')
  122. dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
  123. dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
  124. s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
  125. s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
  126. p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
  127. print(f'{p:12}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}')
  128. def is_parallel(model):
  129. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  130. def intersect_dicts(da, db, exclude=()):
  131. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  132. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  133. def initialize_weights(model):
  134. for m in model.modules():
  135. t = type(m)
  136. if t is nn.Conv2d:
  137. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  138. elif t is nn.BatchNorm2d:
  139. m.eps = 1e-3
  140. m.momentum = 0.03
  141. elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
  142. m.inplace = True
  143. def find_modules(model, mclass=nn.Conv2d):
  144. # Finds layer indices matching module class 'mclass'
  145. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  146. def sparsity(model):
  147. # Return global model sparsity
  148. a, b = 0., 0.
  149. for p in model.parameters():
  150. a += p.numel()
  151. b += (p == 0).sum()
  152. return b / a
  153. def prune(model, amount=0.3):
  154. # Prune model to requested global sparsity
  155. import torch.nn.utils.prune as prune
  156. print('Pruning model... ', end='')
  157. for name, m in model.named_modules():
  158. if isinstance(m, nn.Conv2d):
  159. prune.l1_unstructured(m, name='weight', amount=amount) # prune
  160. prune.remove(m, 'weight') # make permanent
  161. print(' %.3g global sparsity' % sparsity(model))
  162. def fuse_conv_and_bn(conv, bn):
  163. # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  164. fusedconv = nn.Conv2d(conv.in_channels,
  165. conv.out_channels,
  166. kernel_size=conv.kernel_size,
  167. stride=conv.stride,
  168. padding=conv.padding,
  169. groups=conv.groups,
  170. bias=True).requires_grad_(False).to(conv.weight.device)
  171. # prepare filters
  172. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  173. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  174. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
  175. # prepare spatial bias
  176. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  177. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  178. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  179. return fusedconv
  180. def model_info(model, verbose=False, img_size=640):
  181. # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
  182. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  183. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  184. if verbose:
  185. print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
  186. for i, (name, p) in enumerate(model.named_parameters()):
  187. name = name.replace('module_list.', '')
  188. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  189. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  190. try: # FLOPS
  191. from thop import profile
  192. stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
  193. img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
  194. flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS
  195. img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
  196. fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPS
  197. except (ImportError, Exception):
  198. fs = ''
  199. logger.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
  200. def load_classifier(name='resnet101', n=2):
  201. # Loads a pretrained model reshaped to n-class output
  202. model = torchvision.models.__dict__[name](pretrained=True)
  203. # ResNet model properties
  204. # input_size = [3, 224, 224]
  205. # input_space = 'RGB'
  206. # input_range = [0, 1]
  207. # mean = [0.485, 0.456, 0.406]
  208. # std = [0.229, 0.224, 0.225]
  209. # Reshape output to n classes
  210. filters = model.fc.weight.shape[1]
  211. model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
  212. model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
  213. model.fc.out_features = n
  214. return model
  215. def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
  216. # scales img(bs,3,y,x) by ratio constrained to gs-multiple
  217. if ratio == 1.0:
  218. return img
  219. else:
  220. h, w = img.shape[2:]
  221. s = (int(h * ratio), int(w * ratio)) # new size
  222. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  223. if not same_shape: # pad/crop img
  224. h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
  225. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  226. def copy_attr(a, b, include=(), exclude=()):
  227. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  228. for k, v in b.__dict__.items():
  229. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  230. continue
  231. else:
  232. setattr(a, k, v)
  233. class ModelEMA:
  234. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  235. Keep a moving average of everything in the model state_dict (parameters and buffers).
  236. This is intended to allow functionality like
  237. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  238. A smoothed version of the weights is necessary for some training schemes to perform well.
  239. This class is sensitive where it is initialized in the sequence of model init,
  240. GPU assignment and distributed training wrappers.
  241. """
  242. def __init__(self, model, decay=0.9999, updates=0):
  243. # Create EMA
  244. self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
  245. # if next(model.parameters()).device.type != 'cpu':
  246. # self.ema.half() # FP16 EMA
  247. self.updates = updates # number of EMA updates
  248. self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
  249. for p in self.ema.parameters():
  250. p.requires_grad_(False)
  251. def update(self, model):
  252. # Update EMA parameters
  253. with torch.no_grad():
  254. self.updates += 1
  255. d = self.decay(self.updates)
  256. msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
  257. for k, v in self.ema.state_dict().items():
  258. if v.dtype.is_floating_point:
  259. v *= d
  260. v += (1. - d) * msd[k].detach()
  261. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  262. # Update EMA attributes
  263. copy_attr(self.ema, model, include, exclude)