Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

torch_utils.py 11KB

il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. # PyTorch utils
  2. import logging
  3. import math
  4. import os
  5. import time
  6. from contextlib import contextmanager
  7. from copy import deepcopy
  8. import torch
  9. import torch.backends.cudnn as cudnn
  10. import torch.nn as nn
  11. import torch.nn.functional as F
  12. import torchvision
  13. try:
  14. import thop # for FLOPS computation
  15. except ImportError:
  16. thop = None
  17. logger = logging.getLogger(__name__)
  18. @contextmanager
  19. def torch_distributed_zero_first(local_rank: int):
  20. """
  21. Decorator to make all processes in distributed training wait for each local_master to do something.
  22. """
  23. if local_rank not in [-1, 0]:
  24. torch.distributed.barrier()
  25. yield
  26. if local_rank == 0:
  27. torch.distributed.barrier()
  28. def init_torch_seeds(seed=0):
  29. # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
  30. torch.manual_seed(seed)
  31. if seed == 0: # slower, more reproducible
  32. cudnn.deterministic = True
  33. cudnn.benchmark = False
  34. else: # faster, less reproducible
  35. cudnn.deterministic = False
  36. cudnn.benchmark = True
  37. def select_device(device='', batch_size=None):
  38. # device = 'cpu' or '0' or '0,1,2,3'
  39. cpu_request = device.lower() == 'cpu'
  40. if device and not cpu_request: # if device requested other than 'cpu'
  41. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
  42. assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device # check availablity
  43. cuda = False if cpu_request else torch.cuda.is_available()
  44. if cuda:
  45. c = 1024 ** 2 # bytes to MB
  46. ng = torch.cuda.device_count()
  47. if ng > 1 and batch_size: # check that batch_size is compatible with device_count
  48. assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)
  49. x = [torch.cuda.get_device_properties(i) for i in range(ng)]
  50. s = f'Using torch {torch.__version__} '
  51. for i in range(0, ng):
  52. if i == 1:
  53. s = ' ' * len(s)
  54. logger.info("%sCUDA:%g (%s, %dMB)" % (s, i, x[i].name, x[i].total_memory / c))
  55. else:
  56. logger.info(f'Using torch {torch.__version__} CPU')
  57. logger.info('') # skip a line
  58. return torch.device('cuda:0' if cuda else 'cpu')
  59. def time_synchronized():
  60. # pytorch-accurate time
  61. torch.cuda.synchronize() if torch.cuda.is_available() else None
  62. return time.time()
  63. def profile(x, ops, n=100, device=None):
  64. # profile a pytorch module or list of modules. Example usage:
  65. # x = torch.randn(16, 3, 640, 640) # input
  66. # m1 = lambda x: x * torch.sigmoid(x)
  67. # m2 = nn.SiLU()
  68. # profile(x, [m1, m2], n=100) # profile speed over 100 iterations
  69. device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  70. x = x.to(device)
  71. x.requires_grad = True
  72. print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
  73. print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
  74. for m in ops if isinstance(ops, list) else [ops]:
  75. m = m.to(device) if hasattr(m, 'to') else m # device
  76. m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type
  77. dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward
  78. try:
  79. flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS
  80. except:
  81. flops = 0
  82. for _ in range(n):
  83. t[0] = time_synchronized()
  84. y = m(x)
  85. t[1] = time_synchronized()
  86. try:
  87. _ = y.sum().backward()
  88. t[2] = time_synchronized()
  89. except: # no backward method
  90. t[2] = float('nan')
  91. dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
  92. dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
  93. s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list'
  94. s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list'
  95. p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters
  96. print(f'{p:12.4g}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}')
  97. def is_parallel(model):
  98. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  99. def intersect_dicts(da, db, exclude=()):
  100. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  101. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  102. def initialize_weights(model):
  103. for m in model.modules():
  104. t = type(m)
  105. if t is nn.Conv2d:
  106. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  107. elif t is nn.BatchNorm2d:
  108. m.eps = 1e-3
  109. m.momentum = 0.03
  110. elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
  111. m.inplace = True
  112. def find_modules(model, mclass=nn.Conv2d):
  113. # Finds layer indices matching module class 'mclass'
  114. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  115. def sparsity(model):
  116. # Return global model sparsity
  117. a, b = 0., 0.
  118. for p in model.parameters():
  119. a += p.numel()
  120. b += (p == 0).sum()
  121. return b / a
  122. def prune(model, amount=0.3):
  123. # Prune model to requested global sparsity
  124. import torch.nn.utils.prune as prune
  125. print('Pruning model... ', end='')
  126. for name, m in model.named_modules():
  127. if isinstance(m, nn.Conv2d):
  128. prune.l1_unstructured(m, name='weight', amount=amount) # prune
  129. prune.remove(m, 'weight') # make permanent
  130. print(' %.3g global sparsity' % sparsity(model))
  131. def fuse_conv_and_bn(conv, bn):
  132. # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  133. fusedconv = nn.Conv2d(conv.in_channels,
  134. conv.out_channels,
  135. kernel_size=conv.kernel_size,
  136. stride=conv.stride,
  137. padding=conv.padding,
  138. groups=conv.groups,
  139. bias=True).requires_grad_(False).to(conv.weight.device)
  140. # prepare filters
  141. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  142. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  143. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
  144. # prepare spatial bias
  145. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  146. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  147. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  148. return fusedconv
  149. def model_info(model, verbose=False, img_size=640):
  150. # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
  151. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  152. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  153. if verbose:
  154. print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
  155. for i, (name, p) in enumerate(model.named_parameters()):
  156. name = name.replace('module_list.', '')
  157. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  158. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  159. try: # FLOPS
  160. from thop import profile
  161. stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
  162. img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
  163. flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS
  164. img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
  165. fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPS
  166. except (ImportError, Exception):
  167. fs = ''
  168. logger.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
  169. def load_classifier(name='resnet101', n=2):
  170. # Loads a pretrained model reshaped to n-class output
  171. model = torchvision.models.__dict__[name](pretrained=True)
  172. # ResNet model properties
  173. # input_size = [3, 224, 224]
  174. # input_space = 'RGB'
  175. # input_range = [0, 1]
  176. # mean = [0.485, 0.456, 0.406]
  177. # std = [0.229, 0.224, 0.225]
  178. # Reshape output to n classes
  179. filters = model.fc.weight.shape[1]
  180. model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
  181. model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
  182. model.fc.out_features = n
  183. return model
  184. def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
  185. # scales img(bs,3,y,x) by ratio
  186. if ratio == 1.0:
  187. return img
  188. else:
  189. h, w = img.shape[2:]
  190. s = (int(h * ratio), int(w * ratio)) # new size
  191. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  192. if not same_shape: # pad/crop img
  193. gs = 32 # (pixels) grid size
  194. h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
  195. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  196. def copy_attr(a, b, include=(), exclude=()):
  197. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  198. for k, v in b.__dict__.items():
  199. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  200. continue
  201. else:
  202. setattr(a, k, v)
  203. class ModelEMA:
  204. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  205. Keep a moving average of everything in the model state_dict (parameters and buffers).
  206. This is intended to allow functionality like
  207. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  208. A smoothed version of the weights is necessary for some training schemes to perform well.
  209. This class is sensitive where it is initialized in the sequence of model init,
  210. GPU assignment and distributed training wrappers.
  211. """
  212. def __init__(self, model, decay=0.9999, updates=0):
  213. # Create EMA
  214. self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
  215. # if next(model.parameters()).device.type != 'cpu':
  216. # self.ema.half() # FP16 EMA
  217. self.updates = updates # number of EMA updates
  218. self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
  219. for p in self.ema.parameters():
  220. p.requires_grad_(False)
  221. def update(self, model):
  222. # Update EMA parameters
  223. with torch.no_grad():
  224. self.updates += 1
  225. d = self.decay(self.updates)
  226. msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
  227. for k, v in self.ema.state_dict().items():
  228. if v.dtype.is_floating_point:
  229. v *= d
  230. v += (1. - d) * msd[k].detach()
  231. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  232. # Update EMA attributes
  233. copy_attr(self.ema, model, include, exclude)