Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

torch_utils.py 9.3KB

pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
pirms 4 gadiem
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. # PyTorch utils
  2. import logging
  3. import math
  4. import os
  5. import time
  6. from contextlib import contextmanager
  7. from copy import deepcopy
  8. import torch
  9. import torch.backends.cudnn as cudnn
  10. import torch.nn as nn
  11. import torch.nn.functional as F
  12. import torchvision
  13. logger = logging.getLogger(__name__)
  14. @contextmanager
  15. def torch_distributed_zero_first(local_rank: int):
  16. """
  17. Decorator to make all processes in distributed training wait for each local_master to do something.
  18. """
  19. if local_rank not in [-1, 0]:
  20. torch.distributed.barrier()
  21. yield
  22. if local_rank == 0:
  23. torch.distributed.barrier()
  24. def init_torch_seeds(seed=0):
  25. # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
  26. torch.manual_seed(seed)
  27. if seed == 0: # slower, more reproducible
  28. cudnn.deterministic = True
  29. cudnn.benchmark = False
  30. else: # faster, less reproducible
  31. cudnn.deterministic = False
  32. cudnn.benchmark = True
  33. def select_device(device='', batch_size=None):
  34. # device = 'cpu' or '0' or '0,1,2,3'
  35. cpu_request = device.lower() == 'cpu'
  36. if device and not cpu_request: # if device requested other than 'cpu'
  37. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
  38. assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device # check availablity
  39. cuda = False if cpu_request else torch.cuda.is_available()
  40. if cuda:
  41. c = 1024 ** 2 # bytes to MB
  42. ng = torch.cuda.device_count()
  43. if ng > 1 and batch_size: # check that batch_size is compatible with device_count
  44. assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)
  45. x = [torch.cuda.get_device_properties(i) for i in range(ng)]
  46. s = f'Using torch {torch.__version__} '
  47. for i in range(0, ng):
  48. if i == 1:
  49. s = ' ' * len(s)
  50. logger.info("%sCUDA:%g (%s, %dMB)" % (s, i, x[i].name, x[i].total_memory / c))
  51. else:
  52. logger.info(f'Using torch {torch.__version__} CPU')
  53. logger.info('') # skip a line
  54. return torch.device('cuda:0' if cuda else 'cpu')
  55. def time_synchronized():
  56. torch.cuda.synchronize() if torch.cuda.is_available() else None
  57. return time.time()
  58. def is_parallel(model):
  59. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  60. def intersect_dicts(da, db, exclude=()):
  61. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  62. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  63. def initialize_weights(model):
  64. for m in model.modules():
  65. t = type(m)
  66. if t is nn.Conv2d:
  67. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  68. elif t is nn.BatchNorm2d:
  69. m.eps = 1e-3
  70. m.momentum = 0.03
  71. elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
  72. m.inplace = True
  73. def find_modules(model, mclass=nn.Conv2d):
  74. # Finds layer indices matching module class 'mclass'
  75. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  76. def sparsity(model):
  77. # Return global model sparsity
  78. a, b = 0., 0.
  79. for p in model.parameters():
  80. a += p.numel()
  81. b += (p == 0).sum()
  82. return b / a
  83. def prune(model, amount=0.3):
  84. # Prune model to requested global sparsity
  85. import torch.nn.utils.prune as prune
  86. print('Pruning model... ', end='')
  87. for name, m in model.named_modules():
  88. if isinstance(m, nn.Conv2d):
  89. prune.l1_unstructured(m, name='weight', amount=amount) # prune
  90. prune.remove(m, 'weight') # make permanent
  91. print(' %.3g global sparsity' % sparsity(model))
  92. def fuse_conv_and_bn(conv, bn):
  93. # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  94. fusedconv = nn.Conv2d(conv.in_channels,
  95. conv.out_channels,
  96. kernel_size=conv.kernel_size,
  97. stride=conv.stride,
  98. padding=conv.padding,
  99. groups=conv.groups,
  100. bias=True).requires_grad_(False).to(conv.weight.device)
  101. # prepare filters
  102. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  103. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  104. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
  105. # prepare spatial bias
  106. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  107. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  108. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  109. return fusedconv
  110. def model_info(model, verbose=False, img_size=640):
  111. # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
  112. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  113. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  114. if verbose:
  115. print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
  116. for i, (name, p) in enumerate(model.named_parameters()):
  117. name = name.replace('module_list.', '')
  118. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  119. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  120. try: # FLOPS
  121. from thop import profile
  122. stride = int(model.stride.max())
  123. flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, stride, stride),), verbose=False)[0] / 1E9 * 2
  124. img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
  125. fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS
  126. except (ImportError, Exception):
  127. fs = ''
  128. logger.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
  129. def load_classifier(name='resnet101', n=2):
  130. # Loads a pretrained model reshaped to n-class output
  131. model = torchvision.models.__dict__[name](pretrained=True)
  132. # ResNet model properties
  133. # input_size = [3, 224, 224]
  134. # input_space = 'RGB'
  135. # input_range = [0, 1]
  136. # mean = [0.485, 0.456, 0.406]
  137. # std = [0.229, 0.224, 0.225]
  138. # Reshape output to n classes
  139. filters = model.fc.weight.shape[1]
  140. model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
  141. model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
  142. model.fc.out_features = n
  143. return model
  144. def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
  145. # scales img(bs,3,y,x) by ratio
  146. if ratio == 1.0:
  147. return img
  148. else:
  149. h, w = img.shape[2:]
  150. s = (int(h * ratio), int(w * ratio)) # new size
  151. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  152. if not same_shape: # pad/crop img
  153. gs = 32 # (pixels) grid size
  154. h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
  155. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  156. def copy_attr(a, b, include=(), exclude=()):
  157. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  158. for k, v in b.__dict__.items():
  159. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  160. continue
  161. else:
  162. setattr(a, k, v)
  163. class ModelEMA:
  164. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  165. Keep a moving average of everything in the model state_dict (parameters and buffers).
  166. This is intended to allow functionality like
  167. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  168. A smoothed version of the weights is necessary for some training schemes to perform well.
  169. This class is sensitive where it is initialized in the sequence of model init,
  170. GPU assignment and distributed training wrappers.
  171. """
  172. def __init__(self, model, decay=0.9999, updates=0):
  173. # Create EMA
  174. self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
  175. # if next(model.parameters()).device.type != 'cpu':
  176. # self.ema.half() # FP16 EMA
  177. self.updates = updates # number of EMA updates
  178. self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
  179. for p in self.ema.parameters():
  180. p.requires_grad_(False)
  181. def update(self, model):
  182. # Update EMA parameters
  183. with torch.no_grad():
  184. self.updates += 1
  185. d = self.decay(self.updates)
  186. msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
  187. for k, v in self.ema.state_dict().items():
  188. if v.dtype.is_floating_point:
  189. v *= d
  190. v += (1. - d) * msd[k].detach()
  191. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  192. # Update EMA attributes
  193. copy_attr(self.ema, model, include, exclude)