You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

203 line
8.3KB

  1. import math
  2. import os
  3. import time
  4. from copy import deepcopy
  5. import torch
  6. import torch.backends.cudnn as cudnn
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. import torchvision.models as models
  10. def init_seeds(seed=0):
  11. torch.manual_seed(seed)
  12. # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
  13. if seed == 0: # slower, more reproducible
  14. cudnn.deterministic = True
  15. cudnn.benchmark = False
  16. else: # faster, less reproducible
  17. cudnn.deterministic = False
  18. cudnn.benchmark = True
  19. def select_device(device='', apex=False, batch_size=None):
  20. # device = 'cpu' or '0' or '0,1,2,3'
  21. cpu_request = device.lower() == 'cpu'
  22. if device and not cpu_request: # if device requested other than 'cpu'
  23. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
  24. assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device # check availablity
  25. cuda = False if cpu_request else torch.cuda.is_available()
  26. if cuda:
  27. c = 1024 ** 2 # bytes to MB
  28. ng = torch.cuda.device_count()
  29. if ng > 1 and batch_size: # check that batch_size is compatible with device_count
  30. assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)
  31. x = [torch.cuda.get_device_properties(i) for i in range(ng)]
  32. s = 'Using CUDA ' + ('Apex ' if apex else '') # apex for mixed precision https://github.com/NVIDIA/apex
  33. for i in range(0, ng):
  34. if i == 1:
  35. s = ' ' * len(s)
  36. print("%sdevice%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" %
  37. (s, i, x[i].name, x[i].total_memory / c))
  38. else:
  39. print('Using CPU')
  40. print('') # skip a line
  41. return torch.device('cuda:0' if cuda else 'cpu')
  42. def time_synchronized():
  43. torch.cuda.synchronize() if torch.cuda.is_available() else None
  44. return time.time()
  45. def initialize_weights(model):
  46. for m in model.modules():
  47. t = type(m)
  48. if t is nn.Conv2d:
  49. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  50. elif t is nn.BatchNorm2d:
  51. m.eps = 1e-4
  52. m.momentum = 0.03
  53. elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
  54. m.inplace = True
  55. def find_modules(model, mclass=nn.Conv2d):
  56. # finds layer indices matching module class 'mclass'
  57. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  58. def fuse_conv_and_bn(conv, bn):
  59. # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  60. with torch.no_grad():
  61. # init
  62. fusedconv = torch.nn.Conv2d(conv.in_channels,
  63. conv.out_channels,
  64. kernel_size=conv.kernel_size,
  65. stride=conv.stride,
  66. padding=conv.padding,
  67. bias=True)
  68. # prepare filters
  69. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  70. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  71. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
  72. # prepare spatial bias
  73. if conv.bias is not None:
  74. b_conv = conv.bias
  75. else:
  76. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device)
  77. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  78. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  79. return fusedconv
  80. def model_info(model, verbose=False):
  81. # Plots a line-by-line description of a PyTorch model
  82. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  83. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  84. if verbose:
  85. print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma'))
  86. for i, (name, p) in enumerate(model.named_parameters()):
  87. name = name.replace('module_list.', '')
  88. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  89. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  90. try: # FLOPS
  91. from thop import profile
  92. macs, _ = profile(model, inputs=(torch.zeros(1, 3, 480, 640),), verbose=False)
  93. fs = ', %.1f GFLOPS' % (macs / 1E9 * 2)
  94. except:
  95. fs = ''
  96. print('Model Summary: %g layers, %g parameters, %g gradients%s' % (len(list(model.parameters())), n_p, n_g, fs))
  97. def load_classifier(name='resnet101', n=2):
  98. # Loads a pretrained model reshaped to n-class output
  99. model = models.__dict__[name](pretrained=True)
  100. # Display model properties
  101. input_size = [3, 224, 224]
  102. input_space = 'RGB'
  103. input_range = [0, 1]
  104. mean = [0.485, 0.456, 0.406]
  105. std = [0.229, 0.224, 0.225]
  106. for x in [input_size, input_space, input_range, mean, std]:
  107. print(x + ' =', eval(x))
  108. # Reshape output to n classes
  109. filters = model.fc.weight.shape[1]
  110. model.fc.bias = torch.nn.Parameter(torch.zeros(n), requires_grad=True)
  111. model.fc.weight = torch.nn.Parameter(torch.zeros(n, filters), requires_grad=True)
  112. model.fc.out_features = n
  113. return model
  114. def scale_img(img, ratio=1.0, same_shape=False): # img(16,3,256,416), r=ratio
  115. # scales img(bs,3,y,x) by ratio
  116. h, w = img.shape[2:]
  117. s = (int(h * ratio), int(w * ratio)) # new size
  118. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  119. if not same_shape: # pad/crop img
  120. gs = 32 # (pixels) grid size
  121. h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
  122. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  123. class ModelEMA:
  124. """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  125. Keep a moving average of everything in the model state_dict (parameters and buffers).
  126. This is intended to allow functionality like
  127. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  128. A smoothed version of the weights is necessary for some training schemes to perform well.
  129. E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
  130. RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
  131. smoothing of weights to match results. Pay attention to the decay constant you are using
  132. relative to your update count per epoch.
  133. To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
  134. disable validation of the EMA weights. Validation will have to be done manually in a separate
  135. process, or after the training stops converging.
  136. This class is sensitive where it is initialized in the sequence of model init,
  137. GPU assignment and distributed training wrappers.
  138. I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
  139. """
  140. def __init__(self, model, decay=0.9999, device=''):
  141. # make a copy of the model for accumulating moving average of weights
  142. self.ema = deepcopy(model)
  143. self.ema.eval()
  144. self.updates = 0 # number of EMA updates
  145. self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
  146. self.device = device # perform ema on different device from model if set
  147. if device:
  148. self.ema.to(device=device)
  149. for p in self.ema.parameters():
  150. p.requires_grad_(False)
  151. def update(self, model):
  152. self.updates += 1
  153. d = self.decay(self.updates)
  154. with torch.no_grad():
  155. if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
  156. msd, esd = model.module.state_dict(), self.ema.module.state_dict()
  157. else:
  158. msd, esd = model.state_dict(), self.ema.state_dict()
  159. for k, v in esd.items():
  160. if v.dtype.is_floating_point:
  161. v *= d
  162. v += (1. - d) * msd[k].detach()
  163. def update_attr(self, model):
  164. # Assign attributes (which may change during training)
  165. for k in model.__dict__.keys():
  166. if not k.startswith('_'):
  167. setattr(self.ema, k, getattr(model, k))