No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.

195 líneas
8.1KB

  1. import math
  2. import os
  3. import time
  4. from copy import deepcopy
  5. import torch
  6. import torch.backends.cudnn as cudnn
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. def init_seeds(seed=0):
  10. torch.manual_seed(seed)
  11. # Reduce randomness (may be slower on Tesla GPUs) # https://pytorch.org/docs/stable/notes/randomness.html
  12. if seed == 0:
  13. cudnn.deterministic = False
  14. cudnn.benchmark = True
  15. def select_device(device='', apex=False, batch_size=None):
  16. # device = 'cpu' or '0' or '0,1,2,3'
  17. cpu_request = device.lower() == 'cpu'
  18. if device and not cpu_request: # if device requested other than 'cpu'
  19. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
  20. assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device # check availablity
  21. cuda = False if cpu_request else torch.cuda.is_available()
  22. if cuda:
  23. c = 1024 ** 2 # bytes to MB
  24. ng = torch.cuda.device_count()
  25. if ng > 1 and batch_size: # check that batch_size is compatible with device_count
  26. assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)
  27. x = [torch.cuda.get_device_properties(i) for i in range(ng)]
  28. s = 'Using CUDA ' + ('Apex ' if apex else '') # apex for mixed precision https://github.com/NVIDIA/apex
  29. for i in range(0, ng):
  30. if i == 1:
  31. s = ' ' * len(s)
  32. print("%sdevice%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" %
  33. (s, i, x[i].name, x[i].total_memory / c))
  34. else:
  35. print('Using CPU')
  36. print('') # skip a line
  37. return torch.device('cuda:0' if cuda else 'cpu')
  38. def time_synchronized():
  39. torch.cuda.synchronize() if torch.cuda.is_available() else None
  40. return time.time()
  41. def initialize_weights(model):
  42. for m in model.modules():
  43. t = type(m)
  44. if t is nn.Conv2d:
  45. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  46. elif t is nn.BatchNorm2d:
  47. m.eps = 1e-4
  48. m.momentum = 0.03
  49. elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
  50. m.inplace = True
  51. def find_modules(model, mclass=nn.Conv2d):
  52. # finds layer indices matching module class 'mclass'
  53. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  54. def fuse_conv_and_bn(conv, bn):
  55. # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  56. with torch.no_grad():
  57. # init
  58. fusedconv = torch.nn.Conv2d(conv.in_channels,
  59. conv.out_channels,
  60. kernel_size=conv.kernel_size,
  61. stride=conv.stride,
  62. padding=conv.padding,
  63. bias=True)
  64. # prepare filters
  65. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  66. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  67. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
  68. # prepare spatial bias
  69. if conv.bias is not None:
  70. b_conv = conv.bias
  71. else:
  72. b_conv = torch.zeros(conv.weight.size(0))
  73. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  74. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  75. return fusedconv
  76. def model_info(model, verbose=False):
  77. # Plots a line-by-line description of a PyTorch model
  78. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  79. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  80. if verbose:
  81. print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
  82. for i, (name, p) in enumerate(model.named_parameters()):
  83. name = name.replace('module_list.', '')
  84. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  85. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  86. try: # FLOPS
  87. from thop import profile
  88. macs, _ = profile(model, inputs=(torch.zeros(1, 3, 480, 640),), verbose=False)
  89. fs = ', %.1f GFLOPS' % (macs / 1E9 * 2)
  90. except:
  91. fs = ''
  92. print('Model Summary: %g layers, %g parameters, %g gradients%s' % (len(list(model.parameters())), n_p, n_g, fs))
  93. def load_classifier(name='resnet101', n=2):
  94. # Loads a pretrained model reshaped to n-class output
  95. import pretrainedmodels # https://github.com/Cadene/pretrained-models.pytorch#torchvision
  96. model = pretrainedmodels.__dict__[name](num_classes=1000, pretrained='imagenet')
  97. # Display model properties
  98. for x in ['model.input_size', 'model.input_space', 'model.input_range', 'model.mean', 'model.std']:
  99. print(x + ' =', eval(x))
  100. # Reshape output to n classes
  101. filters = model.last_linear.weight.shape[1]
  102. model.last_linear.bias = torch.nn.Parameter(torch.zeros(n))
  103. model.last_linear.weight = torch.nn.Parameter(torch.zeros(n, filters))
  104. model.last_linear.out_features = n
  105. return model
  106. def scale_img(img, ratio=1.0, same_shape=True): # img(16,3,256,416), r=ratio
  107. # scales img(bs,3,y,x) by ratio
  108. h, w = img.shape[2:]
  109. s = (int(h * ratio), int(w * ratio)) # new size
  110. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  111. if not same_shape: # pad/crop img
  112. gs = 64 # (pixels) grid size
  113. h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
  114. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  115. class ModelEMA:
  116. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  117. Keep a moving average of everything in the model state_dict (parameters and buffers).
  118. This is intended to allow functionality like
  119. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  120. A smoothed version of the weights is necessary for some training schemes to perform well.
  121. E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
  122. RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
  123. smoothing of weights to match results. Pay attention to the decay constant you are using
  124. relative to your update count per epoch.
  125. To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
  126. disable validation of the EMA weights. Validation will have to be done manually in a separate
  127. process, or after the training stops converging.
  128. This class is sensitive where it is initialized in the sequence of model init,
  129. GPU assignment and distributed training wrappers.
  130. I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
  131. """
  132. def __init__(self, model, decay=0.9999, device=''):
  133. # make a copy of the model for accumulating moving average of weights
  134. self.ema = deepcopy(model)
  135. self.ema.eval()
  136. self.updates = 0 # number of EMA updates
  137. self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
  138. self.device = device # perform ema on different device from model if set
  139. if device:
  140. self.ema.to(device=device)
  141. for p in self.ema.parameters():
  142. p.requires_grad_(False)
  143. def update(self, model):
  144. self.updates += 1
  145. d = self.decay(self.updates)
  146. with torch.no_grad():
  147. if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
  148. msd, esd = model.module.state_dict(), self.ema.module.state_dict()
  149. else:
  150. msd, esd = model.state_dict(), self.ema.state_dict()
  151. for k, v in esd.items():
  152. if v.dtype.is_floating_point:
  153. v *= d
  154. v += (1. - d) * msd[k].detach()
  155. def update_attr(self, model):
  156. # Assign attributes (which may change during training)
  157. for k in model.__dict__.keys():
  158. if not k.startswith('_'):
  159. setattr(self.ema, k, getattr(model, k))