No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.

torch_utils.py 8.3KB

hace 4 años
hace 4 años
hace 4 años
hace 4 años
hace 4 años
hace 4 años
hace 4 años
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import math
  2. import os
  3. import time
  4. from copy import deepcopy
  5. import torch
  6. import torch.backends.cudnn as cudnn
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. def init_seeds(seed=0):
  10. torch.manual_seed(seed)
  11. # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
  12. if seed == 0: # slower, more reproducible
  13. cudnn.deterministic = True
  14. cudnn.benchmark = False
  15. else: # faster, less reproducible
  16. cudnn.deterministic = False
  17. cudnn.benchmark = True
  18. def select_device(device='', apex=False, batch_size=None):
  19. # device = 'cpu' or '0' or '0,1,2,3'
  20. cpu_request = device.lower() == 'cpu'
  21. if device and not cpu_request: # if device requested other than 'cpu'
  22. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
  23. assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device # check availablity
  24. cuda = False if cpu_request else torch.cuda.is_available()
  25. if cuda:
  26. c = 1024 ** 2 # bytes to MB
  27. ng = torch.cuda.device_count()
  28. if ng > 1 and batch_size: # check that batch_size is compatible with device_count
  29. assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)
  30. x = [torch.cuda.get_device_properties(i) for i in range(ng)]
  31. s = 'Using CUDA ' + ('Apex ' if apex else '') # apex for mixed precision https://github.com/NVIDIA/apex
  32. for i in range(0, ng):
  33. if i == 1:
  34. s = ' ' * len(s)
  35. print("%sdevice%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" %
  36. (s, i, x[i].name, x[i].total_memory / c))
  37. else:
  38. print('Using CPU')
  39. print('') # skip a line
  40. return torch.device('cuda:0' if cuda else 'cpu')
  41. def time_synchronized():
  42. torch.cuda.synchronize() if torch.cuda.is_available() else None
  43. return time.time()
  44. def initialize_weights(model):
  45. for m in model.modules():
  46. t = type(m)
  47. if t is nn.Conv2d:
  48. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  49. elif t is nn.BatchNorm2d:
  50. m.eps = 1e-4
  51. m.momentum = 0.03
  52. elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
  53. m.inplace = True
  54. def find_modules(model, mclass=nn.Conv2d):
  55. # finds layer indices matching module class 'mclass'
  56. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  57. def fuse_conv_and_bn(conv, bn):
  58. # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  59. with torch.no_grad():
  60. # init
  61. fusedconv = torch.nn.Conv2d(conv.in_channels,
  62. conv.out_channels,
  63. kernel_size=conv.kernel_size,
  64. stride=conv.stride,
  65. padding=conv.padding,
  66. bias=True)
  67. # prepare filters
  68. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  69. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  70. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
  71. # prepare spatial bias
  72. if conv.bias is not None:
  73. b_conv = conv.bias
  74. else:
  75. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device)
  76. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  77. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  78. return fusedconv
  79. def model_info(model, verbose=False):
  80. # Plots a line-by-line description of a PyTorch model
  81. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  82. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  83. if verbose:
  84. print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
  85. for i, (name, p) in enumerate(model.named_parameters()):
  86. name = name.replace('module_list.', '')
  87. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  88. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  89. try: # FLOPS
  90. from thop import profile
  91. macs, _ = profile(model, inputs=(torch.zeros(1, 3, 480, 640),), verbose=False)
  92. fs = ', %.1f GFLOPS' % (macs / 1E9 * 2)
  93. except:
  94. fs = ''
  95. print('Model Summary: %g layers, %g parameters, %g gradients%s' % (len(list(model.parameters())), n_p, n_g, fs))
  96. def load_classifier(name='resnet101', n=2):
  97. # Loads a pretrained model reshaped to n-class output
  98. import pretrainedmodels # https://github.com/Cadene/pretrained-models.pytorch#torchvision
  99. model = pretrainedmodels.__dict__[name](num_classes=1000, pretrained='imagenet')
  100. # Display model properties
  101. for x in ['model.input_size', 'model.input_space', 'model.input_range', 'model.mean', 'model.std']:
  102. print(x + ' =', eval(x))
  103. # Reshape output to n classes
  104. filters = model.last_linear.weight.shape[1]
  105. model.last_linear.bias = torch.nn.Parameter(torch.zeros(n))
  106. model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters))
  107. model.last_linear.out_features = n
  108. return model
  109. def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
  110. # scales img(bs,3,y,x) by ratio
  111. h, w = img.shape[2:]
  112. s = (int(h * ratio), int(w * ratio)) # new size
  113. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  114. if not same_shape: # pad/crop img
  115. gs = 32 # (pixels) grid size
  116. h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
  117. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  118. class ModelEMA:
  119. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  120. Keep a moving average of everything in the model state_dict (parameters and buffers).
  121. This is intended to allow functionality like
  122. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  123. A smoothed version of the weights is necessary for some training schemes to perform well.
  124. E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
  125. RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
  126. smoothing of weights to match results. Pay attention to the decay constant you are using
  127. relative to your update count per epoch.
  128. To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
  129. disable validation of the EMA weights. Validation will have to be done manually in a separate
  130. process, or after the training stops converging.
  131. This class is sensitive where it is initialized in the sequence of model init,
  132. GPU assignment and distributed training wrappers.
  133. I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
  134. """
  135. def __init__(self, model, decay=0.9999, device=''):
  136. # make a copy of the model for accumulating moving average of weights
  137. self.ema = deepcopy(model)
  138. self.ema.eval()
  139. self.updates = 0 # number of EMA updates
  140. self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
  141. self.device = device # perform ema on different device from model if set
  142. if device:
  143. self.ema.to(device=device)
  144. for p in self.ema.parameters():
  145. p.requires_grad_(False)
  146. def update(self, model):
  147. self.updates += 1
  148. d = self.decay(self.updates)
  149. with torch.no_grad():
  150. if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
  151. msd, esd = model.module.state_dict(), self.ema.module.state_dict()
  152. else:
  153. msd, esd = model.state_dict(), self.ema.state_dict()
  154. for k, v in esd.items():
  155. if v.dtype.is_floating_point:
  156. v *= d
  157. v += (1. - d) * msd[k].detach()
  158. def update_attr(self, model):
  159. # Assign attributes (which may change during training)
  160. for k in model.__dict__.keys():
  161. if not k.startswith('_'):
  162. setattr(self.ema, k, getattr(model, k))